diff --git a/.gitattributes b/.gitattributes index 3a69f0fc8e91b65e42a652802e8be2a7d9b1fd20..a78ca19fbcdb139428bb9b915237cf0077adb34f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1825,3 +1825,180 @@ illustrious_generated/e3bdc40b2136.png filter=lfs diff=lfs merge=lfs -text illustrious_generated/6ea1f2330bb2.png filter=lfs diff=lfs merge=lfs -text illustrious_generated/a08733bfce6d.png filter=lfs diff=lfs merge=lfs -text illustrious_generated/0aec7eb07d4f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/699b6be34d1c.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b8fe03330562.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/efb6fe0239e2.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a00e6e3ca42d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4c39593e352f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/daf79f3a5a47.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b66afd053661.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/8a594acaada9.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/532b5ad4f3b4.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/aa2403d0dbde.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/ffa981380c46.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/60f785243e60.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/776bf5e2d2a1.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/5ab3056565d7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/bde193a32ad5.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/58b9009a8433.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/db68dd8cb2cf.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/79a4d58eda5c.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/6c2bec98397a.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/56b23fd84d9e.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e38a1e2e958f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4014bca6d153.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/754db85fe85f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/0a4d2753b7c7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a0e55ddbf74c.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/c472bc9db7e5.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/07f25b42b8d7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a317dff0fe2e.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/1ca2c247cfef.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/92e013a48101.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/53d7fbefadd0.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/0d1561551252.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/13d5cc66d420.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/7429db9846f3.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/49222c07a03f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e526540fdbb8.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/0e42a3a7d296.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/29259007fbbf.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e189b3abf6d2.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4af1ebb4488c.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/1755674c7bc1.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/9a7e94351f39.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/548d5f09851f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/8d3200b9e787.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a3a89fac4fcf.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e4bf1a6be7c6.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/37c414606b7d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/c4033b9e6115.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/29c45b6176ff.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e729db57f47c.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/96a62fa83b20.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/6a51d82ae8bc.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4632a958366a.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/47aa86a18b5a.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/d1995ff196e3.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e5ef86680557.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/bd545030c487.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/cef6d8bbd520.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/6b23dba7be7d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/8d3ab41288cb.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e31935899f08.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/bc6707dcf1f2.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b3a626cfe029.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/3503bc2ba8ab.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/ce38aa6d180d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/6035404a25b1.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/9b1bb21a46e8.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a4eb52675c75.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a7acf95bd41f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/c636dce88f49.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/8ef3560e5a42.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4f7953e3c61d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a9d16660f9be.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/5028ea7283a1.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/259e79d12cba.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/aa08fc4ab5f7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a048f190d616.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/64d5e838ff32.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/f8b258d2d426.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/73e5f02dca91.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/176f37a6cd7f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/75bc8bf6da61.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/831ac5f05b4e.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/6061d4cce84a.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/1970ccc17731.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/46d7a6324536.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/593a5b377fb6.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/9e588e17280b.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/f086d40d76b7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e5b4630753fb.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/bf19c6825d9d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/ad90d731a0fa.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4c7a209b8887.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b52834e10fd4.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e0fa2a6e6b42.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b770edd1abe0.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/bf8fcd7e3ebb.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b2a469095584.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4af582043c66.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/666f1a84cb80.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/c87663be02a7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/5fa4fb280d76.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/c65bd94bf971.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a3f618970fb7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/c588cc386611.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e412e3f1af16.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/acf002111806.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/3dbcf74b53f1.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/149a82673fe9.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4536122bec42.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/25df61f584ba.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/5803e88c45f3.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/d5f24e4f2554.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/57264915c81a.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b5a6eaf580ce.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/18f1f90221cc.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/798df3adbd88.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/617ce74ffafa.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/1dfed5dbd442.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/243fd5bd7679.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/2433395d6d43.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/96ab8498c7a6.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4a574003bbc9.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/d209e317052f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/fb9ed29c63d7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/ece3750ad6a6.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/690e7719d4bf.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/639cf9531d64.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/56e66190e7ce.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/ceb937e0801e.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/ed0e6ab9592e.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/312f15890343.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/d21efb5f15d1.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b01a98afee9a.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/3cbd0d9a36a7.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/6994acf3db4f.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b5ee5bce0095.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/e150c196c1de.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/d0d5e2ddb402.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b102a265b15d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b0202f93d233.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/14cde88e8c02.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/30f0d4a3b6b6.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/33aeb710e683.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/71485a12c097.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/2c3ca96d94bc.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a630cf5674f3.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/163359d65694.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/2f054f26c97c.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/af2a10d42166.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4d7d6abe842c.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/62fb4f26c8d0.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/8bc741c00933.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/f3021039e6df.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/7ed913ac6954.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/a9ab451bd4c9.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/49c90b238c77.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/1371e80e6f16.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/f284fb62c9bf.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/71e9866c5494.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/abf601407cb0.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/83da1d18e129.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/484afd596fd0.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4c0f1f03dcd2.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/2008ee2e6f66.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/48fcc54ce758.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/4769b078890d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/b716f3a23f11.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/8e740eadaf18.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/88e259bbe761.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/476617799f8d.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/1bb92008067a.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/38cdd5aae9bb.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/5e8555d59ee2.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/f3f41fd3689e.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/ca3147d568a3.png filter=lfs diff=lfs merge=lfs -text +illustrious_generated/8e375573fe30.png filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d263edd27a37cef093e9e064b8cb1c64286ca9e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf1dad6ed1b5960b960433bb4b1eb7b78ae513a1 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aad28897a5df3bbc8e9a22cdf91cf4e537a0d460 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06ad87dd3ab3e146cf8ae1c6179ac25d46061a5b Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_dispatch/__pycache__/python.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..368a81ce9d1d6896e86564e7a43c3d6df363c3b8 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36604fa5642c4f416befa2a6e6cd256b9f47a91e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/config.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f01053c807961b2cc9c5f6edfd1938b0425a524f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7a8cabb67f0dd72b80bdddcfdd7b5d868840cac Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bb6068bad891cb108cce5187b2f6f8e343c6a66 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2cf25ab6cf3ef00b7614254b18e4665961ae31e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/__pycache__/test_operators.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8cce6f870b1bc6bcda58e267b6d50feec56717 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py @@ -0,0 +1,296 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 52.6245059967041: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 312.0: + return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)] + else: + if context.get_value('k') <= 40.0: + return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)] + else: + return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)] + else: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)] + else: + if context.get_value('k') <= 68.0: + return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)] + else: + return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)] + else: + if context.get_value('k') <= 35.0: + if context.get_value('k') <= 18.0: + if context.get_value('m*n') <= 19505152.0: + return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)] + else: + return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)] + else: + if context.get_value('n') <= 68.0: + return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)] + else: + return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)] + else: + if context.get_value('m*n') <= 309760.0: + return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)] + else: + if context.get_value('n') <= 72.0: + return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)] + else: + return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)] + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 815360.0: + if context.get_value('k') <= 1184.0: + return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)] + else: + return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)] + else: + if context.get_value('arith_intensity') <= 187.23922729492188: + if context.get_value('mat1_stride_0') <= 198.0: + return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)] + else: + return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)] + else: + return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)] + else: + return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py new file mode 100644 index 0000000000000000000000000000000000000000..e794b8e646f3af299d4f5566a0bcd2fe1f22e6d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py @@ -0,0 +1,321 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 29.89772129058838: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 432.0: + if context.get_value('arith_intensity') <= 7.8700292110443115: + return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)] + else: + return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)] + else: + if context.get_value('k') <= 40.0: + return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)] + else: + return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)] + else: + if context.get_value('mat1_stride_0') <= 40.0: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)] + else: + return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)] + else: + if context.get_value('mat1_stride_0') <= 68.0: + return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)] + else: + return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)] + else: + if context.get_value('k') <= 18.0: + if context.get_value('m*k') <= 528.0: + return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)] + else: + if context.get_value('n') <= 80.0: + return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)] + else: + return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)] + else: + if context.get_value('k') <= 36.0: + if context.get_value('n') <= 68.0: + return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)] + else: + return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)] + else: + if context.get_value('mat2_stride_0') <= 384.0: + return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)] + else: + return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)] + else: + if context.get_value('arith_intensity') <= 56.995582580566406: + if context.get_value('n') <= 68.0: + if context.get_value('k*n') <= 4448.0: + if context.get_value('m*n') <= 29626368.0: + return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)] + else: + return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)] + else: + if context.get_value('k') <= 348.0: + return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)] + else: + return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)] + else: + if context.get_value('m') <= 3264.0: + return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)] + else: + if context.get_value('k') <= 62.5: + return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)] + else: + return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)] + else: + if context.get_value('m*n') <= 1097728.0: + return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)] + else: + if context.get_value('m*n') <= 3244032.0: + return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)] + else: + if context.get_value('n') <= 136.0: + return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)] + else: + return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9ea693a96dcc01efac82eb3cf092db4c5f8629 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py @@ -0,0 +1,150 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: + if str(context.get_value('1LEQmLEQ16')) != 'True': + if context.get_value('m') <= 32.5: + if context.get_value('n') <= 6976.0: + if context.get_value('n') <= 3520.0: + if context.get_value('m*n') <= 37632.0: + return None + else: + return [(1.000, 13)] + else: + if context.get_value('m*k') <= 452352.0: + return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)] + else: + return [(0.778, 8), (0.222, 13)] + else: + if context.get_value('k*n') <= 102776832.0: + if context.get_value('n') <= 14656.0: + return [(1.000, 11)] + else: + return [(0.889, 11), (0.111, 13)] + else: + return [(1.000, 11)] + else: + if context.get_value('m*n') <= 446464.0: + if context.get_value('m*n') <= 223424.0: + if context.get_value('mat1_stride_0') <= 3968.0: + return None + else: + return None + else: + if context.get_value('m*n') <= 346112.0: + return [(0.960, 16), (0.040, 7)] + else: + return [(0.750, 16), (0.136, 14), (0.114, 7)] + else: + if str(context.get_value('33LEQmLEQ64')) != 'True': + if context.get_value('n') <= 6976.0: + return [(1.000, 14)] + else: + return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)] + else: + if context.get_value('n') <= 13888.0: + return [(0.710, 14), (0.275, 21), (0.014, 12)] + else: + return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)] + else: + if context.get_value('n') <= 3520.0: + if context.get_value('arith_intensity') <= 3.994754433631897: + if str(context.get_value('mat2_dtype')) != 'torch.uint8': + if context.get_value('m*k') <= 18944.0: + return [(0.577, 5), (0.423, 6)] + else: + return [(0.988, 5), (0.012, 6)] + else: + if context.get_value('arith_intensity') <= 2.9899919033050537: + return None + else: + return None + else: + if context.get_value('arith_intensity') <= 7.956453561782837: + if context.get_value('k*n') <= 9244032.0: + return [(0.822, 5), (0.178, 6)] + else: + return [(0.977, 5), (0.023, 0)] + else: + if context.get_value('m*k') <= 978944.0: + return [(1.000, 5)] + else: + return [(0.971, 5), (0.029, 0)] + else: + if context.get_value('n') <= 13632.0: + if context.get_value('n') <= 6976.0: + return [(1.000, 6)] + else: + if context.get_value('k') <= 3968.0: + return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)] + else: + return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)] + else: + if context.get_value('k*n') <= 39518208.0: + return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)] + else: + if context.get_value('n') <= 20800.0: + return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)] + else: + return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py new file mode 100644 index 0000000000000000000000000000000000000000..b4552c5257e73ddb6699e1040c758997aa96c062 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py @@ -0,0 +1,149 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 15.988086223602295: + if context.get_value('n') <= 25280.0: + if context.get_value('n') <= 1344.0: + if context.get_value('mat1_stride_0') <= 7808.0: + return [(0.581, 7), (0.419, 6)] + else: + if context.get_value('m*n') <= 7680.0: + return [(0.875, 0), (0.125, 6)] + else: + return [(0.833, 0), (0.167, 7)] + else: + if context.get_value('n') <= 8512.0: + if str(context.get_value('mat2_dtype')) != 'torch.int8': + return [(0.763, 6), (0.237, 7)] + else: + return [(0.725, 7), (0.275, 6)] + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)] + else: + return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)] + else: + if context.get_value('n') <= 42254.0: + if context.get_value('n') <= 33856.0: + if context.get_value('k*n') <= 68157440.0: + return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)] + else: + return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)] + else: + return [(0.659, 5), (0.341, 6)] + else: + if context.get_value('k*n') <= 326052992.0: + if context.get_value('n') <= 55232.0: + return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)] + else: + return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)] + else: + if context.get_value('n') <= 57024.0: + return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)] + else: + return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)] + else: + if context.get_value('m*n') <= 543936.0: + if str(context.get_value('17LEQmLEQ32')) != 'True': + if context.get_value('m*n') <= 262272.0: + if context.get_value('n') <= 1592.5: + return [(0.860, 0), (0.140, 9)] + else: + return None + else: + if context.get_value('m*k') <= 1294336.0: + return [(0.833, 17), (0.150, 18), (0.017, 15)] + else: + return [(0.917, 17), (0.083, 8)] + else: + if context.get_value('n') <= 12416.0: + if context.get_value('m*n') <= 43008.0: + return None + else: + return [(0.853, 14), (0.147, 9)] + else: + return [(0.625, 12), (0.375, 14)] + else: + if context.get_value('m') <= 32.5: + if context.get_value('mat2_stride_1') <= 6656.0: + if context.get_value('n') <= 69184.0: + return [(0.611, 12), (0.361, 14), (0.028, 13)] + else: + return [(1.000, 12)] + else: + if context.get_value('mat2_stride_1') <= 20864.0: + return [(1.000, 12)] + else: + return [(0.958, 12), (0.042, 9)] + else: + if context.get_value('m*n') <= 1085440.0: + if context.get_value('n') <= 9152.0: + return [(1.000, 18)] + else: + return [(0.780, 18), (0.160, 16), (0.060, 20)] + else: + if context.get_value('m') <= 67.0: + return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)] + else: + return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..b61f8a9dd1e99056864a9dddc663b090f6971214 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py @@ -0,0 +1,109 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/ +from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicRegression, +) + + +class PadMMA100(LearnedHeuristicRegression): + + def __init__(self) -> None: + pass + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + context.context_dict[CHOICE_COL] = choice + return self.predict(context) + + def get_confidence_threshold(self) -> float: + return 1.7025303314066 + + def get_name(self) -> str: + return 'pad_mm' + + def predict(self, context: AHContext) -> float: + if str(context.get_value('choice')) != 'pad': + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 4171264.0: + if context.get_value('m*k') <= 3999308.0: + return 1.8751469764071178 + else: + if str(context.get_value('n_multiple_32')) != 'True': + return 0.9117231355626345 + else: + return 1.1607689608873861 + else: + if str(context.get_value('n_multiple_2')) != 'True': + if str(context.get_value('using_tf32')) != 'True': + return 0.7430382200435992 + else: + return 0.8531269794448678 + else: + if str(context.get_value('k_multiple_2')) != 'True': + return 0.7577181972719917 + else: + return 0.8977349440424219 + else: + if context.get_value('m*n') <= 1299712.0: + return 1.1669723418995592 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + if context.get_value('m*n') <= 55884158.0: + return 1.0262769936909601 + else: + return 1.0022677428470845 + else: + if context.get_value('m') <= 18478.0: + return 1.1127066261894312 + else: + return 1.0337740659894263 + else: + if str(context.get_value('mat1_dtype')) != 'torch.float32': + if str(context.get_value('n_multiple_2')) != 'False': + if str(context.get_value('k_multiple_2')) != 'True': + if context.get_value('mat1_stride_0') <= 561.0: + return 1.2900382135142956 + else: + return 1.5761737616057887 + else: + if context.get_value('num_dims_needs_padding') <= 1.5: + return 1.0472263310239422 + else: + return 1.1727673465762514 + else: + if context.get_value('k') <= 28238.5: + if context.get_value('k/(m*n)') <= 0.00026227018679492176: + return 1.6770542505397175 + else: + return 1.3974785435105923 + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return 1.3952699800111992 + else: + return 1.5759286511628336 + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 14119424.0: + return 0.8875772670422478 + else: + if str(context.get_value('mat2_innermost_needs_padding')) != 'True': + return 1.1467728924377265 + else: + return 1.215842963532998 + else: + if context.get_value('arith_intensity') <= 396.8774871826172: + return 0.89940161869551 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + return 0.9964328169353532 + else: + return 0.9493479238294826 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py new file mode 100644 index 0000000000000000000000000000000000000000..5cb19bbbeaef37439fe7077b64827f1766d2497e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic.py @@ -0,0 +1,315 @@ +import json +import os +from functools import partial +from typing import Any, Callable, Optional + +import torch +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + AHOperation, + Choice, + CHOICE_COL, + Feedback, + FEEDBACK_COL, + get_metadata_str_from_log, +) +from torch._inductor.autoheuristic.learned_heuristic_controller import ( + LearnedHeuristicController, +) +from torch._inductor.ir import ChoiceCaller +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import get_gpu_shared_memory + + +class LocalFeedback: + """ + To be able to collect data for a choice, a function providing feedback given a choice has to be provided. + LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice + (see pad_mm.py, where the autotuning happens locally, for an example). + """ + + def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None: + self.feedback_fn = feedback_fn + + def __call__(self, choice: Choice) -> Feedback: + return self.feedback_fn(choice) + + +class InconsistentMetadata(Exception): + """ + Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does + not match the metadata it would store if the file didn't exist. + """ + + +class AutoHeuristic: + """ + AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and + generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train + a heuristic (see torchgen/autoheuristic/). + """ + + collected_feedback: dict[Choice, Feedback] + + def __init__( + self, + fallback: Callable[[], Choice], + choices: list[Choice], + feedback: Optional[LocalFeedback], + context: AHContext, + name: str, + augment_context: Optional[list[AHOperation]] = None, + precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, + ) -> None: + """ + Initializes an instance of the AutoHeuristic class. + + Args: + fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or + AutoHeuristic is in data collection mode. + choices: A list of possible choices the heuristic can make. + feedback: An instance of LocalFeedback that provides feedback for a given choice. + context: Context to store with each choice and feedback. + name: A string that identifies the heuristic. + augment_context: An optional list of AHOperation instances that augment the context. + precondition: A callable that returns a boolean indicating whether AutoHeuristic should run. + """ + self.fallback = fallback + self.choices = choices + self.feedback = feedback + self.context = context + self.name = name + self.collected_feedback = {} + self.augment_context = augment_context + self.metadata = AHMetadata( + get_gpu_shared_memory(), + torch.cuda.get_device_capability(), + self.choices, + self.name, + ) + self.precondition = precondition + + if not self.satisfies_precondition(): + return + + if torch._inductor.config.autoheuristic_log_path == "DEFAULT": + self.log_path = self.get_default_log_path() + else: + self.log_path = torch._inductor.config.autoheuristic_log_path + + if torch._inductor.config.collect_autoheuristic(self.name): + if self.feedback is not None: + for choice in self.choices: + feedback_val = self.feedback(choice) + self.save_data(choice, feedback_val) + + def satisfies_precondition(self) -> bool: + return self.precondition is None or self.precondition( + self.metadata, self.context + ) + + def get_choice(self) -> Choice: + """ + Returns the chosen option based on the value of autoheuristic_use. + If self.name is one of the comma separated strings in autoheuristic_use, + it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option. + """ + + if not self.satisfies_precondition(): + return self.fallback() + + if torch._inductor.config.use_autoheuristic(self.name): + if self.augment_context is not None: + self.context.apply_operations(self.augment_context) + controller = LearnedHeuristicController( + self.metadata, + self.context, + ) + decision = controller.get_decision() + if decision not in self.choices: + # TODO(AlnisM): We might want to allow this in the future + return self.fallback() + if decision is not None: + return decision + return self.fallback() + + def get_top_k_choices( + self, top_k: int, always_included: Optional[list[str]] = None + ) -> Optional[list[Choice]]: + if not self.satisfies_precondition(): + return None + if torch._inductor.config.use_autoheuristic(self.name): + if self.augment_context is not None: + self.context.apply_operations(self.augment_context) + controller = LearnedHeuristicController( + self.metadata, + self.context, + ) + choices = controller.get_decisions_ranked(top_k) + if choices is None: + return None + if always_included is not None: + for choice in always_included: + if choice not in choices: + choices.append(choice) + return choices + return None + + def get_collected_feedback(self, choice: Choice) -> Any: + return self.collected_feedback.get(choice, None) + + @staticmethod + def get_device_identifier() -> str: + # a heuristic might work well for one GPU, but not for another + # we store the collected data per GPU model and learn a heuristic per GPU model + + # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names + device_name = torch.cuda.get_device_name().replace(" ", "_") + return device_name + + def get_default_log_path(self) -> str: + device_name = self.get_device_identifier() + path = f"{cache_dir()}/autoheuristic/{device_name}/" + os.makedirs(path, exist_ok=True) + path += f"{self.name}.txt" + return path + + def serialize_metadata(self) -> str: + metadata_dict = self.metadata.to_dict() + ( + num_features, + cat_features, + ) = self.context.get_numerical_and_categorical_features() + metadata_dict["numerical_features"] = num_features + metadata_dict["categorical_features"] = cat_features + return json.dumps(metadata_dict) + + def save_data(self, choice: Choice, feedback_val: Feedback) -> None: + self.collected_feedback[choice] = feedback_val + log_path = self.log_path + + lines = [] + log_exists = os.path.exists(log_path) + if log_exists: + # if log already exists, make sure it is consistent + metadata = self.serialize_metadata() + existing_metadata = get_metadata_str_from_log(self.log_path) + if existing_metadata != metadata: + raise InconsistentMetadata( + "Given metadata does not match existing metadata" + ) + else: + lines.append(self.serialize_metadata()) + feature_header = self.context.get_feature_names_csv() + header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL + lines.append(header) + + line = "" + feature_values = self.context.get_feature_values_csv() + line += feature_values + "," + choice + "," + str(feedback_val) + lines.append(line) + + with open(log_path, "a") as f: + f.write("\n".join(lines) + "\n") + + +class AutoHeuristicSelectAlgorithm(AutoHeuristic): + """ + AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic + when one wants to use AutoHeuristic for kernel choice selection. + """ + + def __init__( + self, + fallback: Callable[[], Optional[ChoiceCaller]], + choices: list[ChoiceCaller], + input_nodes: list[Any], + context: AHContext, + name: str, + augment_context: Optional[list[AHOperation]] = None, + precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, + ) -> None: + """ + The arguments choices, input_nodes and name have to match the ones used in the call to + autotune_select_algorithm(), e.g. if the following call is made + autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes + have to be used here. + """ + self.input_nodes = input_nodes + self.choicestr2choice: dict[str, ChoiceCaller] = {} + for choice in choices: + self.choicestr2choice[choice.autoheuristic_id()] = choice + choices_str = list(self.choicestr2choice.keys()) + + def fallback_str() -> str: + fallback_choice = fallback() + if fallback_choice is None: + # TODO: Find a nicer way to handle this + return "unsure" + return fallback_choice.autoheuristic_id() + + super().__init__( + fallback_str, + choices_str, + None, + context, + name, + augment_context, + precondition, + ) + + if ( + torch._inductor.config.collect_autoheuristic(self.name) + and self.satisfies_precondition() + ): + self.register_global_feedback(input_nodes, choices) + + def register_global_feedback( + self, input_nodes: list[Any], choices: list[ChoiceCaller] + ) -> None: + """ + Registers a callback in select_algorithm, which is called with the timing of each choice. + """ + + from torch._inductor.select_algorithm import ( + add_feedback_saver, + create_inputs_key, + create_precompile_key, + ) + + def store_global_feedback( + ah_inputs_key: str, + ah_precompile_key: str, + timings: dict[ChoiceCaller, float], + name: str, + input_nodes: list[Any], + choices: list[ChoiceCaller], + ) -> None: + current_inputs_key = create_inputs_key(input_nodes) + if current_inputs_key != ah_inputs_key: + return + current_precompile_key = create_precompile_key( + name, current_inputs_key, choices + ) + if current_precompile_key != ah_precompile_key: + return + for choice, time in timings.items(): + self.save_data(choice.autoheuristic_id(), time) + + inputs_key = create_inputs_key(input_nodes) + precompile_key = create_precompile_key(self.name, inputs_key, choices) + feedback_saver = partial(store_global_feedback, inputs_key, precompile_key) + add_feedback_saver(feedback_saver) + + def get_choice_caller(self) -> Optional[ChoiceCaller]: + choice = self.get_choice() + return self.choicestr2choice.get(choice, None) + + def get_top_k_choices_caller( + self, top_k: int, always_included: Optional[list[str]] = None + ) -> Optional[list[ChoiceCaller]]: + choices = self.get_top_k_choices(top_k, always_included) + if choices is None: + return None + return [self.choicestr2choice[choice] for choice in choices] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2ed414d43bedf8a2349d5db468e7b9b9a1a7e7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py @@ -0,0 +1,339 @@ +import functools +from typing import Any, Callable + +import torch + + +Feedback = float +Choice = str +Value = Any + +CHOICE_COL = "choice" +FEEDBACK_COL = "feedback" + + +class AHFeature: + """ + The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is + categorical (i.e., not a continuous variable) to learn a machine learning model. + """ + + def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None: + self.name = name + self.value = value + self.is_categorical = is_categorical + + +class AHOperation: + """ + AHOperation can be used to augment the data collected by AutoHeuristic. + One might for example store features like m, k, n, but also want to use + features like m*n, or k*n, to learn a heuristic. Instead of storing features + that can be created from the collected data, one can use AHOperation to + create new features from the collected data. + """ + + def __init__( + self, name: str, func: Callable[[Any], Value], is_categorical: bool = False + ) -> None: + self.name = name + self.func = func + self.is_categorical = is_categorical + + def apply_operation(self, data: Any) -> None: + data[self.name] = self.func(data) + + +class AHContext: + """ + This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will + store the context and the collected feedback. The context could be something like the shape of a tensor, i.e., + information that will help to learn a heuristic. + """ + + features: list[AHFeature] + context_dict: dict[str, Value] + + def __init__(self) -> None: + self.features = [] + self.context_dict = {} + + def add_feature( + self, name: str, value: Value, is_categorical: bool = False + ) -> None: + self.features.append(AHFeature(name, value, is_categorical=is_categorical)) + self.context_dict[name] = value + + def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]: + numerical_features = [] + categorical_features = [] + for feature in self.features: + if feature.is_categorical: + categorical_features.append(feature.name) + else: + numerical_features.append(feature.name) + + return numerical_features, categorical_features + + def get_feature_names_csv(self) -> str: + return ",".join(feature.name for feature in self.features) + + def get_feature_values_csv(self) -> str: + return ",".join(str(feature.value) for feature in self.features) + + def get_value(self, name: str) -> Value: + return self.context_dict[name] + + def apply_operations(self, operations: list[AHOperation]) -> None: + for op in operations: + op.apply_operation(self.context_dict) + + +class AHMetadata: + def __init__( + self, + shared_memory: Any, + device_capa: tuple[int, int], + choices: list[Choice], + name: str, + ) -> None: + # use amount of shared_memory and device_capability to identify GPU + # TODO(AlnisM): there might be a better way to do this + self.shared_memory = shared_memory + self.device_capa = device_capa + self.choices = choices + self.name = name + + def to_dict(self) -> dict[str, Value]: + return { + "shared_memory": self.shared_memory, + "device_capa": self.device_capa, + "name": self.name, + } + + +def get_metadata_str_from_log(log_path: str) -> str: + with open(log_path, newline="") as file: + json_string = file.readline().strip() + return json_string + + +def check_minsize(context: AHContext, minsize: int) -> bool: + return ( + context.get_value("m") >= minsize + and context.get_value("k") >= minsize + and context.get_value("n") >= minsize + ) + + +def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool: + if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0): + # A100 precondition + return check_minsize(context, 512) + elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0): + # H100 precondition + return check_minsize(context, 768) + return True + + +def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool: + m = context.get_value("m") + k = context.get_value("k") + n = context.get_value("n") + if m > 128 or k < 1024 or n < 1024: + return False + mat1_iscontig = context.get_value("mat1_iscontig") + mat2_iscontig = context.get_value("mat2_iscontig") + return mat1_iscontig and not mat2_iscontig + + +def get_mult_dims_ops() -> list[AHOperation]: + m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"]) + m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"]) + k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"]) + return [m_times_k_op, m_times_n_op, k_times_n_op] + + +def get_arith_intensity(data: Any) -> float: + m = data["m"] + k = data["k"] + n = data["n"] + if m == 0 or k == 0 or n == 0: + return 0.0 + return m * k * n / (m * k + k * n + m * n) + + +def pad_mm_operations() -> list[AHOperation]: + mult_dims_ops = get_mult_dims_ops() + k_div_m_times_n_op = AHOperation( + "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"]) + ) + + def bfloat_perf_hit(data: Any) -> bool: + m = data["m"] + k = data["k"] + n = data["n"] + is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16" + return k > (m * 1024) and k > (n * 1024) and is_bfloat + + bfloat_perf_hit_op = AHOperation( + "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True + ) + + arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) + dims_need_padding_ops = get_dims_need_padding_ops() + dims_multiple_ops = get_dims_multiple_ops() + is_contig_ops = get_is_contig_ops() + + ah_operations = mult_dims_ops + [ + k_div_m_times_n_op, + bfloat_perf_hit_op, + arith_intensity_op, + ] + ah_operations.extend(dims_need_padding_ops) + ah_operations.extend(dims_multiple_ops) + ah_operations.extend(is_contig_ops) + return ah_operations + + +def between_op(data: Any, dim: str, lower: int, upper: int) -> bool: + return data[dim] >= lower and data[dim] <= upper + + +def between_ops() -> list[AHOperation]: + dims = ["m", "k", "n"] + limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)] + ah_operations = [] + for dim in dims: + for lower, upper in limits: + between_op_fn = functools.partial( + between_op, dim=dim, lower=lower, upper=upper + ) + # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot + between_op_name = f"{lower}LEQ{dim}LEQ{upper}" + ah_operations.append( + AHOperation(between_op_name, between_op_fn, is_categorical=True) + ) + return ah_operations + + +def pow2_op(data: Any, dim: str, exponent: int) -> bool: + return data[dim] == 2**exponent + + +def mm_operations() -> list[AHOperation]: + mult_dims_ops = get_mult_dims_ops() + arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) + return mult_dims_ops + [arith_intensity_op] + + +def mixed_mm_operations() -> list[AHOperation]: + return mm_operations() + between_ops() + + +def is_multiple(data: Any, dim: str, mult: int) -> bool: + return data[dim] % mult == 0 + + +def get_dims_multiple_ops() -> list[AHOperation]: + multiples = [2, 4, 8, 16, 32] + dims = ["m", "k", "n"] + dims_multiple_ops = [] + for dim in dims: + for mult in multiples: + is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult) + dims_multiple_op = AHOperation( + f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True + ) + dims_multiple_ops.append(dims_multiple_op) + return dims_multiple_ops + + +def get_dims_need_padding_ops() -> list[AHOperation]: + def mat1_innermost_needs_padding_fn(data: Any) -> bool: + mat1_stride_0 = data["mat1_stride_0"] + mat1_stride_1 = data["mat1_stride_1"] + m_padded_length = data["m_padded_length"] + k_padded_length = data["k_padded_length"] + mat1_innermost_needs_padding = False + if mat1_stride_0 == 1 and m_padded_length != 0: + mat1_innermost_needs_padding = True + if mat1_stride_1 == 1 and k_padded_length != 0: + mat1_innermost_needs_padding = True + return mat1_innermost_needs_padding + + mat1_innermost_op = AHOperation( + "mat1_innermost_needs_padding", + mat1_innermost_needs_padding_fn, + is_categorical=True, + ) + + def mat2_innermost_needs_padding_fn(data: Any) -> bool: + mat2_stride_0 = data["mat2_stride_0"] + mat2_stride_1 = data["mat2_stride_1"] + k_padded_length = data["k_padded_length"] + n_padded_length = data["n_padded_length"] + mat2_innermost_needs_padding = False + if mat2_stride_0 == 1 and k_padded_length != 0: + mat2_innermost_needs_padding = True + if mat2_stride_1 == 1 and n_padded_length != 0: + mat2_innermost_needs_padding = True + return mat2_innermost_needs_padding + + mat2_innermost_op = AHOperation( + "mat2_innermost_needs_padding", + mat2_innermost_needs_padding_fn, + is_categorical=True, + ) + + def num_dims_needs_padding_fn(data: Any) -> int: + m_padded_length = data["m_padded_length"] + k_padded_length = data["k_padded_length"] + n_padded_length = data["n_padded_length"] + num_dims_needs_padding = 0 + if m_padded_length != 0: + num_dims_needs_padding += 1 + if k_padded_length != 0: + num_dims_needs_padding += 1 + if n_padded_length != 0: + num_dims_needs_padding += 1 + return num_dims_needs_padding + + num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn) + return [mat1_innermost_op, mat2_innermost_op, num_dims_op] + + +def get_is_contig_ops() -> list[AHOperation]: + def mat1_is_contig_fn(data: Any) -> bool: + stride_0 = data["mat1_stride_0"] + stride_1 = data["mat1_stride_1"] + k = data["k"] + return stride_0 == k and stride_1 == 1 + + mat1_is_contig_op = AHOperation( + "mat1_iscontig", mat1_is_contig_fn, is_categorical=True + ) + + def mat2_is_contig_fn(data: Any) -> bool: + stride_0 = data["mat2_stride_0"] + stride_1 = data["mat2_stride_1"] + n = data["n"] + return stride_0 == n and stride_1 == 1 + + mat2_is_contig_op = AHOperation( + "mat2_iscontig", mat2_is_contig_fn, is_categorical=True + ) + + return [mat1_is_contig_op, mat2_is_contig_op] + + +def context_add_strides(context: AHContext, name: str, stride: tuple[int, ...]) -> None: + for i, s in enumerate(stride): + context.add_feature(f"{name}_stride_{i}", s) + + +def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None: + using_tf32 = "not_float_32" + if dtype == torch.float32: + using_tf32 = torch.backends.cuda.matmul.allow_tf32 + context.add_feature("using_tf32", using_tf32, is_categorical=True) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..50c11eb9a712afafee7479987a6832e412cc393a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py @@ -0,0 +1,119 @@ +import importlib +import inspect +import pkgutil +from collections import defaultdict +from typing import Any, Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic + + +def find_and_instantiate_subclasses( + package_name: str, base_class: Any +) -> list[LearnedHeuristic]: + instances = [] + + package = importlib.import_module(package_name) + for _, module_name, _ in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + try: + module_basename = module_name.split(".")[-1] + if not module_basename.startswith("_"): + # learned heuristics start with an underscore + continue + module = importlib.import_module(module_name) + + # look for classes that are subclasses of base_class + for _name, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and issubclass(obj, base_class) + and obj != base_class + ): + instance = obj() + instances.append(instance) + except Exception as e: + print(f"Error processing module {module_name}: {e}") + + return instances + + +class LearnedHeuristicController: + """ + Class that finds and instantiates all learned heuristics. It also provides + a way to get the decision of a learned heuristic. + """ + + existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list) + """ + A dictionary that stores all the learned heuristics for each optimization. + The key is the optimization name, and the value is a list of LearnedHeuristic objects. + """ + + heuristics_initialized: bool = False + """ + A flag that indicates whether the learned heuristics have been initialized. + Set to true when the get_decision() function is called for the first time. + """ + + def __init__( + self, + metadata: AHMetadata, + context: AHContext, + ) -> None: + self.metadata = metadata + self.context = context + + def get_heuristics(self, name: str) -> list[LearnedHeuristic]: + """ + Returns a list of learned heuristics for the given optimization name. + """ + + if not LearnedHeuristicController.heuristics_initialized: + # learned heuristics are generated into the following package + learned_heuristics_package = "torch._inductor.autoheuristic.artifacts" + + # learned heuristics have to be of type LearnedHeuristic + base_class = LearnedHeuristic + found_heuristics = find_and_instantiate_subclasses( + learned_heuristics_package, base_class + ) + + for learned_heuristic in found_heuristics: + opt_name = learned_heuristic.get_name() + LearnedHeuristicController.existing_heuristics[opt_name].append( + learned_heuristic + ) + LearnedHeuristicController.heuristics_initialized = True + + return LearnedHeuristicController.existing_heuristics[name] + + def get_decision(self) -> Optional[Choice]: + """ + Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure + which choice to make. + """ + + heuristics = self.get_heuristics(self.metadata.name) + for heuristic in heuristics: + if heuristic.check_precondition(self.metadata, self.context): + return heuristic.get_decision(self.context, self.metadata.choices) + return None + + def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]: + heuristics = self.get_heuristics(self.metadata.name) + for heuristic in heuristics: + if heuristic.check_precondition(self.metadata, self.context): + choices = heuristic.get_decisions_ranked(self.context) + if choices is None: + return None + avail_choices = [ + choice for choice in choices if choice in self.metadata.choices + ] + return avail_choices[:top_k] + return None diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..cb2568d8a680193c0906e714fdb7736273e25ec3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -0,0 +1,95 @@ +import operator +from typing import Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) + + +class LearnedHeuristic: + """ + LearnedHeuristic is a base class for all learned heuristics. + """ + + def __init__(self) -> None: + pass + + def check_precondition( + self, + metadata: AHMetadata, + context: AHContext, + ) -> bool: + return True + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + return None + + def get_confidence_threshold(self) -> float: + return 1.0 + + def get_name(self) -> str: + return "" + + def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: + return None + + +class LearnedHeuristicRegression(LearnedHeuristic): + def __init__(self) -> None: + super().__init__() + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + return 1.0 + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + choice2feedback = {} + for choice in choices: + predicted_feedback = self.get_feedback(context, choice) + choice2feedback[choice] = predicted_feedback + sorted_choices_feedback = sorted( + choice2feedback.items(), key=operator.itemgetter(1) + ) + highest_feedback = sorted_choices_feedback[-1][1] + second_highest_feedback = sorted_choices_feedback[-2][1] + if highest_feedback / second_highest_feedback > self.get_confidence_threshold(): + return sorted_choices_feedback[-1][0] + # We are not sure which choice is the best one + return None + + +class LearnedHeuristicDecision(LearnedHeuristic): + def __init__(self) -> None: + super().__init__() + + def get_choice(self, idx: int) -> Optional[str]: + return None + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + best_choices = self.get_best_choices(context) + if not best_choices: + return None + (best_choice_proba, best_choice_idx) = best_choices[0] + if best_choice_proba <= self.get_confidence_threshold(): + return None + return self.get_choice(best_choice_idx) + + def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: + feedback_idx_list = self.get_best_choices(context) + if feedback_idx_list is None: + return None + choices = [ + self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list + ] + choices = [choice for choice in choices if choice is not None] + return choices + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + return [] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb71d4ee7f392121a5589bae04e4c3ddb6f54025 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py @@ -0,0 +1,31 @@ +import re + +import torch +from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE + + +# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like: +# "... +# from ..codecache import CudaKernelParamCache +# ..." +# In such cases, we do not need to hipify_torch the original class/file name in codegen/codecache + + +def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str: + if torch.version.hip is None and not force_hipify: + return source_codes + + def c2_repl(m: re.Match[str]) -> object: + return PYTORCH_MAP[m.group(0)] + + # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch, + # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching + # keyword at the beginning of code line. However, this can happen in codegen, + # which will cause the pattern to not match. + + # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example + # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA" + RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)") + + source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes) # type: ignore[arg-type] + return source_codes diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3931e86dd13415cad03ecb95e18bc490c004073 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp @@ -0,0 +1,443 @@ +// Definition of AOTI runtime interface functions + +#include +#include + +#include +#include + +#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception& e) { \ + std::cerr << "Error: " << e.what() << '\n'; \ + return AOTI_RUNTIME_FAILURE; \ + } catch (...) { \ + std::cerr << "Unknown exception occurred.\n"; \ + return AOTI_RUNTIME_FAILURE; \ + } \ + return AOTI_RUNTIME_SUCCESS; + +#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \ + do { \ + AOTI_RUNTIME_CHECK( \ + actual_size == expected_size, \ + "expected " + std::string(name) + " vector size to be " + \ + std::to_string(expected_size) + ", but got " + \ + std::to_string(actual_size)); \ + } while (0) + +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(false); + } + AOTINoGradGuard(const AOTINoGradGuard&) = delete; + AOTINoGradGuard(AOTINoGradGuard&&) noexcept = delete; + ~AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(prev_mode); + } + AOTINoGradGuard& operator=(const AOTINoGradGuard&) = delete; + AOTINoGradGuard& operator=(AOTINoGradGuard&&) noexcept = delete; + bool prev_mode{aoti_torch_grad_mode_is_enabled()}; +}; + +extern "C" { + +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + bool is_cpu, + const char* cubin_dir) { + return AOTInductorModelContainerCreateWithDevice( + container_handle, + num_models, + is_cpu ? "cpu" : "cuda", + cubin_dir); +} + +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir) { + if (num_models == 0) { + std::cerr << "Error: num_models must be positive, but got 0\n"; + return AOTI_RUNTIME_FAILURE; + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::optional cubin_dir_opt; + if (cubin_dir != nullptr) { + cubin_dir_opt.emplace(cubin_dir); + } + auto* container = new torch::aot_inductor::AOTInductorModelContainer( + num_models, std::string(device_str), cubin_dir_opt); + *container_handle = + reinterpret_cast(container); + }) +} + +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto* container = + reinterpret_cast( + container_handle); + delete container; + }); +} + +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_single_threaded( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *num_constants = container->num_constants(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** name) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *name = container->constant_name(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** original_fqn) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *original_fqn = container->constant_original_fqn(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, + size_t idx, + bool* from_folded) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantType( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* type) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *type = container->constant_type(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* dtype) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *dtype = container->constant_dtype(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( + AOTInductorModelContainerHandle container_handle, + size_t idx, + size_t* data_size) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *data_size = container->constant_data_size(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive) { + auto* container = + reinterpret_cast( + container_handle); + auto constants_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { const auto ret = container->extract_constants_map(use_inactive); + for (const auto& pair: ret) { + constants_map->emplace(pair.first, pair.second); + } + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update, /* user_managed = */ true); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle) { + return AOTInductorModelContainerUpdateConstantBuffer(container_handle, + constant_map_handle, + /*use_inactive*/ true, + /*validate_full_update*/ true); +} + +AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->free_inactive_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, + bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_const_fold(use_inactive, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->swap_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_inputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_inputs = container->num_inputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** ret_input_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_input_names = container->input_name(input_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_outputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_outputs = container->num_outputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, + size_t output_idx, + const char** ret_output_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_output_names = container->output_name(output_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, + const char** in_spec, + const char** out_spec) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + *in_spec = container->get_in_spec(); + *out_spec = container->get_out_spec(); + }) +} + +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle* model_handle, + AOTInductorConstantMapHandle constant_map_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto constant_array = std::make_shared>(); + auto input_map = reinterpret_cast*>(constant_map_handle); + + auto model = new torch::aot_inductor::AOTInductorModel( + constant_map, + constant_array, + "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models + "" + ); + + if (input_map) { + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + } else { + model->load_constants(); + } + + *model_handle = reinterpret_cast(model); + })} + +AOTIRuntimeError AOTInductorModelRun( + AOTInductorModelHandle model_handle, + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + model->run_impl( + input_handles, + output_handles, + (torch::aot_inductor::DeviceStreamType) nullptr, + nullptr); + }) +} + +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast( + model_handle); + delete model; + })} + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, + size_t* ret_num_outputs) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + *ret_num_outputs = model->num_outputs(); + }) +} + +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto input_map = + reinterpret_cast*>( + constant_map_handle); + + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + model->update_constants_map(std::move(constant_map)); + }) +} + +} // extern "C" diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..b99f7f786cff295591897bccd10fda34c7faff59 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/block_analysis.py @@ -0,0 +1,175 @@ +import collections +import functools +import textwrap +from typing import Optional + +import sympy +from sympy import Expr, Symbol + +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from ..utils import sympy_dot, sympy_subs +from ..virtualized import V + + +class BlockPatternMatcher: + """ + Matches block indexing expressions. + """ + + @classmethod + def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: + """ + Given a sympy expression, return the subexpression comprised only of terms + involving the specified symbol. + + For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`, + this returns `x * 5 + x ** 2`. + """ + expr = cls._preprocess(expr) + return sympy.S.Zero + sum( + term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols + ) + + @staticmethod + def get_slice_numels(dims: list[Expr]) -> list[Expr]: + """ + Compute the cumulative size of each dimension's slice. + This proceeds from the last dim up to the second. + """ + numels = collections.deque([sympy.S.One]) + for dim in dims[:0:-1]: + numel = dim * numels[0] + numels.appendleft(numel) + return [*numels] + + @staticmethod + def _preprocess(expr: Expr) -> Expr: + # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y. + return expr.expand(identity=True) + + @classmethod + def match_mod_div_block_expr( + cls, + index: Expr, + index_var: Symbol, + numel: Expr, + num_dims: int, + ) -> Optional[tuple[list[Expr], list[Expr], list[Expr]]]: + """ + Matches modular indexing expressions, converting them to implied block dimensions and strides. + See triton.py for more information. + """ + index = cls._preprocess(index) + + # Pattern match to find the strides and offset. + wild = functools.partial(sympy.Wild, exclude=[index_var]) + dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)] + strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)] + + # The first dimension's index is computed by division. + # The remaining are computed by modulo. + slice_numels = cls.get_slice_numels(dims[:num_dims]) + block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [ + ModularIndexing(index_var, numel, dim) + for dim, numel in zip(dims[1:], slice_numels[1:]) + ] + + # Calculate a linear index from block indices. + match_expr = sympy_dot(strides, block_index_exprs) + + # Heuristic: if the number of dimensions is high, check that the minimum requirements + # are met before attempting an expensive full match. see triton.py:match_mod_div_block + # for more details. In short, here we check that each subexpression in sympy.Add contains + # only FloorDiv or ModularIndexing expressions. + if num_dims >= 5: + stride, denom, other = sympy.symbols("stride denominator other", cls=wild) + mod_div_pattern = stride * ModularIndexing(index_var, denom, other) + floor_div_pattern = stride * FloorDiv(index_var, denom) + first_dim_floor_div_matched = False + match_failed = False + for arg in sympy.Add.make_args(index): + if arg.match(floor_div_pattern): + # There should only be a single FloorDiv(index, denom) expression + # corresponding to the first dimension + if first_dim_floor_div_matched: + match_failed = True + break + first_dim_floor_div_matched = True + elif arg.match(mod_div_pattern): + continue + else: + match_failed = True + break + + if match_failed: + return None + + # Pattern match. + match = index.match(match_expr) + if match is None: + return None + + # Provide default values for unmatched dims and strides. + for dim in dims[1:]: + if dim not in match: + match[dim] = sympy.S.One + for stride in strides[1:]: + if stride not in match: + match[stride] = sympy.S.Zero + + sizevars = V.graph.sizevars + + def get_match(expr: Expr) -> Expr: + return sizevars.lookup_precomputed_size(match[expr]) + + # Replace wildcards with matched expressions. + dims = [dims[0]] + [get_match(dim) for dim in dims[1:]] + strides = [get_match(stride) for stride in strides] + slice_numels = cls.get_slice_numels(dims) + block_index_exprs = [sympy_subs(expr, match) for expr in block_index_exprs] + + # The leading dimension is not directly matched in our expression. + # We solve for it by dividing the range tree numel by the product of + # all other dimensions. We quit if they are not known to be divisible. + assert dims[0] not in match, "Expected not to match the leading dimension!" + if not sizevars.statically_known_multiple_of(numel, slice_numels[0]): + return None + dims[0] = numel / slice_numels[0] + + # Sanity check that we can recover the index from the matched subexpressions. + matched_index = sympy_dot(strides, block_index_exprs) + assert sizevars.statically_known_equals( + # New precomputed replacements may be generated when the `get_match` function + # above is called, but the `index` that is being matched has not been updated. + # So remove them when checking for equivalence e.g. if ps0=3*s0 and + # index=3*s0*expr, matched_index=ps0*expr, then index == matched_index + sizevars.remove_precomputed_replacements(matched_index), + sizevars.remove_precomputed_replacements(index), + ), textwrap.dedent( + f""" + Invalid match! + Index: {index} + Matched expression: {matched_index} + """ + ) + + return dims, strides, block_index_exprs + + @classmethod + def match_affine_block_expr( + cls, + index: Expr, + index_var: Symbol, + ) -> Optional[Expr]: + """ + Matches simple expressions of the form stride * index, returning the + stride. + """ + index = cls._preprocess(index) + stride = sympy.Wild("stride", exclude=[index_var]) + m = index.match(index_var * stride) + if m is None: + return None + + return m[stride] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py new file mode 100644 index 0000000000000000000000000000000000000000..828050d6da1402be9e29fcf0c59d7c23c834267e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/common.py @@ -0,0 +1,2691 @@ +from __future__ import annotations + +import atexit +import contextlib +import dataclasses +import enum +import functools +import itertools +import logging +import math +import operator +import os +import re +import tempfile +from abc import ABC, abstractmethod +from enum import auto, Enum +from itertools import chain +from typing import ( + Any, + Callable, + cast, + ClassVar, + Generic, + NamedTuple, + Optional, + TYPE_CHECKING, + Union, +) +from typing_extensions import Self, TypeVar + +import sympy + +import torch +import torch.fx +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .. import config, metrics +from ..dtype_propagation import DtypePropagationOpsHandler +from ..ops_handler import BasicMathOpsMixin, DefaultHandler +from ..utils import ( + boolean_ops, + DeferredLineBase, + generate_assert, + get_current_backend, + IndentedBuffer, + ir_dataclass, + ScopedDict, + sympy_dot, + sympy_index_symbol, + sympy_subs, + triton_type, + unique, +) +from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V + + +if TYPE_CHECKING: + from collections.abc import Iterator, MutableMapping, Sequence + + from torch.fx import GraphModule + + from ..custom_graph_pass import CustomGraphModulePass + from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode + from ..loop_body import LoopBody + from ..scheduler import BaseScheduling, Scheduler, SchedulerNode + from .wrapper import PythonWrapperCodegen + + _T = TypeVar("_T") + SchedulingConstructor = Callable[[Optional[Scheduler]], BaseScheduling] + WrapperConstructor = type[PythonWrapperCodegen] + SymbolLike = Union[str, sympy.Symbol] + + # OpVarT should really be Union[CSEVariable, str], however this + # causes typing errors in subclasses (defined in other files). + OpVarT = str + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +log = logging.getLogger(__name__) + + +def data_type_logger(msg: str) -> None: + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Data type propagation: %s", msg) + + +@dataclasses.dataclass +class FileBackedGraphModule: + """ + Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these + map back to a GraphModule instead of Python source. + """ + + gm: GraphModule + compiled_fn: Callable[..., Any] + + def __post_init__(self) -> None: + # Write the code to a file for compatibility with debugging utilities. + # The file is deleted upon program termination. + self.tempfile = tempfile.NamedTemporaryFile( + mode="w+", suffix=".py", delete=False + ) + atexit.register(os.remove, self.tempfile.name) + with self.tempfile as f: + f.write(self.value) + + @property + def __file__(self) -> str: + return self.tempfile.name + + def call(self, args: list[Any]) -> Any: + return self.compiled_fn(*args) + + @property + def value(self) -> str: + return self.gm.code + + +class WorkspaceZeroMode(enum.Enum): + UNINITIALIZED = 0 + ZERO_ON_CALL = 1 # kernel may leave workspace dirty + ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel + + @staticmethod + def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode: + if a == b or b == WorkspaceZeroMode.UNINITIALIZED: + return a + if a == WorkspaceZeroMode.UNINITIALIZED: + return b + raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})") + + @staticmethod + def from_bool(zero_fill: bool) -> WorkspaceZeroMode: + if zero_fill: + return WorkspaceZeroMode.ZERO_ON_CALL + return WorkspaceZeroMode.UNINITIALIZED + + +class CodegenSymbol(ABC): + """ + An IR object possibly corresponding to a variable in the wrapper code. + """ + + @abstractmethod + def get_name(self) -> str: + pass + + @abstractmethod + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + pass + + +@ir_dataclass(frozen=True) +class WorkspaceArg(CodegenSymbol): + """A temporary buffer used for a single kernel, then discarded. + + Not registered as a traditional buffer since there are no users, + so it would be dead code eliminated. + + Args: + nbytes: The size of the buffer in bytes. + zero_fill: Whether the buffer should be initialized to zero. + + """ + + count: sympy.Expr + zero_mode: WorkspaceZeroMode + device: torch.device + outer_name: str + inner_name: str = "ws_ptr" + dtype: torch.dtype = torch.uint8 + + @staticmethod + def unique_name(prefix: str = "workspace_") -> str: + return f"{prefix}{next(V.graph.workspace_id)}" + + @staticmethod + def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool: + return ( + a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device + ) + + @staticmethod + def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: + return WorkspaceArg( + count=a.count + b.count, + zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), + dtype=a.dtype, + device=a.device, + inner_name=a.inner_name, + outer_name=a.outer_name, + ) + + @staticmethod + def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: + assert ( + a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name + ) + return WorkspaceArg( + count=sympy.Max(a.count, b.count), + zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), + dtype=a.dtype, + device=a.device, + inner_name=a.inner_name, + outer_name=a.outer_name, + ) + + # These methods let WorkspaceArg pretend it is a buffer to reuse allocation code + def get_device(self) -> torch.device: + return self.device + + get_device_or_error = get_device + + def get_dtype(self) -> torch.dtype: + return self.dtype + + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + return self.get_layout().get_example() + + def get_layout(self) -> FixedLayout: + from ..ir import FixedLayout + + return FixedLayout( + device=self.device, + dtype=self.dtype, + size=[self.count], + stride=[1], + ) + + @property + def layout(self) -> FixedLayout: + return self.get_layout() + + get_output_spec = get_layout + maybe_get_output_spec = get_layout + maybe_get_layout = get_layout + + def get_offset(self) -> sympy.Expr: + return sympy.S.Zero + + def get_size(self) -> list[sympy.Expr]: + return [self.count] + + def get_stride(self) -> list[sympy.Expr]: + return [sympy.S.One] + + def get_name(self) -> str: + return self.outer_name + + def get_inputs_that_alias_output(self) -> list[str]: + return [] + + +class TritonScratchWorkspace: + def __init__(self, size: int, generate_dtype_str: Callable[..., str]): + self.size = size + self._generate_dtype_str = generate_dtype_str + + def generate_dtype_str(self) -> str: + return self._generate_dtype_str() + + +@dataclasses.dataclass +class TensorArg: + name: str + buffer: str + dtype: torch.dtype + offset: sympy.Expr = sympy.S.Zero # c++ only + alias_of: Optional[str] = None # halide only + + +@dataclasses.dataclass +class SizeArg: + name: str + expr: sympy.Expr + + @property + def alias_of(self) -> Optional[str]: + return None + + +@dataclasses.dataclass +class ConstexprArg: + name: str + + +@dataclasses.dataclass +class TMADescriptorArg: + name: str + api_type: str # "experimental" or "stable" + block_shape: Optional[list[sympy.Expr]] # only needed for "stable" + dtype: Optional[torch.dtype] # only needed for "stable" + + +@dataclasses.dataclass +class DeviceCodegen: + scheduling: SchedulingConstructor + wrapper_codegen: WrapperConstructor + cpp_wrapper_codegen: Optional[WrapperConstructor] = None + + +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg] + +device_codegens: dict[str, DeviceCodegen] = {} + + +class DeviceOpOverrides: + def import_get_raw_stream_as(self, name: str) -> str: + raise NotImplementedError + + def set_device(self, device_idx: int) -> str: + raise NotImplementedError + + def synchronize(self) -> str: + raise NotImplementedError + + def device_guard(self, device_idx: int) -> str: + raise NotImplementedError + + def cpp_device_guard(self) -> str: + raise NotImplementedError + + def cpp_aoti_device_guard(self) -> str: + raise NotImplementedError + + def cpp_stream_guard(self) -> str: + raise NotImplementedError + + def cpp_aoti_stream_guard(self) -> str: + raise NotImplementedError + + def cpp_getStreamFromExternal(self) -> str: + raise NotImplementedError + + def kernel_header(self) -> str: + raise NotImplementedError + + def kernel_driver(self) -> str: + raise NotImplementedError + + def cpp_stream_type(self) -> str: + raise NotImplementedError + + def aoti_get_stream(self) -> str: + raise NotImplementedError + + def cpp_kernel_type(self) -> str: + raise NotImplementedError + + def cpp_device_ptr(self) -> str: + raise NotImplementedError + + def tma_descriptor_helpers(self) -> str: + raise NotImplementedError + + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: + # optionally return (scratch definition, arg name) + raise NotImplementedError + + +device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} +custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} + + +# The code generated by Inductor consists of two main parts: kernel code and wrapper code. +# For any new backend looking to integrate with Inductor, customization of these two main +# parts are necessary to generate its specific code. +# +# Kernel code generation is determined by different Scheduling. Consequently, a new +# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, +# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. +# +# For the Wrapper, Inductor provides a PythonWrapperCodegen class to generate the Python wrapper code +# that bridges kernels. This allows out-of-tree backends to inherit from PythonWrapperCodegen, +# and override specific member functions to create backend-specific Python wrapper code. +# +# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part +# of the logic for either Scheduling or PythonWrapperCodegen. So the Scheduling and PythonWrapperCodegen interfaces +# provide flexibility to the backend. A backend can choose to implement these classes from scratch, +# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, +# register_backend_for_device, to equip a new backend at runtime. +# +# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. +# This backend can be used as a reference: +# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 +def register_backend_for_device( + device: str, + device_scheduling: SchedulingConstructor, + device_wrapper_codegen: WrapperConstructor, + device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, + device_custom_pass: Optional[CustomGraphModulePass] = None, +) -> None: + device_codegens[device] = DeviceCodegen( + device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen + ) + custom_backend_passes[device] = device_custom_pass + + +class BackendFeature(Enum): + FOREACH = auto() + BUCKETIZE = auto() + INPLACE_BUFFERS = auto() + MASKED_SCATTER_WITH_INDEX = auto() + SCAN = auto() + SORT = auto() + TUPLE_REDUCTION = auto() + PREFER_STORE_LOOP_ORDER = auto() + TRITON_TEMPLATES = auto() + REDUCE_TO_SINGLE_ELEMENT = auto() + + +def get_backend_features( + device: Union[torch.device, str, None], +) -> OrderedSet[BackendFeature]: + if device is None: + return OrderedSet() + init_backend_registration() + if isinstance(device, torch.device): + device_type = device.type + else: + assert isinstance(device, str), type(device) + device_type = device + device = torch.device(device_type) + scheduling_ctor = get_scheduling_for_device(device_type) + assert scheduling_ctor + scheduling = scheduling_ctor(None) + return scheduling.get_backend_features(device) + + +def has_backend_feature( + device: Union[torch.device, str, None], feature: BackendFeature +) -> bool: + """See also V.graph.has_feature""" + assert isinstance(feature, BackendFeature) + return feature in get_backend_features(device) + + +def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]: + return device_codegens[device].scheduling if device in device_codegens else None + + +def get_wrapper_codegen_for_device( + device: str, cpp_wrapper: bool = False +) -> Optional[WrapperConstructor]: + if device in device_codegens: + wrapper_codegen_obj: DeviceCodegen = device_codegens[device] + return ( + wrapper_codegen_obj.cpp_wrapper_codegen + if cpp_wrapper + else wrapper_codegen_obj.wrapper_codegen + ) + return None + + +def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]: + return custom_backend_passes[device] if device in custom_backend_passes else None + + +@functools.cache +def init_backend_registration() -> None: + from .cpp import CppScheduling + from .cpp_wrapper_cpu import CppWrapperCpu + from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef + from .cpp_wrapper_gpu import CppWrapperGpu + from .cpp_wrapper_mps import CppWrapperMps + from .cuda_combined_scheduling import CUDACombinedScheduling + from .halide import HalideScheduling + from .mps import MetalScheduling + from .triton import TritonScheduling + from .wrapper import PythonWrapperCodegen + + if get_scheduling_for_device("cpu") is None: + cpu_backends = { + "cpp": CppScheduling, + "halide": HalideScheduling, + "triton": TritonScheduling, + } + register_backend_for_device( + "cpu", + lambda scheduling: cpu_backends[config.cpu_backend](scheduling), + PythonWrapperCodegen, + CppWrapperCpuArrayRef + if config.aot_inductor.allow_stack_allocation + else CppWrapperCpu, + ) + + if get_scheduling_for_device("cuda") is None: + # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation + cuda_backends = { + "triton": CUDACombinedScheduling, + "halide": HalideScheduling, + } + register_backend_for_device( + "cuda", + lambda scheduling: cuda_backends[config.cuda_backend](scheduling), + PythonWrapperCodegen, + CppWrapperGpu, + ) + + if get_scheduling_for_device("xpu") is None: + register_backend_for_device( + "xpu", + TritonScheduling, + PythonWrapperCodegen, + CppWrapperGpu, + ) + + if get_scheduling_for_device("mps") is None: + register_backend_for_device( + "mps", + MetalScheduling, + PythonWrapperCodegen, + CppWrapperMps, + ) + + private_backend = torch._C._get_privateuse1_backend_name() + if ( + private_backend != "privateuseone" + and get_scheduling_for_device(private_backend) is None + ): + from torch.utils.backend_registration import _get_custom_mod_func + + try: + device_scheduling = _get_custom_mod_func("Scheduling") + wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen") + cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen") + if device_scheduling and wrapper_codegen and cpp_wrapper_codegen: + register_backend_for_device( + private_backend, + device_scheduling, + wrapper_codegen, + cpp_wrapper_codegen, + ) + except RuntimeError: + pass + + +def index_prevent_reordering( + index: Sequence[sympy.Expr], + index_vars: Sequence[sympy.Expr], + sizes: Sequence[sympy.Expr], +) -> list[sympy.Expr]: + from ..ir import FlexibleLayout + + # added contiguous index prevents reordering + return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] + + +def register_device_op_overrides( + device: str, device_op_overrides: DeviceOpOverrides +) -> None: + device_op_overrides_dict[device] = device_op_overrides + + +def get_device_op_overrides(device: str) -> DeviceOpOverrides: + assert isinstance(device, str), type(device) + + if not device_op_overrides_dict: + from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401 + from .cuda import device_op_overrides # noqa: F401 + from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 + + return device_op_overrides_dict[device] + + +DTYPE_TO_COMPUTATION_DTYPE: dict[torch.dtype, torch.dtype] = { + torch.bfloat16: torch.float, + torch.float16: torch.float, + **{ + dtype: dtype + for dtype in [ + torch.bool, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ] + }, +} + + +def deduce_output_dtype_by_name( + op_name: str, + *args: Any, + **kwargs: Any, +) -> Optional[torch.dtype]: + """ + Given op name and a list of input dtypes, deduce the output dtype + """ + if op_name in boolean_ops(): + return torch.bool + elif op_name in ( + "to_dtype", + "index_expr", + ): + return kwargs["dtype"] if "dtype" in kwargs else args[-1] + elif op_name in ( + "rand", + "randn", + ): + return torch.float + elif op_name in ( + "get_index", + "randint64", + "load_seed", + ): + return torch.int64 + elif op_name == "reduction": + return kwargs["dtype"] if "dtype" in kwargs else args[1] + elif op_name == "constant": + return kwargs["dtype"] if "dtype" in kwargs else args[-1] + elif op_name in ( + "load", + "store", + "store_reduction", + ): + buf_name = args[1] + return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + elif op_name == "to_dtype_bitcast": + return kwargs["dtype"] if "dtype" in kwargs else args[-2] + return None + + +def check_dtype( + buffer: IndentedBuffer, var: CSEVariableType, dtype: torch.dtype +) -> None: + backend = get_current_backend() + if config.test_configs.runtime_triton_dtype_assert and backend == "triton": + buffer.writeline(f"tl.static_assert({var}.dtype == {triton_type(dtype)})") + elif config.test_configs.static_cpp_dtype_assert and backend == "cpp": + from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP + + assert isinstance(var, CppCSEVariable), type(var) + if dtype == torch.bool: + if var.is_vec: + is_same_dt = f"IsVecMaskType::value" + else: + # operator&(bool, bool) returns int and it can be used as boolean in C++ + is_same_dt = f"std::is_same_v || std::is_same_v" + else: + c_var_type = f"decltype({var})" + if var.is_vec: + c_var_type = f"typename {c_var_type}::value_type" + is_same_dt = f"std::is_same_v<{c_var_type}, {DTYPE_TO_CPP[dtype]}>" + + buffer.writeline(f"static_assert({is_same_dt});") + + +class DataTypePropagation: + def __init__(self, body: LoopBody) -> None: + self.body = body + self.graphs: dict[Union[Callable[..., Any], str], Any] = { + "root": body.root_block.graph + } + for k, v in body.subblocks.items(): + self.graphs[k] = v.graph + + def deduce_node_dtype_by_inputs(self, node: torch.fx.Node) -> Optional[torch.dtype]: + inputs = node.all_input_nodes + input_nodes = [ + n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + if len(input_nodes) == 0: + return None + + all_input_nodes_propagated = all( + OptimizationContext.key in n.meta + and n.meta[OptimizationContext.key].dtype is not None + for n in input_nodes + ) + if not all_input_nodes_propagated: + return None + + return functools.reduce( + torch.promote_types, + [n.meta[OptimizationContext.key].dtype for n in input_nodes], + ) + + def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node) -> torch.dtype: + sub_graph = self.graphs[node.target] + dtype = self.propagate_graph(sub_graph) + assert dtype + return dtype + + def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]: + if node.op == "placeholder": + return None + + if node.target == "output" and len(node.args) != 1: + # we can infer output node if it only have 1 arg + return None + + if node.target == operator.getitem: + node_arg = node.args[0] + assert isinstance(node_arg, torch.fx.Node), type(node_arg) + return self.deduce_node_dtype(node_arg) + + assert isinstance(node.target, str), type(node.target) + + if node.target.startswith("masked_subblock"): + return self.deduce_node_dtype_by_subgraph(node) + + if ( + output_dtype := deduce_output_dtype_by_name( + node.target, + *node.args, + **node.kwargs, + ) + ) is not None: + return output_dtype + + return self.deduce_node_dtype_by_inputs(node) + + def propagate_graph(self, graph: torch.fx.Graph) -> Optional[torch.dtype]: + assert graph.nodes + graph_dtype: Optional[torch.dtype] = None + # For masked_subblock, we use output's dtype to represent + # the dtype of this subgraph. For other cases, graph_dtype + # might be None + for node in graph.nodes: + if OptimizationContext.key in node.meta: + opt_ctx = node.meta[OptimizationContext.key] + else: + opt_ctx = OptimizationContext() + + opt_ctx.dtype = self.deduce_node_dtype(node) + node.meta[OptimizationContext.key] = opt_ctx + if node.target == "output": + graph_dtype = opt_ctx.dtype + return graph_dtype + + def propagate(self) -> Optional[torch.dtype]: + return self.propagate_graph(self.graphs["root"]) + + @classmethod + def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]: + return cls(body).propagate() + + @classmethod + def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]: + from ..loop_body import LoopBody + from ..scheduler import SchedulerNode + + assert isinstance(node, SchedulerNode), type(node) + assert isinstance(node._body, LoopBody), type(node._body) + return DataTypePropagation.propagate_loopbody(node._body) + + +class PythonPrinter(_PythonPrinter): + def doprint( + self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True + ) -> str: + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + +class OpDecompositions: + """ + Decomposes inductor ops + """ + + @staticmethod + def identity(value: OpVarT) -> OpVarT: + # used to trigger cse + return value + + @staticmethod + def reciprocal(x: OpVarT) -> OpVarT: + return ops.truediv(ops.constant(1, torch.int32), x) + + @staticmethod + def square(x: OpVarT) -> OpVarT: + return ops.mul(x, x) + + @staticmethod + def erfc(x: OpVarT) -> OpVarT: + return ops.sub(ops.constant(1, torch.float32), ops.erf(x)) + + @staticmethod + def erfcx(x: OpVarT) -> OpVarT: + return ops.mul(ops.exp(ops.square(x)), ops.erfc(x)) + + @staticmethod + def expm1(x: OpVarT) -> OpVarT: + return ops.sub(ops.exp(x), ops.constant(1, torch.float32)) + + @staticmethod + def log10(x: OpVarT) -> OpVarT: + return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32)) + + @staticmethod + def log2(x: OpVarT) -> OpVarT: + return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32)) + + @staticmethod + def exp2(x: OpVarT) -> OpVarT: + return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32))) + + @staticmethod + def log1p(x: OpVarT) -> OpVarT: + return ops.log(ops.add(x, ops.constant(1, torch.int32))) + + @staticmethod + def sigmoid(x: OpVarT) -> OpVarT: + one = ops.constant(1, torch.int32) + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + + @staticmethod + def relu(x: OpVarT) -> OpVarT: + return ops.maximum(x, ops.constant(0, torch.int32)) + + @staticmethod + def fma(x: OpVarT, y: OpVarT, z: OpVarT) -> OpVarT: + # for backends that don't override this (halide) + return ops.add(ops.mul(x, y), z) + + @staticmethod + def floor_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def trunc_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def remainder(a: OpVarT, b: OpVarT) -> OpVarT: + r = ops.mod(a, b) + cond = ops.and_( + ops.ne(r, ops.constant(0, torch.int32)), + ops.ne(ops.signbit(r), ops.signbit(b)), + ) + return ops.where(cond, ops.add(r, b), r) + + @staticmethod + def round_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.round(a), dtype) + + +_RE_PAREN_NOT_NEEDED = re.compile(r"[a-z0-9_.]+|\([^)]*\)|", flags=re.IGNORECASE) + + +def _all_in_parens(string: str) -> bool: + if string[0] != "(" or len(string) < 2: + return False + count = 1 + for i, char in enumerate(string[1:]): + if char == "(": + count += 1 + elif char == ")": + count -= 1 + if count == 0 and i != len(string) - 2: + return False + assert count == 0 + return True + + +class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): + @staticmethod + def paren(string: OpVarT) -> OpVarT: + if ( + isinstance(string, CSEVariable) + or _RE_PAREN_NOT_NEEDED.fullmatch(string) + or _all_in_parens(string) + ): + # don't put extra parens for strings that are already wrapped in parens + return string + return f"({string})" + + @staticmethod + def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT: + return repr(value) + + @staticmethod + def bitwise_not(x: OpVarT) -> OpVarT: + return f"~{OpOverrides.paren(x)}" + + @staticmethod + def logical_not(a: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(a)} == 0" + + @staticmethod + def bitwise_and(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_or(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_xor(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_left_shift(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_right_shift(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}" + + @staticmethod + def int_truediv(a: OpVarT, b: OpVarT) -> OpVarT: + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + + @staticmethod + def load_seed(name: str, offset: OpVarT) -> OpVarT: + return ops.load(name, sympy.Integer(offset)) + + def indirect_indexing( + self, + var: OpVarT, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg: bool = True, + ) -> sympy.Symbol: + return sympy_index_symbol(str(var)) + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + raise NotImplementedError( + f"{type(self).__name__}: check_bounds should be handled by CSEProxy" + ) + + def load(self, name: str, index: sympy.Expr) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: load should be handled by CSEProxy" + ) + + def store( + self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None + ) -> None: + raise NotImplementedError( + f"{type(self).__name__}: store should be handled by CSEProxy" + ) + + def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None: + raise NotImplementedError( + f"{type(self).__name__}: store_reduction should be handled by CSEProxy" + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[OpVarT, tuple[OpVarT, ...]], + ) -> Union[OpVarT, tuple[OpVarT, ...]]: + raise NotImplementedError( + f"{type(self).__name__}: reduction should be handled by CSEProxy" + ) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[OpVarT, ...], tuple[OpVarT, ...]], + tuple[OpVarT, ...], + ], + values: tuple[OpVarT, ...], + ) -> tuple[OpVarT, ...]: + raise NotImplementedError( + f"{type(self).__name__}: scan should be handled by CSEProxy" + ) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[OpVarT, ...], + stable: bool, + descending: bool, + ) -> tuple[OpVarT, ...]: + raise NotImplementedError( + f"{type(self).__name__}: sort should be handled by CSEProxy" + ) + + def bucketize( + self, + values: OpVarT, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: OpVarT, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[OpVarT] = None, + ) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: bucketize should be handled by CSEProxy" + ) + + def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: halide_clamp only implemented for Halide backend" + ) + + def inline_asm_elementwise( + self, + *inputs: OpVarT, + asm: str, + constraints: Optional[str] = None, + dtype: torch.dtype = torch.float32, + is_pure: bool = True, + pack: int = 1, + ) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" + ) + + def output(self, *args: OpVarT) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear at codegen time" + ) + + def placeholder(self, index: int) -> OpVarT: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear at codegen time" + ) + + @staticmethod + def _unimplemented(name: str) -> Callable[..., OpVarT]: + def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__} does not implement ops.{name}" + ) + + unimplemented.__name__ = name + unimplemented.is_unimplemented = True # type: ignore[attr-defined] + return unimplemented + + @classmethod + def _is_unimplemented(cls, name: str) -> bool: + fn = getattr(cls, name, None) + default_fn = getattr(OpsHandler, name, None) + return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False) + + @classmethod + def _initialize_pointwise_overrides(cls, target: str) -> None: + assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target + + for funcname, data in pointwise_overrides_data.items(): + impl = getattr(data, target) + if impl is None: + if cls._is_unimplemented(funcname): + setattr(cls, funcname, cls._unimplemented(funcname)) + else: + assert funcname not in cls.__dict__, ( + f"multiple definitions of {funcname} on {cls.__name__}" + ) + impl.__name__ = funcname + setattr(cls, funcname, staticmethod(impl)) + + +@dataclasses.dataclass +class OverridesData: + name: str + cpp: Callable[..., str] + # None when not impl in libdevice/triton + triton: Optional[Callable[..., str]] = None + # None when not impl in aten/.../vec + cppvec: Optional[Callable[..., str]] = None + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + halide: Optional[Callable[..., str]] = None + mps: Optional[Callable[..., str]] = None + + +# NB: if you add a new special function, don't forget to update +# torch._inductor.ops_handler too +pointwise_overrides_data: dict[str, OverridesData] = dict( + airy_ai=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"airy_ai_forward({x})", + name="special_airy_ai", + ), + bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j0_forward({x})", + triton=lambda x: f"libdevice.j0({x})", + name="special_bessel_j0", + ), + bessel_j1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j1_forward({x})", + triton=lambda x: f"libdevice.j1({x})", + name="special_bessel_j1", + ), + bessel_y0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y0_forward({x})", + triton=lambda x: f"libdevice.y0({x})", + name="special_bessel_y0", + ), + bessel_y1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y1_forward({x})", + triton=lambda x: f"libdevice.y1({x})", + name="special_bessel_y1", + ), + digamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_digamma({x})", + cppvec=lambda x: f"{x}.digamma()", + name="digamma", + ), + # no cpp nor triton implementation for entr, it is defined as decomposition + # erf, erfc + erfcx=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_erfcx({x})", + triton=lambda x: f"libdevice.erfcx({x})", + name="special_erfcx", + ), + fma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})", + cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})", + triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})", + name="fma", + ), + # erfinv, exp2, expit, gammaln + igamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="igamma", + ), + igammac=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="igammac", + ), + gammainc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="special_gammainc", + ), + gammaincc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="special_gammaincc", + ), + i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + cppvec=lambda x: f"{x}.i0()", + name="i0", + ), + i0e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0e({x})", + cppvec=lambda x: f"{x}.i0e()", + name="special_i0e", + ), + i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_i1", + ), + i1e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1e({x})", + name="special_i1e", + ), + log_ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_log_ndtr({x})", + name="special_log_ndtr", + ), + # logit + modified_bessel_i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i0_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + name="special_modified_bessel_i0", + ), + modified_bessel_i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i1_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_modified_bessel_i1", + ), + modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k0_forward({x})", + name="special_modified_bessel_k0", + ), + modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k1_forward({x})", + name="special_modified_bessel_k1", + ), + # multigamma + ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtr({x})", + name="special_ndtr", + ), + ndtri=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtri({x})", + name="special_ndtri", + ), + polygamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, + y: f"{x} == 0 ? calc_digamma({y}) : ({x} == 1 ? trigamma({y}) : calc_polygamma({y}, {x}))", + name="polygamma", + ), + # psi - alias to digamma + # round + scaled_modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})", + name="special_scaled_modified_bessel_k0", + ), + scaled_modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})", + name="special_scaled_modified_bessel_k1", + ), + # sinc + spherical_bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"spherical_bessel_j0_forward({x})", + name="special_spherical_bessel_j0", + ), + zeta=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"zeta({x}, {y})", + name="special_zeta", + ), + chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})", + name="special_chebyshev_polynomial_t", + ), + chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})", + name="special_chebyshev_polynomial_u", + ), + chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})", + name="special_chebyshev_polynomial_v", + ), + chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})", + name="special_chebyshev_polynomial_w", + ), + legendre_polynomial_p=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})", + name="special_legendre_polynomial_p", + ), + shifted_chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_t", + ), + shifted_chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_u", + ), + shifted_chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_v", + ), + shifted_chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_w", + ), + hermite_polynomial_h=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})", + name="special_hermite_polynomial_h", + ), + hermite_polynomial_he=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})", + name="special_hermite_polynomial_he", + ), + laguerre_polynomial_l=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})", + name="special_laguerre_polynomial_l", + ), +) + + +def is_buffer_removed(name: str) -> bool: + return any( + name in x + for x in ( + V.graph.removed_buffers, + V.kernel.removed_buffers, + V.graph.inplaced_to_remove, + V.kernel.inplaced_to_remove, + ) + ) + + +class DeferredLine(DeferredLineBase): + """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" + + def __init__(self, name: str, line: str): + super().__init__(line) + self.name = name + assert not isinstance(line, DeferredLineBase) + + def __call__(self) -> Optional[str]: + if not is_buffer_removed(self.name): + return self.line + return None + + def _new_line(self, line: str) -> DeferredLine: + return DeferredLine(self.name, line) + + +class BracesBuffer(IndentedBuffer): + def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]: + @contextlib.contextmanager + def ctx() -> Iterator[None]: + for _ in range(offset): + self.writeline("{") + self._indent += 1 + for _ in range(-offset): + self._indent -= 1 + self.writeline("}") + yield + for _ in range(-offset): + self.writeline("{") + self._indent += 1 + for _ in range(offset): + self._indent -= 1 + self.writeline("}") + + return ctx() + + +class InplacedBuffer(NamedTuple): + inner_name: str + other_names: list[str] + + +@dataclasses.dataclass +class ArgName: + name: str + # is_constexpr=True is used to attach a " : tl.constexpr" into the argument list + is_constexpr: bool = False + + def full_name(self) -> str: + return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}" + + +class RemovedArg: + def __str__(self) -> str: + return "REMOVED" + + +REMOVED = RemovedArg() + + +class KernelArgs: + @staticmethod + def _lookup( + prefix: str, + odict: Union[dict[_T, Union[str, RemovedArg]], dict[_T, str]], + name: _T, + ) -> str: + result: Union[str, RemovedArg] = odict.get(name, REMOVED) + if isinstance(result, RemovedArg): + odict[name] = new_result = f"{prefix}{len(odict)}" + return new_result + return result + + def __init__(self) -> None: + self.input_buffers: dict[str, str] = {} + self.output_buffers: dict[str, Union[str, RemovedArg]] = {} + self.inplace_buffers: dict[str, Union[InplacedBuffer, RemovedArg]] = {} + self.sizevars: dict[sympy.Expr, str] = {} + self.workspace_args: list[WorkspaceArg] = [] + + def __repr__(self) -> str: + return "KernelArgs({})".format( + ", ".join( + map( + repr, + [ + self.input_buffers, + self.output_buffers, + self.inplace_buffers, + self.sizevars, + ], + ) + ) + ) + + @staticmethod + def _buffer_is_marked_removed(name: Any) -> bool: + # this function is needed by MTIA + return isinstance(name, RemovedArg) + + def input(self, name: str) -> str: + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.output_buffers: + return cast(str, self.output_buffers[name]) + if name in self.inplace_buffers: + return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name + if name.startswith("seed"): + return self._lookup("seed", self.input_buffers, name) + return self._lookup("in_ptr", self.input_buffers, name) + + def output(self, name: str) -> str: + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name + return self._lookup("out_ptr", self.output_buffers, name) + + def make_inplace(self, input_name: str, output_name: str) -> None: + if input_name in V.graph.unaligned_buffers: + V.graph.unaligned_buffers.add(output_name) + assert output_name not in self.inplace_buffers, output_name + if input_name in self.inplace_buffers: + buf = self.inplace_buffers[input_name] + assert not isinstance(buf, RemovedArg) + buf.other_names.append(output_name) + self.inplace_buffers[output_name] = buf + else: + alive_buffers = [ + val + for val in self.inplace_buffers.values() + if not isinstance(val, RemovedArg) + ] + removed_buffers = [ + val + for val in self.inplace_buffers.values() + if isinstance(val, RemovedArg) + ] + inplace_buffer_idx = len(unique(alive_buffers)) + len(removed_buffers) + buf = InplacedBuffer( + f"in_out_ptr{inplace_buffer_idx}", + [input_name, output_name], + ) + self.inplace_buffers[input_name] = buf + self.inplace_buffers[output_name] = buf + + def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]: + """ + Allocate or extend a workspace buffer of nbytes bytes. + + This function manages the allocation of a workspace buffer. It either creates + a new WorkspaceArg or extends an existing one. + + Note: + - Calling this function will in-place mutate the args by adding or updating + a WorkspaceArg. + - The codegen for generating the Python argdefs and call_defs will check + this field and allocate the buffer accordingly. + - A new argument "ws_ptr" will be present in the generated code. + + Args: + nbytes (sympy.Expr): The number of bytes to allocate. + zero_fill (bool): Whether to initialize the buffer to zero. + + Returns: + Tuple[str, int]: A tuple containing: + - "ws_ptr": A string identifier for the workspace pointer. + - offset: An integer representing the byte offset in the workspace. + """ + arg = WorkspaceArg( + count=nbytes, + zero_mode=WorkspaceZeroMode.from_bool(zero_fill), + device=V.graph.get_current_device_or_throw(), + outer_name=WorkspaceArg.unique_name(), + ) + for i, existing_arg in enumerate(self.workspace_args): + if WorkspaceArg.can_join(existing_arg, arg): + offset = existing_arg.count + self.workspace_args[i] = WorkspaceArg.join(existing_arg, arg) + return existing_arg.inner_name, offset + assert ( + existing_arg.inner_name != arg.inner_name + and existing_arg.outer_name != arg.outer_name + ), existing_arg + self.workspace_args.append(arg) + return arg.inner_name, 0 + + def semaphores(self, min_size: sympy.Expr) -> str: + """ + Lazily allocate a graph-wide semaphores buffer with at least min_size. This is a single buffer shared by + all kernels and zero initialized once at graph start. Each kernel must leave the buffer zeroed on exit. + + Warning: multiple calls to this function will return the same buffer. + + Args: + min_size: the number of int32 semaphores required + + Returns: + name of the semaphores buffer + """ + current_device = V.graph.get_current_device_or_throw() + arg = WorkspaceArg( + count=min_size, + zero_mode=WorkspaceZeroMode.ZERO_PER_GRAPH, + dtype=torch.uint32, + inner_name="sem_ptr", + outer_name=f"semaphores_{current_device.type}_{current_device.index}", + device=current_device, + ) + for existing_arg in self.workspace_args: + if existing_arg.inner_name == arg.inner_name: + assert arg == existing_arg, (arg, existing_arg) + self.workspace_args.append(arg) + return arg.inner_name + + def seed_offset(self, name: str, value: int) -> str: + assert isinstance(value, int), (type(value), value) + # here we are lifting a constant integer into an arg to the kernel to try to get additional cache hits + value = sympy.Integer(value) + if value in self.sizevars: + return self.sizevars[value] + if name in self.sizevars.values(): + name = ( + f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" + ) + self.sizevars[value] = name + return name + + def size(self, name: sympy.Symbol) -> str: + assert isinstance(name, sympy.Symbol), (type(name), name) + if name.name == "seed": + self.sizevars[name] = "seed" # don't manage the name of seeds + return "seed" + return self._lookup("ks", self.sizevars, name) + + def call_names(self) -> Iterator[str]: + return chain( + self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() + ) + + def arg_name(self, name: str) -> Optional[str]: + """ + Returns inner name of a given outer name. + """ + inplaced = self.inplace_buffers.get(name, None) + if inplaced is not None and not isinstance(inplaced, RemovedArg): + return inplaced.inner_name + output_name = self.output_buffers.get(name, None) + if output_name is not None and not isinstance(output_name, RemovedArg): + return output_name + return self.input_buffers.get(name, None) + + def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str: + return buf + + def wrap_size_arg(self, size: SymbolLike) -> str: + return str(size) + + def cpp_argdefs( + self, dtype_to_cpp_type: Optional[dict[torch.dtype, str]] = None + ) -> tuple[list[str], list[str], list[str]]: + from .cpp_utils import INDEX_TYPE + + if dtype_to_cpp_type is None: + from .cpp_utils import DTYPE_TO_CPP + + dtype_to_cpp_type = DTYPE_TO_CPP + + call_args = [] + arg_defs = [] + arg_types = [] + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + outer = inplaced.other_names[-1] + inner = inplaced.inner_name + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.input_buffers.items(): + if outer in self.inplace_buffers: + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"const {cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"const {cpp_dtype}*") + for outer, maybe_inner in self.output_buffers.items(): + if outer in self.inplace_buffers or isinstance(maybe_inner, RemovedArg): + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"{cpp_dtype}* {maybe_inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.sizevars.items(): + arg_defs.append(f"const {INDEX_TYPE} {inner}") + call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + assert not self.workspace_args, "Workspace not supported on CPU " + return arg_defs, call_args, arg_types + + def python_argdefs( + self, + ) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]: + arg_defs: list[ArgName] = [] + call_args: list[str] = [] + arg_types: list[Any] = [] + precompile_args: list[KernelArgType] = [] + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + arg_defs.append(ArgName(inplaced.inner_name)) + call_args.append(inplaced.other_names[-1]) + arg_types.append(V.graph.get_dtype(inplaced.other_names[-1])) + precompile_args.append( + TensorArg( + name=inplaced.inner_name, + buffer=inplaced.other_names[-1], + dtype=V.graph.get_dtype(inplaced.other_names[-1]), + ) + ) + for outer, inner in chain( + self.input_buffers.items(), self.output_buffers.items() + ): + if outer in self.inplace_buffers or isinstance(inner, RemovedArg): + continue + arg_defs.append(ArgName(inner)) + call_args.append(outer) + arg_types.append(V.graph.get_dtype(outer)) + precompile_args.append( + TensorArg( + name=inner, + buffer=outer, + dtype=V.graph.get_dtype(outer), + ) + ) + for outer, inner in self.sizevars.items(): + arg_defs.append(ArgName(inner)) + call_args.append(outer) + arg_types.append(type(outer)) + precompile_args.append(SizeArg(inner, outer)) + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + for arg in self.workspace_args: + arg_defs.append(ArgName(arg.inner_name)) + call_args.append(arg.outer_name) + precompile_args.append(arg) + arg_types.append(arg.dtype) + return arg_defs, call_args, precompile_args, arg_types + + def aliases(self) -> Iterator[tuple[str, str]]: + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + for other in inplaced.other_names: + if ( + other in V.graph.inplaced_to_remove + or other in V.kernel.inplaced_to_remove + ): + continue + if other in self.input_buffers: + yield self.input_buffers[other], inplaced.inner_name + if other in self.output_buffers: + yield cast(str, self.output_buffers[other]), inplaced.inner_name + + def is_removed(self, name: str) -> bool: + return isinstance( + self.output_buffers.get(name, REMOVED), RemovedArg + ) and isinstance(self.inplace_buffers.get(name, REMOVED), RemovedArg) + + # Includes inplace buffers, excludes removed buffers. Essentially, + # after you do a call into this kernel, which buffers actually contain + # updated data? Modeled off of python_argdefs. + def live_output_buffers(self) -> OrderedSet[str]: + live_outs: OrderedSet[str] = OrderedSet() + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + live_outs.add(inplaced.other_names[-1]) + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or isinstance(inner, RemovedArg): + continue + live_outs.add(outer) + return live_outs + + +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + To do so, the backends can simply overload `Kernel.create_cse_var` + The "CSEVariable.update_on_args" method gives you a hook for annotations + See example of TritonCSEVariable in triton.py + """ + + def __init__( + self, + name: str, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + assert isinstance(bounds, ValueRanges), type(bounds) + self.name = name + self.bounds = bounds + self.use_count = 1 # track how many times this expression is used + self.dtype = dtype + + def __str__(self) -> str: + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + return isinstance(other, CSEVariable) and other.name == self.name + + def update_on_args(self, name: str, args: Any, kwargs: Any) -> None: + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name!r})" + + +AugmentedKeyT = TypeVar("AugmentedKeyT", default=str) +CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable) + +if TYPE_CHECKING: + ReductionCacheKey = tuple[ + torch.dtype, + ReductionType, + Union[CSEVariable, tuple[CSEVariable, ...]], + ] + + +class CSE(Generic[CSEVariableType, AugmentedKeyT]): + """Common subexpression elimination""" + + def __init__( + self, + prefix: str = "", + suffix: str = "", + name_prefix: str = "tmp", + iter_buffers: Optional[itertools.count[int]] = None, + store_cache: Optional[MutableMapping[str, CSEVariableType]] = None, + reduction_cache: Optional[ + MutableMapping[ReductionCacheKey, CSEVariableType] + ] = None, + varname_map: Optional[dict[str, CSEVariableType]] = None, + ): + self.prefix = prefix + self.suffix = suffix + self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {} + self.name_prefix = name_prefix + self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {} + self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = ( + reduction_cache or {} + ) + self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count() + self.invalidated_stores: OrderedSet[str] = OrderedSet() + self.varname_map: dict[str, CSEVariableType] = varname_map or {} + + def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None: + for name, tmp in [*self.store_cache.items()]: + if tmp not in keep_vars: + del self.store_cache[name] + self.invalidated_stores.add(name) + if keep_vars: + self._cache = {k: v for k, v in self._cache.items() if v in keep_vars} + else: + self._cache = {} + + def clone(self) -> Self: + return type(self)( + prefix=self.prefix, + suffix=self.suffix, + name_prefix=self.name_prefix, + iter_buffers=self.iter_buffer_ids, + store_cache=self.store_cache, + varname_map=self.varname_map, + reduction_cache=self.reduction_cache, + ) + + def scoped_copy(self) -> Self: + """Return a copy of using ScopedDict so changes to *_cache aren't visible in self""" + new_cse = self.clone() + new_cse._cache = ScopedDict(self._cache) + new_cse.reduction_cache = ScopedDict(self.reduction_cache) + new_cse.store_cache = ScopedDict(self.store_cache) + return new_cse + + def augment_key(self, cache_key: str) -> AugmentedKeyT: + "Override this method to augment cache key with backend specifics" + return cast(AugmentedKeyT, cache_key) + + def put(self, cache_key: str, val: CSEVariableType) -> None: + self._cache[self.augment_key(cache_key)] = val + + def contains(self, cache_key: str) -> bool: + return self.augment_key(cache_key) in self._cache + + def try_get(self, cache_key: str) -> Optional[CSEVariableType]: + return self._cache.get(self.augment_key(cache_key), None) + + def get(self, cache_key: str) -> CSEVariableType: + return self._cache[self.augment_key(cache_key)] + + def generate( + self, + buffer: IndentedBuffer, + expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase], + *, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + write: bool = True, + assignment: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> CSEVariableType: + if isinstance(expr, OpsValue): + expr = expr.value + + assert write or assignment + if isinstance(expr, CSEVariable): + # If the expressions were always created with all the information, we could + # assert expr.bounds == bounds, but sometimes the expression is created + # with the loose ValueRanges.unknown(), so we need to tighten the bounds + expr.bounds = expr.bounds.tighten(bounds) + expr.use_count += 1 + return cast(CSEVariableType, expr) + elif isinstance(expr, IndentedBuffer): + cache_key = expr.getvalue() + elif isinstance(expr, DeferredLineBase): + cache_key = expr.line + else: + assert isinstance(expr, str) + cache_key = expr + var = self.try_get(cache_key) + if not var: + var = self.newvar(bounds, dtype) + self.put(cache_key, var) + if write: + if V.kernel.current_node: + V.kernel.current_node.codegen_originating_info( + buffer, only_once=True + ) + if isinstance(expr, IndentedBuffer): + if assignment: + buffer.writeline(f"{self.prefix}{var} =") + buffer.splice(expr) + buffer.writeline(self.suffix) + elif isinstance(expr, DeferredLineBase): + assert assignment + buffer.writeline( + expr._new_line(f"{self.prefix}{var} = {expr.line}{self.suffix}") + ) + else: + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) + + # cpp backend cannot determine is_vec at this point + if ( + assignment + and ( + config.test_configs.runtime_triton_dtype_assert + or config.test_configs.static_cpp_dtype_assert + ) + and dtype is not None + and get_current_backend() != "cpp" + ): + check_dtype(buffer, var, dtype) + + else: + var.bounds = var.bounds.tighten(bounds) + var.use_count += 1 + + return var + + def newvar( + self, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + dtype: Optional[torch.dtype] = None, + ) -> CSEVariableType: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name, bounds, dtype) + self.varname_map[var_name] = var + return var + + def namedvar( + self, + name: str, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + dtype: Optional[torch.dtype] = None, + ) -> CSEVariableType: + torch._check_value( + name not in self.varname_map, lambda: f"duplicate name: {name}" + ) + var = V.kernel.create_cse_var(name, bounds, dtype) + self.varname_map[name] = var + return var + + +class CodeGen: + def __init__(self) -> None: + super().__init__() + self.exit_stack = contextlib.ExitStack() + + def __enter__(self) -> Self: + self.exit_stack.__enter__() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + +class Kernel(CodeGen, Generic[CSEVariableType]): + newvar_prefix: str = "" + suffix: str = "" + overrides: Optional[Callable[[], OpsHandler[Any]]] = None + + def __init__( + self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True + ) -> None: + super().__init__() + if increase_kernel_count: + metrics.generated_kernel_count += 1 + self.args = args or KernelArgs() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + + self.num_load = 0 + self.num_reduction = 0 + + self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix) + self.must_keep_buffers: OrderedSet[str] = OrderedSet() + self.store_buffer_names: OrderedSet[str] = OrderedSet() + self._load_mask: Optional[str] = None + self._load_other: Union[None, int, float] = None + # OrderedSet in set_current_node + self.current_node: Optional[SchedulerNode] = None + self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None + + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() + + # key: the buffer to write + # value: the buffer to read and whose memory can be reused for + # the buffer specified by key + self.inplace_update_buffers: dict[str, str] = {} + # Set minimum number of elements processed per thread. + self.min_elem_per_thread = 1 + self.kernel_name: Optional[str] = None + + @contextlib.contextmanager + def set_current_node(self, node: SchedulerNode) -> Iterator[None]: + prior = self.current_node + self.current_node = node + self.node_to_bounds = node._body.bounds().get_bounds() + try: + yield + finally: + self.current_node = prior + + @contextlib.contextmanager + def swap_buffers( + self, + lb: IndentedBuffer, + cb: Optional[IndentedBuffer] = None, + sb: Optional[IndentedBuffer] = None, + ) -> Iterator[None]: + if cb is None: + cb = lb + if disallow_stores := sb is None: + sb = IndentedBuffer() + loads = self.loads + compute = self.compute + stores = self.stores + cse = self.cse + self.loads = lb + self.compute = cb + self.stores = sb + self.cse = cse.scoped_copy() + try: + yield + finally: + self.loads = loads + self.compute = compute + self.stores = stores + self.cse = cse + if disallow_stores: + assert not sb, "unexpected store inside swap_buffers" + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + raise NotImplementedError + + def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable: + """A load the depends on an index we have read""" + prior = self.loads + try: + # put the load in the compute section as it might have deps + self.loads = self.compute + return self.load(name, index) + finally: + self.loads = prior + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + raise NotImplementedError + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + raise NotImplementedError + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + raise NotImplementedError + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + raise NotImplementedError + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + raise NotImplementedError + + def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]: + raise NotImplementedError + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + raise NotImplementedError + + @property + def assert_function(self) -> str: + raise NotImplementedError + + def indirect_assert( + self, + var: Union[CSEVariable, str], + lower: Optional[str], + upper: Optional[str], + mask: Optional[Union[CSEVariable, str]] = None, + ) -> str: + if isinstance(var, CSEVariable): + var = str(var) + assert isinstance(var, str), type(var) + assert lower is None or isinstance(lower, str) + assert upper is None or isinstance(upper, str) + if lower and upper: + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less error-prone to use and/or/not, which is supported by triton + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower} <= {var} < {upper}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = cond + else: + assert upper + cond = f"{var} < {upper}" + cond_print = cond + + if mask: + cond = f"({cond}) | ~({mask})" + + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + raise NotImplementedError + + def index_to_str(self, index: sympy.Expr) -> str: + raise NotImplementedError + + def __enter__(self) -> Self: + super().__enter__() + assert self.overrides + self.exit_stack.enter_context( + V.set_ops_handler(CSEProxy(self, self.overrides())) + ) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.remove_kernel_local_buffers() + super().__exit__(exc_type, exc_val, exc_tb) + + def remove_kernel_local_buffers(self) -> None: + """ + Any buffers that are both created and have a last use in the + same kernel can be removed. + + Note that V.graph.scheduler can be None when codegening triton template + kernels. + """ + scheduler = V.graph.scheduler + if not scheduler: + return + fused_node_names = OrderedSet( + scheduler.name_to_buf[buf].defining_op_name() + for buf in self.store_buffer_names + if buf in scheduler.name_to_buf + ) + names_to_remove: OrderedSet[str] = OrderedSet() + for name in self.store_buffer_names: + if ( + name not in self.must_keep_buffers + and name not in self.args.input_buffers + and scheduler.can_buffer_be_removed_through_fusion( + name, fused_node_names + ) + ): + names_to_remove.add(name) + + for name in names_to_remove: + if name in self.args.inplace_buffers: + buf = self.args.inplace_buffers[name] + if isinstance(buf, RemovedArg): + continue + remove = all(n in names_to_remove for n in buf.other_names) + if remove: + self.remove_inplace_buffer(name) + self.inplaced_to_remove.add(name) + else: + self.remove_buffer(name) + + def remove_buffer(self, name: str) -> None: + # Assign a special value instead of deleting the entry + # because we still rely on output_buffers's length to + # generate unique arg name. + log.debug("remove_buffer(%r)", name) + self.args.output_buffers[name] = REMOVED + self.removed_buffers.add(name) + + def remove_inplace_buffer(self, name: str) -> None: + log.debug("removing_inplace_buffer(%r)", name) + self.args.inplace_buffers[name] = REMOVED + self.removed_buffers.add(name) + + def rename_indexing( + self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr] + ) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.args.size(x) + for x in sorted_symbols + if symbol_is_type( + x, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + ), + ) + } + return sympy_subs(index, replacements) + + def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable: + return CSEVariable(*args, **kwargs) + + def arg_name(self, node: IRNode) -> Optional[str]: + """ + Returns arg name of a given input or output node. + """ + if node is None: + return None + return self.args.arg_name(node.get_name()) + + +@dataclasses.dataclass +class OptimizationContext: + key: ClassVar[str] = "opt_ctx" + + dtype: Optional[torch.dtype] = None + ops_name: str = "" + + +@functools.cache +def jinja2_env() -> Any: + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class KernelTemplate: + """ + Base class for defining kernel templates. + + Children classes: TritonTemplate, CUDATemplate + """ + + @staticmethod + def indent_except_first( + source: str, num_indents: int, indents_spacing: int = 4 + ) -> str: + lines = source.splitlines(True) + if len(lines) > 1: + lines[1:] = [ + (" " * indents_spacing * num_indents) + line for line in lines[1:] + ] + return "".join(lines) + + @staticmethod + def _template_from_string(source: str) -> Any: + env = jinja2_env() + if env is None: + return None + env.filters["indent_except_first"] = KernelTemplate.indent_except_first + from jinja2 import TemplateSyntaxError + + try: + return env.from_string(source) + except TemplateSyntaxError as e: + + class DetailedTemplateSyntaxError(TemplateSyntaxError): + def __init__(self, original_error: TemplateSyntaxError) -> None: + super().__init__( + original_error.message, + original_error.lineno, + original_error.name, + original_error.filename, + ) + self.original_error = original_error + + def __str__(self) -> str: + error_info = f"Error in template at line {self.lineno}\n" + error_info += f"Error message: {self.message}\n" + if hasattr(self.original_error, "source"): + lines = self.original_error.source.split("\n") + error_info += "Context:\n" + start = max(0, self.lineno - 2) + end = min(len(lines), self.lineno + 2) + for i in range(start, end): + if i == self.lineno - 1: + error_info += f"{i + 1}: --> {lines[i]}\n" + if hasattr(self.original_error, "column"): + error_info += ( + " " + + " " * (self.original_error.column - 1) + + "^\n" + ) + else: + error_info += f"{i + 1}: {lines[i]}\n" + return error_info + + raise DetailedTemplateSyntaxError(e) from e + + @staticmethod + def _fake_get_dtype( + fake_outs: Union[list[Buffer], Buffer], + ) -> Callable[[str], torch.dtype]: + _get_dtype_real = V.graph.get_dtype + if isinstance(fake_outs, (list, tuple)): + lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs} + else: + lookup = {fake_outs.get_name(): fake_outs.get_dtype()} + + def get_dtype(name: str) -> torch.dtype: + result = lookup.get(name) + if result is not None: + return result + return _get_dtype_real(name) + + return get_dtype + + def __init__(self, name: str) -> None: + self.name = name + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + Returns None if success, otherwise returns the error. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choices.append(self.generate(**kwargs)) + return None + except NotImplementedError as e: + log.info( + "Cannot Append Choice: %s. KernelTemplate type is %s", + e, + type(self), + stack_info=log.getEffectiveLevel() < logging.INFO, + ) + return e + + def generate(self, **kwargs: Any) -> ChoiceCaller: + """ + Generates a ChoiceCaller instance from the given arguments. + """ + + raise NotImplementedError + + +class CSEProxy(DefaultHandler): + name = "CSEProxy" + + def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): + super().__init__() + from ..bounds import ValueRangeAnalysis + + self.vr_analysis = ValueRangeAnalysis() + self.kernel = kernel + self.parent_handler = parent_handler + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + bounds = self._bound_variable(name, *args, **kwargs) + + value = getattr(self.parent_handler, name)(*args, **kwargs) + dtype_handler = DtypePropagationOpsHandler() + + backend = get_current_backend() + + output_dtype = None + if name == "masked" and backend == "triton": + output_dtype = value.dtype + elif name == "masked" and backend == "cpp": + output_dtype = V.interpreter.current_node.meta.get( + OptimizationContext.key, None + ).dtype + elif backend in ("triton", "cpp", "mps"): + dtype_op = getattr(dtype_handler, name) + output_dtype = dtype_op(*args, **kwargs) + + if backend in ("triton", "cpp"): + # maybe there are some exceptions on mps? + assert output_dtype is not None + + output_idx = 0 + + def do_cse(v: str) -> CSEVariable: + # we tree_map over the output, so we need to fetch corresponding dtype + nonlocal output_idx + var_dtype: Optional[torch.dtype] = ( + output_dtype[output_idx] + if isinstance(output_dtype, (list, tuple)) + else output_dtype + ) + output_idx += 1 + + # some cpp op implementations don't set the dtype + if backend == "cpp" and isinstance(v, CSEVariable) and v.dtype is None: + v.dtype = var_dtype + + csevar = V.kernel.cse.generate( + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, + ) + + csevar.update_on_args(name, args, kwargs) + + if ( + config.test_configs.runtime_triton_dtype_assert + or config.test_configs.static_cpp_dtype_assert + ): + assert var_dtype is not None + check_dtype(V.kernel.compute, csevar, var_dtype) + return csevar + + return pytree.tree_map(do_cse, value) + + def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]: + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from ..bounds import ValueRangeAnalysis + from ..select_algorithm import TritonTemplateKernel + from .cuda.cuda_kernel import CUDATemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + if isinstance(V.kernel, CUDATemplateKernel): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name and self.kernel.node_to_bounds is not None: + assert isinstance(self.kernel.node_to_bounds, dict), type( + self.kernel.node_to_bounds + ) + return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any(s in fx_node.target for s in ("set_indirect", "reduction", "scan")): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + assert not kwargs + + def arg_to_bound(x: Any) -> Any: + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(self.vr_analysis, name)(*arg_bounds) + return ValueRanges.unknown() + + def indirect_indexing( + self, + var: CSEVariable, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg: bool = True, + ) -> sympy.Symbol: + if isinstance(size, int): + size = sympy.Integer(size) + assert isinstance(size, sympy.Expr), (type(size), size) + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: + if wrap_neg: + stm = ops.add(var, ops.index_expr(size, torch.long)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: + lt = ops.lt(var, 0) + stm = ops.where(lt, stm, var) + else: + stm = var + + # Propagate bounds as we know how to compute them properly + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) + new_bounds = ValueRanges( + neg_bounds.lower + size, neg_bounds.upper + size + ) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: + pos = var.bounds & ValueRanges(0, int_oo) + new_bounds = new_bounds | pos + + var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds) + + sympy_var = self.parent_handler.indirect_indexing(var, size, check) + if generate_assert(check): + assert_lower = not (var.bounds.lower >= 0) + # value ranges cannot x < s when x and s are symbols + assert_upper = not isinstance(size, sympy.Number) or not ( + var.bounds.upper < size + ) + self.kernel.check_bounds(sympy_var, size, assert_lower, assert_upper) + return sympy_var + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + return self.kernel.check_bounds(expr, size, lower, upper) + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + if name in self.kernel.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_is_type(index, SymT.TMP): + return self.kernel.indirect_load(name, index) + store_cache = self.kernel.cse.store_cache + if name in store_cache: + return store_cache[name] + out = self.kernel.load(name, index) + # count load that is not in the store_cache, and also not in the + # cse cache. + if out.use_count == 1: + self.kernel.num_load += 1 + return out + + def _update_store_cache(self, name: str, value: CSEVariable) -> None: + self.kernel.cse.store_cache[name] = value + if self.kernel.current_node and name in V.graph.name_to_buffer: + buf = self.kernel.current_node.get_output(name) + for other_name in buf.get_mutations(): + self.kernel.cse.store_cache[other_name] = value + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.kernel.store_buffer_names.add(name) + if mode is None: + self._update_store_cache(name, value) + if name not in V.graph.removed_buffers: + self.kernel.store(name, index, value, mode=mode) + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + self.kernel.store_buffer_names.add(name) + self._update_store_cache(name, value) + + if name not in V.graph.removed_buffers: + return self.kernel.store_reduction(name, index, value) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + self.kernel.num_reduction += 1 + return self.kernel.reduction(dtype, src_dtype, reduction_type, value) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], + tuple[CSEVariable, ...], + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + return self.kernel.scan(dtypes, combine_fn, values) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + return self.kernel.sort(dtypes, values, stable, descending) + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + [Note: Inductor bucketize op] + + Inputs: + ------- + values: the values to be bucketized. + boundaries: a tuple containing + (a) the name of the boundaries tensor (which must be sorted, unless + the sorting tensor is present), + (b) the length of the tensor in the last dimension (i.e. the length of + one set of boundaries), + (c) the number of elements in the underlying storage (i.e. the length + of the flattened tensor, ignoring striding), and + (d) the stride of the tensor in the last dimension. + boundary_indices: indices into a flattened version of the boundaries + tensor, of the same size and shape as "values". Each index points to + the first element in the set of boundaries to be used for the + corresponding value. + indexing_dtype: the dtype to use when indexing into the boundaries + tensor. This must be int64 or int32. This additionally specifies the + dtype of the return value. + right: see "Details" below. + sorter: an optional tuple containing + (a) the name of an optional sorting tensor, used to access unsorted + boundaries without reordering the boundaries tensor, and + (b) the stride of the tensor in the last dimension. + The values in the sorting tensor are used as indices into the *last* + dimension of the boundaries tensor, with all other indices matching. + The size of the sorting and boundaries tensors must be equivalent. + sorter_indices: must be present if the sorting array is present; see + "boundary_indices" for the equivalent definition for the boundaries + tensor. + + Output: + ------- + The buckets each value belongs in, within a given set of boundaries. 0 + indicates a position before the first boundary, and len(boundaries_set) + represents a position after the last boundary. + + Details: + -------- + Given a value and a set of boundaries, calculate the bucket that each + value belongs to. This works differently in 1-D and N-D cases. + + for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True + return = [[ 0, 1, 1, 1], [1, 3, 3, 4]]. + + for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True + return = [[ 0, 1, 1, 1], [0, 1, 1, 2]] + + Note that in the N-D boundaries case, the shape of "values" and + "boundaries" must match in every dimension _except_ the last. + + When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]]. + When right == True, bucket i refers to range [boundaries[i], boundaries[i+1]). + + Boundaries must be non-decreasing, or a sorter must be provided which + would re-index offsets in a non-decreasing order (e.g. the second output + of torch.sort(offsets)). Otherwise, the result is undefined. + """ + return self.kernel.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffaa7f8b3fa19e1b614d6fdce4a03c463087ee9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp.py @@ -0,0 +1,5566 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import itertools +import math +import operator +import re +import sys +import warnings +from collections.abc import Sequence +from enum import Enum +from typing import Any, Callable, cast, Optional, Union + +import sympy + +import torch +import torch.fx +from torch._inductor import dependencies +from torch._prims_common import is_float_dtype, is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT + +from ..._dynamo.utils import counters +from .. import config, cpp_builder, cpu_vec_isa, ir, metrics +from ..loop_body import LoopBody +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + ExternKernelSchedulerNode, + ForeachKernelSchedulerNode, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from ..utils import ( + cache_on_self, + get_bounds_index_expr, + get_fused_kernel_name, + has_free_symbols, + is_multi_outputs_template, + is_welford_reduction, + parallel_num_threads, + Placeholder, + set_kernel_post_grad_provenance_tracing, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, +) +from ..virtualized import NullKernelHandler, ops, OpsValue, V +from .common import ( + BackendFeature, + BracesBuffer, + CSE, + CSEVariable, + DataTypePropagation, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + IndentedBuffer, + Kernel, + KernelArgs, + OpOverrides, + OptimizationContext, +) +from .cpp_utils import ( + _get_dtype_from_loopbodies, + _get_loop_body, + cexpr, + cexpr_index, + codegen_rand, + CppCSEVariable, + DTYPE_TO_CPP, + get_promote_dtype, + INDEX_TYPE, + LocalBufferContext, + may_unify_binary_op_mask_type, + promote_args, + template_fusion_with_epilogues_supported, + unify_mask_base_type, + value_to_cpp, +) + + +_IS_WINDOWS = sys.platform == "win32" + + +@functools.cache +def get_export_declaration(): + return "__declspec(dllexport)" if _IS_WINDOWS else "" + + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + +NATIVE_OMP_RTYPES = OrderedSet(["+", "*", "^", "||", "min", "max"]) +RTYPE_TO_CPP = { + "sum": "+", + "prod": "*", + "xor_sum": "^", + "min": "min", + "max": "max", + "argmin": "argmin", + "argmax": "argmax", + "any": "||", + "welford_reduce": "welford", + "welford_combine": "welford", +} +VECTORIZABLE_RTYPES = OrderedSet( + [ + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", + "argmin", + "argmax", + "any", + ] +) + +PYTHON_TO_CPP = { + "Tensor": "at::Tensor", + "int": "long", + "float": "double", + "bool": "bool", + "str": "std::string", + "ScalarType": "c10::ScalarType", + "MemoryFormat": "at::MemoryFormat", + "Layout": "at::Layout", + "Device": "at::Device", + "number": "at::Scalar", +} + +CONTAINER_PYTHON_TO_CPP = { + "List": "std::vector", + "Optional": "std::optional", +} + +DTYPE_LOWP_FP = [ + torch.bfloat16, + torch.float16, +] + +VECTORIZABLE_DTYPES: list[torch.dtype] = [ + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + torch.float8_e4m3fn, + torch.float8_e5m2, +] + +MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, +] + + +def reduction_init(reduction_type, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, the initial + # constant for reduction must be promoted as well + dtype = torch.float32 + if reduction_type in ("xor_sum", "sum", "any"): + return 0 + if reduction_type == "prod": + return 1 + if reduction_type in ("max", "argmax", "min", "argmin"): + cdtype = DTYPE_TO_CPP[dtype] + if dtype == torch.bool and reduction_type in ("argmin", "argmax"): + cdtype = DTYPE_TO_CPP[torch.float] + min_var = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + max_var = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + init_var = min_var if reduction_type in ("max", "argmax") else max_var + return ( + init_var + if reduction_type in ("max", "min") + else f"IndexValue<{cdtype}>{{0, {init_var}}}" + ) + if is_welford_reduction(reduction_type): + return f"Welford<{DTYPE_TO_CPP[dtype]}>()" + raise AssertionError(reduction_type) + + +def reduction_acc_type(reduction_type, dtype): + scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]] + if is_welford_reduction(reduction_type): + return f"Welford<{scalar_type}>" + if reduction_type in ("argmin", "argmax"): + if dtype == torch.bool: + scalar_type = DTYPE_TO_CPP[torch.float] + return f"IndexValue<{scalar_type}>" + return scalar_type + + +def reduction_combine( + reduction_type, + var, + next_value, + index: Optional[sympy.Symbol] = None, + src_dtype=None, +): + is_bool = src_dtype == torch.bool + if reduction_type == "sum": + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + if reduction_type == "prod": + return f"{var} * {next_value}" + if reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + if reduction_type == "any": + return f"{var} || {next_value}" + if reduction_type in ("min", "max"): + return f"{reduction_type}_propagate_nan({var}, {next_value})" + if reduction_type == "welford_reduce": + return f"welford_combine({var}, {next_value})" + if reduction_type == "welford_combine": + if isinstance(next_value, tuple): + mean, m2, weight = next_value + else: + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + if reduction_type in ("argmin", "argmax"): + if ( + hasattr(next_value, "dtype") + and next_value.dtype == torch.bool + and not next_value.is_vec + ): + if index is not None: + return f"{reduction_type}_combine({var}, static_cast({next_value}), {index})" + else: + return ( + f"{reduction_type}_combine({var}, static_cast({next_value}))" + ) + if index is not None: + return f"{reduction_type}_combine({var}, {next_value}, {index})" + else: + return f"{reduction_type}_combine({var}, {next_value})" + raise AssertionError(reduction_type) + + +def reduction_project(reduction_type, acc): + if is_welford_reduction(reduction_type): + return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight" + elif reduction_type in ("argmin", "argmax"): + return f"{acc}.index" + return acc + + +def move_code_under_inner_loop( + code: IndentedBuffer, + iter_var: sympy.Expr, + new_iter_var: str, + loop_start: sympy.Expr, + loop_end: sympy.Expr, +) -> BracesBuffer: + r""" + f(iter_var) is transformed to f(new_iter_var) under the inner loop + \/ + for (new_iter_var = loop_start; new_iter_var < loop_end; new_iter_var++) { + f(new_iter_var) + } + Please be careful while using this function, + as the variable defined in f(iter_var) will be invalid outside the for loop. + For example: + auto tmp0 = in_ptr[x0]; -> + for (new_x0 = start; new_x0 < end; new_x0++){ + auto tmp0 = in_ptr[new_x0]; + } + The tmp0 is invalid outside the loop. + """ + transformed_code = BracesBuffer() + with contextlib.ExitStack() as stack: + transformed_code.writeline( + f"for ({INDEX_TYPE} {new_iter_var} = {cexpr_index(loop_start)};" + + f"{new_iter_var} < {cexpr_index(loop_end)}; {new_iter_var}++)" + ) + stack.enter_context(transformed_code.indent()) + for _, line in enumerate(code._lines): + assert isinstance( + line, + ( + str, + DeferredLine, + ), + ) + deferred_name = None + if isinstance(line, DeferredLine): + deferred_name = line.name + line = line.line + new_line = re.sub(r"\b" + f"{iter_var}" + r"\b", f"{new_iter_var}", line) + if deferred_name: + new_line = DeferredLine(deferred_name, new_line) # type: ignore[assignment] + transformed_code.writeline(new_line) + return transformed_code + + +def reduction_prefix_array( + acc_var: Union[str, CSEVariable], + acc_type: str, + reduction_type: str, + dtype: torch.dtype, + len: Union[str, int], + init_fn, +): + """ + MSVC don't support dynamic array(VLA). So we use std::unique_ptr here. + Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler + MSVC is the only one compiler without VLA. support. Since MSVC can't get good performance here. + We just use unique_ptr make it works on MSVC. + For other compilers, we continue to use VLA to get best performance. + """ + code_buffer = IndentedBuffer() + acc_decl = ( + f"auto {acc_var}_arr = std::make_unique<{acc_type}[]>({len});" + if cpp_builder.is_msvc_cl() + else f"{acc_type} {acc_var}_arr[{len}];" + ) + code_buffer.writeline(f"{acc_decl}") + code_buffer.writelines( + [ + f"for (int i = 0; i < {len}; i++)", + "{", + f" {acc_var}_arr[i] = {init_fn(reduction_type, dtype)};", + "}", + ], + ) + return code_buffer + + +def replace_acc_name(buffer: IndentedBuffer, name: str, new_name: str): + for i, line in enumerate(buffer._lines): + assert isinstance( + line, + ( + str, + DeferredLine, + ), + ) + if isinstance(line, DeferredLine): + line.line = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line.line) + else: + buffer._lines[i] = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line) + + +@functools.lru_cache +def stride_at(index: sympy.Expr, var: sympy.Symbol): + if not index.has(var): + # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu + # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation. + # in this case, there is no dependencies between index and var. + return sympy.S.Zero + replacement = {var: var + 1} + new_index = sympy_subs(index, replacement) # type: ignore[arg-type] + return sympy.simplify(new_index - index) + + +@functools.lru_cache +def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): + """ + Simplifies the index expression within the range of a vectorized loop. + Given a vectorized loop variable `var` in the range of a loop with `vec_length`, + this function transforms the `index` into an equivalent form. It handles + simplifications for cases where `var` can be expressed as `vec_length * a + b`, + where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences + of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. + + NOTE: + The simplified index expression is intended for analysis purposes only, not + for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables + which are not dependent on the loop variable `var` in the vectorized range. Check + https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. + + Examples: + 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then + `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable + when `div` is divisible by 16. + 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free + variable when `mod` is divisible by 16. + """ + + div_freevar_id = 0 + mod_freevar_id = 0 + + def visit_indexing_div(divisor): + nonlocal div_freevar_id + result = FloorDiv(var, divisor) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") + div_freevar_id += 1 + return result + + def visit_modular_indexing(divisor, modulus): + nonlocal mod_freevar_id + result = ModularIndexing(var, divisor, modulus) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: + result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + return result + + original_index = index + + div = sympy.Wild("divisor", integer=True) + if index.has(FloorDiv): + index = index.replace(FloorDiv(var, div), visit_indexing_div) + + mod = sympy.Wild("modulus", integer=True) + if index.has(ModularIndexing): + index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) + + index = sympy.simplify(index) + if index != original_index: + return simplify_index_in_vec_range(index, var, vec_length) + + return index + + +@functools.lru_cache +def stride_at_vec_range( + index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None +): + if vec_length: + index = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index, var) + + +@dataclasses.dataclass +class ParallelDepth: + """ + A class representing parallel depth. + Includes the starting depth of parallelism and the depth of parallelism. + """ + + parallel_depth: int + start_depth: int + + +class OuterLoopFusedSchedulerNode(FusedSchedulerNode): + @classmethod + def fuse( # type: ignore[override] + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth + ): + assert node1.scheduler is node2.scheduler + assert all( + type(node) + in ( + OuterLoopFusedSchedulerNode, + SchedulerNode, + FusedSchedulerNode, + ) + for node in (node1, node2) + ) + if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return cls( + node1.scheduler, + ( + list(node1.get_outer_nodes()) + if type(node1) is OuterLoopFusedSchedulerNode + else [ + node1, + ] + ) + + ( + list(node2.get_outer_nodes()) + if type(node2) is OuterLoopFusedSchedulerNode + else [ + node2, + ] + ), + outer_loop_fusion_depth, + ) + else: + return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth) # type: ignore[list-item] + + def __init__( + self, + scheduler: "Scheduler", + outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]], + outer_loop_fusion_depth, + ): + self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = ( + outer_fused_nodes + ) + self.outer_loop_fusion_depth = outer_loop_fusion_depth + flatten_snodes = [] + for _node in self.outer_fused_nodes: + assert isinstance(_node, (SchedulerNode, FusedSchedulerNode)) + flatten_snodes.extend(list(_node.get_nodes())) + super().__init__(scheduler, flatten_snodes) # type: ignore[arg-type] + + def get_outer_nodes(self): + return self.outer_fused_nodes + + def check_outer_fusion_loop_level_attr( + self, cpp_kernel_proxy_list, outer_loop_fusion_depth + ): + # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth. + # In the fusion stage, we only examine nodes with same vars and reduce. + # However, for nodes with same vars and reduce, the loops may still have different tile splits. + # For example (test_expr_vec_non_contiguous in test_cpu_repro.py): + # * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level. + # If the check failed, we should fall back to standard loop codegen. + def _inner( + left_loop_nest: LoopNest, + right_loop_nest: LoopNest, + loop_fusion_depth: int, + current_checking_depth: int, + ) -> bool: + assert left_loop_nest.loops + assert right_loop_nest.loops + left_loop_level = left_loop_nest.loops[current_checking_depth] + right_loop_level = right_loop_nest.loops[current_checking_depth] + # Check if same loop level attr + outer_loops_attr_compare_list = [ + "var", + "size", + "offset", + "steps", + ] + if not ( + all( + getattr(left_loop_level, attr_compare) + == getattr(right_loop_level, attr_compare) + for attr_compare in outer_loops_attr_compare_list + ) + ): + return False + + assert loop_fusion_depth >= 1 + if (loop_fusion_depth := loop_fusion_depth - 1) > 0: + # Check next loop level attr + current_checking_depth = current_checking_depth + 1 + assert current_checking_depth < len(left_loop_nest.loops) + assert current_checking_depth < len(right_loop_nest.loops) + if not _inner( + left_loop_nest, + right_loop_nest, + loop_fusion_depth, + current_checking_depth, + ): + return False + + return True + + for idx in range(len(cpp_kernel_proxy_list) - 1): + left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest + right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest + if not _inner( + left_loop_nest, + right_loop_nest, + outer_loop_fusion_depth, + 0, + ): + return False + + for cpp_kernel_proxy in cpp_kernel_proxy_list: + outer_ranges = functools.reduce( + operator.mul, + cpp_kernel_proxy.ranges[:outer_loop_fusion_depth], + ) + # When the range of the first inner loop is much larger than the range of + # all outer loops, do not fuse outer loop and fallback to standard loop codegen, + # so that the inner loops with larger range have a chance to be parallelized. + # We set a conservative threshold here: + # First inner loop range / all outer loops range > 300. + if ( + len(cpp_kernel_proxy.ranges) > outer_loop_fusion_depth + and isinstance(outer_ranges, sympy.Integer) + and isinstance( + cpp_kernel_proxy.ranges[outer_loop_fusion_depth], + sympy.Integer, + ) + and outer_ranges * 300 + < cpp_kernel_proxy.ranges[outer_loop_fusion_depth] + ): + return False + + return True + + def merge_outer_fusion_kernels( + self, + cpp_kernel_proxy_list, + ): + kernel_group = cpp_kernel_proxy_list[0].kernel_group + outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group) + outer_loop_fused_kernel.inner = [ + proxy.loop_nest.from_loop_level(self.outer_loop_fusion_depth) + for proxy in cpp_kernel_proxy_list + ] + outer_fused_proxy = cpp_kernel_proxy_list[0] + outer_fused_proxy.loop_nest.kernel = outer_loop_fused_kernel + outer_fused_proxy.loop_nest.loops = outer_fused_proxy.loop_nest.loops[ + : self.outer_loop_fusion_depth + ] + return outer_fused_proxy + + +class RecordOptimizationContext: + def __init__(self, func_name: str = ""): + self.func_name = func_name + self.current_node: Optional[torch.fx.Node] = None + self.opt_ctx: Optional[OptimizationContext] = None + + def __enter__(self): + assert V.interpreter + assert V.interpreter.current_node + + self.current_node = V.interpreter.current_node + assert self.current_node is not None + if OptimizationContext.key in self.current_node.meta: + self.opt_ctx = self.current_node.meta[OptimizationContext.key] + else: + self.opt_ctx = OptimizationContext() + assert self.opt_ctx is not None + self.opt_ctx.ops_name = self.func_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.current_node + assert self.opt_ctx + self.current_node.meta[OptimizationContext.key] = self.opt_ctx + + def get_opt_ctx(self): + return self.opt_ctx + + def get_fx_node(self): + assert self.current_node + return self.current_node + + +def decltype_promoted(*args): + assert not any(isinstance(arg, CppCSEVariable) and arg.is_vec for arg in args), ( + "Promotion of vector types is not supported" + ) + + if (dt := get_promote_dtype(args)) is not None: + return DTYPE_TO_CPP[dt] + else: + return f"decltype({args[0]})" + + +class CppOverrides(OpOverrides): + """Map element-wise ops to C++""" + + @staticmethod + def add(a, b): + return f"{decltype_promoted(a, b)}({a} + {b})" + + @staticmethod + def sub(a, b): + return f"{decltype_promoted(a, b)}({a} - {b})" + + @staticmethod + def mul(a, b): + return f"{decltype_promoted(a, b)}({a} * {b})" + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): + assert isinstance(x, CppCSEVariable) + if src_dtype is None: + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in DTYPE_LOWP_FP and src_dtype == torch.float: + """ + https://github.com/pytorch/pytorch/issues/115260 + For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is + in low-precision floating point data type. When the output of node1 also serves as the output of the + kernel, the result of nodes would be different from the case when output of node1 is not the output + of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on + storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type + to the cse cache. + + Example (pseudo code): + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = load(buf) + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + Without cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + With cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = node1_output # hit cse cache + """ + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def to_dtype_bitcast(x, dtype, src_dtype): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def abs(x): + return f"std::abs({x})" + + @staticmethod + def sin(x): + return f"std::sin({x})" + + @staticmethod + def cos(x): + return f"std::cos({x})" + + @staticmethod + def neg(x): + return f"decltype({x})(-{x})" + + @staticmethod + def exp(x): + # return f"Sleef_expf_u10({x})" + return f"std::exp({x})" + + @staticmethod + def exp2(x): + return f"std::exp2({x})" + + @staticmethod + def expm1(x): + return f"std::expm1({x})" + + @staticmethod + def erf(x): + return f"std::erf({x})" + + @staticmethod + def erfc(x): + return f"std::erfc({x})" + + @staticmethod + def erfinv(x): + return f"calc_erfinv({x})" + + @staticmethod + def sqrt(x): + return f"std::sqrt({x})" + + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::log1p({x})" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def tan(x): + return f"std::tan({x})" + + @staticmethod + def tanh(x): + return f"std::tanh({x})" + + @staticmethod + def signbit(x): + """ + On windows std::signbit only support float type. + Ref: https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/signbit?view=msvc-170 + """ + return ( + f"std::signbit(static_cast({x}))" + if _IS_WINDOWS + else f"std::signbit({x})" + ) + + @staticmethod + def pow(a, b): + return f"std::pow({a}, {b})" + + @staticmethod + def log(x): + return f"std::log({x})" + + @staticmethod + def round(x): + return f"std::nearbyint({x})" + + @staticmethod + def floor(x): + return f"std::floor({x})" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + quot = f"{a} / {b}" + rem = f"{a} % {b}" + return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" + + @staticmethod + def ceil(x): + return f"std::ceil({x})" + + @staticmethod + def trunc(x): + return f"std::trunc({x})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def fmod(a, b): + return f"std::fmod({a}, {b})" + + @staticmethod + def isinf(x): + return f"std::isinf({x})" + + @staticmethod + def isnan(x): + return f"std::isnan({x})" + + @staticmethod + def lgamma(x): + return f"std::lgamma({x})" + + @staticmethod + def acos(x): + return f"std::acos({x})" + + @staticmethod + def acosh(x): + return f"std::acosh({x})" + + @staticmethod + def cosh(x): + return f"std::cosh({x})" + + @staticmethod + def sinh(x): + return f"std::sinh({x})" + + @staticmethod + def asin(x): + return f"std::asin({x})" + + @staticmethod + def asinh(x): + return f"std::asinh({x})" + + @staticmethod + def atan2(x, y): + return f"std::atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"std::atan({x})" + + @staticmethod + def atanh(x): + return f"std::atanh({x})" + + @staticmethod + def copysign(x, y): + return f"std::copysign({x}, {y})" + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys): + return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys) + + code = BracesBuffer() + exponent = V.kernel.cse.newvar(dtype=torch.int32) + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + code.writeline(f"int32_t {exponent};") + code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.put(cache_key, cse_var) + return mantissa, exponent + + @staticmethod + def hypot(x, y): + return f"std::hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"std::log10({x})" + + @staticmethod + def log2(x): + return f"std::log2({x})" + + @staticmethod + def nextafter(x, y): + return f"std::nextafter({x}, {y})" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::max({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"min_propagate_nan({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"max_propagate_nan({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"{a} ? {b} : {c}" + + @staticmethod + def mod(a, b): + return f"mod({a}, {b})" + + @staticmethod + def constant(val, dtype): + return value_to_cpp(val, DTYPE_TO_CPP[dtype]) + + @staticmethod + def index_expr(expr, dtype): + idx_str = cexpr(V.kernel.rename_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) + + @staticmethod + def masked(mask, body, other): + code = BracesBuffer() + + # Write masked operation into a lambda + body_var = V.kernel.cse.newvar() + code.writeline(f"auto {body_var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + # Use the lambda's return type as the type of other + other_code = value_to_cpp(other, f"decltype({body_var}())") + return f"{mask} ? {body_var}() : {other_code}" + + @staticmethod + def logical_and(a, b): + return f"{a} && {b}" + + @staticmethod + def logical_not(a): + return f"!{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} || {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} != {b}" + + @staticmethod + def bitwise_and(a, b): + return f"decltype({a})({a} & {b})" + + @staticmethod + def bitwise_not(a): + return f"decltype({a})(~{a})" + + @staticmethod + def bitwise_or(a, b): + return f"decltype({a})({a} | {b})" + + @staticmethod + def bitwise_xor(a, b): + return f"decltype({a})({a} ^ {b})" + + @staticmethod + def bitwise_left_shift(a, b): + code = BracesBuffer() + code.writeline("[&]()") + with code.indent(): + scalar_t = DTYPE_TO_CPP[a.dtype] + code.writeline( + f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT;" + ) + code.writeline( + f"if ((static_cast>({b}) < 0) || ({b} >= max_shift))" + ) + with code.indent(): + code.writeline(f"return decltype({a})(0);") + code.writeline( + f"return decltype({a})(static_cast>({a}) << {b});" + ) + code.writeline("()") + return code + + @staticmethod + def bitwise_right_shift(a, b): + code = BracesBuffer() + code.writeline("[&]()") + with code.indent(): + scalar_t = DTYPE_TO_CPP[a.dtype] + code.writeline( + f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT - std::is_signed_v<{scalar_t}>;" + ) + code.writeline( + f"if ((static_cast>({b}) < 0) || ({b} >= max_shift))" + ) + with code.indent(): + code.writeline(f"return decltype({a})({a} >> max_shift);") + code.writeline(f"return decltype({a})({a} >> {b});") + code.writeline("()") + return code + + @staticmethod + def rand(seed: sympy.Expr, offset: sympy.Expr): + return f"normalized_rand_cpu({seed}, {offset})" + + @staticmethod + def randn(seed: sympy.Expr, offset: sympy.Expr): + return f"randn_cpu({seed}, {offset})" + + @staticmethod + def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high): + return f"randint64_cpu({seed}, {offset}, {low}, {high})" + + @staticmethod + def sigmoid(x): + return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" + + @staticmethod + def sign(x): + code = BracesBuffer() + scalar_zero = f"decltype({x})(0)" + scalar_one = f"decltype({x})(1)" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};") + code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};") + code.writeline("return left - right;") + code.writeline("()") + return code + + +CppOverrides._initialize_pointwise_overrides("cpp") + + +class CppVecOverrides(CppOverrides): + """Map element-wise ops to aten vectorization C++""" + + def __new__(cls, *args, **kargs): + self = super().__new__(cls) + + def wrap(func): + # `CppVecKernel` generates both scalar ops and vector ops according to + # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` + # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in + # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to + # `CppOverrides` when all inputs are scalars. + # + # Notes on ops handled separately in their own functions: + # `ops.masked`: + # needs recursive handling of masked body. + # `ops.index_expr`: + # needs to further analyze the dependency of the index expression on + # the tiling itervar. + def wrapper(*args, **kwargs): + scalars = [ + arg + for arg in args + if isinstance(arg, (int, sympy.Expr)) + or (isinstance(arg, CppCSEVariable) and not arg.is_vec) + ] + vectors = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and arg.is_vec + ] + new_args = list(args) + if scalars and vectors: + new_args = [] + for arg in args: + if isinstance(arg, (int, sympy.Expr)): + if isinstance(arg, sympy.Expr) and not arg.is_number: + arg = ops.index_expr(arg, torch.int64) + else: + arg = ops.constant(arg, torch.int64) + arg = arg.value if isinstance(arg, OpsValue) else arg + new_args.append(arg) + + # DType Promotion + if vectors: + # We have saw several data type mismatch issues related with index_expr in + # the lowering phase of torch.int8. torch.int32, torch.int64. + # 1. int32 and int64 in test_torchinductor.py::test_max_pool2d_with_indices_backward3_cpu + # 2. int8 and int32 in test_torchinductor.py::test_max_pool2d5_cpu + # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu + if len(new_args) == 2: + new_args = promote_args(new_args) + elif func == CppVecOverrides.where: + new_args[1:] = promote_args(new_args[1:]) + + # Broadcast scalar args to vector + if scalars and vectors: + assert isinstance(V.kernel, CppVecKernel) + new_args = [ + ( + V.kernel.broadcast(new_arg) + if ( + isinstance(new_arg, CppCSEVariable) + and not new_arg.is_vec + and func + not in [ + CppVecOverrides.rand, + CppVecOverrides.randn, + CppVecOverrides.randint64, + ] + ) + else new_arg + ) + for new_arg in new_args + ] + + if vectors: + return func(*new_args, **kwargs) + else: + # fallback to scalar ops + scalar_ops = super(CppVecOverrides, self) + scalar_func = getattr(scalar_ops, func.__name__) + assert scalar_func is not None + return scalar_func(*args, **kwargs) + + return wrapper + + for name, method in vars(CppVecOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in [ + "masked", + "index_expr", + ]: + setattr(self, name, wrap(method.__func__)) + + return self + + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def abs(x): + return f"{x}.abs()" + + @staticmethod + def sin(x): + return f"{x}.sin()" + + @staticmethod + def cos(x): + return f"{x}.cos()" + + @staticmethod + def exp(x): + return f"{x}.exp()" + + @staticmethod + def exp2(x): + return f"{x}.exp2()" + + @staticmethod + def expm1(x): + # decompose for a better performance + vec_one = f"decltype({x})(1)" + return f"{x}.exp() - {vec_one}" + + @staticmethod + def erf(x): + return f"{x}.erf()" + + @staticmethod + def erfc(x): + return f"{x}.erfc()" + + @staticmethod + def erfinv(x): + return f"{x}.erfinv()" + + @staticmethod + def sqrt(x): + return f"{x}.sqrt()" + + @staticmethod + def eq(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})" + + @staticmethod + def ne(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + if x.dtype == torch.bool: + assert y.dtype == torch.bool + x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y)) + return f"{x_cast} != {y_cast}" + else: + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" + + @staticmethod + def lt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})" + + @staticmethod + def gt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})" + + @staticmethod + def le(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})" + + @staticmethod + def ge(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})" + + @staticmethod + def and_(x, y): + return f"{x} & {y}" + + @staticmethod + def rsqrt(x): + return f"{x}.rsqrt()" + + @staticmethod + def pow(a, b): + return f"{a}.pow({b})" + + @staticmethod + def log(x): + return f"{x}.log()" + + @staticmethod + def round(x): + return f"{x}.round()" + + @staticmethod + def floor(x): + return f"{x}.floor()" + + @staticmethod + def ceil(x): + return f"{x}.ceil()" + + @staticmethod + def trunc(x): + return f"{x}.trunc()" + + @staticmethod + def fmod(a, b): + return f"{a}.fmod({b})" + + @staticmethod + def lgamma(x): + return f"{x}.lgamma()" + + @staticmethod + def logical_and(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"~{a}" + + @staticmethod + def logical_or(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} ^ {b}" + + @staticmethod + def bitwise_and(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def load_seed(name, offset): + assert isinstance(V.kernel, CppVecKernel) + return f"{V.kernel.load(name, offset)}" + + @staticmethod + def rand(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = ( + f"result[offset_idx] = normalized_rand_cpu({seed}, offset[offset_idx]);" + ) + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randn(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randn_cpu({seed}, offset[offset_idx]);" + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randint64(seed, offset, low, high): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randint64_cpu({seed}, offset[offset_idx], {low}, {high});" + return codegen_rand(offset, code, rand_function, torch.int64) + + @staticmethod + def remainder(a, b): + assert a.dtype == b.dtype, ( + "remainder vec implementation expect the same inputs' dtype." + ) + return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" + + @staticmethod + def tan(a): + return f"{a}.tan()" + + @staticmethod + def tanh(a): + if config.cpp.use_decompose_tanh: + vec_one = f"decltype({a})(1)" + vec_two = f"decltype({a})(2)" + vec_minus_two = f"decltype({a})(-2)" + return ( + f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}" + ) + else: + return f"{a}.tanh()" + + @staticmethod + def reciprocal(a): + return f"{a}.reciprocal()" + + @staticmethod + def atan(x): + return f"{x}.atan()" + + @staticmethod + def acos(x): + return f"{x}.acos()" + + @staticmethod + def asin(x): + return f"{x}.asin()" + + @staticmethod + def cosh(x): + return f"{x}.cosh()" + + @staticmethod + def sinh(x): + return f"{x}.sinh()" + + @staticmethod + def log10(x): + return f"{x}.log10()" + + @staticmethod + def log2(x): + return f"{x}.log2()" + + @staticmethod + def nextafter(x, y): + return f"{x}.nextafter({y})" + + @staticmethod + def copysign(a, b): + return f"{a}.copysign({b})" + + @staticmethod + def atan2(a, b): + return f"{a}.atan2({b})" + + @staticmethod + def hypot(a, b): + return f"{a}.hypot({b})" + + @staticmethod + def atanh(x): + # For real x, atanh(x) = 1/2 * log((1+x)/(1-x)) + vec_one = f"decltype({x})(1)" + vec_one_half = f"decltype({x})(0.5)" + return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()" + + @staticmethod + def asinh(x): + return f"{x}.asinh()" + + @staticmethod + def acosh(x): + return f"{x}.acosh()" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"at::vec::clamp_min({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + # TODO: this seems to be dead + @staticmethod + def sigmoid(x): + return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" + + @staticmethod + def neg(x): + return f"{x}.neg()" + + @staticmethod + def floordiv(a, b): + if is_float_dtype(a.dtype): + assert a.dtype == b.dtype, ( + "div_floor_floating_vec implementation expect the same inputs' dtype." + ) + return f"div_floor_floating_vec({a}, {b})" + else: + assert all(is_integer_dtype(item.dtype) for item in [a, b]) + # a and b are integer type + _t = f"decltype({a})" + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + quot = f"{a} / {b}" + has_rem = f"({a} % {b} != {_t}(0))" + is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))" + return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + _t = f"decltype({b})" + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + return f"{a} / {b}" + + @staticmethod + def minimum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} & {b_cast}" + else: + return f"at::vec::minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} | {b_cast}" + else: + return f"at::vec::maximum({a}, {b})" + + @staticmethod + def square(a): + return f"{a} * {a}" + + @staticmethod + def where(a, b, c): + assert isinstance(V.kernel, CppVecKernel) + if b.dtype == torch.bool: + assert c.dtype == torch.bool + blendv_a, blendv_b, blendv_c = unify_mask_base_type( + V.kernel.compute, (a, b, c) + ) + return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" + else: + return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})" + + @staticmethod + def sign(x): + code = BracesBuffer() + vec_zero = f"decltype({x})(0)" + vec_one = f"decltype({x})(1)" + blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" + blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {blendv_l};") + code.writeline(f"auto right = {blendv_r};") + code.writeline("return left - right;") + code.writeline("()") + return code + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True): + assert dtype in [ + torch.bool, + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], f"{__name__} does not support {dtype}" + assert isinstance(x, CppCSEVariable) + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in DTYPE_LOWP_FP and src_dtype == torch.float: + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"{x}.log1p()" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def masked(mask, body, other): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + var = V.kernel.cse.newvar() + with V.kernel.masked(mask) as new_mask: + code.writeline(f"auto {var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + dtype = result.dtype + body_code = f"{var}()" + + def maskify_or_vecify(code): + return ( + f"{V.kernel._get_mask_type()}::from({code})" + if dtype == torch.bool + else f"{V.kernel._get_vec_type(dtype)}({code})" + ) + + if result.is_vec: + body_code_vec = body_code + else: + body_code_vec = maskify_or_vecify(body_code) + other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype]) + # loading bool as VecMask + other_code_vec = maskify_or_vecify(other_code) + assert isinstance(new_mask, CppCSEVariable), new_mask + if new_mask.is_vec: + code = BracesBuffer() + code.writeline("[&]") + with V.kernel.swap_buffers(code), code.indent(): + code.writeline(f"if ({new_mask}.all_zero())") + with code.indent(): + code.writeline(f"return {other_code_vec};") + code.writeline("else") + with code.indent(): + # Create cse variable to reuse kernel.overrides.where + body_vec_var = V.kernel.cse.generate( + V.kernel.compute, + body_code_vec, + ) + other_vec_var = V.kernel.cse.generate( + V.kernel.compute, + other_code_vec, + ) + assert isinstance(body_vec_var, CppCSEVariable), body_vec_var + assert isinstance(other_vec_var, CppCSEVariable), other_vec_var + body_vec_var.dtype = dtype + other_vec_var.dtype = dtype + overrides: type[Union[CppOverrides, CppVecOverrides]] = ( + V.kernel.overrides + ) # type: ignore[has-type] + code.writeline( + f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};" + ) + code.writeline("()") + csevar = V.kernel.cse.generate( + V.kernel.compute, + code, + ) + elif result.is_vec: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}" + ) + else: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code} : {other_code}" + ) + # `result` is explicitly added to the args for correct propagation + # of relevant itervars and vectorization status. + csevar.update_on_args("masked", (mask, body, other, result), {}) + return csevar + + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppVecKernel) + index = V.kernel.rename_indexing(expr) + tiling_var = V.kernel.itervars[V.kernel.tiling_idx] + stride = V.kernel._try_get_const_stride(index, tiling_var) + if stride == 0: + return CppOverrides.index_expr(expr, dtype) + elif stride is not None: + idx = V.kernel.cse.generate( + V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr) + ) + value = ops.to_dtype(idx, dtype) + if isinstance(value, OpsValue): + value = value.value + csevar = V.kernel.arange(value, stride) + else: + csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment] + None, index, dtype, V.kernel.compute + ) + csevar.update_on_args("index_expr", (expr, dtype), {}) + return csevar + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys): + return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys) + + cdtype = DTYPE_TO_CPP[x.dtype] + size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor + code = BracesBuffer() + exponent = V.kernel.cse.newvar(dtype=torch.int32) + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + exponent.update_on_args("frexp", (x,), kwargs={}) + mantissa.update_on_args("frexp", (x,), kwargs={}) + n_vec = V.kernel._get_num_vectors(x.dtype) + mantissa_t = ( + f"at::vec::Vectorized<{cdtype}>" + if n_vec == 1 + else f"at::vec::VectorizedN<{cdtype}, {n_vec}>" + ) + code.writeline( + f"at::vec::Vectorized {exponent};" + if n_vec == 1 + else f"at::vec::VectorizedN {exponent};" + ) + code.writeline(f"{mantissa_t} {mantissa};") + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;" + ) + code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});") + code.writeline( + f"__at_align__ std::array tmpbuf_exponent;" + ) + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;" + ) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline( + "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" + ) + code.writeline( + f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + if n_vec == 1 + else f"{exponent} = at::vec::VectorizedN::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + ) + code.writeline( + f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" + ) + code.writeline("();") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.put(cache_key, cse_var) + return mantissa, exponent + + @classmethod + def _scalarize(cls, scalar_func): + def inner(*args, **kwargs): + assert not kwargs + kernel = V.kernel + assert isinstance(kernel, CppVecKernel) + code = BracesBuffer() + code.writeline("[&]()") + vec_dtype = args[0].dtype + n_vec = kernel._get_num_vectors(vec_dtype) + size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor + scalar_args = [] + cdtype = DTYPE_TO_CPP[vec_dtype] + output_mask = scalar_func.__name__ in ( + "isinf", + "isnan", + "signbit", + ) + octype = "bool" if output_mask else cdtype + octype = ( + DTYPE_TO_CPP[args[-2]] + if (scalar_func.__name__ == "to_dtype_bitcast") + else octype + ) + with code.indent(): + for argidx, arg in enumerate(args): + if isinstance(arg, CppCSEVariable): + assert arg.is_vec + assert arg.dtype == vec_dtype + code.writeline( + f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};" + ) + code.writeline( + f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});" + ) + scalar_args.append(f"tmpbuf{argidx}[i]") + else: + scalar_args.append(arg) + code.writeline( + f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;" + ) + res = scalar_func(*scalar_args) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline(f"tmpbuf_out[i] = {res};") + if output_mask: + assert not kernel.tail_size + load_args = "tmpbuf_out.data()" + load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" + else: + load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" + if n_vec == 1: + load_fn = f"at::vec::Vectorized<{octype}>::loadu" + else: + load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" + code.writeline(f"return {load_fn}({load_args});") + code.writeline("()") + return code + + return inner + + @classmethod + def _initialize_scalarize(cls): + vec_vars = vars(CppVecOverrides) + for name, method in vars(CppOverrides).items(): + if isinstance(method, staticmethod) and name not in vec_vars: + func = cls._scalarize(method.__func__) + func.__name__ = name + setattr(cls, name, staticmethod(func)) + + +CppVecOverrides._initialize_pointwise_overrides("cppvec") +CppVecOverrides._initialize_scalarize() + + +class CppTile2DOverrides(CppVecOverrides): + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppTile2DKernel) + expr = V.kernel.transform_indexing(expr) + return CppVecOverrides.index_expr(expr, dtype) + + +class CppKernel(Kernel): + overrides = CppOverrides # type: ignore[assignment] + sexpr = cexpr + newvar_prefix = "auto " + suffix = ";" + + def __init__(self, args, num_threads): + super().__init__(args) + # Indicate when this kernel is active, for example + # {x0, {24, 26}} -> this kernel is active when x0 >= 24 and x0 < 26 + self.active_ranges: dict[sympy.Expr, tuple[sympy.Expr, ...]] = {} + # Indicate this kernel will be moved under the inner for-loop + # See move_code_under_inner_loop + self.inner_itervars: list[sympy.Symbol] = [] + self.call_ranges: Optional[tuple[sympy.Expr, ...]] = None + self.ranges: list[sympy.Expr] = [] + self.itervars: list[sympy.Symbol] = [] + self.reduction_depth = None + self.reduction_prefix = IndentedBuffer() + # We need this because when we run "reduction" nodes here, we lack + # "loop" information to decide whether we need a scalar init or an array init + # in the reduction prefix. Meanwhile, we have other information like + # reduction types and dtype to generate the reduction prefix. We record the information + # with a callable lambda function, and when we have enough information to finalize + # the reduction prefix, we can invoke the functions here with additional information. + self.reduction_prefix_generators: list[Callable] = [] # type: ignore[type-arg] + self.reduction_suffix = IndentedBuffer() + self.parallel_reduction_prefix = IndentedBuffer() + self.parallel_reduction_suffix = IndentedBuffer() + self.local_reduction_init = IndentedBuffer() + self.local_reduction_stores = IndentedBuffer() + self.is_reduction = False + self.non_parallel_reduction_prefix = IndentedBuffer() + self.non_parallel_reduction_suffix = IndentedBuffer() + self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.welford_helper_cse = CSE( + self.newvar_prefix, self.suffix, name_prefix="welford_helper" + ) + self.preloads = IndentedBuffer() + self.poststores = IndentedBuffer() + self.num_threads = num_threads # num_threads the kernel specialized for + self.reduction_omp_dec: dict[tuple[str, str], str] = {} + self.reduction_var_names: list[str] = [] + + def _gen_parallel_reduction_buffers( + self, + acc, + acc_type, + reduction_type, + dtype, + reduction_combine_fn=reduction_combine, + reduction_init_fn=reduction_init, + ): + if config.cpp.dynamic_threads and not self.parallel_reduction_prefix: + self.parallel_reduction_prefix.writeline( + "int max_threads = omp_get_max_threads();" + ) + acc_local = f"{acc}_local" + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + acc_local_in_array = f"{acc}_arr[tid]" + self.local_reduction_init.writeline( + f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};" + ) + self.parallel_reduction_prefix.splice( + reduction_prefix_array( + acc, + acc_type, + reduction_type, + dtype, + num_threads, + reduction_init_fn, + ) + ) + self.local_reduction_stores.writeline(f"{acc_local_in_array} = {acc_local};") + self.parallel_reduction_suffix.writelines( + [ + f"for (int tid = 0; tid < {num_threads}; tid++)", + "{", + f" {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array, src_dtype=dtype)};", + "}", + ], + ) + + def update_stores_with_parallel_reduction(self): + for var_name in self.reduction_var_names: + replace_acc_name(self.stores, var_name, f"{var_name}_local") + + def gen_body(self, code: Optional[BracesBuffer] = None): + assert code is None + code = BracesBuffer() + with contextlib.ExitStack() as stack: + if hasattr(self, "codegen_inner_loops"): + code.splice(self.preloads) + self.codegen_inner_loops(code) + stack.enter_context(code.indent()) + code.splice(self.loads) + code.splice(self.compute) + code.splice(self.stores) + if hasattr(self, "codegen_inner_loops"): + code.splice(self.poststores) + + if self.inner_itervars: + for idx in self.inner_itervars: + start, end = self.active_ranges[idx] + code = move_code_under_inner_loop(code, idx, f"{idx}_tail", start, end) + return code + + @contextlib.contextmanager + def masked(self, mask): + """Context manager to add an additional mask to loads and stores.""" + prior = self._load_mask + if prior: + mask = ops.and_(mask, prior) + if isinstance(mask, OpsValue): + mask = mask.value + assert isinstance(mask, CppCSEVariable) + # see NOTE [dtype of CppCSEVariable] + # mask's dtype should be bool + mask.dtype = torch.bool + + self._load_mask = mask + try: + yield mask + finally: + self._load_mask = prior + + def scale_index_with_offset( + self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 + ): + var = self.itervars[itervar_idx] + replacement = {var: var * scale + offset} + new_index = sympy_subs(index, replacement) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in cpp code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel. + """ + return cexpr(self.rename_indexing(index)) + + def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + """ + Check if an index has free symbol CppCSEVariable that depends on `itervar`. + """ + return any( + self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] + for s in index.free_symbols + if s.name in self.cse.varname_map # type: ignore[attr-defined] + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] + ) + + def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + return itervar in index.free_symbols or self.index_indirect_depends_on( + index, itervar + ) + + def var_ranges(self): + return dict(zip(self.itervars, self.ranges)) + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + indirect = free_symbol_is_type(expr, SymT.TMP) + if indirect: + # indexing in compute + csevar = ops.index_expr(expr, torch.int64).value + buffer = V.kernel.compute + else: + # indexing in loads + prior_compute = V.kernel.compute + try: + V.kernel.compute = self.loads + csevar = ops.index_expr(expr, torch.int64).value + finally: + V.kernel.compute = prior_compute + buffer = self.loads + + size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None + + line = self.indirect_assert( + csevar, "0" if lower else None, size_str, self._load_mask + ) + self.cse.generate(buffer, line, assignment=False) + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + line = f"{var}[{cexpr_index(index)}]" + csevar = self.cse.generate(self.loads, line, dtype=V.graph.get_dtype(name)) + csevar.update_on_args("load", (self, name, index), {}) + return csevar + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + index = self.rename_indexing(index) + if mode is None: + line = f"{var}[{cexpr_index(index)}] = {value};" + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + line = f"{var}[{cexpr_index(index)}] += {value};" + else: + dtype = V.graph.get_dtype(name) + # mirroring static_cast(...) in load: + value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})" + line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + + def _gen_reduction_prefix( + self, + acc: Union[CSEVariable, str], + acc_type: str, + rtype: str, + dtype: torch.dtype, + init_fn, + ): + # Generate reduction prefix + # If size is None, we will define and initialize a single reduction variable + # => float tmp_acc0 = 0; + # Otherwise, we will define and initialize a reduction array + # => float tmp_acc0_arr[size]; + # => for (int i = 0; i < size; i++) tmp_acc0_arr[i] = 0; + def inner(size: Optional[int] = None): + if size is None: + return f"{acc_type} {acc} = {init_fn(rtype, dtype)};" + else: + return reduction_prefix_array( + acc, + acc_type, + rtype, + dtype, + size, + init_fn, + ) + + return inner + + def finalize_reduction_prefix(self, size: Optional[int] = None): + for gen_fn in self.reduction_prefix_generators: + self.reduction_prefix.splice(gen_fn(size)) + + def reduction(self, dtype, src_dtype, reduction_type, value): + argmax_or_argmin = reduction_type in ("argmax", "argmin") + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + self.reduction_var_names.append(f"{acc}") + self.is_reduction = True + init_dtype = src_dtype if argmax_or_argmin else dtype + acc_type = reduction_acc_type(reduction_type, init_dtype) + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc, acc_type, reduction_type, init_dtype, reduction_init + ) + ) + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, index)};" + ) + + self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype) + result = reduction_project(reduction_type, acc) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + self.reduction_suffix.writeline( + DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};") + ) + + def set_ranges(self, lengths, reduction_lengths): + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), ( + f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + ) + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(self.ranges)) + ] + self.reduction_depth = len(lengths) + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) + + def size_hint(self): + assert self.call_ranges is not None + return V.graph.sizevars.size_hint( + sympy_product(self.call_ranges), fallback=8192 + ) + + def codegen_loops_impl(self, loop_nest, code, worksharing): + assert isinstance(self, CppKernelProxy) + threads = parallel_num_threads() + assert self.call_ranges is not None + if isinstance(loop_nest.kernel, OuterLoopFusedKernel): + par_depth = loop_nest.kernel.decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + else: + par_depth = self.decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + + is_reduction_loop = ( + loop_nest.loops is not None + and loop_nest.loops[par_depth.start_depth].is_reduction + ) + with contextlib.ExitStack() as stack: + if par_depth.parallel_depth: + if is_reduction_loop: + # need to close the worksharing scope to define reduction vars outside it + worksharing.close() + else: + worksharing.parallel(threads) + loop_nest.mark_parallel(par_depth) + elif threads > 1: + if worksharing.single(): + stack.enter_context(code.indent()) + + def gen_kernel(_loop_nest: LoopNest): + def is_parallel_reduction(): + assert _loop_nest.loops + root = _loop_nest.loops[par_depth.start_depth] + return root.is_reduction and root.parallel + + kernel = _loop_nest.get_kernel() + if isinstance(kernel, OuterLoopFusedKernel): + for _loop_nest in kernel.inner: + gen_loop_nest(_loop_nest) + else: + assert isinstance(kernel, CppKernelProxy) + if _loop_nest.loops is not None and is_parallel_reduction(): + kernel.update_stores_with_parallel_reduction() + with contextlib.ExitStack() as stack: + stack.enter_context(code.indent()) + kernel.gen_body(code) + + def get_reduction_prefix_suffix(kernel, parallel=False, is_suffix=False): + if is_suffix: + suffix = kernel.reduction_suffix + if parallel: + suffix = kernel.parallel_reduction_suffix + suffix + else: + suffix = kernel.non_parallel_reduction_suffix + suffix + return suffix + else: + prefix = kernel.reduction_prefix + if parallel: + prefix = prefix + kernel.parallel_reduction_prefix + else: + prefix = prefix + kernel.non_parallel_reduction_prefix + return prefix + + def gen_loop_with_reduction( + _loop_nest: LoopNest, depth: int = 0, in_reduction=False + ): + kernel = _loop_nest.get_kernel() + assert _loop_nest.loops + loop = _loop_nest.loops[depth] + with contextlib.ExitStack() as stack_outer: + if loop.is_reduction and not in_reduction: + reduction_prefix = get_reduction_prefix_suffix( + kernel, loop.parallel, is_suffix=False + ) + if reduction_prefix: + stack_outer.enter_context(code.indent()) + code.splice(reduction_prefix) + if is_reduction_loop and loop.parallel: + worksharing.parallel(threads) + if kernel.local_reduction_init: + assert kernel.local_reduction_stores + code.splice(kernel.local_reduction_init) + + gen_loop_at(_loop_nest, depth) + + if is_reduction_loop and loop.parallel: + if kernel.local_reduction_stores: + code.splice(kernel.local_reduction_stores) + worksharing.close() + if loop.is_reduction and not in_reduction: + code.splice( + get_reduction_prefix_suffix( + kernel, loop.parallel, is_suffix=True + ) + ) + + def gen_loop_at(_loop_nest: LoopNest, depth: int = 0): + with contextlib.ExitStack() as stack: + assert _loop_nest.loops + loop = _loop_nest.loops[depth] + loop_lines = loop.lines() + if loop_lines is None: + return + code.writelines(loop_lines) + stack.enter_context(code.indent()) + gen_loop_nest(_loop_nest, depth + 1, loop.is_reduction) + + def gen_loop_nest( + _loop_nest: LoopNest, + depth: int = 0, + in_reduction: bool = False, + ): + if _loop_nest.loops is None or depth == len(_loop_nest.loops): # type: ignore[arg-type] + gen_kernel(_loop_nest) + else: + gen_loop_with_reduction(_loop_nest, depth, in_reduction) + + stack.enter_context(code.indent()) + + if ( + isinstance(loop_nest.kernel, OuterLoopFusedKernel) + and isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + # Allocate local buffer + local_buffers = V.local_buffer_context.local_buffers + for local_buffer in local_buffers.values(): + # For dynamic size, rename s to ks + local_buf_size = sympy_product( + [ + self.rename_indexing(size_val) + for size_val in local_buffer.get_layout().size + ] + ) + local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype] + allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})" + local_buffer_name = local_buffer.get_name() + code.splice( + f"std::unique_ptr<{local_buf_dtype} []> buf_{local_buffer_name} = {allocate};" + ) + code.splice( + f"{local_buf_dtype}* {local_buffer_name} = buf_{local_buffer_name}.get();" + ) + gen_loop_nest(loop_nest) + + def codegen_loops(self, code, worksharing): + loop_nest = LoopNest.build(self) + self.codegen_loops_impl(loop_nest, code, worksharing) + + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + + def decide_parallel_depth(self, max_parallel_depth, threads): + assert self.call_ranges is not None + ranges = self.call_ranges[ + max_parallel_depth.start_depth : ( + max_parallel_depth.start_depth + max_parallel_depth.parallel_depth + ) + ] + seq = self.size_hint() + par = 1 + depth = 0 + for expr in ranges: + hint = V.graph.sizevars.size_hint(expr, fallback=8192) + if par >= 2 * threads or par == threads: + break + if seq // threads < config.cpp.min_chunk_size: + # not enough work + break + depth += 1 + par *= hint + seq /= hint + # if we assume thread number is dynamic, make sure we + # have at least one parallel scope and let OMP runtime + # to manage the serial vs. parallel. + if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: + depth = 1 + return ParallelDepth( + parallel_depth=depth, start_depth=max_parallel_depth.start_depth + ) + + @contextlib.contextmanager + def write_to_suffix(self): + prior = (self.loads, self.compute, self.stores, self.cse) + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse = self.cse.clone() + yield + self.reduction_suffix.splice(self.loads) + self.reduction_suffix.splice(self.compute) + self.reduction_suffix.splice(self.stores) + (self.loads, self.compute, self.stores, self.cse) = prior + + def create_cse_var(self, *args, **kwargs): + return CppCSEVariable(*args, **kwargs) + + def get_to_dtype_expr(self, src, dtype, src_dtype): + return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({src})" + + def cache_dtype_convert(self, dst, dst_dtype, src, src_dtype): + expr = self.get_to_dtype_expr(src, dst_dtype, src_dtype) + self.cse.put(expr, dst) + + def codegen_conditions( + self, + code: BracesBuffer, + prefix: Optional[str] = None, + var: Optional[sympy.Symbol] = None, + ): + if prefix is None: + prefix = "" + if not self.active_ranges: + return True + conditions = [] + + def gen(start, end, var): + if start == end: + return False + var_id = None + for i, _var in enumerate(self.itervars): + if var == _var: + var_id = i + break + if ( + type(self) == CppKernel + and var_id + and start == 0 + and end == self.ranges[var_id] + ): + end = 1 + conditions.append(f"{var} >= {cexpr_index(start)}") + conditions.append(f"{var} < {cexpr_index(end)}") + return True + + if var is not None: + assert var in self.active_ranges + start, end = self.active_ranges[var] + if not gen(start, end, var): + return False + else: + for _var, _range in self.active_ranges.items(): + start, end = _range + if not gen(start, end, _var): + return False + joined_conditions = " && ".join(conditions) + if joined_conditions: + code.writeline(f"if({prefix}({joined_conditions}))") + return True + else: + return False + + +class CppVecKernel(CppKernel): + overrides = CppVecOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_idx, + tail_size=None, + ): + super().__init__(args, num_threads) + self.vec_isa = cpu_vec_isa.pick_vec_isa() + assert self.vec_isa + assert tiling_factor > 0, "Expect pass in Non-Zero tiling_factor explicitly" + self.tiling_factor = tiling_factor + self.tiling_idx = tiling_idx + self.tail_size = tail_size + self.num_elems = tail_size if tail_size else tiling_factor + + def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol): + if self.index_indirect_depends_on(index, itervar): + return None + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + return None + stride = stride_at_vec_range(index, itervar, self.tiling_factor) + return stride if stride.is_number else None + + def _get_num_vectors(self, dtype: torch.dtype) -> int: + num_vectors = math.ceil( + self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + ) + assert num_vectors >= 1 + return num_vectors + + def _get_raw_num_vectors(self, dtype: torch.dtype) -> float: + # This utility function is used to check if the vector lanes has been + # fully utilized. For example, uint8 will only use 1/4 of the vector lanes. + return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + + def _get_vec_type(self, dtype: torch.dtype) -> str: + num_vectors = self._get_num_vectors(dtype) + if num_vectors == 1: + return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + else: + return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str: + if dtype == torch.bool: + return "" + num_vectors = self._get_num_vectors(dtype) + return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str: + assert mask.dtype == torch.bool, repr(mask) + num_vectors = self._get_num_vectors(dtype) + return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()" + + def _get_vec_load_line( + self, + var: str, + index: sympy.Expr, + dtype: torch.dtype, + load_mask: Optional[CppCSEVariable] = None, + ): + """ + Get a load line str that loads a vector from `var` at `index` of type `dtype`. + If `load_mask` is not None, we do a masked load accordingly. + Notes on the `dtype`: + 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. + It means we load half of the vector lanes for 16-bit data types and quarter of the + vector lanes for 8-bit data types. + 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. + """ + cpp_type = DTYPE_TO_CPP[dtype] + num_vectors = self._get_num_vectors(dtype) + load_mask_str = None + if load_mask: + if not load_mask.is_vec: + # TODO: avoid hard-code torch.float + load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})" + else: + load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}" + loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var + if dtype == torch.bool: + # TODO: should we consider load mask here? + line = f"{self._get_mask_type()}::from({loadbuf})" + else: + line = ( + f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})" + ) + return line + + def _load_or_store_non_contiguous( + self, + var: Optional[str], + index: sympy.Expr, + dtype: torch.dtype, + buffer: Optional[IndentedBuffer] = None, + store_value: Optional[Union[str, CppCSEVariable]] = None, + accu_store: bool = False, + ) -> Optional[CppCSEVariable]: + """ + Load or store a vector in a non-contiguous way. The vector is initialized from an array that is + filled in an inner loop over the tiling factor. + :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index + as index expression, i.e. `transformed(index)`. + :param index: index into the `var` or the index expression by its own if `var` is None. + The `index` could contain indirect indexing or the tiling itervar. When used in + the inner loop, the index is transformed as follows: + 1. the index is linearized along the tiling dim. + 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. + :param dtype: data type of `var` or `index` if `var` is None. + :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. + :param store_value: the value to store. If None, we load the vector. + :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided + :return: a CppCSEVariable that represents the loaded vector or None if it is a store. + """ + assert not store_value or var is not None, "store var must be provided" + if accu_store: + assert store_value + if buffer is None: + buffer = self.loads + + def get_result_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.num_elems * (4 // dtype.itemsize) + else: + return self.num_elems + + def get_tiling_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: + assert vec_var.is_vec + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + vec_dtype = vec_var.dtype + assert vec_dtype is not None + if vec_dtype == torch.bool: + vec_dtype = torch.float + result_size = get_result_size(vec_dtype) + tiling_size = get_tiling_size(vec_dtype) + code.writeline( + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;" + ) + line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});" + code.writeline(line) + code.writeline("return tmpbuf;") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + return csevar + + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + result_size = get_result_size(dtype) + tiling_size = get_tiling_size(dtype) + result_declare = ( + f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;" + ) + code.writeline(result_declare) + if store_value: + code.writeline( + f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});" + ) + itervar_inner = sympy_index_symbol( + f"{self.itervars[self.tiling_idx]}_inner" + ) + replacements = {} + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + array_var = vec_to_array(indirect_var) + replacements[indirect_var] = f"{array_var}[{itervar_inner}]" + index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=itervar_inner + ) + load_mask = None + if self._load_mask is not None: + assert not store_value, "unexpected store with load mask" + assert isinstance(self._load_mask, CppCSEVariable), self._load_mask + if self._load_mask.is_vec: + load_mask = f"{self._load_mask}.is_masked({itervar_inner})" + else: + load_mask = f"{self._load_mask} != 0" + if cpp_builder.is_gcc(): + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") + else: + code.writeline(f"#pragma unroll {self.tiling_factor}") + code.writeline( + f"for (long {itervar_inner} = 0; " + + f"{itervar_inner} < {cexpr_index(self.num_elems)}; " + + f"{itervar_inner}++)" + ) + with code.indent(), contextlib.ExitStack() as stack: + index_c = cexpr_index(index) + for indirect_var in replacements: + index_c = re.sub( + r"\b" + f"{indirect_var}" + r"\b", + replacements[indirect_var], + index_c, + ) + rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}" + if load_mask: + code.writeline(f"if ({load_mask})") + stack.enter_context(code.indent()) + if store_value: + conjunction = "+=" if accu_store else "=" + code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];") + else: + code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") + if not store_value: + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] + code.writeline(f"return {load_line};") + code.writeline("()") + if store_value: + code.writeline(";") + buffer.splice(code) + return None + else: + csevar = self.cse.generate(buffer, code, dtype=dtype) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + tiling_var = self.itervars[self.tiling_idx] + stride = self._try_get_const_stride(index, tiling_var) + if stride == 0: + # load scalar and lazily broadcast it on demand + return super().load(name, index) + elif stride == 1: + # load contiguously + line = self._get_vec_load_line(var, index, dtype, self._load_mask) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line, dtype=dtype) # type: ignore[assignment] + else: + csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment] + assert isinstance(csevar, CppCSEVariable) + csevar.update_on_args("load", (self, name, index), {}) + csevar.is_vec = True + return csevar + + def _get_store_line( + self, + value: Union[str, CppCSEVariable], + var: str, + index: sympy.Expr, + dtype: torch.dtype, + accu_store: bool = False, + ): + """ + Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles + both contiguous and non-contiguous store cases. + :param value: Vectorized type templaterized on `dtype`. + :param var: buffer to store into. + :index: index into the `var`. + """ + # when value's type is str (e.g., welford reduction), caller should make sure + # it is a vector + assert isinstance(value, str) or ( + isinstance(value, CppCSEVariable) and value.is_vec + ), value + tiling_var = self.itervars[self.tiling_idx] + var_expr = f"{var} + {cexpr_index(index)}" + stride = self._try_get_const_stride(index, tiling_var) + code = IndentedBuffer() + if stride == 1: + if accu_store: + load = ( + f"{self._get_vec_type(dtype)}::loadu({var_expr})" + if dtype == torch.float and self.tail_size is None + else f"{self._get_vec_type(dtype)}::loadu({var_expr}, {cexpr_index(self.num_elems)})" + ) + value = f"({value} + {load})" + if dtype == torch.float and self.tail_size is None: + code.writeline(f"{value}.store({var_expr});") + else: + code.writeline( + f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});" + ) + else: + self._load_or_store_non_contiguous( + var, index, dtype, buffer=code, store_value=value, accu_store=accu_store + ) + return code + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + var = self.args.output(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + if mode is None: + code = self._get_store_line(value, var, index, dtype) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + code = self._get_store_line( + f"{value}", + var, + index, + dtype, + accu_store=True, + ) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + else: + n_src = self._get_num_vectors(dtype) + n_idx = self._get_num_vectors(torch.int64) + cdtype = DTYPE_TO_CPP[dtype] + index = ops.index_expr(index, torch.int64).value + assert isinstance(index, CppCSEVariable) and index.is_vec + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" + self.stores.writeline(DeferredLine(name, line)) + else: + raise NotImplementedError(f"store mode={mode}") + + def reduction(self, dtype, src_dtype, reduction_type, value): + # Note: For argmax and argmin on bool type, we always convert bool to float. + # Fix issue: https://github.com/pytorch/pytorch/issues/143568 + assert reduction_type in VECTORIZABLE_RTYPES + argmax_or_argmin = reduction_type in ("argmax", "argmin") + horizontal_reduction = self.tiling_idx >= self.reduction_depth + init_dtype = src_dtype if argmax_or_argmin else dtype + assert isinstance(value, CppCSEVariable), value + + if not value.is_vec: + value = self.broadcast(value) + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + vec_ns = "at::vec" + vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" + acc_type = reduction_acc_type(reduction_type, init_dtype) + acc_type_vec = self.reduction_acc_type_vec(reduction_type, init_dtype) + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + assert isinstance(acc, CppCSEVariable) + acc_vec = f"{acc}_vec" + masked_acc_vec = f"masked_{acc_vec}" + self.reduction_var_names += [f"{acc}", acc_vec, masked_acc_vec] + self.is_reduction = True + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc, acc_type, reduction_type, init_dtype, reduction_init + ) + ) + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc_vec, + acc_type_vec, + reduction_type, + init_dtype, + self.reduction_init_vec, + ) + ) + if reduction_type == "welford_reduce": + # use masked acc_vec for tail vec kernel + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + masked_acc_vec, + acc_type_vec, + reduction_type, + dtype, + self.reduction_init_vec, + ) + ) + + # use welford_helper for vec kernel + assert self.reduction_depth is not None + reduction_size = functools.reduce( + operator.mul, self.ranges[self.reduction_depth :] + ) + welford_helper_val = self.welford_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + masked_welford_helper_val = f"masked_{welford_helper_val}" + welford_helper_vec_range = ( + ( + FloorDiv(reduction_size, self.ranges[self.tiling_idx]) + * FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) + if self.tiling_idx >= self.reduction_depth + else reduction_size + ) + if FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) + else sympy.Integer(0) + ) + masked_welford_helper_vec_range = ( + ( + FloorDiv(reduction_size, self.ranges[self.tiling_idx]) + if self.tiling_idx >= self.reduction_depth + else reduction_size + ) + if self.ranges[self.tiling_idx] % self.tiling_factor + else sympy.Integer(0) + ) + self._use_welford_helper( + acc_vec, welford_helper_val, welford_helper_vec_range, dtype + ) + self._use_welford_helper( + masked_acc_vec, + masked_welford_helper_val, + masked_welford_helper_vec_range, + dtype, + ) + + # use masked acc_vec for tail vec kernel + acc_vec_ = masked_acc_vec if self.tail_size else acc_vec + welford_helper_val_ = ( + masked_welford_helper_val if self.tail_size else welford_helper_val + ) + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, welford_helper_val_)};" + ) + else: + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + kwargs = { + "next_value": value, + "index": index, + "horizontal_reduction": horizontal_reduction, + "src_dtype": src_dtype, + } + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, **kwargs)};" + ) + self._gen_parallel_reduction_buffers( + acc_vec, + acc_type_vec, + reduction_type, + init_dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + self._gen_parallel_reduction_buffers( + acc, + acc_type, + reduction_type, + init_dtype, + reduction_combine_fn=reduction_combine, + reduction_init_fn=reduction_init, + ) + if reduction_type == "welford_reduce": + # use masked acc_vec for tail vec kernel + self._gen_parallel_reduction_buffers( + masked_acc_vec, + acc_type_vec, + reduction_type, + dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + tmpvar: Union[str, CSEVariable] + is_bool = dtype == torch.bool + if horizontal_reduction: + # Horizontal reduction + if is_welford_reduction(reduction_type): + assert self._get_num_vectors(dtype) in [ + 1, + 2, + ], "Welford reduction does not support VectorizedN (N>2)" + next_value = f"welford_vec_reduce_all({acc_vec})" + masked_next_value = f"welford_vec_reduce_all({masked_acc_vec})" + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, masked_next_value)};" + ) + elif argmax_or_argmin: + next_value = f"{reduction_type}_vec_reduce_all({acc_vec})" + elif is_bool: + if reduction_type in ( + "any", + "sum", + "max", + ): + next_value = f"!{acc_vec}.all_zero()" + else: + assert reduction_type == "min" + next_value = f"{acc_vec}.all_masked()" + else: + reduce_all_body = ( + "{ return " + + self.reduction_combine_vec(reduction_type, "x", "y") + + "; }" + ) + is_bool = dtype == torch.bool + # we are using at::vec::VecMask for bool + vec_dtype = torch.float if is_bool else dtype + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" + + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};" + ) + tmpvar = acc + else: + tmpvar = acc_vec + if is_welford_reduction(reduction_type): + masked_tmpvar = f"masked_{tmpvar}" + self.reduction_suffix.writeline( + f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};" + ) + + result = reduction_project(reduction_type, tmpvar) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + out_dtype = V.graph.get_dtype(name) + dtype = ( + (out_dtype if out_dtype == torch.double else torch.float) + if out_dtype.is_floating_point + else torch.int64 + ) + out_num_vectors = V.kernel._get_num_vectors(out_dtype) + src_num_vectors = V.kernel._get_num_vectors(dtype) + code = IndentedBuffer() + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + code.writeline( + f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});" + ) + else: + # Vertical reduction + if out_dtype != dtype: + converted_value = ( + f"{DTYPE_TO_CPP[out_dtype].replace('::', '_')}_{value}" + ) + if out_dtype == torch.bool: + convert = f"{value}.template cast()" + else: + if src_num_vectors == out_num_vectors == 1: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + ) + else: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}," + f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})" + ) + code.writeline(f"auto {converted_value} = {convert};") + value = converted_value + code.splice(self._get_store_line(value, var, index, out_dtype)) + self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x))) + + def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: + assert not scalar_var.is_vec + if scalar_var.dtype == torch.bool: + vec_var = self.cse.generate( + self.compute, f"{self._get_mask_type()}::from({scalar_var.name})" + ) + else: + assert scalar_var.dtype is not None + vec_var = self.cse.generate( + self.compute, + f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})", + ) + assert isinstance(vec_var, CppCSEVariable) + vec_var.dtype = scalar_var.dtype + vec_var.dependent_itervars = scalar_var.dependent_itervars + vec_var.is_vec = True + return vec_var + + def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable: + assert not index.is_vec + assert index.dtype is not None + csevar = self.cse.generate( + self.compute, + f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})", + ) + assert isinstance(csevar, CppCSEVariable) + csevar.dtype = index.dtype + csevar.is_vec = True + return csevar + + def reduction_init_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>()" + + if reduction_type in ("argmin", "argmax"): + cdtype = DTYPE_TO_CPP[scalar_type] + acc_type = self.reduction_acc_type_vec(reduction_type, dtype) + if reduction_type == "argmin": + val = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + else: + val = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + return f"{acc_type}({val})" + + if reduction_type == "any": + return f"{self._get_mask_type()}::from(0)" + + scalar_init = reduction_init(reduction_type, dtype) + vec_init = f"{vec_type}({scalar_init})" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "sum") + return f"{self._get_mask_type()}::from({scalar_init})" + return vec_init + + def reduction_acc_type_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>" + if reduction_type in ("argmin", "argmax"): + n_src = self._get_num_vectors(scalar_type) + n_idx = self._get_num_vectors(torch.int64) + if dtype == torch.bool: + return f"IndexValueVec<{DTYPE_TO_CPP[torch.float]}, {n_src}, {n_idx}>" + return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "any", "sum") + return f"{self._get_mask_type()}" + return vec_type + + def _welford_helper_init( + self, welford_helper_val, welford_helper_vec_range, dtype, num_threads=None + ): + vec_num_range_thread = ( + CeilDiv(welford_helper_vec_range, num_threads) + if num_threads + else welford_helper_vec_range + ) + vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) + chunk_size = 4096 + num_chunks = CeilDiv(vec_num_range_thread, chunk_size) + welford_helper_init_line = ( + f"WelfordHelper<{self._get_vec_type(dtype)}, {chunk_size}> {welford_helper_val}" + f"(" + f"{vec_num_range_thread_expr}" + f");" + ) + if isinstance(num_chunks, sympy.Integer) and num_chunks <= 1: + # When the number of chunks <= 1, there is no need to use cascade summation to improve + # reduction accuracy. We can initialize a static WelfordHelper to improve performance. + return f"static {welford_helper_init_line}" + else: + return welford_helper_init_line + + def _use_welford_helper( + self, acc_vec, welford_helper_val, welford_helper_vec_range, dtype + ): + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + self.non_parallel_reduction_prefix.writeline( + self._welford_helper_init( + welford_helper_val, welford_helper_vec_range, dtype + ) + ) + self.local_reduction_init.writeline( + self._welford_helper_init( + welford_helper_val, welford_helper_vec_range, dtype, num_threads + ) + ) + self.non_parallel_reduction_suffix.writeline( + f"{acc_vec} = welford_combine({acc_vec}, &{welford_helper_val});" + ) + self.local_reduction_stores.writeline( + f"{acc_vec}_local = welford_combine({acc_vec}_local, &{welford_helper_val});" + ) + + def reduction_combine_vec( + self, + reduction_type, + var, + next_value, + welford_helper_val=None, + index: Optional[sympy.Symbol] = None, + horizontal_reduction: Optional[bool] = None, + src_dtype: Optional[torch.dtype] = torch.float32, + ): + is_bool = src_dtype == torch.bool + if reduction_type == "max": + if self.tail_size: + return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} | {next_value}" + if is_bool + else f"at::vec::maximum({var}, {next_value})" + ) + elif reduction_type == "min": + if self.tail_size: + return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} & {next_value}" + if is_bool + else f"at::vec::minimum({var}, {next_value})" + ) + elif reduction_type == "sum": + if self.tail_size: + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + elif reduction_type == "prod": + if self.tail_size: + return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} * {next_value}" + elif reduction_type == "xor_sum": + if self.tail_size: + return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} ^ {next_value}" + elif reduction_type == "welford_reduce": + if welford_helper_val: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{welford_helper_val})" + else: + return ( + f"welford_combine({var}, {next_value}, &{welford_helper_val})" + ) + else: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {next_value})" + elif reduction_type == "welford_combine": + if isinstance(next_value, tuple): + # When reading a value from Inductor IR we have a tuple of variable names + mean, m2, weight = next_value + else: + # When combining intermediate accumulators we have a Welford struct + mean, m2, weight = reduction_project(reduction_type, next_value) + if self.tail_size: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + elif reduction_type in ("argmin", "argmax"): + assert src_dtype is not None + cdtype = DTYPE_TO_CPP[src_dtype] + if src_dtype == torch.bool: + cdtype = DTYPE_TO_CPP[torch.float] + n_src = self._get_num_vectors(src_dtype) + n_idx = self._get_num_vectors(torch.int64) + t_extra = "" + arg_extra = "" + if index is not None: + assert horizontal_reduction is not None + t_extra = f", {str(horizontal_reduction).lower()}" + arg_extra = f", {index}" + if self.tail_size: + return ( + f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" + f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})" + ) + else: + return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})" + elif reduction_type == "any": + if isinstance(next_value, CppCSEVariable): + assert next_value.dtype == torch.bool + (next_value,) = unify_mask_base_type(V.kernel.compute, (next_value,)) + return f"{var} | {next_value}" + else: + raise NotImplementedError + + def indirect_assert(self, var, lower, upper, mask=None): + assert isinstance(var, CppCSEVariable) + assert var.dtype is not None + if not var.is_vec: + if isinstance(mask, CppCSEVariable) and mask.is_vec: + mask = f"({mask}).all_masked()" + return super().indirect_assert(var, lower, upper, mask) + lower_scalar = lower + upper_scalar = upper + if lower: + lower = f"{self._get_vec_type(var.dtype)}({lower})" + if upper: + upper = f"{self._get_vec_type(var.dtype)}({upper})" + if lower and upper: + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower_scalar} <= {var} < {upper_scalar}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = f"{lower_scalar} <= {var}" + else: + assert upper + cond = f"{var} < {upper}" + cond_print = f"{var} < {upper_scalar}" + cond = f"{self._get_mask_type(var.dtype)}({cond})" + if mask: + if not mask.is_vec: + mask = f"{self._get_mask_type(var.dtype)}({mask})" + # We need not check when the mask is False + cond = f"({cond}) | ~({mask})" + if self.tail_size: + cond = ( + f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)" + f", ({cond}), {cexpr_index(self.tail_size)})" + ) + cond = f"({cond}).all_masked()" + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def get_to_dtype_expr(self, src, dtype, src_dtype): + assert isinstance(src, CppCSEVariable) + if not src.is_vec: + return super().get_to_dtype_expr(src, dtype, src_dtype) + src_cpp_type = DTYPE_TO_CPP[src_dtype] + src_num_vectors = self._get_num_vectors(src_dtype) + dst_cpp_type = DTYPE_TO_CPP[dtype] + dst_num_vectors = self._get_num_vectors(dtype) + expr = f"({src})" + if src_dtype != torch.bool and dtype == torch.bool: + expr = f"{self._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({src})" + elif src_dtype == torch.bool and dtype != torch.bool: + expr = f"{src}.to<{dst_cpp_type},{dst_num_vectors}>()" + elif src_dtype != dtype: + if src_num_vectors == dst_num_vectors == 1: + expr = f"at::vec::convert<{dst_cpp_type}>({src})" + else: + expr = f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({src})" + return expr + + +class CppTile2DKernel(CppVecKernel): + """ + A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on + the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data + tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the + tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization + logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load + and store are generated into kernel.preloads and kernel.poststores buffers. + + The loop structure looks like below: + for ... + for i_outer ... + for ... + for inner_most ... + // generated by CppTile2DKernel + float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads + float tmp1[16*16]; // into kernel.preloads + for i_inner ... { // the kernel inner loop + vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores + } + at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores + for inner_most ... (tail) + // generated by CppVecKernel + ... + for i_outer ... (tail) + for ... + for ... + // generated by CppKernel + ... + """ + + overrides = CppTile2DOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_indices, + inner_tail_size=None, + outer_tail_size=None, + ): + super().__init__( + args, + num_threads, + tiling_factor, + tiling_indices[1], + inner_tail_size, + ) + self.tiling_indices = tiling_indices + self.inner_tail_size = inner_tail_size + self.outer_tail_size = outer_tail_size + self.inner_num_elems = inner_tail_size if inner_tail_size else tiling_factor + self.outer_num_elems = outer_tail_size if outer_tail_size else tiling_factor + self.inner_is_tiling_idx = True + + def inner_itervar(self): + return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner") + + def need_vec_transpose(self, index): + outer_var = self.itervars[self.outer_idx] + inner_var = self.itervars[self.tiling_idx] + outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) + inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) + return ( + self._load_mask is None # TODO: support transposition with mask + and outer_stride == 1 + and index.has(inner_var) + and not inner_stride.has(inner_var) + and not inner_stride.has(outer_var) + ) + + def gen_transposed_tile_load_store( + self, name, var, index, is_store, store_mode=None + ): + # transposed tile load/store outside the kernel inner loop + dtype = V.graph.get_dtype(name) + factor = self.tiling_factor + src = f"{var} + {cexpr_index(index)}" + dst = "__place_holder__" + ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" + ld_dst = f"{cexpr_index(self.num_elems)}" + if is_store: + src, dst = dst, src + ld_src, ld_dst = ld_dst, ld_src + + need_define = True + if self.inner_is_tiling_idx ^ is_store: + M, N = self.inner_num_elems, self.outer_num_elems + else: + M, N = ( + self.outer_num_elems, + self.inner_num_elems, + ) + atomic_add = "true" if (is_store and (store_mode == "atomic_add")) else "false" + if (isinstance(M, sympy.Expr) and not M.is_number) or ( + isinstance(N, sympy.Expr) and not N.is_number + ): + load_or_store = ( + f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{atomic_add}>" + f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});" + ) + else: + load_or_store = ( + f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)},{atomic_add}>" + f"({src}, {ld_src}, {dst}, {ld_dst});" + ) + if is_store: + tile_var = self.cse.newvar() + elif not self.cse.contains(load_or_store): + tile_var = self.cse.generate(self.preloads, load_or_store, write=False) + else: + need_define = False + tile_var = self.cse.get(load_or_store) + + if need_define: + cpp_dtype = DTYPE_TO_CPP[dtype] + # tiling_factor might be smaller than the alignment of cpp_dtype, such as + # with a vector that only holds 4 elements due to NEON 128-bit vectors and + # cpp_dtype being a 64-bit integer. + alignas = f"alignas(std::max(std::size_t({factor}), alignof({cpp_dtype})))" + define_line = f"{alignas} {cpp_dtype} {tile_var}[{factor}*{factor}];" + self.preloads.writeline(define_line) + + load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) + if is_store: + self.poststores.writeline(DeferredLine(name, load_or_store)) + else: + self.preloads.writeline(load_or_store) + + return tile_var + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + + inner = self.inner_itervar() + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=False + ) + # vector load inside the kernel inner loop + loadbuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + dtype = V.graph.get_dtype(name) + line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line, dtype=dtype) + csevar.update_on_args("load", (self, name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + else: + new_index = self.transform_indexing(index) + return super().load(name, new_index) + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + + var = self.args.output(name) + + inner = self.inner_itervar() + index = self.rename_indexing(index) + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=True, store_mode=mode + ) + # vector store inside the kernel inner loop + storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [ + torch.uint8, + torch.int8, + ]: + line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});" + else: + line = f"{value}.store({storebuf});" + self.stores.writeline(DeferredLine(name, line)) + else: + new_index = self.transform_indexing(index) + super().store(name, new_index, value, mode) + + def codegen_inner_loops(self, code): + inner = self.inner_itervar() + if self.inner_is_tiling_idx: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)" + ) + else: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)" + ) + + def set_ranges(self, group, reduction_group): + vars = super().set_ranges(group, reduction_group) + # do vertical reduction as the tail loop + self.outer_idx, self.tiling_idx = ( + self.tiling_indices + if self.tiling_indices[1] < self.reduction_depth + else reversed(self.tiling_indices) + ) + if self.tiling_idx == self.tiling_indices[0]: + self.tail_size = self.outer_tail_size + self.num_elems = self.outer_num_elems + self.inner_is_tiling_idx = False + else: + self.tail_size = self.inner_tail_size + self.num_elems = self.inner_num_elems + self.inner_is_tiling_idx = True + return vars + + def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: + return self.scale_index_with_offset( + index, + itervar_idx=self.outer_idx, + offset=self.inner_itervar(), + ) + + +def get_loop_body_lowp_fp(_body: LoopBody) -> tuple[Optional[torch.dtype], bool]: + """ + Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes + and if all the nodes can codegen with this data type without converting to float. + Otherwise returns None and True. + """ + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + + _lowp_fp_type: Optional[torch.dtype] = None + _use_fp32 = False + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.op == "placeholder" or _node.target in ( + "get_index", + "index_expr", + ): + continue + + # Fast path if all operations can support bf16/fp16 without converting to fp32 + if _node.target not in [ + "load", + "store", + "abs", + "neg", + "output", + ]: + _use_fp32 = True + + if hasattr(_node, "meta") and _node.meta: + assert OptimizationContext.key in _node.meta + opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] + if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: + _use_fp32 = True + elif _lowp_fp_type is not None: + if _lowp_fp_type != opt_ctx.dtype: + warnings.warn("bf16 and fp16 are mixed in the scheduler node.") + else: + _lowp_fp_type = opt_ctx.dtype + else: + _use_fp32 = True + + return _lowp_fp_type, _use_fp32 + + +class TilingSelect: + """ + Implement the heuristic to select the tiling factors and tiling indices. + In the future, we can implement advanced heuristic in a subclass. + """ + + def __init__(self): + super().__init__() + + def select_tiling( + self, + fn_list, + var_sizes_list, + ) -> tuple[list[int], list[int]]: + # TODO(jgong5): support alternative tiling factors and data types + loop_bodies = _get_loop_body(fn_list) + all_dtypes = _get_dtype_from_loopbodies(loop_bodies) + assert all_dtypes + if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): + return [], [] + dtype = torch.float + _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0] + if _lowp_fp_dtype and all( + (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype) + for loop_body in loop_bodies[1:] + ): + dtype = _lowp_fp_dtype + + tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + tiling_indices = self._select_tiling_indices( + fn_list, var_sizes_list, tiling_factor + ) + + if tiling_indices: + group, reduction_group = max( + var_sizes_list, key=lambda sizes: len(sizes[1]) + ) + call_ranges = tuple(group) + tuple(reduction_group) + + if config.cpp.enable_tiling_heuristics: + + def _try_get_stride( + index, + itervars, + tiling_factor, + tiling_indices, + ): + itervar = itervars[tiling_indices[0]] + stride = stride_at_vec_range(index, itervar, tiling_factor) + return stride if stride.is_number else None + + def _update_negative_op_count( + node_name, non_contig_indexing_op_counter + ): + if node_name not in non_contig_indexing_op_counter: + non_contig_indexing_op_counter[node_name] = 1 + else: + non_contig_indexing_op_counter[node_name] += 1 + + def _is_valid_indices( + itervars, + tiling_indices, + ): + return ( + len(tiling_indices) == 1 + and len(itervars) > 0 + and ( + tiling_indices[0] + if tiling_indices[0] >= 0 + else tiling_indices[0] + len(itervars) + ) + < len(itervars) + ) + + itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(call_ranges)) + ] + reduction_depth = len(group) + vars, reduction_vars = ( + itervars[:reduction_depth], + itervars[reduction_depth:], + ) + op_counter: dict[str, int] = {} + # ops may cause overhead with vectorization, like non-contiguous + # index_expr, load, store + non_contig_indexing_op_counter: dict[str, int] = {} + for _body in loop_bodies: + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.target in ["index_expr", "load", "store"]: + # get the index and replace prefix from z to x + arg_idx = 1 if _node.target == "index_expr" else 2 + index = sub_block.body.indexing_from_args( + (vars, reduction_vars) + )[_node.args[arg_idx].args[0]] + if _is_valid_indices(itervars, tiling_indices): + stride = _try_get_stride( + index, itervars, tiling_factor, tiling_indices + ) + if ( + stride is None + if _node.target == "index_expr" + else stride not in [0, 1] + ): + _update_negative_op_count( + _node.target, non_contig_indexing_op_counter + ) + if isinstance(_node.target, str) and not ( + _node.target.startswith("masked_subblock") + or _node.target + in ["ops", "output", "constant", "get_index"] + ): + if _node.target not in op_counter: + op_counter[_node.target] = 1 + else: + op_counter[_node.target] += 1 + + op_num = sum(op_counter.values()) + non_contig_indexing_op_num = sum( + non_contig_indexing_op_counter.values() + ) + ratio_threshold = 0.12 + quantity_threshold = 35 + if non_contig_indexing_op_num >= quantity_threshold or ( + op_num > 0 + and non_contig_indexing_op_num / op_num >= ratio_threshold + ): + # Too many non-contiguous load/store/index_expr which hurts the + # vectorization performance. Disable vectorization when exceeding + # the thresholds. + return [], [] + + if ( + not reduction_group + and group + and len(tiling_indices) == 1 + and not has_free_symbols( + [ + group[tiling_indices[0]], + ] + ) + and group[tiling_indices[0]] < tiling_factor / 4 + and op_num < 10 + ): + # We found that when the number of elements in the inner loop range is + # relatively small(< tiling_factor / 4) and the number of operations is + # not large(< 10), vectorization is not efficient. + # And found that `#pragma GCC ivdep` has better performance than + # `#pragma omp simd simdlen(8)` for these cases. + return [], [] + + if dtype in DTYPE_LOWP_FP: + # For lower precision data type, if the call_range is not long enough, + # use tiling_factor // 2 for better performance + factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + for tiling_indice in tiling_indices: + if tiling_indice < 0: + tiling_indice = tiling_indice + len(call_ranges) + if tiling_indice < 0 or tiling_indice >= len(call_ranges): + continue + if has_free_symbols(call_ranges): + call_range = V.graph.sizevars.size_hint( + call_ranges[tiling_indice], fallback=0 + ) + if call_range < factor_lowp: + V.graph.sizevars.guard_lt(call_range, factor_lowp) # type: ignore[arg-type] + tiling_factor = factor_lowp // 2 + break + elif call_ranges[tiling_indice] < factor_lowp: + tiling_factor = factor_lowp // 2 + break + + if len(tiling_indices) == 1: + return [tiling_factor], tiling_indices + if len(tiling_indices) == 2: + return [tiling_factor, tiling_factor], tiling_indices + return [], [] + + def _select_tiling_indices( + self, + fn_list, + var_sizes_list, + tiling_factor, + ): + all_index = [] + for fn, var_sizes in zip(fn_list, var_sizes_list): + rw = dependencies.extract_read_writes(fn, *var_sizes) + all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] + contig_vars = OrderedSet[int]() + contig_vars_list = [] + non_contig_stride_const = OrderedSet[int]() + non_contig_stride_other = OrderedSet[int]() + for index in all_index: + for var in index.free_symbols: + if not re.search(r"^d\d+$", var.name): + continue + stride = stride_at_vec_range(index, var, tiling_factor) + if stride == 0: + continue + elif stride == 1: + contig_vars.add(int(var.name[1:])) + contig_vars_list.append(int(var.name[1:])) + elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols): + non_contig_stride_const.add(int(var.name[1:])) + else: + non_contig_stride_other.add(int(var.name[1:])) + contig_only = contig_vars - non_contig_stride_const - non_contig_stride_other + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + num_itervars = len(group) + len(reduction_group) + if len(contig_vars) == 0: + # no contiguous vars + return [num_itervars - 1] + if contig_only: + return sorted(contig_only)[-1:] + contig_and_const_stride = ( + contig_vars & non_contig_stride_const + ) - non_contig_stride_other + contig_vars_sorted = sorted(contig_vars) + if ( + len(contig_vars_sorted) == 2 + and contig_vars_sorted[-1] in contig_and_const_stride + and contig_vars_sorted[-1] == num_itervars - 1 + ): + return contig_vars_sorted + return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:] + + +class CppKernelProxy(CppKernel): + # Subclass CppKernel, CppVecKernel, etc., to customize code generation. + # Override CppOverrides or CppVecOverrides to emit custom ops. + # Earlier, this meant copying codegen_functions() to use your subclasses. + # Now, use kernel_cls and vec_kernel_cls class attributes instead. + # This lets CppKernelProxy subclasses inject custom behavior cleanly. + # No need to duplicate codegen_functions() just to swap kernel classes. + kernel_cls: type[CppKernel] = CppKernel + vec_kernel_cls: type[CppVecKernel] = CppVecKernel + tile2d_kernel_cls: type[CppTile2DKernel] = CppTile2DKernel + + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.kernel_group = kernel_group + self.loop_nest = None + self.call_ranges = None + self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + self.kernels: list[CppKernel] = [] + + def data_type_propagation(self, nodes): + for _node in nodes: + assert isinstance(_node, SchedulerNode) + DataTypePropagation.propagate_scheduler_node(_node) + + # Check if all the nodes of a given fx graph can support BF16/FP16 + def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): + if not isinstance(scheduler_node._body, LoopBody): + return True + # Propagate the dtype to check if all the fx node is bf16/fp16 + DataTypePropagation.propagate_scheduler_node(scheduler_node) + return ( + get_loop_body_lowp_fp(scheduler_node._body)[0] is not None + and not get_loop_body_lowp_fp(scheduler_node._body)[1] + ) + + def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody): + def add_to_dtype(sub_graph: torch.fx.Graph): + def get_input_dtype(node: torch.fx.Node) -> Optional[torch.dtype]: + """Get input dtype for nodes that may consumes lowp fp dt""" + if node.target == "store": + return V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + elif node.target == "to_dtype_bitcast": + return node.args[-1] # type: ignore[return-value] + elif node.target == "to_dtype": + if len(node.args) > 3: + return node.args[3] # type: ignore[return-value] + else: + return node.kwargs.get("src_dtype", None) # type: ignore[return-value] + else: + return None + + def get_output_dtype(node: torch.fx.Node) -> Optional[torch.dtype]: + """Get output dtype for nodes that may produce lowp fp dt""" + if node.target == "load": + assert len(node.args) == 3 + return V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + elif node.target in ["to_dtype", "constant", "index_expr"]: + return node.args[-1] # type: ignore[return-value] + elif node.target == "to_dtype_bitcast": + return node.args[2] # type: ignore[return-value] + else: + return None + + def is_lowp_fp_source(node: torch.fx.Node, dt: torch.dtype): + """Check if the given node produces output with expected low precision floating point data type.""" + assert dt in DTYPE_LOWP_FP + return get_output_dtype(node) == dt + + def is_lowp_fp_sink(node: torch.fx.Node, dt: torch.dtype): + """Check if the given node accept input with expected low precision floating point data type.""" + assert dt in DTYPE_LOWP_FP + if input_dtype := get_input_dtype(node): + return input_dtype == dt + elif node.target == "to_dtype": + # The `src_dtype` of a `to_dtype` node might miss, in which case the node accept any input dtype. + return True + else: + return False + + def is_lowp_fp_source_no_promote(node: torch.fx.Node, dt: torch.dtype): + """Check if the node is a lowp fp sources which are all directly fed to ops that accepts lowp fp input + thus no need to promote to float + """ + return is_lowp_fp_source(node, dt) and all( + is_lowp_fp_sink(user, dt) for user in node.users + ) + + sub_graph_nodes = list(sub_graph.nodes) + to_lowp_fp_legalized_nodes = [] + for _node in sub_graph_nodes: + if ( + _node.target in ["load", "index_expr"] + and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP + ): + # No need to promote to float if all users are ops that accepts lowp fp input + if all(is_lowp_fp_sink(user, dt) for user in _node.users): + continue + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + _node.replace_all_uses_with( + to_type_node, lambda n: n is not to_type_node + ) + metrics.cpp_to_dtype_count += 1 + elif ( + _node.target == "store" + and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP + ): + ops, name, _, value_var, _ = _node.args + if is_lowp_fp_source_no_promote(value_var, dt): + continue + dtype = V.graph.get_dtype(name) + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, dtype) + ) + _node.replace_input_with(value_var, to_type_node) + metrics.cpp_to_dtype_count += 1 + elif _node.target == "reduction": + ( + ops, + dtype, + src_dtype, + reduction_type, + value, + ) = _node.args + if src_dtype in DTYPE_LOWP_FP: + # Since we always convert the load/store value to float if the tensor is bfloat16/float16. + # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update + # the bfloat16/float16 reduction by + # 1) updating the src_dtype to float + # and 2) updating the dtype to float if it is bfloat16/float16. + assert dtype in [ + torch.float, + torch.bfloat16, + torch.float16, + torch.int64, + ] + _node.args = ( + ops, + torch.float if dtype in DTYPE_LOWP_FP else dtype, + torch.float, + reduction_type, + value, + ) + elif _node.target == "constant" and _node.args[-1] in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + (ops, value, dt) = _node.args + if all(is_lowp_fp_sink(user, dt) for user in _node.users): # type: ignore[arg-type] + continue + _node.args = (ops, value, torch.float) + elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + (ops, x, dt) = _node.args + if all(is_lowp_fp_sink(user, dt) for user in _node.users): # type: ignore[arg-type] + continue + # The legalization always loads the BF16/FP16 tensor as FP32 for computation + # and converts back to BF16/FP16 after the computation. + # Hence, there should be no computation w/ BF16/FP16. + # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32. + # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step): + # 1) Eliminate the redundant to_dtype node if we have a pattern as follows: + # graph(): + # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float)) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16)) + # Regarding the first to_dtype, it is redundant because + # the second to_type also converts to the torch.bfloat16/torch.float16. + # Hence, we remove the first to_type. + to_lowp_fp_legalized_nodes.append(_node) + _node.args = (ops, x, torch.float) + elif _node.target == "to_dtype_bitcast": + (ops, value_var, dtype, src_dtype) = _node.args + + # to_dtype_bitcast act as a lowp fp sink: + # c10::bit_cast requires the source and target have the same bitwidth. Because the input tensor's + # dtype could be promoted, e.g. from float16 to float, we have to cast the tensor to its original + # source dtype before invoking bit_cast. + if src_dtype in DTYPE_LOWP_FP: + # No need to promote to float if it is a user of a lowp fp sources + # which are all directly fed to ops that accepts lowp fp input + if not is_lowp_fp_source_no_promote(value_var, src_dtype): + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, src_dtype) + ) + _node.replace_input_with(value_var, to_type_node) + metrics.cpp_to_dtype_count += 1 + + # to_dtype_bitcast act as a lowp fp source: + # We also need to convert the bit-casted tensor back to float to make sure we keep using higher + # precision values for the rest of the computation. + if dtype in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + if not ( + all(is_lowp_fp_sink(user, dtype) for user in _node.users) + ): + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + _node.replace_all_uses_with( + to_type_node, lambda n: n is not to_type_node + ) + metrics.cpp_to_dtype_count += 1 + else: + pass + + def eliminate_to_dtype(sub_graph: torch.fx.Graph): + def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph): + # Eliminate the redundant to_dtype node. Let's consider a pattern as follows: + # graph(): + # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {}) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {}) + # Regarding the first to_dtype, it is redundant because the second to_type also converts to the + # torch.float. Hence, we remove the first to_type + def _used_by_to(to_node: torch.fx.Node): + return all(usr.target == "to_dtype" for usr in to_node.users) + + all_to_nodes = [ + node for node in sub_graph.nodes if node.target == "to_dtype" + ] + all_to_nodes_and_users = [ + {node: node.users} for node in all_to_nodes if _used_by_to(node) + ] + for node_users in all_to_nodes_and_users: + for node, users in node_users.items(): + if node in sub_graph.nodes and ( + all(usr.args[-1] == node.args[-1] for usr in users) + or ( + node in to_lowp_fp_legalized_nodes + and all( + usr.args[-1] in DTYPE_LOWP_FP for usr in users + ) + ) + ): + val_node = node.all_input_nodes[-1] + node.replace_all_uses_with(val_node) + sub_graph.erase_node(node) + + # For debug mode, the graph of LoopBody will attach a new GraphModule as + # owning_module for debugging while the release mode will not. The lint will + # check whether the graph has owning_module to decide if it needs to check + # call_module. LoopBody might contain get_index as a module call. But it + # is just a function. Hence, it cannot pass the lint check for debug mode. + # We bypass the check if the owning_module is None. Eventually, we should call + # get_index via call_function but not call_module. + if sub_graph.owning_module is None: + sub_graph.lint() + + _eliminate_duplicate_to_node(sub_graph) + + eliminate_to_dtype(sub_graph) + + sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values()) + for sub_block in sub_blocks: + add_to_dtype(sub_block.graph) + + def legalize_lowp_fp_dtype(self, nodes): + if all( + isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node) + for _node in nodes + ): + # Mark the load node to load bf16/fp16 + for _node in nodes: + sub_blocks = [_node._body.root_block] + list( + _node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for fx_node in sub_block.graph.nodes: + if fx_node.target in ["load", "store"]: + assert fx_node.meta + assert OptimizationContext.key in fx_node.meta + opt_ctx: OptimizationContext = fx_node.meta[ + OptimizationContext.key + ] + assert opt_ctx.dtype in DTYPE_LOWP_FP + + # Bypass the legalization as the kernel can run with bf16/fp16 directly + return + + for _node in nodes: + assert isinstance(_node, SchedulerNode) + assert isinstance(_node._body, LoopBody) + body: LoopBody = _node._body + if not body.is_memory_copy(): + self.legalize_lowp_fp_dtype_loopbody(body) + + def codegen_functions(self, fn_list, var_sizes_list): + assert len(fn_list) == len(var_sizes_list) + kernel_group = self.kernel_group + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + + self.set_ranges(group, reduction_group) + + def codegen_kernel(cls, *args): + with kernel_group.new_kernel(cls, *args) as kernel: + # Ugly hack to maintain the metrics kernel count since + # we only count in CppKernelProxy, not those contained in it + metrics.generated_kernel_count -= 1 + + run(kernel) + return kernel + + def run(kernel): + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + in_suffix = False + for fn, var_sizes in zip(fn_list, var_sizes_list): + if var_sizes in [ + (group, reduction_group), + (tuple(itertools.chain(group, reduction_group)), ()), + ]: + assert not in_suffix + fn(vars, reduction_vars) + else: + in_suffix = True + assert var_sizes == ( + group, + (), + ), f"unexpected group: {var_sizes} != {group}, {reduction_group}" + # we can fuse in some extra pointwise into the suffix + with kernel.write_to_suffix(): + fn(vars, ()) + + scalar_kernel = codegen_kernel(self.kernel_cls) + V.graph.removed_buffers |= scalar_kernel.removed_buffers + V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove + self.loop_nest = LoopNest.build(scalar_kernel) + + if not self.picked_vec_isa or not self.itervars: + self.kernels = [scalar_kernel] + self.aggregate_reduction_buffers(False, None) + self.loop_nest.set_kernel(self) + return + + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + # But the generated scalar kernel has updated these global contexts. Hence, the other kernels + # should not do this again to avoid context conflict. By now, we only control the + # config.inplace_buffers. In the future, we could maintain more contexts. + with torch._inductor.config.patch(inplace_buffers=False): + tiling_select = TilingSelect() + tiling_factors, tiling_indices = tiling_select.select_tiling( + fn_list, var_sizes_list + ) + assert len(tiling_factors) == len(tiling_indices) + # This should be removed after full support for vectorization is implemented. + could_masked_vec = True + all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list)) + if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): + # can be removed after masked vectorizable dtype are same with vectorizable dtype + could_masked_vec = False + + _inner_loop_reduction_outer_not = False + _outer_loop = None + if tiling_indices: + inner_loop_reduction = False + outer_loop_level = tiling_indices[0] + inner_loop_level = outer_loop_level + 1 + if len(self.loop_nest.loops) > inner_loop_level: + inner_loop_reduction = self.loop_nest.loops[ + inner_loop_level + ].is_reduction + outer_loop_reduction = self.loop_nest.loops[ + outer_loop_level + ].is_reduction + _inner_loop_reduction_outer_not = ( + inner_loop_reduction and not outer_loop_reduction + ) + + if len(tiling_indices) == 1: + metrics.generated_cpp_vec_kernel_count += 1 + loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0]) + vec_kernel = codegen_kernel( + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0] + ) + tail_size = loop.size - loop.tiled_size + vec_kernel.active_ranges = {loop.var: (0, loop.tiled_size)} + if config.cpp.enable_loop_tail_vec and could_masked_vec: + tail_kernel = codegen_kernel( + self.vec_kernel_cls, + tiling_factors[0], + tiling_indices[0], + tail_size, + ) + else: + tail_kernel = scalar_kernel + scalar_kernel.inner_itervars = [loop.var] + tail_kernel.active_ranges = {loop.var: (loop.tiled_size, loop.size)} + self.kernels = [vec_kernel, tail_kernel] + _outer_loop = loop + elif len(tiling_indices) == 2: + assert ( + tiling_indices[1] == len(self.itervars) - 1 + and tiling_factors[0] == tiling_factors[1] + ) + + metrics.generated_cpp_vec_kernel_count += 2 + outer_loop = self.loop_nest.tile( + tiling_indices[0], factor=tiling_factors[0] + ) + outer_ranges = { + "main": (0, outer_loop.tiled_size), + "tail": (outer_loop.tiled_size, outer_loop.size), + } + outer_tail_size = outer_loop.size - outer_loop.tiled_size + inner_loop = self.loop_nest.tile( + tiling_indices[1], factor=tiling_factors[0] + ) + inner_ranges = { + "main": (0, inner_loop.tiled_size), + "tail": (inner_loop.tiled_size, inner_loop.size), + } + inner_tail_size = inner_loop.size - inner_loop.tiled_size + tile2d_kernel = codegen_kernel( + self.tile2d_kernel_cls, + tiling_factors[0], + tiling_indices, + ) + tile2d_kernel.active_ranges = { + outer_loop.var: outer_ranges["main"], + inner_loop.var: inner_ranges["main"], + } + tail_kernel = [] + if config.cpp.enable_loop_tail_vec and could_masked_vec: + for outer_r, inner_r in ( + ("main", "tail"), + ("tail", "main"), + ("tail", "tail"), + ): + _inner_tail_size = ( + inner_tail_size if inner_r == "tail" else None + ) + _outer_tail_size = ( + outer_tail_size if outer_r == "tail" else None + ) + kernel = codegen_kernel( + self.tile2d_kernel_cls, + tiling_factors[0], + tiling_indices, + _inner_tail_size, + _outer_tail_size, + ) + kernel.active_ranges = { + outer_loop.var: outer_ranges[outer_r], + inner_loop.var: inner_ranges[inner_r], + } + tail_kernel.append(kernel) + else: + vec_kernel = codegen_kernel( + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0] + ) + vec_kernel.active_ranges = { + outer_loop.var: outer_ranges["main"], + inner_loop.var: inner_ranges["tail"], + } + vec_kernel.inner_itervars = [inner_loop.var] + tail_kernel.append(vec_kernel) + scalar_kernel.active_ranges = { + outer_loop.var: outer_ranges["tail"], + inner_loop.var: (0, inner_loop.size), + } + scalar_kernel.inner_itervars = [inner_loop.var, outer_loop.var] + tail_kernel.append(scalar_kernel) + self.kernels = [tile2d_kernel] + tail_kernel + _outer_loop = outer_loop + else: + self.kernels = [scalar_kernel] + self.aggregate_reduction_buffers( + _inner_loop_reduction_outer_not, _outer_loop + ) + self.loop_nest.set_kernel(self) + + def codegen_loop_bodies(self, loop_bodies, var_sizes_list): + for body in loop_bodies: + self.legalize_lowp_fp_dtype_loopbody(body) + DataTypePropagation.propagate_loopbody(body) + self.codegen_functions(loop_bodies, var_sizes_list) + + def codegen_nodes(self, nodes: list[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + assert len(nodes) >= 1 + + def fn(node, *index_vars): + node.decide_inplace_update() + node.mark_run() + if isinstance(V.kernel, NullKernelHandler): + return node._body(*index_vars) + else: + return node.codegen(index_vars) + + fn_list = [functools.partial(fn, node) for node in nodes] + + if ( + isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + + def wrap_fn(fn): + wrapped_fn = V.local_buffer_context.localize_function( + fn, + ) + wrapped_fn.original_fn = fn + return wrapped_fn + + fn_list = [wrap_fn(fn) for fn in fn_list] + + var_sizes_list = [node.group[1] for node in nodes] + self.codegen_functions(fn_list, var_sizes_list) + + def codegen_loops(self, code, worksharing): + self.codegen_loops_impl(self.loop_nest, code, worksharing) + + def update_stores_with_parallel_reduction(self): + for kernel in self.kernels: + kernel.update_stores_with_parallel_reduction() + + def gen_body(self, code: Optional[BracesBuffer] = None): + assert code is not None + if_prefix = "C10_LIKELY" + for kernel in self.kernels: + with contextlib.ExitStack() as stack: + if kernel.codegen_conditions(code, if_prefix): + if_prefix = "C10_UNLIKELY" + stack.enter_context(code.indent()) + code.splice(kernel.gen_body()) + + def aggregate_reduction_buffers( + self, inner_loop_reduction_outer_not: bool, outer_loop: Optional["LoopLevel"] + ): + # CppKernel/CppVecKernel/CppTile2dKernel have reduction buffers themselves. + # Here, we decide how to aggregate them together and place new reduction buffers + # under CppKernelProxy. + def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): + assert len(self.kernels) >= 2 + main_loop_kernel = self.kernels[0] + tail_loop_kernel = self.kernels[-1] + assert isinstance(main_loop_kernel, self.vec_kernel_cls) + + # Prefix + if type(tail_loop_kernel) == self.kernel_cls: + # if tail loop kernel is a scalar kernel, we need to extend tmp_acc -> tmp_acc_arr[] to + # hold the temporary inner loop acc result for outer tail loop + tail_loop_kernel.finalize_reduction_prefix( + main_loop_kernel.tiling_factor + ) + main_loop_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice( + tail_loop_kernel.reduction_prefix + + main_loop_kernel.reduction_prefix + ) + else: + main_loop_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice(main_loop_kernel.reduction_prefix) + + # Suffix + suffix_buf = BracesBuffer() + with contextlib.ExitStack() as stack: + if main_loop_kernel.codegen_conditions( + suffix_buf, "C10_LIKELY", outer_loop.var + ): + stack.enter_context(suffix_buf.indent()) + suffix_buf.splice(main_loop_kernel.reduction_suffix) + with contextlib.ExitStack() as stack: + if tail_loop_kernel.codegen_conditions( + suffix_buf, "C10_UNLIKELY", outer_loop.var + ): + stack.enter_context(suffix_buf.indent()) + if type(tail_loop_kernel) == self.kernel_cls: + reduction_vars = tail_loop_kernel.reduction_var_names + for name in reduction_vars: + new_name = f"{name}_arr[{outer_loop.var}_tail - {cexpr_index(outer_loop.tiled_size)}]" + replace_acc_name(tail_loop_kernel.stores, name, new_name) + replace_acc_name( + tail_loop_kernel.reduction_suffix, name, new_name + ) + suffix_buf.splice( + move_code_under_inner_loop( + tail_loop_kernel.reduction_suffix, + outer_loop.var, + f"{outer_loop.var}_tail", + outer_loop.tiled_size, + outer_loop.size, + ) + ) + else: + suffix_buf.splice(tail_loop_kernel.reduction_suffix) + self.reduction_suffix = suffix_buf + + main_kernel = self.kernels[0] + if inner_loop_reduction_outer_not: + assert outer_loop + aggregate_reduction_prefix_suffix(outer_loop) + else: + main_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice(main_kernel.reduction_prefix) + self.reduction_suffix.splice(main_kernel.reduction_suffix) + self.parallel_reduction_prefix.splice(main_kernel.parallel_reduction_prefix) + self.parallel_reduction_suffix.splice(main_kernel.parallel_reduction_suffix) + self.local_reduction_init.splice(main_kernel.local_reduction_init) + self.local_reduction_stores.splice(main_kernel.local_reduction_stores) + self.non_parallel_reduction_prefix.splice( + main_kernel.non_parallel_reduction_prefix + ) + self.non_parallel_reduction_suffix.splice( + main_kernel.non_parallel_reduction_suffix + ) + + +class OuterLoopFusedKernel(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.inner: list[LoopNest] = [] + + def decide_parallel_depth(self, max_parallel_depth, threads): + kernels_parallel_depth = [] + nested_kernels: list[CppKernel] = [ + loop_nest.get_kernel() for loop_nest in self.inner + ] + # TODO(leslie-fang-intel): only enable parallel within all outer loop levels. + for kernel in nested_kernels: + # For any ScalarKernel, VecKernel, or Tile2DKernel, + # they should all have the same call_ranges + call_ranges = kernel.call_ranges + assert call_ranges is not None + kernels_parallel_depth.append( + kernel.decide_parallel_depth( + ParallelDepth( + parallel_depth=( + len(call_ranges) - max_parallel_depth.start_depth + ), + start_depth=max_parallel_depth.start_depth, + ), + threads, + ).parallel_depth + ) + return ParallelDepth( + parallel_depth=min( + max_parallel_depth.parallel_depth, max(kernels_parallel_depth) + ), + start_depth=max_parallel_depth.start_depth, + ) + + +class ReasonFusedNodes(Enum): + SAME_VARS_REDUCE = "same_vars_reduce" + COMPATIBLE_REDUCTION = "compatible_reduction" + COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction" + + +class CppScheduling(BaseScheduling): + # Subclass CppKernelProxy to customize codegen without copying codegen_node(). + # Use kernel_proxy_cls to inject custom proxies in CppScheduling subclasses. + # Avoid duplicating codegen_node() just to swap in a custom kernel proxy class. + kernel_proxy_cls: type[CppKernelProxy] = CppKernelProxy + # ctypes limits the number of args to 1024, refer to: + # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 + # We set a conservative threshold here. + MAX_FUSED_KERNEL_ARGS_NUM = 500 + backend_features = OrderedSet( + [ + BackendFeature.INPLACE_BUFFERS, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + return cls.backend_features + + def __init__(self, scheduler): + super().__init__(scheduler) + if scheduler: + self.reset_kernel_group() + self._ready_to_flush = False + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def reset_kernel_group(self): + self.kernel_group = KernelGroup() + + def fuse(self, node1, node2): + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + elif node1.is_template(): + assert not node2.is_template() + return FusedSchedulerNode.fuse(node1, node2) + else: + if ( + self._why_fuse_nodes(node1, node2) + == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + ): + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + assert reduce1 == () and reduce2 == (), (reduce1, reduce2) + + def get_indexing_ranges_exprs(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0, node.snodes + var_ranges = None + indexing_exprs = OrderedSet[Any]() + for snode in node.snodes: + v, exprs = get_indexing_ranges_exprs(snode) + if var_ranges is None: + var_ranges = v + assert var_ranges == v, (var_ranges, v, node.snodes) + indexing_exprs.update(exprs) + return var_ranges, list(indexing_exprs) + else: + assert isinstance(node, SchedulerNode) + comp_buffer = node.node + assert isinstance(comp_buffer, ir.ComputedBuffer) + _, body, _ = comp_buffer.get_default_sizes_body() + return body.var_ranges, list(body.indexing_exprs.values()) + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + assert isinstance(node_to_recomp, SchedulerNode) + + ref_node = node2 if len(vars1) < len(vars2) else node1 + + ref_indexing_constraints = get_indexing_ranges_exprs(ref_node) + + node_to_recomp.recompute_size_and_body( + extra_indexing_constraints=ref_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + + if vars1 == vars2: + return FusedSchedulerNode.fuse(node1, node2) + + # recompute ref_node if its ranges are also changed + node_to_recomp_indexing_constraints = get_indexing_ranges_exprs( + node_to_recomp + ) + if isinstance(ref_node, SchedulerNode): + ref_node.recompute_size_and_body( + extra_indexing_constraints=node_to_recomp_indexing_constraints + ) + else: + assert isinstance(ref_node, FusedSchedulerNode) + for snode in ref_node.snodes: + assert isinstance(snode, SchedulerNode) + snode.recompute_size_and_body( + extra_indexing_constraints=node_to_recomp_indexing_constraints + ) + ref_node = FusedSchedulerNode(ref_node.scheduler, ref_node.snodes) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + assert vars1 == vars2, (vars1, vars2) + return FusedSchedulerNode.fuse(node1, node2) + elif self.can_fuse_vertical_outer_loop(node1, node2): + return OuterLoopFusedSchedulerNode.fuse( + node1, node2, self._get_outer_loop_fusion_depth(node1, node2) + ) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]: + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + if vars1 == vars2 and reduce1 == reduce2: + return ReasonFusedNodes.SAME_VARS_REDUCE + if reduce1 == () and vars1 == vars2 + reduce2: + return ReasonFusedNodes.COMPATIBLE_REDUCTION + if self._can_fuse_nodes_with_compatible_ranges(node1, node2): + return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? + return None + + def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): + # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges + # e.g. (s0, s1, s2) and (s0 * s1 * s2) + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + c1 = reduce1 == () and reduce2 == () + c2 = math.prod(vars1) == math.prod(vars2) + c3 = len(vars1) == 1 or len(vars2) == 1 + if not (c1 and c2 and c3): + return False + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + ref_node = node2 if len(vars1) < len(vars2) else node1 + + # We can not recompute sizes and body for nodes other than SchedulerNode + # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode + if isinstance(node_to_recomp, FusedSchedulerNode): + return False + + # It may happen that node1 and node2 compatible number of elements + # but different original ranges, for example: + # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2} + # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details + # TODO: we can fix if it allows us to CSE at least one of the variables + + assert isinstance(node_to_recomp, SchedulerNode) + if isinstance(node_to_recomp.node, ir.TemplateBuffer): + return False + assert isinstance(node_to_recomp.node, ir.ComputedBuffer) + # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges + # but without variable name + ranges2 = node_to_recomp.node.data.get_size() + ranges1 = None + if isinstance(ref_node, FusedSchedulerNode): + ranges_set = OrderedSet[tuple[Any, ...]]() + for snode in ref_node.snodes: + if isinstance(snode.node, ir.TemplateBuffer): + break + assert isinstance(snode.node, ir.ComputedBuffer) + ranges_set.add(tuple(snode.node.data.get_size())) + + if len(ranges_set) != 1: + return False + + ranges1 = list(next(iter(ranges_set))) + else: + assert isinstance(ref_node, SchedulerNode) + assert isinstance(ref_node.node, ir.ComputedBuffer) + ranges1 = ref_node.node.data.get_size() # type: ignore[assignment] + + if ranges1 != ranges2: + return False + + return True + + def _can_fuse_horizontal_impl(self, node1, node2): + assert isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + assert isinstance(node2, (FusedSchedulerNode, SchedulerNode)) + if any( + isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2) + ): + return False + return self._why_fuse_nodes(node1, node2) is not None + + def can_fuse_horizontal(self, node1, node2): + if node1.is_template() or node2.is_template(): + return False + if ( + len(node1.get_nodes()) + len(node2.get_nodes()) + > config.cpp.max_horizontal_fusion_size + ): + return False + + return self._can_fuse_horizontal_impl(node1, node2) + + def can_fuse_multi_outputs_template( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if template_buf := node1.get_template_node(): + return ( + isinstance(template_buf.layout, ir.MultiOutputLayout) + and isinstance(node2.node, ir.MultiOutput) + and len(node2.node.inputs) == 1 + and node2.node.inputs[0].get_name() == template_buf.name + ) + return False + + def _get_outer_loop_fusion_depth(self, node1, node2): + DISABLE_OUTER_LOOP_FUSION = 0 + if not all( + type(node) + in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode) + for node in (node1, node2) + ): + return DISABLE_OUTER_LOOP_FUSION + + _node1 = ( + node1.get_outer_nodes()[-1] + if isinstance(node1, OuterLoopFusedSchedulerNode) + else node1 + ) + assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode)) + _node2 = ( + node2.get_outer_nodes()[0] + if isinstance(node2, OuterLoopFusedSchedulerNode) + else node2 + ) + assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode)) + + _, (vars1, reduce1) = _node1.group + _, (vars2, reduce2) = _node2.group + if vars1 == () and vars2 == () and reduce1 != () and reduce2 != (): + # Reduction only + return DISABLE_OUTER_LOOP_FUSION + if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return ( + node1.outer_loop_fusion_depth + if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth + else DISABLE_OUTER_LOOP_FUSION + ) + outer_loop_fusion_depth = min(len(vars1), len(vars2)) + if ( + outer_loop_fusion_depth >= 1 + and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth] + ): + if any( + type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2) + ): + _compare_node = ( + node1 if type(node1) is OuterLoopFusedSchedulerNode else node2 + ) + if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth: + # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + else: + return DISABLE_OUTER_LOOP_FUSION + else: + # First 2 nodes to generate OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + return DISABLE_OUTER_LOOP_FUSION + + def can_fuse_vertical_outer_loop(self, node1, node2): + return ( + not node1.is_template() + and not node2.is_template() + and node1.get_operation_names() & node2.ancestors + and not ( + self._can_fuse_horizontal_impl(node1, node2) + and not node1.is_reduction() + ) + and self._get_outer_loop_fusion_depth(node1, node2) >= 1 + ) + + def get_fusion_pair_priority(self, node1, node2): + if self.can_fuse_vertical_outer_loop(node1, node2): + # Outer loop fusion with lower priority + return 1 + else: + return 0 + + def can_fuse_vertical(self, node1, node2): + if node2.is_template(): + # TODO(jgong5): support pre-op fusion with template + return False + if node1.is_template(): + template_fusion_supported, _ = template_fusion_with_epilogues_supported( + node1, [node2] + ) + return not node2.is_reduction() and template_fusion_supported + return ( + self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() + ) or self.can_fuse_vertical_outer_loop(node1, node2) + + def try_loop_split(self, nodes: list[SchedulerNode]): + """ + Apply loop split optimization. + When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop + to avoid non-contiguous loads, subject to the following conditions: + 1. No reduction and no mudular index for all nodes. + 2. The indexing_exprs of all nodes contain only one (or more, but all the same) division, + where the divisor is an integer and not too small (the divisor > 8), the dividend is + one of the iter_vars, and this var, i.e. the dimension that needs to be split, is + contiguous in all other indexing_exprs. + + For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: + {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, + we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to + {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to + {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. + """ + + # No reduction and no mudular + if any( + len(node.group[1][1]) != 0 + or any( + expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() + ) + for node in nodes + ): + return nodes + + split_var = None + split_number = None + num_div = 0 + div_expr_ = None + match_div = False + matched_node = None + + for node in nodes: + assert isinstance(node.node, ir.ComputedBuffer) + _, original_body, _ = node.node.get_default_sizes_body() + for name, expr in original_body.indexing_exprs.items(): + if not isinstance(expr, sympy.Expr): + continue + for div_expr in expr.find(FloorDiv): + if ( + any(div_expr.has(var) for var in original_body.iter_vars) + and div_expr != div_expr_ + ): + div_expr_ = div_expr + num_div += 1 + if num_div > 1: + return nodes + if ( + isinstance(div_expr.args[1], sympy.core.numbers.Integer) + and div_expr.args[0] in original_body.iter_vars + and name is not None + and all( + stride_at_vec_range(expr_, div_expr.args[0]) in (0, 1) + for name_, expr_ in original_body.indexing_exprs.items() + if name_ != name + ) + and div_expr.args[1] > 8 + ): + split_var = div_expr.args[0] + split_number = div_expr.args[1] + match_div = True + matched_node = node + + # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. + if not match_div: + return nodes + + extra_indexing_constraints = None + + def loop_split(sizes, body, vars): + index_size, reduce_size = sizes + index_vars, reduce_vars = vars + split_idx = index_vars.index(split_var) + new_index_size = index_size.copy() + new_index_size[split_idx] = index_size[split_idx] // split_number + new_index_size.insert(split_idx + 1, split_number) + (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( + new_index_size, reduce_size, prefix="y" + ) + iter_vars = new_index_vars.copy() + divisor_var = iter_vars.pop(split_idx + 1) + iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var + body = ir.LoopBody( + body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars + ) + nonlocal extra_indexing_constraints + if not extra_indexing_constraints: + extra_indexing_constraints = ( + body.var_ranges, + list(body.indexing_exprs.values()), + ) + return ( + (new_index_size, reduce_size), + body, + (new_index_vars, reduce_vars), + ) + + # Here decide the final loop order + for node in nodes: + if node == matched_node: + node.recompute_size_and_body(recompute_sizes_body_func=loop_split) + for node in nodes: + if node != matched_node: + node.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=loop_split, + ) + + return nodes + + def codegen_outer_loop_node( + self, + node: OuterLoopFusedSchedulerNode, + ): + """ + Generate the code for the outer loop fused scheduler node. + 1. Codegen with fused outer loop: depends on the analysis of + the outer loop fused scheduler node, with or without the local buffer. + 2. If failed, fallback to standard codegen. + """ + kernel_group = self.kernel_group + generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count + cpp_kernel_proxy_list: list[self.kernel_proxy_cls] = [] # type: ignore[name-defined] + nodes_list: list[list[SchedulerNode]] = [] + assert isinstance(node, OuterLoopFusedSchedulerNode) + + def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode): + """ + Codegen code with fused outer loop and local Buffer. + """ + assert isinstance(node, OuterLoopFusedSchedulerNode) + cpp_kernel_proxy_list.clear() + nodes_list.clear() + + def get_call_ranges(node: BaseSchedulerNode): + assert isinstance(node, (SchedulerNode, FusedSchedulerNode)) + nodes: list[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + return call_ranges + + local_buffers: list[ir.Buffer] = [] + # Map local buffer name to a list of global buffers + local_to_global_buffers: dict[str, list[ir.Buffer]] = {} + if all( + len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1 + for _node in node.get_outer_nodes() + ): + # Ref to the typical case of local buffer in + # https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950 + # where the buffer is with size of last dim and contiguous. + # Only support this typical case at first. + visited_scheduler_nodes: OrderedSet[str] = OrderedSet() + for scheduler_node in node.get_nodes(): + # all users inside same OuterLoopFusedSchedulerNode + assert isinstance(scheduler_node, SchedulerNode) + visited_scheduler_nodes.add(scheduler_node.get_name()) + if ( + scheduler_node.is_reduction() + or len(scheduler_node.get_outputs()) != 1 + ): + continue + + scheduler_buffer = scheduler_node.get_outputs()[0] + if all( + user.node in node.get_nodes() for user in scheduler_buffer.users + ): + global_buffer = scheduler_buffer.node + assert isinstance(global_buffer, ir.ComputedBuffer) + global_buffer_layout = global_buffer.get_layout() + size_offset = node.outer_loop_fusion_depth - len( + get_call_ranges(scheduler_node) + ) + + def is_all_write_read_contiguous(): + contiguous_index_expr = 0 + stride = 1 + for var, range in reversed( + scheduler_node._body.var_ranges.items() + ): + contiguous_index_expr += stride * var + stride *= range + write_index_expr = scheduler_node._body.get_write_expr( + scheduler_buffer.get_name() + ) + + def is_contiguous_index(x): + return x == contiguous_index_expr + + return is_contiguous_index(write_index_expr) and all( + isinstance(user.node, SchedulerNode) + and is_contiguous_index( + user.node._body.get_read_expr( + scheduler_buffer.get_name() + ), + ) + for user in scheduler_buffer.users + ) + + if not ( + global_buffer_layout.is_contiguous() + and is_all_write_read_contiguous() + ): + continue + # Local Buffer is a view of global buffer + local_buffer_layout = ir.FixedLayout( + global_buffer_layout.device, + global_buffer_layout.dtype, + global_buffer_layout.size[size_offset:], + global_buffer_layout.stride[size_offset:], + ) + + def try_share_local_buffer(local_buffer_layout, local_buffers): + for local_buf in local_buffers: + if local_buffer_layout == local_buf.layout and all( + all( + user.node.get_name() in visited_scheduler_nodes + for user in V.graph.scheduler.name_to_buf[ + global_buffer.name + ].users + ) + for global_buffer in local_to_global_buffers[ + local_buf.name + ] + if global_buffer.name is not None + ): + return local_buf + return None + + local_buf_prefix = "local_buffer_data" + # Share existing local buffer + local_buffer_used = try_share_local_buffer( + local_buffer_layout, local_buffers + ) + if not local_buffer_used: + # Create new local buffer + local_buffer_used = ir.Buffer( + name=f"{local_buf_prefix}_{len(local_buffers)}", + layout=local_buffer_layout, + ) + local_buffers.append(local_buffer_used) + local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index] + local_to_global_buffers[local_buffer_used.name].append( + global_buffer, + ) + + with LocalBufferContext(kernel_group.args) as scope: + if len(local_buffers) > 0: + for local_buffer in local_buffers: + assert local_buffer.name is not None + scope.add_local_buffer( + local_buffer, local_to_global_buffers[local_buffer.name] + ) + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type] + cpp_kernel_proxy_list.append(cpp_kernel_proxy) + nodes_list.append(_node.get_nodes()) # type: ignore[arg-type] + + if not node.check_outer_fusion_loop_level_attr( + cpp_kernel_proxy_list, node.outer_loop_fusion_depth + ): + for removed_buffer in scope.removed_buffers: + # Restore the removed buffers by this context before + # fallback to codegen without using Local Buffer + V.graph.removed_buffers.remove(removed_buffer) + return False + metrics.cpp_outer_loop_fused_inner_counts.append( + metrics.CppOuterLoopFusedCount( + len(cpp_kernel_proxy_list), + local_buffer_number=len(scope.local_buffers), + ) + ) + outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( + cpp_kernel_proxy_list, + ) + kernel_group.finalize_kernel( + outer_fusion_cpp_kernel_proxy, + [*itertools.chain.from_iterable(nodes_list)], + ) + + return True + + if not try_outer_loop_fusion_with_local_buf(node): + # Reset generated_cpp_vec_kernel_count to codegen again + metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count + cpp_kernel_proxy_list.clear() + nodes_list.clear() + # Similar as comment in + # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272 + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + with torch._inductor.config.patch(inplace_buffers=False): + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + _nodes: list[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(_nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) + + def codegen_node( + self, + node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode], + ): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group + + if isinstance(node, OuterLoopFusedSchedulerNode): + self.codegen_outer_loop_node(node) + else: + nodes: list[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + nodes = self.try_loop_split(nodes) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) + + args_num = self._get_scheduled_num_args() + if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: + self._set_flush_status(True) + + def is_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ir.CppTemplateBuffer + ) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CPP template, possibly with fused epilogues + """ + assert not prologue_nodes + + # remove MultiOutput from epilogue_nodes + epilogue_nodes = [ + epilogue_node + for epilogue_node in epilogue_nodes + if isinstance(epilogue_node, (SchedulerNode, FusedSchedulerNode)) + ] + # The counter cpp_templated_kernel_counter is used for verifying if a + # a templated kernel was successfully compiled in a UT + counters["inductor"]["cpp_templated_kernel_counter"] += 1 + counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cpp_template(template_node), ( + "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_, rnumel) = template_node.group + assert rnumel == () + ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + epilogue_ir_nodes: list[Optional[ir.Operation]] = [ + n.node for n in epilogue_nodes + ] + assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), ( + "Epilogue nodes must all be instances of ir.ComputedBuffer" + ) + + def template_buffer_has_other_users( + template_buffer, outputs_by_name, epilogue_nodes + ): + if not epilogue_nodes: + return False + + assert template_buffer.get_name() in outputs_by_name + users = outputs_by_name[template_buffer.get_name()].users + return not all( + isinstance(user.node, BaseSchedulerNode) + and user.node.node in epilogue_nodes + for user in users + ) + + flag_template_buffer_has_other_users = template_buffer_has_other_users( + ctb, template_node.outputs_by_name, epilogue_ir_nodes + ) + kernel, render = ctb.make_kernel_render( + ctb, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_ir_nodes, + ) + with kernel: + if not is_multi_outputs_template(template_node.node): + template_node.mark_run() # type: ignore[attr-defined] + for node in epilogue_nodes: + node.mark_run() # type: ignore[attr-defined] + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) + + if is_multi_outputs_template(template_node.node): + # For multi outputs template, allocate buffers for each output after the epilogue + # codegen to which determines if the buffer has been removed. + assert len(template_node.outputs) == 1, ( + "Multi outputs template should be with 1 output template buffer of MultiOutputLayout" + ) + for user in template_node.outputs[0].users: + assert isinstance(user.node, ExternKernelSchedulerNode), ( + "Multi outputs template should be with ExternKernelSchedulerNode" + ) + assert isinstance(user.node.node, ir.MultiOutput), ( + "Multi outputs template has multi users with MultiOutput" + ) + user.node.mark_run() + + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() + + def _get_scheduled_num_args(self): + return self.kernel_group.get_num_args() + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def define_kernel(self, src_code, nodes, kernel_args=None): + wrapper = V.graph.wrapper_code + fused_name = ( + get_fused_kernel_name(nodes, config.cpp.descriptive_names) + if config.cpp.descriptive_names + else "" + ) + kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) + # below add provenance tracing info for cpu CppKernel types + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(nodes, kernel_name) + + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name) + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "//") + + # Get the lines in the source code representing the function definition, + # excluding the the first line including cpp_prefix.h. + first_char = src_code.rfind('extern "C"') + last_char = src_code.find(")", first_char) + if _IS_WINDOWS: + # get_export_declaration introduced one more ')' in Windows + last_char = src_code.find(")", last_char + 1) + kernel_definition = f"{src_code[first_char : last_char + 1]};\n" + + compile_wrapper = IndentedBuffer() + args = self.kernel_group.args if kernel_args is None else kernel_args + _, _, arg_types = args.cpp_argdefs() + if not V.graph.cpp_wrapper: + compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") + compile_wrapper.splice(src_code, strip=True) + if not V.graph.cpp_wrapper: + compile_wrapper.writeline("''')") + wrapper.define_kernel( + kernel_name, + compile_wrapper.getvalue(), + gpu=False, + cpp_definition=kernel_definition, + ) + return kernel_name + + def flush(self): + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + self.reset_kernel_group() + self._set_flush_status(False) + + +class KernelGroup: + def __init__(self): + super().__init__() + self.args = KernelArgs() + self.loops_code = BracesBuffer() + self.ws = WorkSharing(self.loops_code) + self.stack = contextlib.ExitStack() + self.stack.enter_context(self.ws) + self.scheduled_nodes = [] + + def new_kernel(self, cls, *args): + return cls(self.args, parallel_num_threads(), *args) + + def finalize_kernel(self, new_kernel, nodes): + self.scheduled_nodes += nodes + code = self.loops_code + ws = self.ws + new_kernel.codegen_loops(code, ws) + + def get_num_args(self): + arg_defs, _call_args, _arg_types = self.args.cpp_argdefs() + args_num = len(arg_defs) + return args_num + + def codegen_group(self, name=None) -> str: + self.stack.close() + if not self.scheduled_nodes: + return "" + code = BracesBuffer() + # 1. Include header files + # TODO: support kernel profile on other platforms + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + code.writelines(["#include "]) + code.writeline("#include ") + + # 2. Function definition + kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name + kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name + arg_defs, _, _ = self.args.cpp_argdefs() + arg_defs = ",\n".ljust(25).join(arg_defs) + func_export_decl = get_export_declaration() + code.writeline( + f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' + ) + + # 3. Function body + with code.indent(): + if enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + code.writelines( + [ + f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' + ] + ) + for old, new in self.args.aliases(): + code.writeline(f"auto {old} = {new};") + code.splice(self.loops_code) + return code.getvalue() + + def call_kernel(self, wrapper, kernel_name): + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call( + kernel_name, call_args, triton=False, arg_types=arg_types + ) + + +class WorkSharing: + def __init__(self, code): + self.code = code + self.in_parallel = False + self.num_threads = None + self.stack = contextlib.ExitStack() + + def parallel(self, threads): + if self.in_parallel and threads != self.num_threads: + # wrong number of threads + self.close() + if not self.in_parallel: + self.num_threads = threads + self.in_parallel = True + if config.cpp.dynamic_threads: + self.code.writeline("#pragma omp parallel") + else: + self.code.writeline(f"#pragma omp parallel num_threads({threads})") + self.stack.enter_context(self.code.indent()) + self.code.writeline( + "int tid = omp_get_thread_num();", + ) + + def single(self): + if self.in_parallel: + self.code.writeline("#pragma omp single") + return self.in_parallel + + def close(self): + self.stack.close() + self.in_parallel = False + + def __enter__(self): + self.stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stack.__exit__(exc_type, exc_val, exc_tb) + + +@dataclasses.dataclass +class LoopLevel: + var: Optional[sympy.Expr] = None + size: Optional[sympy.Expr] = None + offset: sympy.Expr = sympy.S.Zero + # Note [tiled_size] + # We may do loop-tiling at this loop level. + # When var is in [offset, tiled_size), we will perform the vectorization kernel. + # When var is in [tiled_size, size), we will perform the scalar or masked vectorization kernel. + # for (var = offset; var < size; var += steps) { + # if (var >= offset && var < tiled_size) vec_loop_body(); + # if (var >= tiled_size && var < size) scalar_or_maskvec_loop_body(); + # } + tiled_size: sympy.Expr = sympy.S.Zero + steps: sympy.Expr = sympy.S.One + parallel: int = 0 + simd_omp: bool = False + simd_vec: bool = False + collapsed: bool = False + is_reduction: bool = False + + def __post_init__(self): + # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check + # vectorization ISA is a time-consuming and one-shot operation. It leads + # to taking a longer time to import `codegen.cpp` package because the + # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while + # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the + # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation + # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to + # `__post_init__` + picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + + def tile(self, factor): + sympy_factor = sympy.Integer(factor) + loop = LoopLevel(self.var, self.size) + loop.steps = sympy_factor + loop.simd_vec = True + loop.tiled_size = FloorDiv(loop.size, sympy_factor) * sympy_factor + loop.parallel = self.parallel + loop.collapsed = False + loop.is_reduction = self.is_reduction + return loop + + def lines(self): + offset_expr = cexpr_index(self.offset) + size_expr = cexpr_index(self.size) + if config.cpp.no_redundant_loops and offset_expr == size_expr: + return None + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) + if self.parallel: + # TODO(jansel): look into chunk size and other schedules + line1 = "#pragma omp for" + if self.parallel > 1: + line1 += f" collapse({self.parallel})" + if self.simd_omp: + line1 = line1.replace(" for ", f" for {simd}") + elif self.simd_vec: + line1 = "" + elif self.simd_omp: + line1 = f"#pragma omp {simd}" + elif not self.is_reduction and cpp_builder.is_gcc(): + line1 = "#pragma GCC ivdep" + else: + line1 = "" + offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" + size_str = f"{self.var}<{size_expr}" + if self.steps.is_number: + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + else: + # If the step size is 0, change it to 1 because a step size of 0 + # will cause floating point exception (core dump) during parallelization. + steps_str = ( + f"{self.var}+=({cexpr_index(self.steps)} == 0 ? " + f"1 : {cexpr_index(self.steps)})" + ) + line2 = f"for({offset_str}; {size_str}; {steps_str})" + if self.collapsed or not line1: + return [line2] + return [line1, line2] + + +@dataclasses.dataclass +class LoopNest: + """ + A loop-nest-like structure. It is built with the `build` method + as a loop nest and then will perform loop-tiling at some depth. + + A typical case is for vectorization, where we typically do loop-tiling + at the innermost loop level. A more complicated case is when we do + 2D tiling at both the innermost and outer levels. + """ + + loops: Optional[list[LoopLevel]] = None + kernel: Optional[CppKernel] = None + + @staticmethod + def build(kernel: CppKernel): + """Build a LoopNest with the given `kernel` as the leaf""" + itervars = kernel.itervars + ranges = kernel.ranges + reduction_depth = kernel.reduction_depth + assert reduction_depth is not None + + loops: Optional[list[LoopLevel]] = None + for loop_idx, (var, size) in enumerate(zip(itervars, ranges)): + loop = LoopLevel(var, size) + if not loops: + loops = [loop] + else: + loops.append(loop) + if loop_idx >= reduction_depth: + loop.is_reduction = kernel.is_reduction + + loop_nest = LoopNest(loops) + return loop_nest + + def __bool__(self): + return bool(self.loops) + + @cache_on_self + def max_parallel_depth(self): + """ + Maximal allowed depth for parallelism: All reduction or non-reduction levels. + When the range of the first inner loop beyond the maximum parallel depth is much + larger than the range of all outer loops within the maximum parallel depth, + change the starting depth of parallelism to the first inner loop and recalculate + the maximum parallel depth. + """ + if self.loops is None: + return ParallelDepth(parallel_depth=0, start_depth=0) + + start_depth = 0 + max_depth = 0 + is_reduction = self.loops[0].is_reduction + num_steps = sympy.Integer(1) + for loop in self.loops: + if loop.is_reduction != is_reduction: + break + num_steps = num_steps * FloorDiv(loop.size, loop.steps) + max_depth += 1 + + def get_simd_vec_depth(loops): + # Return the first loop level which is simd_vec + for i, loop in enumerate(loops): + if loop.simd_vec: + return i + return None + + simd_vec_depth = get_simd_vec_depth(self.loops) + + # When the number of steps of the first inner loop is much larger than the number of steps of + # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`. + if ( + max_depth < len(self.loops) + and isinstance(num_steps, sympy.Integer) + and isinstance(self.loops[max_depth].size, sympy.Integer) + and num_steps * 300 + < FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps) + and not ( + # Disable parallel reduction under the vec loop + simd_vec_depth is not None + and max_depth > simd_vec_depth + and self.loops[max_depth].is_reduction + ) + ): + start_depth = max_depth + max_depth = 0 + is_reduction = self.loops[start_depth].is_reduction + for i in range(start_depth, len(self.loops)): + if self.loops[i].is_reduction != is_reduction: + break + max_depth += 1 + return ParallelDepth(parallel_depth=max_depth, start_depth=start_depth) + + def mark_parallel(self, par_depth): + assert par_depth.parallel_depth <= self.max_parallel_depth().parallel_depth, ( + "Parallel depth cannot exceed the maximal allowed parallel depth" + ) + assert self.loops is not None + assert len(self.loops) >= par_depth.parallel_depth + loop = self.loops[par_depth.start_depth] + loop.parallel = par_depth.parallel_depth + if loop.is_reduction: + metrics.parallel_reduction_count += 1 + for i in range(par_depth.start_depth + 1, par_depth.parallel_depth): + self.loops[i].collapsed = True + + def tile(self, depth, factor): + """ + Do loop-tiling at the `depth` level with `factor`. + for (x0 = 0; x0 < x0_end; x0++) + -> + for (x0 = 0; x0 < x0_end; x0 += factor) + See details in Note [tiled_size]. + """ + assert self.loops + self.loops[depth] = self.loops[depth].tile(factor) + return self.loops[depth] + + def get_kernel(self) -> CppKernel: + assert self.kernel + return self.kernel + + def set_kernel(self, kernel): + self.kernel = kernel + + def from_loop_level(self, level: int): + assert self.loops + assert len(self.loops) >= level + loops = None if level == len(self.loops) else self.loops[level:] + return LoopNest(loops, self.kernel) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb0ee97d6c6229a88353589a1a8e06817f3a2a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_bmm_template.py @@ -0,0 +1,262 @@ +# mypy: allow-untyped-defs +import contextlib +import itertools +from typing import Any, Callable, Optional +from unittest.mock import patch + +import sympy + +from .. import ir +from ..select_algorithm import PartialRender +from ..virtualized import V +from .common import ArgName +from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE +from .cpp_micro_gemm import LayoutType +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking + + +# We pass all sizevars present in BY to the GEMM templates so variables are not renamed in the BMM definition +GEMM_SINGLE_THREAD_MM_STUB = r""" +{{kernel.def_kernel( + inputs={"X": X, "W": W}, + outputs={"Y": Y_2d}, + aliases=aliases, + function_name=kernel_name+"_single_thread_mm", + extra_sizevars=BY_sizevars + [b_index], + placeholder="")}}""" + +GEMM_THREADED_MM_STUB = r""" +{{kernel.def_kernel( + inputs={"X": X, "W": W}, + outputs={"Y": Y_2d}, + aliases=aliases, + function_name=kernel_name+"_threaded_mm", + extra_sizevars=BY_sizevars + [b_index], + placeholder="")}}""" + +BMM_TEMPLATE = r""" +{{ template.codegen_microkernel_def() }} +{{ template.codegen_single_thread_gemm() }} +{{ template.codegen_multi_thread_gemm() }} + +extern "C" +{{kernel.def_kernel(inputs={"X": BX, "W": BW}, outputs={"Y": BY}, aliases=aliases)}} +{ + const int64_t B = {{kernel.size(BY_2d, 0)}}; + {%- if num_threads > 1 %} + constexpr int64_t num_threads = {{num_threads}}; + int64_t B_single_thread_block = (B / num_threads) * num_threads; + + #pragma omp parallel for num_threads({{num_threads}}) + {%- else %} + int64_t B_single_thread_block = B; + {%- endif %} + for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { + {{template.get_gemm_function_call( + kernel, + kernel_name+"_single_thread_mm", + "", + b_index="b_start", + )}} + } + for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { + {{template.get_gemm_function_call( + kernel, + kernel_name+"_threaded_mm", + "", + b_index="b_start", + )}} + } +} +""" + + +class CppBmmTemplate(CppGemmTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + has_bias=False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + should_block_weights: bool = False, + name="bmm", + ): + """ + In order to simplify the implementation and increase code reuse, the BMM template implements + two versions of the GEMM kernel: a single-threaded version and a multi-threaded version. + GEMM kernels are called in a loop over the batch dimension, with single-threaded GEMM calls + for all but the last (B % num_threads), which are handled by the multi-threaded GEMM kernel. + + We use an extra sizevar `b_index` to index the batch dimension, which we pass into the GEMM + template as a sympy.Symbol. This allows us to slice the 3D batch tensors in the GEMM template + without any changes to the GEMM template itself. + """ + super().__init__( + input_nodes, + layout, + num_threads, + register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + should_block_weights=should_block_weights, + name=name, + ) + self.b_index = sympy.Symbol("s_b_index", integer=True, nonnegative=True) + + @staticmethod + def get_padded_size(n, block_n, k, should_block_weight): + if should_block_weight: + # Tensor is constant or not contiguous, so we will pad and block + new_size, padded_n = CppGemmTemplate.get_padded_size( + n, block_n, k, should_block_weight + ) + # Add the new batch dimension + new_size.insert(0, -1) + return new_size, padded_n + else: + new_size = [-1, k, n] + return new_size, n + + @staticmethod + def check_if_block_weight(W, micro_gemm): + assert isinstance(W, ir.IRNode) + _, n = W.get_size()[-2:] + result = ( + not W.get_layout().is_contiguous() + or W.get_name() in V.graph.constants + or ( + n % micro_gemm.register_blocking.block_n != 0 + and micro_gemm.get_b_layout != LayoutType.NORMAL + ) + ) + return result + + def get_gemm_function_call( + self, + kernel: CppTemplateKernel, + function_name: str, + placeholder: str, + b_index: str, + ) -> str: + """ + Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition, + generate a function call for the GEMM kernel. + Args: + placeholder: The string to replace the function call with + b_index: The index for slicing the 3D batch tensors + """ + + def hook(): + arg_defs, call_args, _, _ = kernel.args.python_argdefs() + for i, buf in enumerate(call_args): + if buf == self.b_index: + arg_defs[i] = ArgName(b_index) + call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});" + return call + + assert placeholder not in kernel.render_hooks + kernel.render_hooks[placeholder] = hook + return placeholder + + def get_default_reindexers(self, epilogue_nodes): + def reindexer(args): + # if epilogue nodes exist, they have 3D ranges but args are 2D, so add 0 index + return [self.b_index] + args + + return [reindexer] * len(epilogue_nodes) + + def get_options( + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> dict[str, Any]: + options = super().get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + + BX, BW, BY = options["X"], options["W"], options["Y"] + options["BX"], options["BW"], options["BY"] = BX, BW, BY + options["BY_2d"] = options["Y_2d"] + for kword in ["X", "W", "GemmOut", "Y_2d"]: + options[kword] = kernel.select(options[kword], 0, self.b_index) + for kword in ["X", "W", "Y_2d"]: + options[kword + "_dtype"] = DTYPE_TO_CPP[options[kword].dtype] + options["b_index"] = self.b_index + options["BY_sizevars"] = [ + s + for sym in itertools.chain(BY.get_size(), BY.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ] + options["kernel_name"] = kernel.kernel_name + + return options + + def render( # type: ignore[override, return] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + options = self.get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + self.render_options = options + + with contextlib.ExitStack() as stack: + for buf in options["fake_buffers"]: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + result = self._template_from_string(BMM_TEMPLATE).render(**options) + + # Finalize the function definitions for the gemm routines + sub_mm_hooks = { + name: hook + for name, hook in kernel.render_hooks.items() + if "FOR_BMM" in name + } + result = PartialRender(result, sub_mm_hooks).finalize_all() + for name in sub_mm_hooks: + del kernel.render_hooks[name] + del kernel.args.sizevars[options["b_index"]] + return result + + def codegen_single_thread_gemm(self): + stub = self._template_from_string(GEMM_SINGLE_THREAD_MM_STUB).render( + self.render_options + ) + return stub + self._template_from_string(GEMM_TEMPLATE).render( + {**self.render_options, "num_threads": 1} + ) + + def codegen_multi_thread_gemm(self): + stub = self._template_from_string(GEMM_THREADED_MM_STUB).render( + self.render_options + ) + return stub + self._template_from_string(GEMM_TEMPLATE).render( + self.render_options + ) + + def codegen_gemm_stub_def(self): + return "" diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py new file mode 100644 index 0000000000000000000000000000000000000000..5081e2ad9f61dc04e5552bed0cde2e8f3e3d6530 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -0,0 +1,1081 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import re +from typing import Optional +from unittest.mock import patch + +import sympy + +import torch +import torch.utils + +from ...utils._ordered_set import OrderedSet +from .. import ir +from ..ir import TensorBox +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import parallel_num_threads +from ..virtualized import V +from .cpp_template import CppTemplate +from .cpp_utils import GemmBlocking + + +log = logging.getLogger(__name__) + +# TODO: reuse cpp codegen to generate below pointwise/reduction kernels +SOFTMAX_FUSIONS = r""" +// 1) out = exp(a - val) +// 2) val = sum(out) +template +inline void {{kernel_name}}_exp_reduce_sum_fusion_kernel( + T1* a, + const int& size, + T2* out, + T1& val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + at::native::_store(out + i, tmp2); + } + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return x + y; + }, + vec_tmp_sum); + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 - val; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + out[i] = tmp2; + } + val = tmp_sum; +} + +// 1) out = a * scale +// 2) max = max(out) +template +inline void {{kernel_name}}_mul_reduce_max_fusion_kernel( + const scalar_t* a, + const scalar_t& scale, + const int& size, + scalar_t* out, + scalar_t& max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + at::native::_store(out + i, tmp1); + } + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + max = std::max( + tmp_max, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + vec_tmp_max)); +} + +template +static inline scalar_t* {{kernel_name}}_conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { + TORCH_CHECK(ptr2 == nullptr); + return ptr; +} + +template , int> = 0> +static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) { + return ptr2; +} + +template +inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + Vec data_vec = Vec(val); + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + data_vec.store(data + d); + } + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (; d < size; d++) { + data[d] = val; + } +} + +// out = a * scale +template +inline void {{kernel_name}}_mul_scale_kernel( + scalar_t* a, + scalar_t scale, + int64_t size) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + for (int64_t i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + at::native::_store(a + i, tmp1); + } + for (int64_t i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + a[i] = tmp1; + } +} + +""" + +BRGEMM_PACK_FUNCTIONS = r""" +template +inline void {{kernel_name}}_copy_value_with_pad( + const scalar_t* value_ptr, + scalar_t* dst_ptr, + int64_t rows, + int64_t cols, + int64_t prows, + int64_t pcols, + int64_t ldi) { + auto vec_size = at::vec::Vectorized::size(); + int64_t i = 0; + for (; i < rows; i++) { + int64_t j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + auto zero_vec = at::vec::Vectorized(0); + int64_t pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + zero_vec.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + // row padding + for (; i < prows; i++) { + auto zero_vec = at::vec::Vectorized(0); + int64_t j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + zero_vec.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + zero_vec.store(dst_ptr + i * pcols + j, pcols - j); + } + + } +} +""" + +MICRO_GEMM_TEMPLATE = r""" +GEMM_DEFINE +""" + +ALLOCATE_BUFFER = r""" + int64_t {{buffer_name}}_dtype_itemsize = c10::is_reduced_floating_point_v<{{buffer_dtype}}> ? 2 : 4; + auto& {{buffer_name}}_allocator = *at::getCPUAllocator(); + auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize); + void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get(); + {{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr; +""" + +FLEX_ATTENTION_TEMPLATE = r""" +{{template.header().getvalue()}} +#include +#include +#include +{{template.codegen_micro_gemm(kernel.kernel_name)}} +{{template.codegen_softmax_fusion(kernel.kernel_name)}} +{{template.codegen_brgemm_pack_function(kernel.kernel_name)}} +{%- set kernel_args = {"query": query, "key": key, "value": value, + "kv_num_blocks": kv_num_blocks, "kv_indices": kv_indices, + "full_kv_num_blocks": full_kv_num_blocks, "full_kv_indices": full_kv_indices } %} +{%- set kernel_args = template.update_kernel_args(kernel_args) %} + +extern "C" +{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}} +{ + {{ kernel.maybe_codegen_profile() }} + int64_t qBlockSize = {{qBlockSize}}; + int64_t kvBlockSize = {{kvBlockSize}}; + int64_t num_thread = {{num_thread}}; + + // dtypes of kernel and internal buffers + using scalar_t = {{kernel.dtype(query)}}; + constexpr bool is_reduced_type = c10::is_reduced_floating_point_v; + using accum_t = at::opmath_type<{{kernel.dtype(query)}}>; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = {{scale}}; + int64_t batchSize = {{kernel.size(query, 0)}}; + int64_t qSize = {{kernel.size(query, 1)}}; + int64_t num_head = {{kernel.size(query, 2)}}; + int64_t headSize = {{kernel.size(query, 3)}}; + int64_t batchSize_k = {{kernel.size(key, 0)}}; + int64_t num_head_k = {{kernel.size(key, 2)}}; + int64_t headSize_v = {{kernel.size(value, 3)}}; + bool is_broadcast_bs_kv = batchSize != batchSize_k; + bool is_broadcast_head_kv = num_head != num_head_k; + int64_t gqa_shards = num_head / num_head_k; + int64_t bs_shards = batchSize / batchSize_k; + + int64_t batchSize_kvi = {{kernel.size(kv_indices, 0)}}; + int64_t num_head_kvi = {{kernel.size(kv_indices, 1)}}; + int64_t block_num_kvi = {{kernel.size(kv_indices, 3)}}; + bool is_broadcast_bs_kvi = batchSize != batchSize_kvi; + bool is_broadcast_head_kvi = num_head != num_head_kvi; + int64_t gqa_shards_kvi = num_head / num_head_kvi; + int64_t bs_shards_kvi = batchSize / batchSize_kvi; + + int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}}; + int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}}; + int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}}; + + int64_t num_kviStrideB = {{kernel.stride(kv_num_blocks, 0)}}; + int64_t num_kviStrideH = {{kernel.stride(kv_num_blocks, 1)}}; + +{%- if has_full_kv_block %} + int64_t full_kviStrideB = {{kernel.stride(full_kv_indices, 0)}}; + int64_t full_kviStrideH = {{kernel.stride(full_kv_indices, 1)}}; + int64_t full_kviStrideQ = {{kernel.stride(full_kv_indices, 2)}}; + + int64_t full_num_kviStrideB = {{kernel.stride(full_kv_num_blocks, 0)}}; + int64_t full_num_kviStrideH = {{kernel.stride(full_kv_num_blocks, 1)}}; + auto full_kv_indices_data = full_kv_indices; + auto full_kv_num_blocks_data = full_kv_num_blocks; +{%- endif %} + + auto kv_num_blocks_data = kv_num_blocks; + auto kv_indices_data = kv_indices; + + // Strides + int64_t qStrideB = {{kernel.stride(query, 0)}}; + int64_t qStrideM = {{kernel.stride(query, 1)}}; + int64_t qStrideH = {{kernel.stride(query, 2)}}; + int64_t kStrideB = {{kernel.stride(key, 0)}}; + int64_t kStrideN = {{kernel.stride(key, 1)}}; + int64_t kStrideH = {{kernel.stride(key, 2)}}; + int64_t vStrideB = {{kernel.stride(value, 0)}}; + int64_t vStrideN = {{kernel.stride(value, 1)}}; + int64_t vStrideH = {{kernel.stride(value, 2)}}; + int64_t oStrideB = {{kernel.stride(output, 0)}}; + int64_t oStrideM = {{kernel.stride(output, 2)}}; + int64_t oStrideH = {{kernel.stride(output, 1)}}; + + int64_t kvSize = {{kernel.size(key, 1)}}; + + int64_t qSplitSize = qBlockSize; + int64_t kvSplitSize = kvBlockSize; + + + qSplitSize = qSplitSize > qSize ? qSize : qSplitSize; + kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize; + int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize; + int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + + bool need_pack = false; + // Whether pack is needed for BFloat16/Half + if (is_reduced_type) { + // check platform ability + need_pack = std::is_same_v ? at::native::cpublas::could_pack(at::kBFloat16) + : at::native::cpublas::could_pack(at::kHalf); + } + if (need_pack) { + // When the number of gemm is greater than the number of pack, + // the pack overhead can be overlapped. + int64_t thresh_size = 64; + need_pack = kvSize >= thresh_size && qSize >= thresh_size; + if (need_pack) { + double pack_size = batchSize * num_head * kvSize * headSize; + double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread; + double gemm_size_per_thread = qs_per_thread * qSplitSize * kvSize * headSize; + need_pack = gemm_size_per_thread / pack_size >= 4; + } + } + // Pad is needed for packing when K is not even + bool headSize_even = headSize % 2 == 0; + int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize; + int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize; + int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail; + int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail; + + // Allocate per thread temp buf (accumulate type) + int64_t _size_per_thread = + /* qk */ qSplitSize * kvSplitSize + + /* qk_max */ qSplitSize + + /* qk_sum */ qSplitSize + + /* dst */ qSplitSize * headSize_v; + + // Inputs/outputs buffers + const scalar_t* q_data = query; + const scalar_t* k_data = key; + const scalar_t* v_data = value; + scalar_t* out_data = output; + + // Buffers to store accum results, padding query and transpose/packing key/value + {{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}} + {{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}} + {{template.codegen_allocate_buffer("key_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*eheadSize*kvSize")}} + {{template.codegen_allocate_buffer("value_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*kv_padding_size*headSize_v")}} + {{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}} + {{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}} + if (need_pack) { + // Pack K, V + at::parallel_for(0, batchSize_k * num_head_k * kvSlice, 1, [&](int64_t begin, int64_t end) { + int ompIdx = at::get_thread_num(); + int64_t i = 0, j = 0, l = 0, n = 0; + scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr; + at::native::data_index_init(begin, i, batchSize_k, j, num_head_k, l, kvSlice); + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + n = l * kvSplitSize; + int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n); + auto k_addr = + k_data + i * kStrideB + j * kStrideH + n * kStrideN; + auto v_addr = + v_data + i * vStrideB + j * vStrideH + n * vStrideN; + // transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize] + at::native::utils::transpose( + cur_kvSplitSize, + headSize, + /* src_ptr */ + reinterpret_cast(k_addr), + /* ld_src */ kStrideN, + /* dst */ reinterpret_cast(transpose_ptr), + /* ld_dst */ cur_kvSplitSize); + + // Pack [headSize, cur_kvSplitSize] + at::vec::pack_vnni2( + /* src */ reinterpret_cast(transpose_ptr), + /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head_k * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize), + /* ld_src */ cur_kvSplitSize, + /* K */ headSize, + /* N */ cur_kvSplitSize); + + // Pack [cur_kvSplitSize, headSize_v] + at::vec::pack_vnni2( + /* src */ reinterpret_cast(v_addr), + /* dst */ reinterpret_cast(value_reorder_ptr + + i * num_head_k * kv_padding_size * headSize_v + + j * kv_padding_size * headSize_v + n * headSize_v), + /* ld_src */ vStrideN, + /* K */ cur_kvSplitSize, + /* N */ headSize_v); + // Move to the next query + at::native::data_index_step(i, batchSize_k, j, num_head_k, l, kvSlice); + } + }); + } + // Attention loop below + at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + accum_t* buf_ptr = buf_data + ompIdx * _size_per_thread; + accum_t* qk_data = buf_ptr; + accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize; + accum_t* qk_sum_data = qk_max_data + qSplitSize; + accum_t* dst_data = qk_sum_data + qSplitSize; + scalar_t *qk_reduced_data = + is_reduced_type + ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize + : nullptr; + scalar_t* query_t_padding_ptr = (!headSize_even && need_pack) + ? query_padding_ptr + ompIdx * qSplitSize * eheadSize + : nullptr; + + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i; + auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j; + auto kv_logical_num_data = kv_num_blocks_data + i_kvi * num_kviStrideB + + j_kvi * num_kviStrideH + k; + int kv_indice_num = *kv_logical_num_data; + std::vector kv_indice_list(kv_indice_num); + for(int kv_i = 0; kv_i < kv_indice_num; kv_i++){ + auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB + + j_kvi * kviStrideH + k*kviStrideQ + kv_i; + kv_indice_list[kv_i] = *kv_logical_data; + } + bool is_skip_kv = kv_indice_num > 0 ? false : true; +{%- if has_full_kv_block %} + auto full_kv_logical_num_data = full_kv_num_blocks_data + i_kvi * num_kviStrideB + + j_kvi * num_kviStrideH + k; + int full_kv_indice_num = *full_kv_logical_num_data; + std::vector full_kv_indice_list(full_kv_indice_num); + for(int kv_i = 0; kv_i < full_kv_indice_num; kv_i++){ + auto full_kv_logical_data = full_kv_indices_data + i_kvi * full_kviStrideB + + j_kvi * full_kviStrideH + k*full_kviStrideQ + kv_i; + full_kv_indice_list[kv_i] = *full_kv_logical_data; + } + is_skip_kv = kv_indice_num + full_kv_indice_num > 0 ? false : true; +{%- endif %} + int64_t m = k * qSplitSize; + int64_t cur_qSplitSize = std::min(qSplitSize, qSize - m); + if (!is_skip_kv){ + // Initialize max and sum + {{kernel.kernel_name}}_fill_stub(qk_max_data, + -std::numeric_limits::infinity(), cur_qSplitSize); + {{kernel.kernel_name}}_fill_stub(qk_sum_data, + static_cast(0), cur_qSplitSize); + + if (!headSize_even && need_pack) { + // Pad query if headSize is not even + {{kernel.kernel_name}}_copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + cur_qSplitSize, + headSize, + cur_qSplitSize, + eheadSize, + qStrideM + ); + } + } + +{%- if has_full_kv_block %} + for (int64_t n_idx = 0; n_idx < kv_indice_num + full_kv_indice_num ; n_idx += 1) { + auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize; +{%- else %} + for (int64_t n_idx = 0; n_idx < kv_indice_num ; n_idx += 1) { + auto n = kv_indice_list[n_idx]*kvSplitSize; +{%- endif %} + + auto cur_n = n/kvSplitSize; + int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n); + int64_t cur_ekvSplitSize = (need_pack && cur_kvSplitSize % 2 != 0) ? cur_kvSplitSize + 1 : cur_kvSplitSize; + + // Calculate scale * q @ k.T + auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i; + auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j; + + if (!need_pack) { + auto k_addr = + k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN; + + {{kernel.kernel_name}}_kernel_micro_gemm_transpose_b(false)>( + q_data + i * qStrideB + j * qStrideH + + m * qStrideM, + k_addr, + qk_data, + cur_qSplitSize, + cur_kvSplitSize, + headSize, + qStrideM, + kStrideN, + cur_kvSplitSize); + + } else { + at::native::cpublas::brgemm( + cur_qSplitSize, + cur_kvSplitSize, + eheadSize, + headSize_even ? qStrideM : eheadSize, + cur_kvSplitSize, + cur_kvSplitSize, + false, + !headSize_even + ? query_t_padding_ptr + : q_data + i * qStrideB + j * qStrideH + m * qStrideM, + key_reorder_ptr + i_kv * num_head_k * eheadSize * kvSize + + j_kv * eheadSize * kvSize + n * eheadSize, + qk_data, + need_pack); + } + + {{kernel.kernel_name}}_mul_scale_kernel(qk_data, scaling_factor, cur_qSplitSize*cur_kvSplitSize); + +{%- if score_mod and mask_mod %} + // TODO: reduce the number of calls of q_idx and kv_idx initialization + std::vector q_idx(cur_qSplitSize); + for (int64_t i = 0; i < cur_qSplitSize; ++i) { + q_idx[i] = m + i; + } + + std::vector kv_idx(cur_kvSplitSize); + for (int64_t i = 0; i < cur_kvSplitSize; ++i) { + kv_idx[i] = n + i; + } + + std::vector b_idx = {i}; + std::vector h_idx = {j}; + + accum_t* in_ptr0 = qk_data; + + auto in_ptr1 = b_idx.data(); + auto in_ptr2 = h_idx.data(); + auto in_ptr3 = q_idx.data(); + auto in_ptr4 = kv_idx.data(); + + // apply score mod function + { + {{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }} + accum_t* out_ptr{{score_buf_idx}} = in_ptr0; + {{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }} + } + + if ((std::find(kv_indice_list.begin(), kv_indice_list.end(), cur_n) != kv_indice_list.end()) ){ + // Apply block mask, fill unused with -inf + { + {{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }} + accum_t* out_ptr{{mask_buf_idx}} = in_ptr0; + {{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }} + } + } + +{%- endif %} + // Update coefficients with Softmax + accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0; + for (int64_t row = 0; row < cur_qSplitSize; ++row) { + // apply scaling factor and max per row in fusion + {{kernel.kernel_name}}_mul_reduce_max_fusion_kernel( + qk_data + row * cur_kvSplitSize, + static_cast(1), + cur_kvSplitSize, + qk_data + row * cur_kvSplitSize, + tmp_max); + tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max; + if (tmp_max == -std::numeric_limits::infinity()) { + // to avoid `nan = exp2f(-inf - (-inf))` + {{kernel.kernel_name}}_fill_stub( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize, + static_cast(0), cur_kvSplitSize); + } else { + tmp_sum = tmp_max; + // qk <- exp(qk - max) and sum per row + {{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel( + qk_data + row * cur_kvSplitSize, cur_kvSplitSize, + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize, + tmp_sum); + // exp_tmp <- exp(max[row] - max) + exp_tmp = std::exp(qk_max_data[row] - tmp_max); + // sum[row] <- sum + exp_tmp * sum[row] + qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row]; + // max[row] <- max + qk_max_data[row] = tmp_max; + // dst <- dst * exp_tmp + if (n_idx > 0) { + at::vec::map( + [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, + dst_data + row * headSize_v, + dst_data + row * headSize_v, + headSize_v); + } + } + if (need_pack && cur_kvSplitSize % 2 != 0) { + // Pad: [qSplitSize, cur_kvSplitSize] -> [qSplitSize, cur_kvSplitSize + 1] + *(qk_reduced_data + row * (1 + cur_kvSplitSize) + cur_kvSplitSize) = scalar_t(0); + } + } + // Calculate Softmax(q @ k.T) @ v + if (!need_pack) { + auto v_addr = + v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN; + // Fallback Half brgemm is slower than micro gemm + if (!std::is_same_v) { + at::native::cpublas::brgemm( + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v, + n_idx > 0, + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + need_pack); + } else { + if (n_idx > 0) { + {{kernel.kernel_name}}_kernel_micro_gemm(true)>( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v); + } else { + {{kernel.kernel_name}}_kernel_micro_gemm(false)>( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v); + } + } + } else { + int64_t psize = n / kvSplitSize * ekvSplitSize; + at::native::cpublas::brgemm( + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + headSize_v, + headSize_v, + n_idx > 0, + qk_reduced_data, + value_reorder_ptr + + i_kv * num_head_k * kv_padding_size * headSize_v + + j_kv * kv_padding_size * headSize_v + psize * headSize_v, + dst_data, + need_pack); + } + } + + // dst <- dst / sum[row] + // reorder MHA output with strides + for (int64_t row = 0; row < cur_qSplitSize; ++row) { + // Row sums for full masked out rows are 0, we set them to 1 + // in order to avoid NaNs in the output and instead set fully + // masked out rows to 0 + qk_max_data[row] = qk_max_data[row] == -std::numeric_limits::infinity() ? 0 : qk_max_data[row]; + qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row]; + accum_t sum_reciprocal = 1 / qk_sum_data[row]; + at::vec::map( + [sum_reciprocal, is_skip_kv](Vec x) { return is_skip_kv ? Vec(0.0) : x * Vec(sum_reciprocal); }, + out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, + dst_data + row * headSize_v, + headSize_v); + } + + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); + } + + at::native::cpublas::brgemm_release(need_pack); + + }); +} +""" + + +class CppFlexAttentionTemplate(CppTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + scale, + score_mod, + mask_mod, + kv_block_size, + q_block_size, + has_other_buffer, + no_full_kv_block, + fake_buffers, + len_score_other, + len_mask_other, + kernel_input_name_to_buffer, + block_vars, + ) -> None: + assert layout.dtype in [torch.float, torch.bfloat16, torch.float16] + super().__init__("flex_attention", input_nodes, layout, parallel_num_threads()) + self.scale = scale + self.score_mod = score_mod + self.mask_mod = mask_mod + self.score_buf_name = ( + V.graph.register_buffer(self.score_mod) if self.score_mod else None + ) + self.mask_buf_name = ( + V.graph.register_buffer(self.mask_mod) if self.mask_mod else None + ) + + def get_idx(buf_name): + match = re.search(r"\d+", buf_name) + assert match, f"incorrect score buf name: {buf_name}" + return match.group() + + self.score_buf_idx = ( + get_idx(self.score_buf_name) if self.score_buf_name else None + ) + self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None + self.kv_block_size = kv_block_size + self.q_block_size = q_block_size + self.has_other_buffer = has_other_buffer + self.no_full_kv_block = no_full_kv_block + self.other_buffer_input_offset = 2 + if self.no_full_kv_block: + self.other_buffer_input_offset = 0 + self.fake_buffers = fake_buffers + self.len_score_other = len_score_other + self.len_mask_other = len_mask_other + self.kernel_input_name_to_buffer = kernel_input_name_to_buffer + self.block_vars = block_vars + self.extra_sizevars = list( + OrderedSet( + val + for val in self.kernel_input_name_to_buffer.values() + if isinstance(val, sympy.Symbol) + ) + ) + self.other_buf_start_idx = 5 + self.score_mod_other_buffers = ( + self.input_nodes[ + self.other_buf_start_idx + + self.other_buffer_input_offset : self.other_buf_start_idx + + self.other_buffer_input_offset + + self.len_score_other + ] + if self.has_other_buffer + else None + ) + self.mask_mod_other_buffers = ( + self.input_nodes[ + self.other_buf_start_idx + + self.other_buffer_input_offset + + self.len_score_other : + ] + if self.has_other_buffer + else None + ) + self.other_ptr_data = {} # type: ignore[var-annotated] + + def update_kernel_args(self, kernel_args): + kernel_args.update( + { + key: value + for key, value in self.kernel_input_name_to_buffer.items() + if not isinstance(value, sympy.Symbol) + } + ) + return kernel_args + + def generate_other_buffer(self, buf_list, start_offset, len_attr, kernel_args): + kernel_input_name_to_buffer_name = { + key: value if isinstance(value, sympy.Symbol) else value.get_name() + for key, value in self.kernel_input_name_to_buffer.items() + } + + def get_arg(name): + return kernel_input_name_to_buffer_name.get(name) + + def get_arg_name(name): + if isinstance(get_arg(name), sympy.Symbol): + return kernel_args.sizevars.get(get_arg(name)) + return kernel_args.input_buffers.get(get_arg(name)) + + if not self.has_other_buffer: + return "" + + if start_offset == -1: + start_offset = getattr(self, len_attr) + + length = getattr(self, len_attr) + for i in range(length): + pointer = f"in_ptr{self.other_buf_start_idx + start_offset + i}" + buffer_key = f"{buf_list}_{i}" + if pointer not in self.other_ptr_data: + self.other_ptr_data[pointer] = ( + get_arg_name(buffer_key), + get_arg(buffer_key), + ) + + return "\n".join( + f"auto {ptr} = {name};" for ptr, (name, _) in self.other_ptr_data.items() + ) + + def modification(self, subgraph_buffer, output_name, output_idx): + assert isinstance(subgraph_buffer, ir.ComputedBuffer) + subgraph_buffer_data = subgraph_buffer.data + from ..loop_body import LoopBody + from ..utils import sympy_index_symbol_with_prefix, SymT + from ..virtualized import V + from .cpp import CppKernelProxy, KernelGroup + + kernel_group = KernelGroup() + kernel_input_args = { + "score": "in_ptr0", + "b": "in_ptr1", + "h": "in_ptr2", + "q_idx": "in_ptr3", + "kv_idx": "in_ptr4", + } + if self.has_other_buffer: + kernel_input_args.update( + {arg: ptr for ptr, (_, arg) in self.other_ptr_data.items()} + ) + + kernel_output_args = {output_name: f"out_ptr{output_idx}"} + + args = kernel_group.args + for name, inp in kernel_input_args.items(): + args.input_buffers[name] = inp + + for name, inp in kernel_output_args.items(): + args.output_buffers[name] = inp + + for name in self.extra_sizevars: + args.sizevars[name] = f"k{name}" + + kernel_group.args = args + + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + var_sizes = tuple(subgraph_buffer.get_size()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes) + } + + dst_layout = subgraph_buffer.get_layout() + output_index = dst_layout.make_indexer()([*var_ranges.keys()]) + + def fn(*args): + V.ops.store( + output_name, + output_index, + subgraph_buffer_data.make_loader()(args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys())), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + + from ..loop_body import MemoryUsageType + + assert all( + mem.buffer_name in kernel_group.args.input_buffers + for mem in body.memory_usage[MemoryUsageType.LOAD] + ), ( + "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers" + ) + + bodies.append(body) + var_sizes_list.append((var_sizes, ())) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + output_code = kernel_group.loops_code.getvalue() + + var_q_symbol, var_kv_symbol = self.block_vars + # See [Note] Handle the case where the split sizes are not statically known. + # We don't know the value of qBlockSize and rkvBlockSize during compilation time + # thus we've represented them by symbols. + # We change the symbol strings back to "cur_qSplitSize" and "cur_kvSplitSize" + # in the generated code thus they'll be filled with the real value during runtime. + if var_q_symbol in kernel_group.args.sizevars: + output_code = output_code.replace( + kernel_group.args.sizevars[var_q_symbol], "cur_qSplitSize" + ) + if var_kv_symbol in kernel_group.args.sizevars: + output_code = output_code.replace( + kernel_group.args.sizevars[var_kv_symbol], "cur_kvSplitSize" + ) + + return output_code + + @staticmethod + def add_choices( + choices, + input_nodes, + layout, + scale, + score_mod, + mask_mod, + kv_block_size, + q_block_size, + has_other_buffer, + no_full_kv_block, + fake_buffers, + len_score_other, + len_mask_other, + kernel_input_name_to_buffer, + block_vars, + ): + def preprocessor(input_nodes, layout): + return input_nodes, layout + + def postprocessor(output): + return output + + template = DataProcessorTemplateWrapper( + CppFlexAttentionTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + scale=scale, + score_mod=score_mod, + mask_mod=mask_mod, + kv_block_size=kv_block_size, + q_block_size=q_block_size, + has_other_buffer=has_other_buffer, + no_full_kv_block=no_full_kv_block, + fake_buffers=fake_buffers, + len_score_other=len_score_other, + len_mask_other=len_mask_other, + kernel_input_name_to_buffer=kernel_input_name_to_buffer, + block_vars=block_vars, + ) + template.maybe_append_choice(choices) + return template + + def apply_score_mod(self, score, b, h, q_idx, kv_idx): + return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item() + + def render( # type: ignore[override,return] + self, + kernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + if epilogue_nodes is not None and epilogue_nodes != []: + raise NotImplementedError( + "Unsupported for `epilogue_nodes` in CppFlexAttentionTemplate." + ) + # Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + # -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + # Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + # -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + # Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + # -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + + query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3]) + key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3]) + value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3]) + self.accumulate_dtype = torch.float + self.input_dtype = query.layout.dtype + + num_threads = parallel_num_threads() + buf_out = TensorBox.create(self.output_node) + if template_buffer_node is not None: + buf_out = template_buffer_node + options = dict( + query=query, + key=key, + value=value, + kv_num_blocks=self.input_nodes[3], + kv_indices=self.input_nodes[4], + full_kv_num_blocks=self.input_nodes[5] + if not self.no_full_kv_block + else None, + full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None, + score_mod_other_buffers=self.score_mod_other_buffers, + mask_mod_other_buffers=self.mask_mod_other_buffers, + scale=self.scale, + has_full_kv_block=not self.no_full_kv_block, + accumulate_dtype=self.accumulate_dtype, + query_dtype=self.input_dtype, + kvBlockSize=self.kv_block_size, + qBlockSize=self.q_block_size, + template=self, + output=buf_out, + kernel=kernel, + num_thread=num_threads, + score_mod=self.score_mod, + mask_mod=self.mask_mod, + score_buf_name=self.score_buf_name, + mask_buf_name=self.mask_buf_name, + score_buf_idx=self.score_buf_idx, + mask_buf_idx=self.mask_buf_idx, + ) + with contextlib.ExitStack() as stack: + for buf in self.fake_buffers: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options) + + def codegen_softmax_fusion(self, kernel_name: str): + # TODO: use inductor IR to rewrite those fusions + return self._template_from_string(SOFTMAX_FUSIONS).render( + dict(kernel_name=kernel_name) + ) + + def codegen_brgemm_pack_function(self, kernel_name: str): + # TODO: make them general for common bmm templates + return self._template_from_string(BRGEMM_PACK_FUNCTIONS).render( + dict(kernel_name=kernel_name) + ) + + def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size): + return self._template_from_string(ALLOCATE_BUFFER).render( + dict( + buffer_name=buffer_name, + buffer_dtype=buffer_dtype, + buffer_size=buffer_size, + ) + ) + + def micro_gemm_define(self, kernel_name: str): + from torch._inductor.codegen.cpp_gemm_template import ( + CppTemplateKernel, + parallel_num_threads, + ) + from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec + from torch._inductor.virtualized import V + + micro_gemm_trans = CppMicroGemmFP32Vec( + kernel_name + "_kernel_micro_gemm_transpose_b", + self.input_dtype, + self.input_dtype, + self.accumulate_dtype, + self.accumulate_dtype, + GemmBlocking(1, 16, 1), + 1, + True, + True, + ) + + micro_gemm = CppMicroGemmFP32Vec( + kernel_name + "_kernel_micro_gemm", + self.input_dtype, + self.input_dtype, + self.accumulate_dtype, + self.accumulate_dtype, + GemmBlocking(1, 16, 1), + 1, + True, + False, + ) + + with V.set_graph_handler(V.graph): + kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads()) + code_trans = micro_gemm_trans.codegen_define(kernel) + code = micro_gemm.codegen_define(kernel) + return code + code_trans + + def codegen_micro_gemm(self, kernel_name: str): + micro_gemm = self.micro_gemm_define(kernel_name) + GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm) + return self._template_from_string(GEMM_SOURCE_CODE).render() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b6832263eb64e4d915a67f182ab51c56da1559 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_gemm_template.py @@ -0,0 +1,1777 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import math +from functools import lru_cache +from typing import Any, Callable, cast, Optional, TypeVar, Union +from unittest.mock import patch + +import torch +import torch.utils +from torch.utils._ordered_set import OrderedSet + +from ..._dynamo.utils import counters +from .. import config, ir, lowering as L +from ..kernel.mm_common import mm_args +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import ( + has_free_symbols, + is_same_mkldnn_tensor, + is_same_tensor, + parallel_num_threads, +) +from ..virtualized import ops, V +from .cpp import get_export_declaration +from .cpp_micro_gemm import ( + CppMicroBrgemm, + CppMicroGemm, + CppMicroGemmAMX, + CppMicroGemmFP32Vec, + create_micro_gemm, + is_int8_woq_gemm_small_m_dim_corner_case, + LayoutType, +) +from .cpp_template import CppTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import ( + create_epilogue_with_attr, + DTYPE_TO_CPP, + GemmBlocking, + get_gemm_template_output_and_compute_dtype, +) + + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK = r""" + constexpr int64_t num_threads = {{num_threads}}; + constexpr int64_t N = {{N}}; + constexpr int64_t K = {{K}}; + constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}}; + constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}}; + constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}}; + constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; + constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; +{%- if is_dynamic_M %} + const int64_t M = {{kernel.size(GemmOut, 0)}}; + const int64_t Mr_blocks = (M + Mr - 1) / Mr; +{%- else %} + constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; + constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; +{%- endif %} +""" + +GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED = r""" +{%- if is_dynamic_M %} + {%- if num_threads > 1 %} + int64_t Mt_blocks, Nt_blocks, Kt_blocks; + mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); + {%- else %} + const auto Mt_blocks = Mr_blocks; + const auto Nt_blocks = Nr_blocks; + const auto Kt_blocks = Kr_blocks; + {%- endif %} + int64_t Mc_blocks, Nc_blocks, Kc_blocks; + uint32_t L1_cache_size = {{L1_cache_size}}; + uint32_t L2_cache_size = {{L2_cache_size}}; + mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>( + num_threads, + M, + N, + K, + Mr, + Nr, + Kr, + Mt_blocks, + Nt_blocks, + Kt_blocks, + Mc_blocks, + Nc_blocks, + Kc_blocks, + L1_cache_size, + L2_cache_size + ); + const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- else %} + constexpr int64_t Mt_blocks = {{template.thread_blocking(num_threads).block_m}}; + constexpr int64_t Nt_blocks = {{template.thread_blocking(num_threads).block_n}}; + constexpr int64_t Kt_blocks = {{template.thread_blocking(num_threads).block_k}}; + constexpr int64_t Mc_blocks = {{template.cache_blocking(num_threads).block_m}}; + constexpr int64_t Nc_blocks = {{template.cache_blocking(num_threads).block_n}}; + constexpr int64_t Kc_blocks = {{template.cache_blocking(num_threads).block_k}}; + constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- endif %} +{%- if is_woq_int4 %} + int64_t group_size = *q_group_size; +{%- endif %} + + // make sure all partitions are assigned + {{kernel.assert_function}}( + Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks, + "Not all partitions are assigned." + ); +""" + +GEMM_TEMPLATE_MULTI_THREADS_PARAMS = r""" +const int tid = omp_get_thread_num(); +const int64_t k_group_id = tid / num_Kt_blocks; +const int64_t k_slice_id = tid % num_Kt_blocks; +const int64_t n_group_id = k_group_id / num_Nt_blocks; +const int64_t n_slice_id = k_group_id % num_Nt_blocks; +const int64_t k_block_start = k_slice_id * Kt_blocks; +const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); +const int64_t n_block_start = n_slice_id * Nt_blocks; +const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); +const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); +const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); +const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; +""" + +GEMM_TEMPLATE_SINGLE_THREAD_PARAMS = r""" +constexpr int tid = 0; +constexpr int64_t k_group_id = 0; +constexpr int64_t k_slice_id = 0; +constexpr int64_t n_group_id = 0; +constexpr int64_t n_slice_id = 0; +constexpr int64_t m_block_start = 0; +constexpr int64_t n_block_start = 0; +constexpr int64_t n_block_end = Nr_blocks; +constexpr int64_t k_block_start = 0; +constexpr int64_t k_block_end = Kr_blocks; +{%- if is_dynamic_M %} +const int64_t num_Mc_blocks_per_thread = num_Mc_blocks; +const int64_t m_block_end = Mr_blocks; +{%- else %} +constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; +constexpr int64_t m_block_end = Mr_blocks; +{%- endif %} +""" + +GEMM_TEMPLATE_M_LOOP_PARAMS = r""" +const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; +const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; +const int64_t m_start = mc * Mr; +const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); +const int64_t m_size = m_end - m_start; +""" + +GEMM_TEMPLATE_N_LOOP_PARAMS = r""" +const int64_t n_start = nc * Nr; +const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); +const int64_t n_size = n_end - n_start; +// NB: assume we pad N, nc_block_end won't exceed padded N here. +const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); +""" + +GEMM_TEMPLATE_MICROKERNEL_DEF = r""" +{{template.header().getvalue()}} + +{{micro_gemm.codegen_define(kernel)}} +""" + +GEMM_TEMPLATE_STUB_DEF = r""" +{%- if x_scale is not none %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %} +{%- elif is_woq_int4 %} + {%- set kernel_args = {"X": X, "W": W, "q_group_size": q_group_size, "qscale_and_zeros": qscale_and_zeros} %} +{%- else %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp} %} +{%- endif %} + +extern "C" {{export_declaration}} +{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}} +""" + +GEMM_TEMPLATE = r""" +{{ template.codegen_gemm_stub_def() }} +{ + {{ kernel.maybe_codegen_profile() }} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W + ) }} + +{%- if maybe_k_slicing %} + std::unique_ptr[]> local_buf_ptrs; + if (num_Kt_blocks > 1) { + local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]); + } +{%- endif %} + +{%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + {{ template.codegen_multi_threads_params()|indent(8, false) }} +{%- else %} + { + {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }} +{%- endif %} + {{ micro_gemm.codegen_init(kernel) }} +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} +{%- endif %} + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + {{ template.codegen_m_loop_params()|indent(12, false) }} + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + {{ template.codegen_n_loop_params()|indent(16, false) }} +{%- if use_local_acc %} + {%- set acc = kernel.local_buffers[acc_buf_name] %} + {{ kernel.reinit_buffer_if_null(acc_buf_name) }} +{%- else %} + {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- endif %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * Kr; + int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); +{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} + for (int64_t nci = nc; nci < nc_block_end; nci++) { +{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} +{%- if template.should_block_weights and not is_woq_int4 %} +{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} +{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} +{%- else %} + {%- if is_woq_int4 %} + {%- set tile_W = kernel.slice_nd(W, [("nci * Nr", "(nci + 1) * Nr"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %} + {%- set tile_qparam = kernel.slice_nd( + qscale_and_zeros, [("k_start // group_size", "k_end // group_size"), ("nci * Nr", "(nci + 1) * Nr"), ()]) %} + {%- else %} + {%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %} + {%- set tile_qparam = None %} + {%- endif %} +{%- endif %} + if (kc == k_block_start) { + {{ micro_gemm.codegen_call(kernel, + tile_X, + tile_W, + acc_slice, + accum=False, + qscale_and_zeros=tile_qparam)|indent(28, false) + }} + } else { + {{ micro_gemm.codegen_call(kernel, + tile_X, + tile_W, + acc_slice, + accum=True, + qscale_and_zeros=tile_qparam)|indent(28, false) + }} + } + } + } +{%- if maybe_k_slicing %} + if (num_Kt_blocks > 1) { + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset( + {{ kernel.release_buffer(acc_buf_name) }}); + } else +{%- endif %} + { +{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %} + {{ kernel.store_output( + tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- if maybe_k_slicing %} + if (num_Kt_blocks > 1) { + #pragma omp barrier + for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + // We slice M-dim and each thread in the k-slicing group works on a slice + const int64_t m_start_unsliced = mc * Mr; + const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); + const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced; + const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks; + const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced); + const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced); + const int64_t m_size = m_end - m_start; + const int64_t m_offset = m_start - m_start_unsliced; + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + const int64_t n_start = nc * Nr; + const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); + const int64_t n_size = n_end - n_start; + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get(); + for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) { + auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get(); + for (int64_t m = m_offset; m < m_offset + m_size; m++) { + #pragma omp simd + for (int64_t n = 0; n < n_size; n++) { + {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n]; + } + } + } + {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %} + {{ kernel.store_output( + tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- endif %} + {{ micro_gemm.codegen_finalize(kernel) }} + } +} +""" + +SMALL_M_GEMM_TEMPLATE = r""" +{{ template.codegen_gemm_stub_def() }} +{ + {{ kernel.maybe_codegen_profile() }} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W + ) }} + # pragma omp parallel + { + #pragma omp for nowait + for (int64_t nr_block_id = 0; nr_block_id < Nr_blocks; nr_block_id++) { + // Handle one output M * Nr block in each thread + int64_t n_start = nr_block_id * Nr; + int64_t n_end = (nr_block_id + 1) * Nr; +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_stack_allocated_buffer(acc_buf_name, ["M", "Nr"], acc_buf_dtype) }} + {%- set acc = kernel.local_buffers[acc_buf_name] %} +{%- else %} + {%- set acc = kernel.slice_nd(GemmOut, [(0, "M"), ("n_start", "n_end")]) %} +{%- endif %} + for (int64_t kr_block_id = 0; kr_block_id < Kr_blocks; kr_block_id++) { + // this loop is not parallelized + int64_t k_start = kr_block_id * Kr; + int64_t k_end = std::min((kr_block_id + 1) * Kr, K); +{%- set tile_X = kernel.slice_nd(X, [(0, "M"), ("k_start", "k_end")]) %} +{%- set tile_W_3d = kernel.slice_nd(W, [("nr_block_id", "nr_block_id + 1"), ("k_start", "k_end"), ()]) %} +{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + if C10_UNLIKELY(kr_block_id == 0) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False, prefetch=True)|indent(20, false) }} + } else if C10_UNLIKELY(k_end == K) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=False)|indent(20, false) }} + } else { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=True)|indent(20, false) }} + } + } +{%- set tile_Y = kernel.slice_nd(Y_2d, [("0", "M"), ("n_start", "n_end")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "M"), ("0", "n_end - n_start")]) %} + {{ kernel.store_output( + tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("0", "n_start"), reindexers=reindexers + )|indent(20, false) }} + } + } +} +""" + + +def _is_int8_gemm(inputs): + return ( + isinstance(inputs[0], ir.IRNode) + and inputs[0].get_dtype() in [torch.uint8, torch.int8] + ) or ( + isinstance(inputs[0], torch.Tensor) + and inputs[0].dtype in [torch.uint8, torch.int8] + ) + + +def get_padded_n(n, block_n): + return (n + block_n - 1) // block_n * block_n + + +_T = TypeVar("_T", ir.IRNode, torch.Tensor) + + +def transpose_w(W: _T, trans_w: bool) -> _T: + """ + Transpose W based on the trans_w flag. + """ + if isinstance(W, ir.IRNode): + if trans_w: + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + W = L.permute(W, [1, 0]) + else: + if trans_w: + assert isinstance(W, torch.Tensor) + W = W.transpose(0, 1) + return W + + +def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]: + """ + Expand Bias to the same size of X. + """ + if B is not None: + if isinstance(B, ir.IRNode): + if not isinstance(B, ir.TensorBox): + B = ir.TensorBox(B) + assert hasattr(X, "get_size") + B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) + else: + assert isinstance(B, torch.Tensor) + assert isinstance(X, torch.Tensor) + B = B.expand(X.shape[0], B.shape[-1]) + return B + + +def prune_tensors(input_nodes: list[ir.IRNode], new_input_nodes: list[ir.IRNode]): + """ + Prune unused tensors from `V.graph` since the GEMM Template use new packed weight. + """ + + def share_storage(base_tensor: torch.Tensor, comp_tensor: torch.Tensor): + return base_tensor.is_mkldnn == comp_tensor.is_mkldnn and ( + is_same_tensor(base_tensor, comp_tensor) + or is_same_mkldnn_tensor(base_tensor, comp_tensor) + ) + + def get_candidates(input_nodes, new_input_nodes): + # Only Constant Buffer like weight and bias might be changed in GEMM Template. + # The Inductor IR Node may changed, but still share the storage. For example: + # bias in bfloat16 case which only do the expand + return [ + node + for node in input_nodes + if ( + node not in new_input_nodes + and isinstance(node, (ir.TensorBox, ir.StorageBox)) + and node.get_name() in V.graph.constants + and not any( + ( + isinstance(new_node, (ir.TensorBox, ir.StorageBox)) + and new_node.get_name() in V.graph.constants + and share_storage( + V.graph.constants[node.get_name()], + V.graph.constants[new_node.get_name()], + ) + ) + for new_node in new_input_nodes + ) + ) + ] + + for candidate_node in get_candidates(input_nodes, new_input_nodes): + # By using the new packed weight for the GEMM template, we can prune the + # old weight if it has no other users. This saves memory but makes the FX graph + # non-retraceable. To support retracing, we can add a repack node to the + # FX graph. For example: + # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template + candidate_tensor_users = 0 + candidate_tensor = V.graph.constants[candidate_node.get_name()] + for node in reversed(V.graph.graph.nodes): + # Case may happen when the candidate tensor is used by more than 1 get_attr node + # https://github.com/pytorch/pytorch/issues/134998 + if node.op == "get_attr" and hasattr( + V.graph.module, node.target + ): # candidate tensor might already be deleted + comp_tensor = getattr(V.graph.module, node.target) + if isinstance(comp_tensor, torch.Tensor) and share_storage( + candidate_tensor, comp_tensor + ): + candidate_tensor_users += 1 + + for node in reversed(V.graph.graph.nodes): + # The get_attr node has only 1 user fx node + # The candidate tensor has been used by only 1 get_attr node + if ( + node.op == "get_attr" + and node.target == candidate_node.get_name() + and len(node.users) == 1 + and candidate_tensor_users == 1 + ): + del V.graph.constants[node.target] + delattr(V.graph.module, node.target) + delattr(V.graph.graph.owning_module, node.target) + counters["inductor"]["select_algorithm_weight_prune"] += 1 + + +def gen_2d_view_of_epilogue_buf( + Y: ir.Buffer, + template_buffer: ir.Buffer, + epilogue_nodes: list[ir.IRNode], + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], + default_reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], +) -> tuple[ + Union[ir.Buffer, ir.ReinterpretView], + list[Optional[Callable[[list[Any]], list[Any]]]], +]: + """ + The dimension and the indexing could be different between the GEMM output, i.e. `template_buffer`, which is + 2D with MxN) and the output from the template after epilogues, i.e. `Y`. In the GEMM template code, + we are not aware of the dimension and the indexing of the epilogues and always work on 2D tiles according to + the indexing of the GEMM output. + In this function, we return a 2D buffer (`Y_2d`) according to GEMM output (reinterpreted from `Y` if needed) and + build a reindexer that converts the indexing of `Y` into `Y_2d`. + """ + Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y + if ( + Y.get_size() == template_buffer.get_size() + and Y.get_stride() == template_buffer.get_stride() + ): + reindexers.extend(default_reindexers) + Y_2d = Y + else: + + def get_reindexer(epilogue_node, default_reindexer=None): + # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example: + # template_buffer: + # size (324, 512), stride (512, 1) + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + stride_order = list( + ir.get_stride_order( + V.graph.sizevars.size_hints(epilogue_node.get_stride()) + ) + ) + fill_order = ir.stride_order2fill_order(stride_order) + reversed_fill_order = list(reversed(fill_order)) + size_with_stride_ordered_decreasingly = [ + epilogue_node.get_size()[i] for i in reversed_fill_order + ] + reshape_reindex = ir.View.dynamic_reshape_indexer( + size_with_stride_ordered_decreasingly, + template_buffer.get_size(), + ) + if default_reindexer: + reshape_reindex = ir.fuse_reindexing(reshape_reindex, default_reindexer) + + # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example: + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + # epilogue_node: + # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) + from_stride_ordered_decreasingly_to_epilogue_node_order = [ + (len(stride_order) - 1) - stride_order[i] + for i in range(len(stride_order)) + ] + stride_reindex = ir.same_reorder( + from_stride_ordered_decreasingly_to_epilogue_node_order + ) + + reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) # type: ignore[var-annotated] + return reindexer + + if default_reindexers is None: + default_reindexers = [None] * len(epilogue_nodes) + new_reindexers = [ + get_reindexer(epilogue_node, default_reindexer) + for epilogue_node, default_reindexer in zip( + epilogue_nodes, default_reindexers + ) + ] + reindexers.extend(new_reindexers) + if isinstance(Y, ir.BaseView): + storage = ir.StorageBox(Y.unwrap_view()) + else: + assert isinstance(Y, ir.Buffer) + storage = ir.StorageBox(Y) + Y_2d = ir.ReinterpretView(data=storage, layout=template_buffer.get_layout()) + return Y_2d, reindexers + + +class CppGemmTemplate(CppTemplate): + """ + GEMM Template for Inductor CPP Backend. + """ + + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + has_bias=False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + should_block_weights: bool = True, + name="packed_gemm", + ) -> None: + assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8] + super().__init__( + name, + input_nodes, + layout, + num_threads, + epilogue_creator=epilogue_creator, + ) + self.beta = beta + self.alpha = alpha + self.has_bias = has_bias + self.register_blocking = register_blocking + m, n = layout.size[-2:] + k = input_nodes[0].get_size()[-1] + self.m, self.n, self.k = m, n, k + self.padded_n = get_padded_n(n, self.register_blocking.block_n) + self.is_dynamic_M = has_free_symbols((m,)) + self.should_block_weights = should_block_weights + self.thread_blocking = self.make_thread_blocking_cache() + self.cache_blocking = self.make_cache_blocking_cache() + + def make_thread_blocking_cache(self): + cache = lru_cache()(self._thread_blocking) + + def thread_blocking(num_threads: int) -> GemmBlocking: + return cache(num_threads) + + return thread_blocking + + def _thread_blocking(self, num_threads: int) -> GemmBlocking: + """ + NOTE [Thread blocking in Cpp GEMM] + We use simple heuristics to decide the thread blocking: + 1. Make sure all threads are occupied as much as possible. + 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse. + 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing. + TODO(jgong5): allow tuning various blocking options + """ + + def get_factors(number): + factors = [] + for i in range(int(number**0.5), 0, -1): + if number % i == 0: + factors.append(number // i) + factors.append(i) + return factors + + def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks): + thread_block_k = math.ceil(k_blocks / k_factor) + thread_block_n = math.ceil(n_blocks / n_factor) + thread_block_m = math.ceil(m_blocks / m_factor) + return GemmBlocking(thread_block_m, thread_block_n, thread_block_k) + + assert not self.is_dynamic_M, ( + "Unable to determine thread blocking for dynamic M." + ) + register_blocking = self.register_blocking + m_blocks = math.ceil(self.m / register_blocking.block_m) + n_blocks = math.ceil(self.n / register_blocking.block_n) + k_blocks = math.ceil(self.k / register_blocking.block_k) + factors = get_factors(num_threads) + assert len(factors) > 0 + + if config.cpp.gemm_thread_factors is not None: + factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")] + assert len(factors) == 3 + assert math.prod(factors) == self.num_threads + return get_blocking( + factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks + ) + + # we favor square-sized thread blocks for good data reuse + def get_better_blocking(blocking, best_blocking): + if best_blocking is None: + best_blocking = blocking + else: + block_m_size = blocking.block_m * register_blocking.block_m + block_n_size = blocking.block_n * register_blocking.block_n + best_block_m_size = best_blocking.block_m * register_blocking.block_m + best_block_n_size = best_blocking.block_n * register_blocking.block_n + if blocking.block_k > best_blocking.block_k: + best_blocking = blocking + elif ( + blocking.block_k == best_blocking.block_k + and block_m_size + block_n_size + < best_block_m_size + best_block_n_size + ): + best_blocking = blocking + return best_blocking + + best_blocking = None + # check if we can have a thread-blocking to occupy all threads without k-slicing + for n_factor in factors: + m_factor = num_threads // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for k_factor in factors: + if k_blocks >= k_factor and ( + config.cpp.gemm_max_k_slices == 0 + or k_factor <= config.cpp.gemm_max_k_slices + ): + n_factors = get_factors(num_threads // k_factor) + for n_factor in n_factors: + m_factor = (num_threads // k_factor) // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, + n_factor, + k_factor, + m_blocks, + n_blocks, + k_blocks, + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for n_factor in factors: + m_factor = num_threads // n_factor + if n_blocks >= n_factor or m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + assert best_blocking is not None + return best_blocking + + def make_cache_blocking_cache(self): + cache = lru_cache()(self._cache_blocking) + + def cache_blocking(num_threads: int) -> GemmBlocking: + return cache(num_threads) + + return cache_blocking + + def _cache_blocking(self, num_threads: int) -> GemmBlocking: + def get_cache_blocking(register_blocking, thread_blocking): + Mr = register_blocking.block_m + Nr = register_blocking.block_n + Kr = register_blocking.block_k + + Mt_blocks = thread_blocking.block_m + Nt_blocks = thread_blocking.block_n + Kt_blocks = thread_blocking.block_k + + if config.cpp.gemm_cache_blocking is not None: + blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")] + assert len(blockings) == 3 + Mc_blocks, Nc_blocks, Kc_blocks = blockings + return ( + min(Mc_blocks, Mt_blocks), + min(Nc_blocks, Nt_blocks), + min(Kc_blocks, Kt_blocks), + ) + + # The ratios below are empirically determined to decide + # the effective sizes of L1 and L2. + # TODO: tune the factor here + L1_limit_factor = 0.8 + L2_limit_factor = 0.5 + + L1_cache_size = ( + torch._C._cpu._L1d_cache_size() + ) # per core cache size in Bytes + assert L1_cache_size > 0, ( + f"Expect L1_cache_size > 0 but got {L1_cache_size}" + ) + L1 = L1_cache_size * L1_limit_factor + + L2_cache_size = ( + torch._C._cpu._L2_cache_size() + ) # per core cache size in Bytes + assert L2_cache_size > 0, ( + f"Expect L2_cache_size > 0 but got {L2_cache_size}" + ) + L2 = L2_cache_size * L2_limit_factor + + def get_num_byte(dtype): + return torch.tensor([], dtype=dtype).element_size() + + dtype_A = self.input_nodes[0].get_dtype() + dtype_B = self.input_nodes[1].get_dtype() + num_byte_A = get_num_byte(dtype_A) + num_byte_B = get_num_byte(dtype_B) + if dtype_A is torch.bfloat16 and dtype_B is torch.int8 and Kr != 1: + # We will cache dequantized weights (BF16) in L1D for AMX micro-kernel. + # In this case, the choice of the micro-kernel being used can't be decoupled from + # the cache blocking. + # TODO: Decouple the choice of micro-kernel from cache blocking + num_byte_B *= num_byte_A + + # NOTE [CPP GEMM Cache Blocking Algorithm] + # Our overall strategy is to + # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc. + # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to + # decide Kc. We want to make Mc large enough to better reuse B. + # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A + # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc. + + # Step 1: Decide Kc assuming B block is L1-reside. + size_cache_B = Kr * Kt_blocks * Nr * num_byte_B + + Kc_blocks = Kt_blocks + if size_cache_B > L1: + Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B)) + + if ( + config.cpp.use_small_dequant_buffer + and dtype_A is torch.bfloat16 + and dtype_B is torch.uint8 + and Mt_blocks == 1 + ): + # Make a small dequant_B buffer for woq int4 [q_group_size, Nr] + # Since when Mt_blocks == 1, L1-reside B block can't be reused by A. + if Kc_blocks * Kr >= self.q_group_size(): + Kc_blocks = self.q_group_size() // Kr + + # Step 2: Decide Mc assuming A block is L2-reside. + min_Mc_ratio = 2 # TODO(jgong5): something to tune? + min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr) + assert min_Mc_blocks >= 1 + Kt_bytes = Kt_blocks * Kr * num_byte_A + if min_Mc_blocks * Mr * Kt_bytes < L2: + # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt + # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks) + # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside + # in L1. + Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes))) + Nc_blocks = 1 + else: + # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse + # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2. + Mc_blocks = Mt_blocks + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32 + Kc_bytes = Kc_blocks * Kr * num_byte_A + if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2: + # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2, + # assuming Mc == Nc for good data reuse. + M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8 + if M_max < Mc_blocks * Mr: + Mc_blocks = math.floor(M_max / Mr) + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + + return Mc_blocks, Nc_blocks, Kc_blocks + + assert not self.is_dynamic_M, ( + "Unable to determine cache blocking for dynamic M." + ) + register_blocking = self.register_blocking + thread_blocking = self.thread_blocking(num_threads) + + return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking)) + + def log_blockings(self): + log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004 + if self.is_dynamic_M: + # thread and cache blockings are determined at runtime for dynamic shapes + return + log.debug( + f"Cache blocking: {self.cache_blocking(self.num_threads)}" # noqa: G004 + ) + thread_blocking = self.thread_blocking(self.num_threads) + log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004 + + def get_occupancy(): + m_blocks = math.ceil(self.m / self.register_blocking.block_m) + n_blocks = math.ceil(self.n / self.register_blocking.block_n) + k_blocks = math.ceil(self.k / self.register_blocking.block_k) + m = math.ceil(m_blocks / thread_blocking.block_m) + n = math.ceil(n_blocks / thread_blocking.block_n) + k = math.ceil(k_blocks / thread_blocking.block_k) + return (m, n, k) + + log.debug( + f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004 + ) + + def maybe_k_slicing(self): + if self.num_threads == 1: + return False + if self.is_dynamic_M: + # TODO(jgong5): perhaps use size hint to decide? + return True + register_blocking = self.register_blocking + k_blocks = math.ceil(self.k / register_blocking.block_k) + thread_blocking = self.thread_blocking(self.num_threads) + return k_blocks > thread_blocking.block_k + + @classmethod + def add_choices( + cls, + choices, + layout, + input_nodes, + beta=1, + alpha=1, + has_bias=False, + trans_w=False, + input_indices=None, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, + ): + """ + Add choices for the GEMM template. + """ + # Fast path to save the epilogue calculation when x_scale/x_zp/w_scale are constant + use_int8_fast_compensation_path = _is_int8_gemm(input_nodes) and all( + ( + isinstance(input_nodes[idx], ir.TensorBox) + and isinstance(input_nodes[idx].data.data, ir.ConstantBuffer) + ) + for idx in [1, 2, 4] + ) + + if input_indices is None: + input_indices = list(range(len(input_nodes))) + only_one_input = ( + input_nodes[0] == input_nodes[1] if len(input_nodes) > 1 else False + ) + + def reorder_and_filter(inputs, layout_or_out): + if has_bias: + assert len(input_indices) >= 3 + # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp] + inp_idx = input_indices[0] + x_idx = input_indices[1] + w_idx = input_indices[2] + return [ + inputs[x_idx], + inputs[w_idx], + inputs[inp_idx], + *[inputs[idx] for idx in input_indices[3:]], + ], layout_or_out + elif len(inputs) >= len(input_indices): + assert len(input_indices) >= 2 + return [inputs[idx] for idx in input_indices], layout_or_out + else: + # For when input is used for x and w, i.e. X@X.T or similar + # Assumes the first input is the only input + assert len(inputs) == 1 + return [inputs[0]] * len(input_indices), layout_or_out + + new_inputs, new_layout = reorder_and_filter(input_nodes, layout) + is_mkldnn_wgt = ( + new_inputs[1].get_name() in V.graph.constants + and V.graph.constants[new_inputs[1].get_name()].is_mkldnn + ) + if is_mkldnn_wgt: + # It shouldn't happen as viewing an mkldnn tensor, we can extend the + # implementation if it does. + assert not isinstance(new_inputs[1], ir.BaseView) + # Note that the layout of MKLDNN Tensor is with the wrong stride + view_size = new_inputs[1].layout.size + view_stride = new_inputs[1].layout.stride + view_offset = new_inputs[1].layout.offset + + def maybe_to_dense(inputs, layout_or_out): + new_inputs = list(inputs) + if isinstance(inputs[1], torch.Tensor): + W = inputs[1] + new_inputs[1] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes(inputs, layout_or_out): + new_inputs = list(inputs) + if not is_mkldnn_wgt and isinstance(new_inputs[1], torch.Tensor): + if has_free_symbols(view_size): + # If batch size B is dynamic, we need to set the batch size and possibly stride + assert not has_free_symbols(view_size[1:]) + view_size[:] = V.graph.sizevars.size_hints(view_size) + view_stride[:] = V.graph.sizevars.size_hints(view_stride) + # With the assumptation that W is the storage of unwrap view + # thus view it back here + new_inputs[1] = new_inputs[1].as_strided( + view_size, view_stride, view_offset + ) + + if not trans_w: + return new_inputs, layout_or_out + X = new_inputs[0] + W = new_inputs[1] + B = new_inputs[2] if has_bias else None + W = transpose_w(W, trans_w) + B = expand_bias(B, X) # type:ignore[arg-type] + new_inputs[1] = W + if B is not None: + new_inputs[2] = B + return new_inputs, layout_or_out + + # TODO(jgong5): decide proper number of threads per problem size + num_threads = parallel_num_threads() + new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout)) + m, n, k, *_ = mm_args( + new_inputs[0], + new_inputs[1], + mat2_transposed=cls.is_woq_int4(), + use_4x2_dim=cls.is_woq_int4(), + ) + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + new_inputs[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=new_inputs[0].get_dtype(), + input2_dtype=new_inputs[1].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=alpha, + num_threads=num_threads, + use_ref=not cls.is_woq_int4(), + q_group_size=cls.q_group_size(), + ) + assert micro_gemm is not None + pre_block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm) + micro_gemm.use_local_vnni_blocking(not pre_block_weights) + + def preprocessor(inputs, layout): + new_inputs, new_layout = normalize_shapes( + *maybe_to_dense(*reorder_and_filter(inputs, layout)) + ) + if only_one_input and isinstance(new_inputs[0], torch.Tensor): + return new_inputs[1:], new_layout + return cls.prep_weight( + new_inputs, + new_layout, + micro_gemm, + pre_block_weights, + use_int8_fast_compensation_path, + ) + + def postprocessor(output): + if isinstance(output, ir.TensorBox): + # prepack the weight as input to the template buffer + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + + W_node = new_input_nodes[1] + if W_node.get_name() not in V.graph.constants: + return output + W = V.graph.constants[W_node.get_name()] + new_input_nodes[1] = W + new_input_nodes, new_layout = normalize_shapes( + *maybe_to_dense(new_input_nodes, layout) + ) + new_input_nodes, _ = cls.prep_weight( + new_input_nodes, + new_layout, + micro_gemm, + pre_block_weights, + use_int8_fast_compensation_path, + skip_int8_compensation=True, + ) + W_packed = new_input_nodes[1] + W_packed_constant = V.graph.add_tensor_constant(W_packed) + new_input_nodes[1] = W_packed_constant + + # Prune unused tensors + prune_tensors(input_nodes, new_input_nodes) + + template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( + W_packed_constant + ) + return output + + template = DataProcessorTemplateWrapper( + cls, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + should_block_weights=pre_block_weights, + name=micro_gemm.__class__.__name__, + ) + template.maybe_append_choice(choices) + return template + + @staticmethod + def get_padded_size(n, block_n, k, should_block_weight): + padded_n = get_padded_n(n, block_n) + # We assume that all GEMM weight tensors should be blocked and padded + new_size = [padded_n // block_n, k, block_n] + return new_size, padded_n + + @classmethod + def prep_weight( + cls, + inputs, + layout: ir.Layout, + micro_gemm: CppMicroGemm, + should_block_weight: bool, + use_int8_fast_compensation_path: bool = False, + skip_int8_compensation: bool = False, + ): + """ + NOTE Weight prep consists of 2 separate steps: + 1. Blocking the weight tensor into a 3D shape: [n//block_n, k, block_n] + This is always done if the weight tensor is constant, i.e. for all GEMM and some BMM. + For BMM, we also block non-contiguous weight tensors, since they would be reshaped anyway. + This assumes that blocked, contiguous weights will be more efficient for the GEMM kernel, + and is worth the overhead of reshape and blocking. + + This blocking includes additional padding, when n is not a multiple of block_n. + This padding allows a more efficient microkernel implementation. For BMM, this is only done + if reshape would happen anyway, i.e. if the weight tensor is constant, is not contiguous, + or is using AMX VNNI layout. + 2. Packing the weight tensor into a VNNI-friendly shape. For constant input, + this is done at the same time as the weight blocking. + + At compile time, the constant weight tensors are blocked and packed. For non-constant tensors (e.g. BMM) + which will be blocked (non-contiguous or VNNI-layout tensors), the weight tensor is blocked and packed at runtime. + + CppBmmTemplate overrides the methods get_padded_size, and block_weight in order to accommodate + an additional dimension for the batch size and to determine if the weight tensor should be blocked. + """ + W = inputs[1] + new_inputs = list(inputs) + if cls.is_woq_int4(): + assert ( + len(W.get_size()) == 2 + if isinstance(W, ir.IRNode) + else len(W.shape) == 2 + ) + n, k = W.get_size() if isinstance(W, ir.IRNode) else W.shape + else: + k, n = W.get_size()[-2:] if isinstance(W, ir.IRNode) else W.shape[-2:] + _, block_n, _ = micro_gemm.register_blocking + new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight) + padding = padded_n - n + + if should_block_weight and not cls.is_woq_int4(): + blocked_w = cls.block_weight(W, new_size, padding) + new_inputs[1] = cls.pack_vnni_weight(blocked_w, micro_gemm, new_size) + elif should_block_weight: + assert cls.is_woq_int4() + new_inputs[1] = cls.block_weight(W, new_size, padding) + elif isinstance(W, ir.IRNode): + # Require W layout to be fixed & contiguous, happens inplace. + ir.ExternKernel.require_contiguous(W) + + if not skip_int8_compensation and _is_int8_gemm(new_inputs): + BCompensate = None + x_w_scale = None + + def _get_compensation_node(W, use_int8_fast_compensation_path): + BCompensate = V.graph.add_tensor_constant( + V.graph.constants[W.get_name() + "_BMatrixCompens"], + W.get_name() + "_BMatrixCompens", + ) + x_w_scale = None + if use_int8_fast_compensation_path: + x_w_scale = V.graph.add_tensor_constant( + V.graph.constants[W.get_name() + "_x_w_compens"], + W.get_name() + "_x_w_compens", + ) + return BCompensate, x_w_scale + + if use_int8_fast_compensation_path: + # new_inputs has been reordered: [x, w, optional[bias], x_scale, x_zp, w_scale, w_zp] + x_scale = new_inputs[-4] + x_zp = new_inputs[-3] + w_scale = new_inputs[-2] + if isinstance(W, ir.IRNode): + BCompensate, x_w_scale = _get_compensation_node( + W, use_int8_fast_compensation_path + ) + else: + # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate + BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] + assert all( + isinstance(item, torch.Tensor) + for item in (x_scale, x_zp, w_scale) + ) + BCompensate = BCompensate * x_scale * w_scale * x_zp + x_w_scale = x_scale * w_scale + new_inputs.append(BCompensate) + new_inputs.append(x_w_scale) + else: + if isinstance(W, ir.IRNode): + BCompensate, _ = _get_compensation_node( + W, use_int8_fast_compensation_path + ) + else: + # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate + BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] + new_inputs.append(BCompensate) + return new_inputs, layout + + @staticmethod + def check_if_block_weight(W, micro_gemm): + return True + + @classmethod + def block_weight(cls, W, new_size, padding): + # These are separated into two methods to allow subclasses to override them separately + if isinstance(W, ir.IRNode): + if W.get_name() in V.graph.constants: + # Create a new buffer, representing the constant blocked tensor + blocked_w = ir.Buffer( + name=W.get_name(), # Borrow the registered buffer name + layout=ir.FixedLayout( + W.get_device_or_error(), + W.get_dtype(), + new_size, + ir.FlexibleLayout.contiguous_strides(new_size), + 0, + ), + ) + else: + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + permute_dims = list(range(len(new_size))) + permute_dims[-2], permute_dims[-3] = permute_dims[-3], permute_dims[-2] + permute_size = list(new_size) + permute_size[-2], permute_size[-3] = permute_size[-3], permute_size[-2] + blocked_w = L.constant_pad_nd(W, (0, padding)) + blocked_w = L.permute( + L.view(blocked_w, permute_size), + permute_dims, + ) + else: + assert isinstance(W, torch.Tensor) + # Pad the weight tensor and reshape it into a 3D blocked shape + blocked_size = list(new_size) + blocked_size[-2], blocked_size[-3] = blocked_size[-3], blocked_size[-2] + blocked_w = ( + torch.nn.functional.pad(W, (0, padding)) # type: ignore[assignment] + .reshape(*blocked_size) + .transpose(-3, -2) + .contiguous() + ) + return blocked_w + + @classmethod + def pack_vnni_weight(cls, W, micro_gemm, new_size): + # WOQ INT4 weights are reordered in microkernel so do not pack them here + should_pack = ( + micro_gemm.get_b_layout() != LayoutType.NORMAL + and not micro_gemm.is_woq_int4() + ) + + # These are separated into two methods to allow subclasses to override them separately + if isinstance(W, ir.IRNode): + if isinstance(W, ir.Buffer) and W.get_name() in V.graph.constants: + return W + k = new_size[-2] + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + if should_pack: + permute_dims = list(range(len(new_size) + 1)) + permute_dims[-1], permute_dims[-2] = permute_dims[-2], permute_dims[-1] + vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 + vnni_view_size = list(new_size) + vnni_view_size[-2] = k // vnni_size + vnni_view_size.insert(-1, vnni_size) + W = L.view( + L.permute(L.view(W, vnni_view_size), permute_dims), + new_size, + ) + W = ir.ExternKernel.realize_input(W) + W = ir.ExternKernel.require_contiguous(W) + return W + else: + k = new_size[-2] + # Apply VNNI packing to the weight tensor + if should_pack: + # TODO: Move VNNI weight packing for non-constant tensors into the template, + # to improve cache locality and avoid full-tensor copy. + layout_str = ( + "VNNI4" + if micro_gemm.get_b_layout() == LayoutType.VNNI4 + else "VNNI2" + ) + assert micro_gemm.get_b_layout() in [ + LayoutType.VNNI2, + LayoutType.VNNI4, + ], f"We only support {layout_str} for now" + vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 + assert k % vnni_size == 0, ( + f"k should be divisible by vnni_size for {layout_str} layout" + ) + vnni_view_size = list(new_size) + vnni_view_size[-2] = k // vnni_size + vnni_view_size.insert(-1, vnni_size) + W = W.view(vnni_view_size).transpose(-1, -2).contiguous().view(new_size) + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + new_stride = [1] + for sz in reversed(W.shape[1:]): + new_stride.insert(0, new_stride[0] * sz) + W = W.as_strided(W.shape, new_stride) + return W + + def get_default_reindexers(self, epilogue_nodes): + return [None] * len(epilogue_nodes) + + def get_options( + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + ) -> dict[str, Any]: + assert len(self.input_nodes) >= 2 + + int8_gemm = self.input_nodes[0].get_dtype() in [torch.uint8, torch.int8] + x_scale = None + x_zp = None + w_scale = None + w_zp = None + inp = None + q_group_size_node = None + qscale_and_zeros = None + if int8_gemm: + X, W = self.input_nodes[0], self.input_nodes[1] + bias_idx = 2 if self.has_bias else 1 + inp = self.input_nodes[bias_idx] if self.has_bias else None + x_scale = self.input_nodes[bias_idx + 1] + x_zp = self.input_nodes[bias_idx + 2] + w_scale = self.input_nodes[bias_idx + 3] + w_zp = self.input_nodes[bias_idx + 4] + Y = self.output_node + elif self.is_woq_int4(): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + q_group_size_node = self.input_nodes[2] + qscale_and_zeros = self.input_nodes[3] + else: + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + inp = self.input_nodes[2] if self.has_bias else None + + template_buffer_has_other_users = None + + if template_buffer_node is not None: + # Use the updated prepacked weight buffer + W = template_buffer_node.inputs[1] + Y = template_buffer_node + + assert flag_template_buffer_has_other_users is not None + template_buffer_has_other_users = flag_template_buffer_has_other_users + + template_buffer = Y + gemm_output_buffer = template_buffer + + epilogues: list[ir.IRNode] = [] + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = [] + epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = [] + fake_buffers: list[ir.Buffer] = [] + Y_aliases: OrderedSet[str] = OrderedSet() + + use_local_acc = ( + self.layout.dtype != torch.float + or template_buffer_has_other_users + or int8_gemm + or self.padded_n != self.n + or self.maybe_k_slicing() + or (epilogue_nodes and epilogue_nodes[-1].get_dtype() != self.layout.dtype) + ) + + # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template, + # but we'd better move it here to align with fp. + if inp is not None and self.beta != 0 and not int8_gemm: + # add an epilogue for bias add + def _bias_add_epilogue(buf): + return create_epilogue_with_attr( + buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype + ) + + epilogue_creators.append(_bias_add_epilogue) + + if self.epilogue_creator is not None: + epilogue_creators.append(self.epilogue_creator) + + # When the GEMM output buffer is localized but it has users other than the epilogue nodes, + # we need to copy the value in the GEMM output local buffer to a global buffer. + def need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + # The GEMM output buffer is a global buffer, thus copy is not needed. + if not use_local_acc: + return False + + # The possible value of template_buffer_has_other_users is (None, False, True) + # It is None when generating the gemm template during autotune and it will have value during scheduler codegen. + # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases: + # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune) + # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the + # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value). + if not template_buffer_has_other_users: + return False + + # When bias is not None or self.epilogue_creator is not None, + # there will be epilogue_creators after the GEMM. + # The GEMM output buffer is localized while + # the output buffer of the epilogue_creators is a global buffer. + if epilogue_creators: + return False + + return True + + if need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + + def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer): + dtype = self.layout.dtype + input_loader = input_buffer.make_loader() + + def copy_inner(index): + input = input_loader(index) + result = ops.to_dtype(input, dtype) + return result + + return ir.Pointwise( + device=input_buffer.get_device_or_error(), + dtype=self.layout.dtype, + inner_fn=copy_inner, + ranges=input_buffer.get_size(), + ) + + epilogue_creators.append(copy_from_local_to_global_buffer_epilogue) + + # NOTE [How CPP GEMM template epilogues are organized] + # gemm_output_buffer + # --> zero or more in-template epilogues (created by `epilogue_creators`) --> + # template_buffer + # --> zero or more out-of-template epilogues (`epilogue_nodes`) --> + # Y + if epilogue_creators: + assert isinstance(template_buffer, ir.IRNode) + gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + gemm_output_buffer = ir.Buffer( + name=gemm_output_name, layout=template_buffer.layout + ) + current_input_buffer = gemm_output_buffer + for i, creator in enumerate(epilogue_creators): + if i == len(epilogue_creators) - 1: + buffer_name = template_buffer.get_name() + else: + buffer_name = f"{gemm_output_name}_epilogue_{i}" + epilogues.append( + ir.ComputedBuffer( + name=buffer_name, + layout=template_buffer.layout, + data=creator(current_input_buffer), + ) + ) + fake_buffers.append(current_input_buffer) + Y_aliases.add(current_input_buffer.get_name()) + reindexers.append(None) + if i < len(epilogue_creators) - 1: + current_input_buffer = ir.Buffer( + name=buffer_name, layout=template_buffer.layout + ) + + assert isinstance(Y, (ir.Buffer, ir.ReinterpretView)) + Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y + + if epilogue_nodes: + if not template_buffer_has_other_users: + assert isinstance(template_buffer, ir.IRNode) + Y_aliases.add(template_buffer.get_name()) + epilogues.extend(epilogue_nodes) + assert Y.get_numel() == epilogues[-1].get_numel() + Y = cast(ir.Buffer, epilogues[-1]) + assert isinstance(template_buffer, ir.Buffer) + Y_2d, reindexers = gen_2d_view_of_epilogue_buf( + Y, + template_buffer, + epilogue_nodes, + reindexers, + default_reindexers=self.get_default_reindexers(epilogue_nodes), + ) + + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + X.get_dtype() + ) + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + input_dtype=X.get_dtype(), + input2_dtype=W.get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=self.alpha, + num_threads=self.num_threads, + use_ref=not self.is_woq_int4(), + q_group_size=self.q_group_size(), + ) + assert micro_gemm is not None + micro_gemm.use_local_vnni_blocking(not self.should_block_weights) + assert self.register_blocking == micro_gemm.register_blocking + self.log_blockings() + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + if isinstance(micro_gemm, CppMicroBrgemm): + counters["inductor"]["cpp_micro_brgemm_counter"] += 1 + + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + + options = dict( + X=X, + W=W, + inp=inp, + Y=Y, + N=self.n, + K=self.k, + PADDED_N=self.padded_n, + GemmOut=gemm_output_buffer, + aliases={alias: Y.get_name() for alias in Y_aliases}, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + export_declaration=get_export_declaration(), + epilogue_nodes=epilogues, + reindexers=reindexers, + Y_2d=Y_2d, + use_local_acc=use_local_acc, + maybe_k_slicing=self.maybe_k_slicing(), + x_scale=x_scale, + x_zp=x_zp, + w_scale=w_scale, + w_zp=w_zp, + acc_buf_dtype=torch.int32 if int8_gemm else torch.float, + DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, + fake_buffers=fake_buffers, + is_woq_int4=self.is_woq_int4(), + q_group_size=q_group_size_node, + qscale_and_zeros=qscale_and_zeros, + ) + return options + + def is_int8_woq_gemm_small_m_dim( + self, + X: ir.ReinterpretView, + W: ir.ReinterpretView, + N, + K, + micro_gemm, + ): + """Use SMALL_M_GEMM_TEMPLATE""" + return ( + isinstance(micro_gemm, CppMicroGemmFP32Vec) + and is_int8_woq_gemm_small_m_dim_corner_case( + micro_gemm, X.get_size()[0], N, K + ) + and X.get_dtype() is torch.bfloat16 + and W.get_dtype() is torch.int8 + ) + + def render( # type: ignore[override, return] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + options = self.get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + ) + self.render_options = options + + with contextlib.ExitStack() as stack: + for buf in options["fake_buffers"]: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + if not options["is_dynamic_M"] and self.is_int8_woq_gemm_small_m_dim( + options["X"], + options["W"], + options["N"], + options["K"], + options["micro_gemm"], + ): + template_str = SMALL_M_GEMM_TEMPLATE + else: + template_str = GEMM_TEMPLATE + return self._template_from_string(template_str).render(**options) + + def codegen_blocks( + self, + num_threads, + N, + K, + micro_gemm, + is_dynamic_M, + kernel, + GemmOut, + config, + L1_cache_size, + L2_cache_size, + X, + W, + ): + options = dict( + num_threads=num_threads, + N=N, + K=K, + micro_gemm=micro_gemm, + is_dynamic_M=is_dynamic_M, + kernel=kernel, + GemmOut=GemmOut, + config=config, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + template=self, + X=X, + W=W, + is_woq_int4=self.is_woq_int4(), + ) + template_str = GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK + if not ( + not is_dynamic_M + and self.is_int8_woq_gemm_small_m_dim(X, W, N, K, micro_gemm) + ): + template_str += GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED + return self._template_from_string(template_str).render(options) + + def codegen_microkernel_def(self): + return self._template_from_string(GEMM_TEMPLATE_MICROKERNEL_DEF).render( + self.render_options + ) + + def codegen_gemm_stub_def(self): + microkernel = self.codegen_microkernel_def() + return microkernel + self._template_from_string(GEMM_TEMPLATE_STUB_DEF).render( + self.render_options + ) + + def codegen_multi_threads_params(self): + return self._template_from_string(GEMM_TEMPLATE_MULTI_THREADS_PARAMS).render() + + def codegen_single_thread_params(self, is_dynamic_M): + options = dict( + is_dynamic_M=is_dynamic_M, + ) + return self._template_from_string(GEMM_TEMPLATE_SINGLE_THREAD_PARAMS).render( + options + ) + + def codegen_m_loop_params(self): + return self._template_from_string(GEMM_TEMPLATE_M_LOOP_PARAMS).render() + + def codegen_n_loop_params(self): + return self._template_from_string(GEMM_TEMPLATE_N_LOOP_PARAMS).render() + + @classmethod + def is_woq_int4(cls): + return False + + @classmethod + def q_group_size(cls): + return None + + +class CppWoqInt4GemmTemplateMeta(type): + def __getitem__(cls, q_group_size): + class CppWoqInt4GemmTemplateInstance(CppGemmTemplate): + def __init__( + self, + *args, + **kwargs, + ) -> None: + super().__init__( + *args, + **kwargs, + ) + + @classmethod + def is_woq_int4(cls): + return True + + @classmethod + def q_group_size(cls): + return q_group_size + + @staticmethod + def check_if_block_weight(W, micro_gemm): + # For WOQ INT4, weight is already packed + # However, for AMX microkernel, we want to change the blocking of weight + from .cpp_micro_gemm import CppMicroGemmWoQInt4Amx + + return isinstance(micro_gemm, CppMicroGemmWoQInt4Amx) + + @classmethod + def block_weight(cls, W, new_size, padding): + # This method is called only if AMX microkernels are used. + # In this case, we unpack and repack weight so that block_n=32 + # the format of packed weight is described here: + # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 + if isinstance(W, ir.IRNode): + # in this case, we do nothing + ir.ExternKernel.require_contiguous(W) + blocked_w = W + else: + # in this case, we unpack and repack weight + assert isinstance(W, torch.Tensor) + assert W.dim() == 2 + N = W.size(0) + K = W.size(-1) * 2 + G = cls.q_group_size() + # x and qscales_and_zeros are in bfloat16 instead of float to use the optimized kernel + # so that the unpacking process is faster + x = torch.eye(K).bfloat16() + # Here we use scale=1 and qzero=8 because we want to unpack weight + # without dequantizing it. The qzero here is 8 instead of 0 because + # int4 values are converted to [-7, 8] in the _weight_int4pack_mm_for_cpu kernel: + # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L95 + qscales_and_zeros = ( + torch.tensor([1.0, 8.0]) + .bfloat16() + .expand(K // G, N, 2) + .contiguous() + ) + # shape: [K, N] + unpacked_w = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + W, + G, + qscales_and_zeros, + ).to(torch.uint8) + block_n = 32 + # shape: [N // block_n, K, block_n] + w_blocked = ( + unpacked_w.view(K, N // block_n, block_n) + .permute(1, 0, 2) + .contiguous() + ) + # pack 2 int4 -> 1 int8 + # block_n: [a0, a1, ..., a15, b0, b1, ..., b15] + # -> [(a0 & 0xf) | (b0 << 4), (a1 & 0xf) | (b1 << 4), ...] + # shape: [N // block_n, K, 2, block_n // 2] + w_blocked = w_blocked.view(N // block_n, K, 2, block_n // 2) + # shape: [N // block_n, K, block_n // 2] + w_blocked_packed = (w_blocked[:, :, 0, :] & 0xF) | ( + w_blocked[:, :, 1, :] << 4 + ) + # shape: [N, K // 2] + blocked_w = w_blocked_packed.view(N, K // 2) + + return blocked_w + + return CppWoqInt4GemmTemplateInstance + + +class CppWoqInt4GemmTemplate(metaclass=CppWoqInt4GemmTemplateMeta): + pass diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9735222275b801451a06aaef8d0ace71d4db09 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py @@ -0,0 +1,500 @@ +import contextlib +import logging +from typing import Any, Callable, cast, Optional, TypeVar +from unittest.mock import patch + +import torch +import torch.utils +from torch.utils._ordered_set import OrderedSet + +from ..._dynamo.utils import counters +from .. import config, ir +from ..kernel.mm_common import mm_args +from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper +from ..utils import parallel_num_threads +from ..virtualized import V +from .cpp import get_export_declaration +from .cpp_gemm_template import ( + CppGemmTemplate, + expand_bias, + gen_2d_view_of_epilogue_buf, + prune_tensors, + transpose_w, +) +from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import ( + create_epilogue_with_attr, + DTYPE_TO_CPP, + GemmBlocking, + get_gemm_template_output_and_compute_dtype, +) + + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} +{{micro_gemm.codegen_define(kernel)}} + +extern "C" {{export_declaration}} +{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}} +{ + {{kernel.maybe_codegen_profile()}} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0] + ) }} +{%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + {{ template.codegen_multi_threads_params()|indent(8, false) }} +{%- else %} + { + {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }} +{%- endif %} + {{ micro_gemm.codegen_init(kernel) }} +{%- set acc_buf_name_list=[] %} +{%- set acc_buf_name_prefix = "local_acc_buf_" %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} + {%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %} +{%- endfor %} + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + {{ template.codegen_m_loop_params()|indent(12, false) }} + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + {{ template.codegen_n_loop_params()|indent(16, false) }} +{%- set acc_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %} + {{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }} +{%- endfor %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * Kr; + int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); +{%- set tile_X_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %} +{%- endfor %} + for (int64_t nci = nc; nci < nc_block_end; nci++) { +{%- set tile_W_3d_list=[] %} +{%- set tile_W_list=[] %} +{%- set acc_slice_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_slice_list = acc_slice_list.append( + kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) + ) %} + {%- set tile_W_3d_list = tile_W_3d_list.append( + kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()]) + ) %} +{%- endfor %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_W_list = tile_W_list.append( + kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n]) + ) %} +{%- endfor %} + if (kc == k_block_start) { + {%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {{ micro_gemm.codegen_call( + kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False + )|indent(28, false) }} + {%- endfor %} + } else { + {%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {{ micro_gemm.codegen_call( + kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True + )|indent(28, false) }} + {%- endfor %} + } + } + } + { +{%- set tile_acc_list = [] %} +{%- set tile_Y_list = [] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_acc_list = tile_acc_list.append( + kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")]) + ) %} + {%- set tile_Y_list = tile_Y_list.append( + kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")]) + ) %} +{%- endfor %} + {{ kernel.store_outputs( + tile_Y_list, + tile_acc_list, + GemmOuts, + epilogue_nodes, + offsets=("m_start", "n_start"), + reindexers=reindexers, + multi_output_buffers=multi_output_buffers + )|indent(20, false) + }} + } + } + } + {{ micro_gemm.codegen_finalize(kernel) }} + } +} +""" + + +def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> list[ir.IRNode]: + act_deduplicated = [] + act_deduplicated_name: OrderedSet[str] = OrderedSet() + for act_idx in range(len(act_mapping.values())): + act = act_mapping[act_idx] + if act.get_name() not in act_deduplicated_name: + act_deduplicated.append(act) + act_deduplicated_name.add(act.get_name()) + return act_deduplicated + + +class CppGroupedGemmTemplate(CppGemmTemplate): + def __init__( + self, + input_nodes: list[ir.IRNode], + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta: int = 1, + alpha: int = 1, + has_bias: bool = False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, + gemm_grouped_num: int = 1, + ) -> None: + """ + Template for Group of GEMMs: + * Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc) + for their A, B, and C matrices. + * Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues. + * In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues. + This behavior can be extended in the future if needed. + """ + super().__init__( + input_nodes, + layout, + num_threads, + register_blocking, + beta, + alpha, + has_bias, + epilogue_creator, + ) + self.act_mapping = act_mapping + self.gemm_grouped_num = gemm_grouped_num + self.output_node: list[ir.Buffer] = [ + ir.Buffer(name="buf_out" + str(idx), layout=layout) + for idx in range(gemm_grouped_num) + ] + + @classmethod + def add_choices( + cls, + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[ir.IRNode], + beta: int = 1, + alpha: int = 1, + has_bias: tuple[bool, ...] = (False, False), + trans_w: bool = False, + input_indices: Optional[list[int]] = None, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf + ) -> DataProcessorTemplateWrapper: + # Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ... + gemm_grouped_num = len(has_bias) + assert act_mapping + act_deduplicated = get_deduplicated_act(act_mapping) + wgt_start_idx = len(act_deduplicated) + bias_start_idx = wgt_start_idx + gemm_grouped_num + input_indices = list(range(len(input_nodes))) + + _T = TypeVar("_T", ir.IRNode, torch.Tensor) + _U = TypeVar("_U", ir.Layout, torch.Tensor) + + def reorder_and_filter( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + assert input_indices is not None, "input_indices must be set" + return [inputs[idx] for idx in input_indices], layout_or_out + + new_inputs, new_layout = reorder_and_filter(input_nodes, layout) + + def maybe_to_dense( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_inputs = list(inputs) + for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + if isinstance(inputs[idx], torch.Tensor): + W = inputs[idx] + assert isinstance(W, torch.Tensor), "W must be a torch.Tensor" + new_inputs[idx] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_inputs: list[_T] = list(inputs) + if not trans_w: + return new_inputs, layout_or_out + X = new_inputs[0] + for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + new_input = new_inputs[wgt_idx] + new_inputs[wgt_idx] = transpose_w(new_input, trans_w) + for bias_idx in range(bias_start_idx, len(new_inputs)): + new_bias = expand_bias(new_inputs[bias_idx], X) + assert new_bias is not None + new_inputs[bias_idx] = new_bias + return new_inputs, layout_or_out + + num_threads = parallel_num_threads() + new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout)) + m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx]) + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + new_inputs[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=new_inputs[0].get_dtype(), + input2_dtype=new_inputs[wgt_start_idx].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=alpha, + num_threads=num_threads, + ) + assert micro_gemm is not None + _, block_n, _ = micro_gemm.register_blocking + new_size, padded_n = cls.get_padded_size( + n, block_n, k, should_block_weight=True + ) + padding = padded_n - n + + def pack_weight( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_W_list = [] + new_inputs = list(inputs) + W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] + for W in W_list: + blocked_w = cls.block_weight(W, new_size, padding) + new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size)) + new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list + return new_inputs, layout_or_out + + def preprocessor( + inputs: list[_T], + layout: _U, + ) -> tuple[list[_T], _U]: + return pack_weight( + *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) + ) + + def postprocessor(output: _T) -> _T: + if isinstance(output, ir.TensorBox): + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + W_nodes = new_input_nodes[ + wgt_start_idx : wgt_start_idx + gemm_grouped_num + ] + W_tensor = [] + for W_node in W_nodes: + assert W_node.get_name() in V.graph.constants + W_tensor.append(V.graph.constants[W_node.get_name()]) + new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = ( + W_tensor # type: ignore[assignment] + ) + new_input_nodes, _ = pack_weight( + *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + ) + # Prune unused tensors + prune_tensors(input_nodes, new_input_nodes) + for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + W_packed = new_input_nodes[idx] + assert isinstance(W_packed, torch.Tensor) + W_packed_constant = V.graph.add_tensor_constant(W_packed) + template_buffer.inputs[idx] = ( + ir.InputsKernel.unwrap_storage_for_input(W_packed_constant) + ) + return output + + template = DataProcessorTemplateWrapper( + CppGroupedGemmTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + act_mapping=act_mapping, + gemm_grouped_num=gemm_grouped_num, + ) + template.maybe_append_choice(choices) + return template + + def render( # type: ignore[override,return,no-untyped-def] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + assert self.act_mapping + act_deduplicated = get_deduplicated_act(self.act_mapping) + wgt_start_idx = len(act_deduplicated) + bias_start_idx = wgt_start_idx + self.gemm_grouped_num + X_list = list(self.act_mapping.values()) + W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num] + inp_list = [] + cur_idx = bias_start_idx + for inp_idx in range(self.gemm_grouped_num): + inp = None + if self.has_bias[inp_idx]: + inp = self.input_nodes[cur_idx] + cur_idx += 1 + inp_list.append(inp) + + Y_list = self.output_node + multi_output_buffers = None + if template_buffer_node is not None: + W_list = template_buffer_node.inputs[ + wgt_start_idx : wgt_start_idx + self.gemm_grouped_num + ] + assert isinstance(template_buffer_node.outputs, list) + Y_list = template_buffer_node.outputs + counters["inductor"]["cpp_grouped_gemm_template"] += 1 + multi_output_buffers = template_buffer_node.outputs + + template_buffer = Y_list[0] + fake_buffers: list[ir.Buffer] = [] + Y_2d_list = Y_list + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + X_list[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + input_dtype=X_list[0].get_dtype(), + input2_dtype=W_list[0].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=self.alpha, + num_threads=self.num_threads, + ) + assert micro_gemm is not None + assert self.register_blocking == micro_gemm.register_blocking + self.log_blockings() + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + + epilogues: list[ir.IRNode] = [] + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = [] + gemm_output_buffers: list[ir.Buffer] = [] + for out_buf_idx in range(self.gemm_grouped_num): + gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str( + out_buf_idx + ) + gemm_output_buffers.append( + ir.Buffer(name=gemm_output_name, layout=template_buffer.layout) + ) + + assert not self.epilogue_creator, ( + "epilogue_creator is not supported yet in Grouped GEMM Template" + ) + + kernel_args: dict[str, Optional[ir.IRNode]] = {} + for x_idx in range(wgt_start_idx): + kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx] + for w_idx in range(self.gemm_grouped_num): + kernel_args["W" + str(w_idx)] = W_list[w_idx] + for inp_idx in range(self.gemm_grouped_num): + kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx] + + def _bias_add_epilogue(buf: ir.IRNode, inp: ir.IRNode) -> ir.Pointwise: + return create_epilogue_with_attr( + buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype + ) + + for gemm_idx, inp in enumerate(inp_list): + if inp: + buffer_name = Y_list[gemm_idx].get_name() + epilogues.append( + ir.ComputedBuffer( + name=buffer_name, + layout=template_buffer.layout, + data=_bias_add_epilogue(gemm_output_buffers[gemm_idx], inp), + ) + ) + reindexers.append(None) + + if epilogue_nodes: + epilogues.extend(epilogue_nodes) + for epilogue_node in epilogue_nodes: + Y = cast(ir.Buffer, epilogue_node) + _, reindexers = gen_2d_view_of_epilogue_buf( + Y, + template_buffer, + [ + epilogue_node, + ], + reindexers, + default_reindexers=[ + None, + ], + ) + + options = dict( + N=self.n, + K=self.k, + PADDED_N=self.padded_n, + aliases={}, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + export_declaration=get_export_declaration(), + acc_buf_dtype=torch.float, + DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, + epilogue_nodes=epilogues, + GemmOuts=gemm_output_buffers, + reindexers=reindexers, + kernel_args=kernel_args, + X_list=X_list, + W_list=W_list, + gemm_grouped_num=self.gemm_grouped_num, + Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)}, + Y_2d_list=Y_2d_list, + multi_output_buffers=multi_output_buffers, + ) + with contextlib.ExitStack() as stack: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers)) + ) + return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c54553756fdf9ef0e4554d4b718a4aa5009cca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py @@ -0,0 +1,2011 @@ +# mypy: allow-untyped-defs +import dataclasses +import operator +import sys +from enum import Enum +from typing import Callable, Optional + +import torch + +from .. import cpp_builder, ir +from ..cpu_vec_isa import ( + pick_vec_isa, + VecAMX, + VecAVX2, + VecAVX512, + VecISA, + VecNEON, + VecSVE256, +) +from ..utils import IndentedBuffer, parallel_num_threads +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp + + +class LayoutType(Enum): + NORMAL = 0 + VNNI2 = 1 + VNNI4 = 2 + + +_IS_WINDOWS = sys.platform == "win32" + + +def get_restrict_keyword() -> str: + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170 + return "__restrict" + else: + return "__restrict__" + + +class CppMicroGemm: + """ + A class that codegens a kernel that computes small-sized matrix multiplication. + + A micro GEMM kernel is responsible for register blocking, instruction selection, + and other CPU architecture-specific optimizations. + + The subclasses need to override `codegen_define` to define the kernel function + that is called by the code generated by `codegen_call`. + """ + + # TODO(jgong5): support constant shapes and lds as template args. + DECLARE_KERNEL = r""" +template +inline void {{kernel_name}}( +{%- if kernel_extra_args_declare %} + {{kernel_extra_args_declare}} +{%- endif %} + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) +""" + + def __init__( + self, + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + ) -> None: + self.name = name + self.input_dtype = input_dtype + assert input2_dtype is not None + self.input2_dtype = input2_dtype + self.output_dtype = output_dtype + self.compute_dtype = compute_dtype + self.register_blocking = register_blocking + self.alpha = alpha + self.pack_vnni_B_locally = False + + def get_common_options(self): + if self.input_dtype in [torch.uint8, torch.int8]: + assert self.compute_dtype == torch.int32 + assert self.output_dtype == torch.int32 + assert self.input2_dtype == torch.int8 + return { + "torch": torch, + "kernel_name": self.name, + "input_dtype": self.input_dtype, + "input2_dtype": self.input2_dtype, + "output_dtype": self.output_dtype, + "compute_dtype": self.compute_dtype, + "input_t": DTYPE_TO_CPP[self.input_dtype], + "input2_t": DTYPE_TO_CPP[self.input2_dtype], + "output_t": DTYPE_TO_CPP[self.output_dtype], + "compute_t": DTYPE_TO_CPP[self.compute_dtype], + "alpha": self.alpha, + "kernel_extra_args_declare": self.get_kernel_extra_args_declare(), + "int8_gemm": self.input_dtype in [torch.uint8, torch.int8], + "vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2, + "restrict_keyword": get_restrict_keyword(), + "pack_vnni_B_locally": self.pack_vnni_B_locally, + "template": self, + "is_woq_int4": self.is_woq_int4(), + } + + def get_kernel_declaration(self): + options = self.get_common_options() + return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) + + def get_kernel_extra_args_declare(self) -> str: + return "" + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + return [] + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + raise NotImplementedError + + def codegen_call( + self, + kernel: CppTemplateKernel, + A: ir.Buffer, + B: ir.Buffer, + C: ir.Buffer, + accum: bool, + prefetch: bool = False, + **kwargs_for_extra_args, + ) -> str: + """ + Generate the code for calling the templated kernel that computes + `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise. + """ + A_ptr = f"&({kernel.index(A, [0, 0])})" + B_ptr = f"&({kernel.index(B, [0, 0])})" + C_ptr = f"&({kernel.index(C, [0, 0])})" + M = kernel.size(C, 0) + N = kernel.size(C, 1) + K = kernel.size(A, 1) + lda = kernel.stride(A, 0) + ldb = kernel.stride(B, 0) + ldc = kernel.stride(C, 0) + res = IndentedBuffer() + res.writeline( + f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>(" + ) + with res.indent(): + kwargs_for_extra_args.update({"kernel": kernel}) + extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args) + for arg in extra_args: + res.writeline(arg) + res.writeline(f"{A_ptr},") + res.writeline(f"{B_ptr},") + res.writeline(f"{C_ptr},") + res.writeline(f"{M},") + res.writeline(f"{N},") + res.writeline(f"{K},") + res.writeline(f"{lda},") + res.writeline(f"{ldb},") + res.writeline(f"{ldc}") + res.writeline(");") + return res.getvalue() + + def use_local_vnni_blocking(self, should_block_weight: bool): + self.pack_vnni_B_locally = should_block_weight + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def get_b_layout(self) -> LayoutType: + return LayoutType.NORMAL + + ALLOCATE_WEIGHT_BUFFER = r""" + {%- if is_msvc_compiler %} + // MSVC doesn't support stack-allocated dynamic-sized arrays, so using heap memory here. + std::unique_ptr<{{buffer_dtype}}[]> heap_deq_b_buf_ptr(new {{buffer_dtype}}[{{buffer_size}}]); + {{buffer_dtype}}* {{buffer_name}} = heap_deq_b_buf_ptr.get(); + {%- else %} + // It's safe to use a stack-allocated array since the blocking strategy would + // require us to allocate an array that's smaller than the size of L1D cache, + // and the default per thread max stack size on Linux is quite higher, + // so we need not worry about stack overflow. + alignas(4096) {{buffer_dtype}} {{buffer_name}}[{{buffer_size}}]; + {%- endif %} +""" + + def codegen_allocate_weight_buffer( + self, buffer_name: str, buffer_dtype: str, *size_args + ) -> str: + buffer_size = " * ".join(map(str, size_args)) + return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render( + dict( + buffer_name=buffer_name, + buffer_dtype=buffer_dtype, + buffer_size=buffer_size, + is_msvc_compiler=cpp_builder.is_msvc_cl(), + ) + ) + + def is_woq_int4(self): + return False + + +@dataclasses.dataclass +class CppMicroGemmConfig: + input_dtype: torch.dtype + input2_dtype: torch.dtype + output_dtype: torch.dtype + compute_dtype: torch.dtype + vec_isa_cls: type[VecISA] + register_blocking: GemmBlocking + extra_check: Optional[Callable[..., bool]] = None + + +micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {} + + +def register_micro_gemm(*configs): + def inner(cls): + assert cls not in micro_gemm_configs, ( + f"Duplicate micro_gemm registration for {cls}" + ) + assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" + micro_gemm_configs[cls] = list(configs) + return cls + + return inner + + +def generate_gemm_config( + vec_isa_cls, + register_blockings, + input_dtype=torch.float, + input2_dtype=None, + output_dtype=None, + compute_dtype=None, + extra_check=None, +): + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + if input2_dtype is None: + input2_dtype = input_dtype + return [ + CppMicroGemmConfig( + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + vec_isa_cls, + GemmBlocking(*blocking), + extra_check, + ) + for blocking in register_blockings + ] + + +class CppMicroGemmRef(CppMicroGemm): + """ + A reference implementation of the CppMicroGemm class with naive C++ code. + It is used for correctness debugging. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + {{compute_t}} result = accum ? C[m * ldc + n] : 0; + for (int64_t k = 0; k < K; ++k) { + result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; + } + C[m * ldc + n] = result; + } + } +} +""" + + def __init__( + self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha + ) -> None: + super().__init__( + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + GemmBlocking(1, 1, 1), + alpha, + ) + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + **self.get_common_options(), + } + return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) + + +def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k): + return ( + k % config.register_blocking.block_k == 0 + and n % config.register_blocking.block_n == 0 + and m < 16 + ) + + +# extra check for small M dimension for int8 WoQ case +def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs): + return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get( + "dynamic_M", False + ) + + +# For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise +def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs): + return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs) + + +@register_micro_gemm( + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.half, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=do_not_use_with_small_m_for_int8_woq, + ), + *generate_gemm_config( + VecAVX512, + [ + (4, 32, 64), + (8, 32, 64), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_int8_woq_small_m_dim, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.half, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=do_not_use_with_small_m_for_int8_woq, + ), + *generate_gemm_config( + VecAVX2, + [ + (2, 16, 64), + (4, 16, 64), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_int8_woq_small_m_dim, + ), + *generate_gemm_config( + VecNEON, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + input2_dtype=torch.float, + output_dtype=torch.float, + compute_dtype=torch.float, + ), + *generate_gemm_config( + VecSVE256, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + input2_dtype=torch.float, + output_dtype=torch.float, + compute_dtype=torch.float, + ), +) +class CppMicroGemmFP32Vec(CppMicroGemm): + """ + This class generates the code for micro gemm using fp32 vec instructions for compute. + It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. + The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template, + if the desired output is BF16/FP16. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + using Vectorized = at::vec::Vectorized<{{compute_t}}>; + constexpr auto VLEN = Vectorized::size(); + {{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + // TODO(jgong5): loop unroll for M and N + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + int64_t block_n = std::min(N - n, {{block_n}}); + if (block_m == {{block_m}} && block_n == {{block_n}}) { +{%- if not trans_b %} + {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>( +{%- else %} + {{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>( +{%- endif %} + A + m * lda, +{%- if not trans_b %} + B + n, +{%- else %} + B + n * ldb, +{%- endif %} + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); +{%- if tail_n %} + } else if (block_n == {{block_n}}){ +{%- else %} + } else { +{%- endif %} + switch (block_m) { +{%- for b in range(block_m - 1, 0, -1) %} + case {{b}}: + {%- if not trans_b %} + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- else %} + {{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- endif %} + A + m * lda, + {%- if not trans_b %} + B + n, + {%- else %} + B + n * ldb, + {%- endif %} + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); + break; +{%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); + } + +{%- if tail_n %} + } else { + switch (block_m) { + {%- for b in range(block_m, 0, -1) %} + case {{b}}: + {%- if not trans_b %} + {{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- else %} + {{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- endif %} + A + m * lda, + {%- if not trans_b %} + B + n, + {%- else %} + B + n * ldb, + {%- endif %} + C + m * ldc + n, + block_n, + K, + lda, + ldb, + ldc + ); + break; + {%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); + } + } +{%- else %} + } +{%- endif %} + } + } +} +""" + + TEMPLATE_KERNEL = r""" + +template +{%- if not trans_b %} + {%- if tail_n %} +inline void {{kernel_name}}_ntail_kernel( + {%- else %} +inline void {{kernel_name}}_kernel( + {%- endif %} +{%- else %} + {%- if tail_n %} +inline void {{kernel_name}}_ntail_transpose_b_kernel( + {%- else %} +inline void {{kernel_name}}_transpose_b_kernel( + {%- endif %} +{%- endif %} + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, +{%- if tail_n %} + int64_t N, +{%- endif %} + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) { + using Vectorized = at::vec::Vectorized<{{compute_t}}>; +{%- if input2_dtype in [torch.bfloat16, torch.float16] %} + using VectorizedIn = at::vec::Vectorized<{{input_t}}>; +{%- endif %} + +{%- if not trans_b %} + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + + Vectorized va; + at::vec::VectorizedN<{{compute_t}}, COLS> vb; + at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + + {%- if tail_n %} + int64_t rCOLS = (N + VLEN - 1) / VLEN; + int ntail = N % VLEN; + {%- endif %} + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + if (col < rCOLS) { + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN, load_size); + } + {%- else %} + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + {%- endif %} + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + {%- endif %} + if constexpr (col == 0) { + {%- if alpha != 1 %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); + {%- else %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); + {%- endif %} + } + + if constexpr (row == 0) { + {%- if tail_n %} + if (col < rCOLS) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, load_size); + vb[col] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + // Convert VLEN int8 elements to int32, and then fp32 + auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN, load_size); + vb[col] = at::vec::convert(b32); + {%- else %} + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN, load_size); + {%- endif %} + } else { + vb[col] = Vectorized(0.0f); + } + + {%- else %} + + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); + vb[col] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + // Convert VLEN int8 elements to int32, and then fp32 + auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN); + if constexpr (prefetch) { + _mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0); + } + vb[col] = at::vec::convert(b32); + {%- else %} + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); + {%- endif %} + {%- endif %} + + } + + constexpr int idx = row * COLS + col; + {%- if tail_n %} + if (col < rCOLS) { + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + } + {%- else %} + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + {%- endif %} + }; + + for (int k = 0; k < K; ++k) { + c10::ForcedUnroll{}(compute, k); + } + + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int store_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + if (col < rCOLS) { + vc[i].store(C + row * ldc + col * VLEN, store_size); + } + {%- else %} + vc[i].store(C + row * ldc + col * VLEN); + {%- endif %} + }; + c10::ForcedUnroll{}(storec); + +{%- else %} + // Use 2 implementations for the transposed B: + // First implementation: + // Transpose first and then perform outer product calculation in sub-blocks, + // which introduces an additional transpose overhead of [K, N] compared to the non-transpose version. + // Second implementation: + // Directly perform inner product calculation in sub-blocks, + // which introduces an additional vector reduction of [M, N] compared to the non-tranpose version. + // Therefore, when M * N / (K * N) is large, the first implementation has better performance. + {%- if tail_n %} + if (K % Vectorized::size() == 0 && N % Vectorized::size() == 0 && 24 * BLOCK_M > K) { + {%- else %} + if (K % Vectorized::size() == 0 && 24 * BLOCK_M > K) { + {%- endif %} + // First implementation: + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + int _K = K / VLEN; + Vectorized va; + at::vec::VectorizedN<{{compute_t}}, VLEN> vb; + at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + auto unroll_loadB = [&](auto i, const {{input2_t}}* {{restrict_keyword}} src_ptr) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(src_ptr + i * ldb, VLEN); + vb[i] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + auto b32 = at::vec::convert_to_int32(src_ptr + i * ldb, VLEN); + vb[i] = at::vec::convert(b32); + {%- else %} + vb[i] = Vectorized::loadu(src_ptr + i * ldb, VLEN); + {%- endif %} + }; + auto compute_trans = [&, COLS](auto i, int k) { + constexpr int row = i % ROWS; + constexpr int col = i / ROWS; + constexpr int e_col = col * VLEN; + int idk = k * VLEN; + if constexpr (row == 0) { + c10::ForcedUnroll{}(unroll_loadB, B + e_col * ldb + idk); + at::vec::transpose_block(vb); + } + constexpr int idx = row * COLS + col; + {{kernel.unroll_pragma(16)}} + for (int j = 0; j < VLEN; j++) { + {%- if alpha != 1 %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]) * {{alpha}}); + {%- else %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j])); + {%- endif %} + vc[idx] = at::vec::fmadd(va, vb[j], vc[idx]); + } + }; + for (int k = 0; k < _K; ++k) { + c10::ForcedUnroll{}(compute_trans, k); + } + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i].store(C + row * ldc + col * VLEN); + }; + c10::ForcedUnroll{}(storec); + } else { + // Second implementation + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + constexpr auto VLEN = VectorizedIn::size(); + {%- else %} + constexpr auto VLEN = Vectorized::size(); + {%- endif %} + int _K = (K + VLEN - 1) / VLEN; + // sub-block size of BLOCK_N and BLOCK_M + constexpr int sM = {{sub_block_m}}; + constexpr int sN = {{sub_block_n}}; + {%- if tail_n %} + int bN = (N + sN - 1) / sN; + {%- else %} + constexpr int bN = (BLOCK_N + sN - 1) / sN; + {%- endif %} + constexpr int bM = (BLOCK_M + sM - 1) / sM; + + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + at::vec::VectorizedN<{{compute_t}}, 2> va; + at::vec::VectorizedN<{{compute_t}}, 2 * sN> vb; + {%- else %} + at::vec::Vectorized<{{compute_t}}> va; + at::vec::VectorizedN<{{compute_t}}, sN> vb; + {%- endif %} + at::vec::VectorizedN<{{compute_t}}, sN * sM> vmid; + + {%- if tail_n %} + int ntail = N % sN; + {%- else %} + constexpr int ntail = BLOCK_N % sN; + {%- endif %} + constexpr int mtail = BLOCK_M % sM; + int ktail = K % VLEN; + + auto compute_trans = [&](int m, int n, int k) { + {%- if tail_n %} + int e_n = (n == bN - 1 && ntail != 0) ? (N - n * sN) : sN; + {%- else %} + int e_n = (n == bN - 1 && ntail != 0) ? (BLOCK_N - n * sN) : sN; + {%- endif %} + int e_m = (m == bM - 1 && mtail != 0) ? (BLOCK_M - m * sM) : sM; + int e_k = (k == _K - 1 && ktail != 0) ? (K - k * VLEN) : VLEN; + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k); + std::tie(vb[2 * i], vb[2 * i + 1]) = at::vec::convert_to_float<{{input_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + auto b32 = at::vec::convert_to_int32(B + (sN * n + i) * ldb + k * VLEN, e_k); + vb[i] = at::vec::convert(b32); + {%- else %} + vb[i] = Vectorized::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k); + {%- endif %} + } + + {{kernel.unroll_pragma(sub_block_m)}} + for (int s = 0; s < e_m; s++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto a = VectorizedIn::loadu(A + (sM * m + s) * lda + k * VLEN, e_k); + std::tie(va[0], va[1]) = at::vec::convert_to_float<{{input_t}}>(a); + {%- elif input2_dtype == torch.int8 %} + auto a32 = at::vec::convert_to_int32(A + (sM * m + s) * lda + k * VLEN, e_k); + va = at::vec::convert(a32); + {%- else %} + va = Vectorized::loadu(A + (sM * m + s) * lda + k * VLEN, e_k); + {%- endif %} + + {%- if alpha != 1 %} + va = va * Vectorized({{alpha}}); + {%- endif %} + if (k == 0) { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], Vectorized(0.0f)); + vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]); + {%- else %} + vmid[sN * s + i] = at::vec::fmadd(va, vb[i], Vectorized(0.0f)); + {%- endif %} + } + } else { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], vmid[sN * s + i]); + vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]); + {%- else %} + vmid[sN * s + i] = at::vec::fmadd(va, vb[i], vmid[sN * s + i]); + {%- endif %} + } + } + } + + // store to C + if (k == _K - 1) { + {{kernel.unroll_pragma(sub_block_m)}} + for (int s = 0; s < e_m; s++) { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + auto v = at::vec::vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, vmid[sN * s + i]); + if constexpr (accum) { + auto c = *(C + (sM * m + s) * ldc + sN * n + i); + *(C + (sM * m + s) * ldc + sN * n + i) = c + v; + } else { + *(C + (sM * m + s) * ldc + sN * n + i) = v; + } + } + } + } + }; + + for (int n = 0; n < bN; ++n) { + for (int m = 0; m < bM; ++m) { + for (int k = 0; k < _K; ++k) { + compute_trans(m, n, k); + } + } + } + } +{%- endif %} +} +""" + + # set trans_b to generate gemm that supports transposed B matrix + # set tail_n to support the tail of N + # TODO add trans_b support for other micro gemms + # and move setting of trans_b to the init of CppMicroGemm + def __init__( + self, + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + tail_n=False, + trans_b=False, + ) -> None: + super().__init__( + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha, + ) + self.tail_n = tail_n + # trans_b is only supported on platforms that + # support avx512 or avx2 since transpose_block is + # only implemented on these platforms + if trans_b: + vec_isa = pick_vec_isa() + assert issubclass(vec_isa.__class__, VecAVX512) or issubclass( + vec_isa.__class__, VecAVX2 + ) + self.trans_b = trans_b + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + "trans_b": False, + "tail_n": False, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + if self.trans_b: + # TODO supports tuning of sub_block_m/sub_block_n + # to get better performance for specific shapes + sub_block_m = min(1, self.register_blocking.block_m) + sub_block_n = min(4, self.register_blocking.block_n) + # update options to generate kernel with trans_b and sub-block size + options.update( + { + "trans_b": self.trans_b, + "sub_block_m": sub_block_m, + "sub_block_n": sub_block_n, + } + ) + result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + # update options to generate the kernel for the tail of N + if self.tail_n: + options.update( + { + "tail_n": self.tail_n, + } + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + +# extra check for CppMicroGemmAMX +def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): + vnni_size = 4 if config.input_dtype in [torch.uint8, torch.int8] else 2 + return k % vnni_size == 0 and alpha == 1 + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [(32, 32, 64), (48, 16, 64)], + input_dtype=torch.int8, + input2_dtype=torch.int8, + output_dtype=torch.int32, + compute_dtype=torch.int32, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 64), (48, 16, 64)], + input_dtype=torch.uint8, + input2_dtype=torch.int8, + output_dtype=torch.int32, + compute_dtype=torch.int32, + extra_check=check_amx_extra, + ), +) +class CppMicroGemmAMX(CppMicroGemm): + """ + This class generates the code for micro gemm using Advanced Matrix extension (AMX) + instructions available in 4th generation Intel Xeon for compute. + It supports input types of torch.bfloat16 with fp32 output. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); +{%- if pack_vnni_B_locally %} + {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K", block_n)}} +{%- endif %} +{%- if use_cached_dequantized_B %} + // Create a stack-allocated buffer for tiles of B. + // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements. + // we cache K * {{block_n}} elements of dequantized B + {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}} + const auto buf_size = K * {{block_n}}; + auto load_dequantized_B = [&](int base_idx) { + // Load a tile of B & cache it in L1D. + {{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx; + for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) { + {%- for vec_idx in range(0, block_n, 32) %} + {%- if (block_n - vec_idx) >= 32 %} + auto b_int8_idx_{{vec_idx}} = at::vec::Vectorized::loadu( + base_addr + idx_q + {{vec_idx}} , + static_cast(32) + ); + auto b_bf16_idx_{{vec_idx}} = at::vec::convert<{{input_t}}>(b_int8_idx_{{vec_idx}}); + b_bf16_idx_{{vec_idx}}.store(dequantized_B_buf + idx_dq + {{vec_idx}}); + {%- else %} + auto b_int8_tail = at::vec::Vectorized::loadu( + base_addr + idx_q + {{block_n - (block_n % 32)}}, + static_cast({{block_n % 32}}) + ); + auto b_bf16_tail = at::vec::convert<{{input_t}}>(b_int8_tail); + b_bf16_tail.store( + dequantized_B_buf + idx_dq + {{block_n - (block_n % 32)}}, + static_cast({{block_n % 32}}) + ); + {%- endif %} + {%- endfor %} + } + }; +{%- endif %} +// The ldb would not be block_n if N != block_n +{%- if use_cached_dequantized_B or pack_vnni_B_locally %} + const int64_t updated_ldb = {{block_n}}; +{%- else %} + const int64_t updated_ldb = ldb; +{%- endif %} + // TODO(jgong5): loop unroll for M and N + for (int64_t n = 0; n < N; n += {{block_n}}) { +{%- if pack_vnni_B_locally %} + // Pack non-constant weights into VNNI interleaved format in packed_B_buf + at::vec::pack_vnni2(B + n, packed_B_buf, ldb, K, {{block_n}}); +{%- elif use_cached_dequantized_B %} + // Dequantize K * block_n int8 B elements into BF16 + load_dequantized_B(n); +{%- endif %} + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; +{%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, +{%- if use_cached_dequantized_B %} + dequantized_B_buf, +{%- elif pack_vnni_B_locally %} + packed_B_buf, +{%- else %} + B + n, +{%- endif %} + C + m * ldc + n, + K, + lda, + updated_ldb, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } +{%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, +{%- if use_cached_dequantized_B %} + dequantized_B_buf, +{%- elif pack_vnni_B_locally %} + packed_B_buf, +{%- else %} + B + n, +{%- endif %} + C + m_tail * ldc + n, + K, + lda, + updated_ldb, + ldc, + block_m + ); + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" + +template +inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + AMXState& amx_state, + const {{input_t}}* {{restrict_keyword}} A, +{%- if use_cached_dequantized_B %} + const {{input_t}}* {{restrict_keyword}} B, +{%- else %} + const {{input2_t}}* {{restrict_keyword}} B, +{%- endif %} + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + uint8_t tilecfg_rows +) { + // TODO(jgong5): add prefetch hint for A, B, C + auto loadconfig = [](const amx_tilecfg& cfg) { + _tile_loadconfig(&cfg); + }; + const auto last_k_offset = K / {{block_k}} * {{block_k}}; + const auto tail_k_size = K - last_k_offset; + if C10_LIKELY (last_k_offset > 0) { + amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig); + } else { + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + } + auto load_c = [&]() { +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} +{%- endfor %} + }; + auto zero_c = [&]() { +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_zero({{tile_idx}}); + {%- endfor %} +{%- endfor %} + }; + + if constexpr (accum) { + load_c(); + } else { + zero_c(); + } + + auto compute = [&](int k) { +{%- set tile_offset_a = num_rows // 16 * num_columns %} +{%- set tile_offset_b = tile_offset_a + num_rows // 16 %} +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx_a = tile_offset_a + tile_row %} + {%- set tile_idx_b = tile_offset_b + tile_col %} + {%- set tile_idx_c = tile_row * num_columns + tile_col %} + {%- if tile_col == 0 %} + _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}})); + {%- endif %} + {%- if tile_row == 0 %} + _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}})); + {%- endif %} + {%- if int8_gemm %} + {%- if input_dtype == torch.int8 %} + _tile_dpbssd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- else %} + _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- endif %} + {%- else %} + _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- endif %} + {%- endfor %} +{%- endfor %} + }; + + {{kernel.unroll_pragma(4)}} + for (int k = 0; k < last_k_offset; k += {{block_k}}) { + compute(k); + } + + auto store_c = [&]() { + // store to C +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} +{%- endfor %} + }; + + // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead + if C10_UNLIKELY (tail_k_size > 0) { + if C10_LIKELY (last_k_offset > 0) { + store_c(); + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + load_c(); + } + compute(last_k_offset); + } + + store_c(); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + block_m, block_n, block_k = self.register_blocking + assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX" + assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX" + if self.input_dtype in [torch.uint8, torch.int8]: + assert block_k == 64, "Only support block_k = 64 for AMX INT8" + else: + assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16" + num_columns = block_n // 16 + options = { + "declare_kernel": self.get_kernel_declaration(), + "use_cached_dequantized_B": ( + self.input_dtype == torch.bfloat16 + and self.input2_dtype in [torch.int8, torch.uint8] + ), + "kernel": kernel, + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_columns": num_columns, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + result = "" + for num_rows in range(block_m, 0, -16): + amx_kernel_options = {**options, "num_rows": num_rows} + result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + amx_kernel_options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "AMXState amx_state;" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "amx_state.release([]() { _tile_release(); });" + + def get_kernel_extra_args_declare(self) -> str: + return "AMXState& amx_state," + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + return ["amx_state,"] + + def get_b_layout(self): + if self.input_dtype in [torch.uint8, torch.int8]: + return LayoutType.VNNI4 + else: + return LayoutType.VNNI2 + + +# extra check for CppMicroBrgemm +def check_brgemm_extra(config, m, n, k, alpha, num_threads, **kwargs): + assert config.input_dtype == torch.half and config.output_dtype == torch.float + vnni_size = 2 + # use brgemm for Half when amx_fp16 is supported + return torch.cpu._is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1 + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.half, + output_dtype=torch.float, + extra_check=check_brgemm_extra, + ), +) +class CppMicroBrgemm(CppMicroGemm): + """ + This class generates the code for micro gemm using oneDNN brgemm. + It supports input types of torch.half. + """ + + TEMPLATE_ENTRY = r""" +#include +{{declare_kernel}} { +{%- if pack_vnni_B_locally %} + {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K * N")}} + at::vec::pack_vnni2(B, packed_B_buf, ldb, K, N); +{%- endif %} + at::native::cpublas::brgemm( + M, N, K, + {%- if pack_vnni_B_locally %} + lda, N, ldc, + {%- else %} + lda, ldb, ldc, + {%- endif %} + accum, + A, + {%- if pack_vnni_B_locally %} + packed_B_buf, + {%- else %} + B, + {%- endif %} + C); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + result = "" + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "at::native::cpublas::brgemm_release();" + + def get_b_layout(self): + assert self.input_dtype == torch.half and torch.cpu._is_amx_fp16_supported() + return LayoutType.VNNI2 + + +def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs): + if alpha != 1: + return False + q_group_size = kwargs.get("q_group_size", None) + assert q_group_size is not None + if ( + q_group_size < 32 + or k % q_group_size != 0 + or config.register_blocking.block_k > q_group_size + ): + return False + return k % config.register_blocking.block_k == 0 and n % 64 == 0 + + +@register_micro_gemm( + # TODO: support float/half input + *generate_gemm_config( + VecAVX512, + [(4, 64, 32), (4, 64, 64), (4, 64, 128)], + input_dtype=torch.bfloat16, + input2_dtype=torch.uint8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_woq_int4_extra, + ), +) +class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): + """ + This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics. + It is based on the corresponding ATen kernel. + Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2] + Shape of packed ScalesAndZeros = [K // group_size, N, 2] + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + auto group_size = q_group_size; + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + if (block_m == {{block_m}}) { + {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>( + A + m * lda, + reinterpret_cast(B) + n * ldb, + C + m * ldc + n, + K, + lda, + /* ldb */ {{block_n}} / 2, + ldc, + group_size, + ScaleAndZeros + n * 2, + lds, + k_start + ); + } else { + switch (block_m) { + {%- for b in range(block_m - 1, 0, -1) %} + case {{b}}: + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( + A + m * lda, + reinterpret_cast(B) + n * ldb, + C + m * ldc + n, + K, + lda, + /* ldb */ {{block_n}} / 2, + ldc, + group_size, + ScaleAndZeros + n * 2, + lds, + k_start + ); + break; + {%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); + } + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" +inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) { + return (k_start + index) % group_size == 0; +} + +inline __m128i {{kernel_name}}_convert_int4_to_int8(const uint8_t* data) { + __m128i tmp = _mm_loadu_si64((const __m128i*)data); + __m128i bytes = _mm_cvtepu8_epi16(tmp); + const __m128i lowMask = _mm_set1_epi8(0xF); + __m128i high = _mm_andnot_si128(lowMask, bytes); + __m128i low = _mm_and_si128(lowMask, bytes); + high = _mm_slli_epi16(high, 4); + bytes = _mm_or_si128(low, high); + return bytes; +} + +template +inline void {{kernel_name}}_kernel( + const {{input_t}}* {{restrict_keyword}} A, + const uint8_t* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t q_group_size, + const at::BFloat16* {{restrict_keyword}} ScaleAndZeros, + int64_t lds, // leading dimension of ScaleAndZeros + int64_t k_start) { + constexpr int BLOCK_K = {{block_k}}; + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + const int PREFETCH_SIZE_K = 16 * 4; + const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K; + + // number of blocks on K + const int KB = K / BLOCK_K; + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 scale[COLS]; + __m512 zero[COLS]; + + // Lookup table to de-quantize int4 values to bf16. + // Values are dequantized as truly int4 [-8, 7] range; + // + // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + static const __m512 lut = _mm512_set_ps( + 7.0f, 6.0f, 5.0f, 4.0f, + 3.0f, 2.0f, 1.0f, 0.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + -5.0f, -6.0f, -7.0f, -8.0f); + + // index for transpose + static const __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, + 14, 12, 10, 8, 6, 4, 2, 0); + static const __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, + 15, 13, 11, 9, 7, 5, 3, 1); + + // load scale and zero point + auto load_scale_and_zeros = [&](int i, int _kb) { + // load 2x bfloat16 vector + __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i)); + if (_kb + PREFETCH_SIZE_KB < KB) { + _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); + } + + // convert to 2x f32 vector + __m512 a, b; + at::vec::cvtbf16_fp32(t, a, b); + + // transpose scale_and_zero from {16, 2} to {2, 16} + // inputs: + // a: {s0, z0, s1, z1, ..., s7, z7} + // b: {s8, z8, s9, z9, ..., s15, z15} + // output: + // scale: {s0, s1, s2, ..., s15} + // zero: {z0, z1, z2, ..., z15} + scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b); + zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b); + }; + + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + float aa = static_cast(A[row * lda + k]); + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } + va = _mm512_set1_ps(aa); + } + + if constexpr (row == 0) { + if constexpr (COLS == 4) { + // when BLOCK_N = 64, handle each row at a time + // to reduce de-quantize overhead. + if constexpr (col == 0) { + __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb)); + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0); + } + + __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4)); + vb[0] = _mm512_permutexvar_ps(b32, lut); + vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]); + vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); + vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]); + + b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1)); + vb[1] = _mm512_permutexvar_ps(b32, lut); + vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]); + vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); + vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]); + } + } else { + __m128i b8 = {{kernel_name}}_convert_int4_to_int8(B + k * ldb + col * 8); + __m512i b32 = _mm512_cvtepu8_epi32(b8); + vb[col] = _mm512_permutexvar_ps(b32, lut); + vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]); + } + } + + constexpr int idx = row * COLS + col; + vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); + }; + + for (int k = 0, kb = 0; k < K; ++k) { + if ({{kernel_name}}_is_block_start(k, k_start, q_group_size)) { + c10::ForcedUnroll{}(load_scale_and_zeros, kb++); + } + c10::ForcedUnroll{}(compute, k); + } + + //store to C + auto storec = [&, COLS](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]); + }; + c10::ForcedUnroll{}(storec); +} +""" + + def get_kernel_extra_args_declare(self) -> str: + return ( + "const int64_t q_group_size,\n" + " const at::BFloat16* __restrict__ ScaleAndZeros,\n" + " const int64_t lds,\n" + " int64_t k_start," + ) + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + assert "kernel" in kwargs + assert "qscale_and_zeros" in kwargs + kernel = kwargs["kernel"] + qscale_and_zeros = kwargs["qscale_and_zeros"] + return [ + "group_size,", + f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),", + "N * 2,", # lds + "k_start,", + ] + + def is_woq_int4(self): + return True + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [ # (block_m, block_n, block_k) + (16, 32, 32), + (32, 32, 32), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.uint8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_amx_extra, + ), +) +class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): + """ + This class generates the code for WoQ int4 micro gemm using AMX intrinsics, + which are available on 4th and newer generations of Intel Xeon. + Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2] + Shape of packed ScalesAndZeros = [K // group_size, N, 2] + Reuse TEMPLATE_KERNEL of CppMicroGemmAMX. + """ + + TEMPLATE_ENTRY = r""" +inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) { + return (k_start + index) % group_size == 0; +} + +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); + {{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4"); + + // Create a stack-allocated buffer for tiles of B. + // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements. + // we cache K * {{block_n}} elements of dequantized B + {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}} + + constexpr int BLOCK_K = {{block_k}}; + constexpr int64_t BLOCK_N = {{block_n}}; + constexpr int COLS = BLOCK_N / 16; + const int PREFETCH_SIZE_K = 16 * 4; + const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K; + const int KB = K / BLOCK_K; + + __m512i b32[COLS * 2]; + __m512 vb[COLS * 2]; + __m512 scale[COLS]; + __m512 zero[COLS]; + + // Lookup table to de-quantize int4 values to bf16. + // Values are dequantized as truly int4 [-8, 7] range; + // + // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + static const __m512 lut = _mm512_set_ps( + 7.0f, 6.0f, 5.0f, 4.0f, + 3.0f, 2.0f, 1.0f, 0.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + -5.0f, -6.0f, -7.0f, -8.0f); + + // index for transpose + static const __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, + 14, 12, 10, 8, 6, 4, 2, 0); + static const __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, + 15, 13, 11, 9, 7, 5, 3, 1); + + // Indices for VNNI layout conversion + __m512i idx_low = _mm512_set_epi32( + 0x17, + 0x07, + 0x16, + 0x06, + 0x15, + 0x05, + 0x14, + 0x04, + 0x13, + 0x03, + 0x12, + 0x02, + 0x11, + 0x01, + 0x10, + 0x00); + __m512i idx_high = _mm512_set_epi32( + 0x1f, + 0x0f, + 0x1e, + 0x0e, + 0x1d, + 0x0d, + 0x1c, + 0x0c, + 0x1b, + 0x0b, + 0x1a, + 0x0a, + 0x19, + 0x09, + 0x18, + 0x08); + + // load scale and zero point + auto load_scale_and_zeros = [&](int i, int _kb) { + // load 2x bfloat16 vector + __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i)); + if (_kb + PREFETCH_SIZE_KB < KB) { + _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); + } + + // convert to 2x f32 vector + __m512 a, b; + at::vec::cvtbf16_fp32(t, a, b); + + // transpose scale_and_zero from {16, 2} to {2, 16} + // inputs: + // a: {s0, z0, s1, z1, ..., s7, z7} + // b: {s8, z8, s9, z9, ..., s15, z15} + // output: + // scale: {s0, s1, s2, ..., s15} + // zero: {z0, z1, z2, ..., z15} + scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b); + zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b); + }; + + // Dequantize a B block of 2 * block_n into bf16 + // So, it handles k and k+1 at the same time + auto dequantize_B = [&](int n) { + constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16 + for (int k = 0, kb = 0; k < K; k += 2) { + // Since block_k must be 32 for AMX microkernels, k_start may not be + // a multiple of q_group_size. In that case, we need to load scales + // and zero points immediately when k == 0 here + if ({{kernel_name}}_is_block_start(k, k_start, q_group_size) || k == 0) { + c10::ForcedUnroll{}(load_scale_and_zeros, kb++); + } + + // load 256 bits = 64 elements in int4 + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0); + } + + __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4)); + b32[0] = _mm512_cvtepu8_epi32(b4); + b32[1] = _mm512_srli_epi32(b32[0], 4); + vb[0] = _mm512_permutexvar_ps(b32[0] , lut); + vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]); + vb[1] = _mm512_permutexvar_ps(b32[1], lut); + vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]); + + b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4)); + b32[0 + COLS] = _mm512_cvtepu8_epi32(b4); + b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4); + vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut); + vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]); + vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut); + vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]); + + for (int i = 0; i < COLS; i++) { + // convert to VNNI + auto low = _mm512_permutex2var_ps(vb[i], idx_low, vb[i + COLS]); + auto high = _mm512_permutex2var_ps(vb[i], idx_high, vb[i + COLS]); + // convert lower 16 float32 values to bfloat16 + auto v0_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(low)); + // convert higher 16 float32 values to bfloat16 + auto v1_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(high)); + // combine the lower 16 and higher 16 bfloat16 values + auto v = _mm512_castsi256_si512(v0_bf16); + v = _mm512_inserti64x4(v, v1_bf16, 1); + // store the VNNI format bfloat16 values + {{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32; + _mm512_storeu_si512(addr, v); + } + } + }; + + for (int64_t n = 0; n < N; n += {{block_n}}) { + // Dequantize K * block_n int8 B elements into BF16 + dequantize_B(n); + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; + {%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, + dequantized_B_buf + n * K, + C + m * ldc + n, + K, + lda, + {{block_n}}, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } + {%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, + dequantized_B_buf + n * K, + C + m_tail * ldc + n, + K, + lda, + {{block_n}}, + ldc, + block_m + ); + } + } // for m + } // for n +} +""" + + def get_kernel_extra_args_declare(self) -> str: + return ( + "AMXState& amx_state,\n" + " const int64_t q_group_size,\n" + " const c10::BFloat16* __restrict__ ScaleAndZeros,\n" + " const int64_t lds,\n" + " int64_t k_start," + ) + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + assert "kernel" in kwargs + assert "qscale_and_zeros" in kwargs + kernel = kwargs["kernel"] + qscale_and_zeros = kwargs["qscale_and_zeros"] + return [ + "amx_state,", + "group_size,", + f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),", + "N * 2,", # lds + "k_start,", + ] + + def is_woq_int4(self): + return True + + +def create_micro_gemm( + name, + m, + n, + k, + input_dtype, + input2_dtype, + output_dtype=None, + compute_dtype=None, + alpha=1, + num_threads=-1, + use_ref=True, + q_group_size=None, +) -> Optional[CppMicroGemm]: + """ + Based on the provided info, try to find the config of the micro-kernel that would + deliver the best performance in terms of lower latency for this case. + """ + + def create_from_config(cls, config: CppMicroGemmConfig): + return cls( + name, + config.input_dtype, + config.input2_dtype, + config.output_dtype, + config.compute_dtype, + config.register_blocking, + alpha, + ) + + def skip_amx_kernel_for_woq(config, dynamic_M, micro_gemm_cls): + # For WoQ GEMM, AMX micro-kernel may not perform well if m is small. + # Exception: for dynamic shapes, we consider using the AMX micro-kernel. + if ( + dynamic_M + or input_dtype != torch.bfloat16 + or input2_dtype not in [torch.int8, torch.uint8] + ): + return False + # For WOQ INT8, use AMX for m >= block_m + # For WOQ INT4, use AMX for m >= 5 + block_m, *_ = config.register_blocking + is_woq_int4 = micro_gemm_cls == CppMicroGemmWoQInt4Amx + m_threshold = 5 if is_woq_int4 else block_m + return m < m_threshold + + assert isinstance(n, int) or n.is_number, n + assert isinstance(k, int) or k.is_number, k + from ..utils import has_free_symbols + + dynamic_M = has_free_symbols((m,)) + m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m + assert isinstance(m, int) or m.is_number, m + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + if num_threads < 0: + num_threads = parallel_num_threads() + vec_isa = pick_vec_isa() + matched_configs = [] + for cls, configs in micro_gemm_configs.items(): + for config in configs: + if not issubclass(vec_isa.__class__, config.vec_isa_cls): + continue + if ( + config.input_dtype == input_dtype + and config.compute_dtype == compute_dtype + and config.input2_dtype == input2_dtype + and config.output_dtype == output_dtype + # The output_dtype here is the output dtype of the micro-kernel. + # In some cases, the actual output dtype of the op for which the micro-kernel + # is being created would be same as that of the activation, but the micro-kernels + # compute output in Float/int32, which is converted in the GEMM template. This is + # subject to change in the future. + ): + if config.extra_check is not None and not config.extra_check( + config, + m, + n, + k, + alpha, + num_threads, + dynamic_M=dynamic_M, + q_group_size=q_group_size, + ): + continue + block_m, block_n, block_k = config.register_blocking + if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq( + config, dynamic_M, cls + ): + continue + # Criteria on the ranking of configurations + # 1. ISA: AMX > VEC + # 2. Dividable by block sizes (block_m, block_n, block_k) + # 3. Number of mxn blocks is large enough to occupy all the threads + # 4. Register blocks are larger + isa_score = 0 + if config.vec_isa_cls == VecAMX: + isa_score += 1 + dividable_score = 0 + if m % block_m == 0: + dividable_score += 1 + if n % block_n == 0: + dividable_score += 1 + if k % block_k == 0: + dividable_score += 1 + occupancy_score = 0 + n_blocks = (n + block_n - 1) // block_n + total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m) + if n_blocks >= num_threads: + occupancy_score += 1 + if total_mxn_blocks >= num_threads: + occupancy_score += 1 + register_bytes = ( + block_m * block_n * config.compute_dtype.itemsize + + (block_m * block_k + block_k * block_n) + * config.input_dtype.itemsize + ) + matched_configs.append( + ( + (isa_score, dividable_score, occupancy_score, register_bytes), + cls, + config, + ) + ) + if len(matched_configs) == 0: + if use_ref: + return CppMicroGemmRef( + name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha + ) + else: + return None + # TODO(jgong5): allow autotuning on choices of configs + return create_from_config(*max(matched_configs, key=operator.itemgetter(0))[1:]) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py new file mode 100644 index 0000000000000000000000000000000000000000..09ee0b184892560ece96f02fbdab682dc7208d70 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template.py @@ -0,0 +1,138 @@ +# mypy: allow-untyped-defs +import ctypes +import functools +import itertools +import logging +import sys +from collections.abc import Iterable +from typing import Callable, Optional, Union +from unittest.mock import patch + +import sympy + +from .. import config, ir +from ..autotune_process import CppBenchmarkRequest, TensorMeta +from ..utils import IndentedBuffer, Placeholder, unique +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel + + +log = logging.getLogger(__name__) + + +class CppTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes, + layout: ir.Layout, + num_threads: int, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + ) -> None: + super().__init__(name) + self.input_nodes = input_nodes + self.index = next(self.index_counter) + self.output_node: Union[ir.Buffer, list[ir.Buffer]] = ir.Buffer( + name=f"buf_out{self.index}", layout=layout + ) + self.layout = layout + self.num_threads = num_threads + self.epilogue_creator = epilogue_creator + + def generate(self, **kwargs): + kernel_name = f"cpp_{self.name}" + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + patch.object(ir.FlexibleLayout, "allow_indexing", True), + V.graph.set_current_device(self.layout.device), + CppTemplateKernel( + kernel_name=kernel_name, num_threads=self.num_threads + ) as kernel, + ): + code = kernel.render(self, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + expected_args = list( + unique(input_node.get_name() for input_node in self.input_nodes) + ) + if isinstance(self.output_node, Iterable): + expected_args.extend([node.get_name() for node in self.output_node]) + else: + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + # Cast the size hint from int to ctypes.c_ulonglong explicitly + # since in cpp kernel, we bind it to C long + extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args) + + kernel_hash_name = f"cpp_{self.name}_{self.index}" + + # Create the BenchmarkRequest for CPP + bmreq = CppBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ir.CppTemplateBuffer, + flag_template_buffer_has_other_users: bool, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + ): + kernel = CppTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads + ) + render = functools.partial( + kernel.render, + self, + template_buffer_node=template_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + return kernel, render + + return CppTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node[0].get_layout() + if isinstance(self.output_node, Iterable) + else self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.writeline("#include ") + # TODO: add c10::ForcedUnroll test to test_aoti_abi_check + res.splice("""#include """) + res.splice("""#include """) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + res.writelines(["#include "]) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..4694a4d0d097c4bc3c0f10fd27f07b3bf81fb1e2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_template_kernel.py @@ -0,0 +1,597 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union + +import sympy +from sympy.parsing.sympy_parser import parse_expr + +import torch +from torch._inductor.utils import do_bench_using_profiling +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import SymT + +from .. import config, cpp_builder, ir, lowering as L +from ..autotune_process import CppBenchmarkRequest +from ..loop_body import LoopBody +from ..select_algorithm import PartialRender +from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix +from ..virtualized import V +from .common import REMOVED +from .cpp import CppKernel, CppKernelProxy, KernelGroup +from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext + + +def parse_expr_with_index_symbols(expr): + if isinstance(expr, sympy.Expr): + return expr + elif isinstance(expr, (list, tuple)): + return [parse_expr_with_index_symbols(e) for e in expr] + else: + expr = parse_expr(str(expr)) + int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} + return expr.subs(int_symbols) + + +def wrap_with_tensorbox(node) -> ir.TensorBox: + return ( + ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) + ) + + +class CppTemplateKernel(CppKernel): + def __init__(self, kernel_name, num_threads): + super().__init__(None, num_threads) + self.kernel_name = kernel_name + self.render_hooks = {} + self.local_buffers = {} + + def render(self, template, **kwargs): + return PartialRender( + template.render(kernel=self, **kwargs), self.render_hooks + ).finalize_all() + + def def_kernel( + self, + inputs: dict[str, ir.Buffer], + outputs: dict[str, ir.Buffer], + aliases: Optional[dict[str, str]] = None, + function_name: str = "", + extra_sizevars: Optional[list[sympy.Expr]] = None, + placeholder: str = "", + ) -> str: + if len(function_name) == 0: + function_name = str(self.kernel_name) + for name, inp in inputs.items(): + if inp is not None: + self.args.input_buffers[inp.get_name()] = name + for name, out in outputs.items(): + self.args.output_buffers[out.get_name()] = name + if aliases is not None: + for alias, orig in aliases.items(): + if orig in self.args.input_buffers: + self.args.input_buffers[alias] = self.args.input_buffers[orig] + if orig in self.args.output_buffers: + self.args.output_buffers[alias] = self.args.output_buffers[orig] + + unique_sizevars = OrderedSet( + s + for input in inputs.values() + if input is not None + for sym in itertools.chain(input.get_size(), input.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + unique_sizevars.update( + s + for sym in extra_sizevars or [] + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + unique_sizevars.update( + s + for output in outputs.values() + for sym in itertools.chain(output.get_size(), output.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + sizevars = sorted(unique_sizevars, key=str) + for sizevar in sizevars: + self.args.sizevars[sizevar] = f"k{sizevar}" + + def hook(): + # remove all aliases before generate function definition + if aliases is not None: + for alias in aliases: + if alias in self.args.input_buffers: + raise AssertionError( + f"input_buffers cannot be removed: {alias}" + ) + if alias in self.args.output_buffers: + self.args.output_buffers[alias] = REMOVED + cpp_argdefs, _, _ = self.args.cpp_argdefs() + return f"void {function_name}({', '.join(cpp_argdefs)})" + + assert placeholder not in self.render_hooks + self.render_hooks[placeholder] = hook + return placeholder + + def call_kernel(self, name: str, node: ir.CppTemplateBuffer): + wrapper = V.graph.wrapper_code + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types) + + def dtype(self, node: ir.Buffer) -> str: + return DTYPE_TO_CPP[node.get_dtype()] + + def acc_dtype(self, node: ir.Buffer) -> str: + if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def size(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_size()[dim])) + + def stride(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_stride()[dim])) + + def index(self, node: ir.Buffer, indices: list[Any]) -> str: + indexer = node.get_layout().as_fixed().make_indexer() + index = indexer(parse_expr_with_index_symbols(indices)) + index = self.rename_indexing(index) + outer_name = node.get_name() + inner_name = ( + outer_name + if outer_name in self.local_buffers + else self.args.input(node.get_name()) + ) + return f"{inner_name}[{cexpr_index(index)}]" + + def slice_nd(self, node, ranges: list[tuple[Any, Any]]) -> ir.ReinterpretView: + """ + Slice the given node with a list of ranges (start and end) corresponding to its dims. + The dim is not sliced if the corresponding range is empty. + """ + assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}" + sliced = wrap_with_tensorbox(node) + for dim, _range in enumerate(ranges): + if len(_range) == 0: + continue + assert len(_range) == 2 + start, end = parse_expr_with_index_symbols(_range) + sliced = L.slice_(sliced, dim, start, end, clamp=False) + assert isinstance(sliced.data, ir.ReinterpretView), sliced.data + return sliced.data + + def select(self, node, dim: int, idx: int) -> ir.ReinterpretView: + # We avoid using L.select here because we need clamp=False so the dim after slicing + # is 1 instead of a sympy expression of symbol - dim_size. + node = wrap_with_tensorbox(node) + idx = ir.View.handle_negative_index(idx, node.get_size()[dim]) + sliced = L.squeeze(L.slice_(node, dim, idx, idx + 1, clamp=False), dim) + assert isinstance(sliced.data, ir.ReinterpretView), sliced.data + return sliced.data + + def view(self, node, sizes: list[Any]) -> ir.View: + node = wrap_with_tensorbox(node) + sizes = parse_expr_with_index_symbols(sizes) + return L.view(node, sizes).data + + def permute(self, node, dims): + node = wrap_with_tensorbox(node) + permuted = L.permute(node, dims).data + assert isinstance(permuted, ir.ReinterpretView) + return permuted + + def maybe_codegen_profile(self) -> str: + if config.cpp.enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' + else: + return "" + + def unroll_pragma(self, unroll): + if cpp_builder.is_gcc(): + return f"#pragma GCC unroll {unroll}" + else: + return f"#pragma unroll {unroll}" + + def define_buffer(self, name, sizes: list[Any], dtype=torch.float) -> str: + """Define kernel local buffer""" + sizes = parse_expr_with_index_symbols(sizes) + buf = ir.Buffer( + name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes) + ) + self.local_buffers[name] = buf + ctype = f"{DTYPE_TO_CPP[dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" + + def define_stack_allocated_buffer( + self, name, sizes: list[Any], dtype=torch.float + ) -> str: + """Define stack-allocated buffer""" + sizes = parse_expr_with_index_symbols(sizes) + buf = ir.Buffer( + name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes) + ) + self.local_buffers[name] = buf + ctype = f"{DTYPE_TO_CPP[dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"alignas(64) {ctype} _{name}[{numel}]; {ctype}* {name} = _{name};" + + def reinit_buffer_if_null(self, name): + """Reinit the previously defined local buffer if it is null""" + assert name in self.local_buffers + buf = self.local_buffers[name] + ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}" + + def release_buffer(self, name): + """Codegen the code to release the ownership of a local buffer to others""" + assert name in self.local_buffers + return f"_{name}.release()" + + def store_pointwise_nodes( + self, + dst: ir.Buffer, + nodes: list[ir.IRNode], + offsets: Optional[list[sympy.Expr]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + ) -> str: + var_sizes = (tuple(dst.get_size()), ()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes[0]) + } + if not offsets: + offsets = [sympy.S.Zero] * len(var_sizes[0]) + if not reindexers: + reindexers = [None] * len(nodes) + assert len(offsets) == len(var_sizes[0]) + output_index = dst.get_layout().make_indexer()([*var_ranges.keys()]) + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(nodes): + output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name() + node = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(node, ir.Pointwise), node + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] + if reindexers[i] is not None: + new_args = reindexers[i](new_args) # type: ignore[misc] + V.ops.store( + output_name, + output_index, + node.make_loader()(new_args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() + + def store_grouped_gemm_pointwise_nodes( + self, + dst: tuple[ir.Buffer], + nodes: list[ir.IRNode], + offsets: list[sympy.Expr], + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], + output_names: list[str], + ) -> str: + ref_dst = dst[0] + var_sizes = (tuple(ref_dst.get_size()), ()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes[0]) + } + assert offsets, "offsets should be set outside" + assert all(len(offset) == len(var_sizes[0]) for offset in offsets) + output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()]) + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(nodes): + output_name = output_names[i] + node = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(node, ir.Pointwise), node + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets[i])] # type: ignore[arg-type] + if reindexers[i] is not None: + new_args = reindexers[i](new_args) # type: ignore[misc] + V.ops.store( + output_name, + output_index, + node.make_loader()(new_args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() + + def store_output( + self, + dst: ir.Buffer, + src: ir.Buffer, + orig_src: Optional[ir.Buffer] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + offsets: Optional[list[Any]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + ): + """ + Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match. + If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues + before stored to `dst`. The `epilogues_nodes` are all pointwise. + + Notes: + 1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute + and stores. In case `epilogue_nodes` are not provided, we do nothing. + 2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since + they come form the original Inductor IR, they might need to be adjusted before working with + `src` and `dst` as outlined below: + a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on. + In this case, the `offsets` could be provided to adjust the indices passed to + `epilogue_nodes` during codegen and the data ranges are also configured according to + the sizes of `src` and `dst`. + b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is + needed on the indices to `epilogue_nodes` to match the indexing of `dst`. + c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer + in `epilogue_nodes` with `src`. + """ + assert isinstance(dst, (ir.Buffer, ir.ReinterpretView)) + assert dst.get_size() == src.get_size(), f"{dst=}, {src=}" + if offsets: + offsets = parse_expr_with_index_symbols(offsets) + if epilogue_nodes: + with LocalBufferContext(self.args) as scope: + assert orig_src is not None + if orig_src.get_name() != src.get_name(): + scope.add_local_buffer( + src, + [ + orig_src, + ], + ) + epilogue_nodes = scope.localize_nodes(epilogue_nodes) + return self.store_pointwise_nodes( + dst, + epilogue_nodes, # type: ignore[arg-type] + offsets, + reindexers, + ) + else: + if dst.get_name() != src.get_name(): + # src is local + copy = L.copy(dst, src).data.data + with LocalBufferContext(self.args) as scope: + scope.add_local_buffer(src) + return self.store_pointwise_nodes(dst, [copy]) + else: + assert dst.layout == src.layout, f"{dst=}, {src=}" + return "" + + def store_outputs( + self, + dst: tuple[ir.Buffer], + src: tuple[ir.IRNode], + orig_src: Optional[tuple[ir.IRNode]] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + offsets: Optional[list[Any]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + multi_output_buffers: Optional[tuple[ir.MultiOutput]] = None, + ): + assert isinstance(dst, Iterable) + assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst)) + if offsets: + offsets = parse_expr_with_index_symbols(offsets) + gemm_num = len(src) + final_offsets = [] + output_names = [] + if epilogue_nodes: + if not reindexers: + reindexers = [None] * len(epilogue_nodes) + with LocalBufferContext(self.args) as scope: + assert orig_src is not None + localize_epilogue_nodes = [] + all_read_names = [] + for epilogue in epilogue_nodes: + all_read_names.extend(list(epilogue.get_read_names())) + localize_epilogue_nodes.extend(scope.localize_nodes(epilogue_nodes)) + final_offsets.extend([offsets] * len(localize_epilogue_nodes)) + output_names.extend( + [node.get_name() for node in localize_epilogue_nodes] + ) + for gemm_idx in range(gemm_num): + if orig_src[gemm_idx].get_name() != src[gemm_idx].get_name(): + if orig_src[gemm_idx].get_name() in all_read_names or ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() + in all_read_names + ): + # If any of the Epilogue nodes use this GEMM output, let's localize the GEMM output + global_buffers = [orig_src[gemm_idx]] + if ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() + in all_read_names + and orig_src[gemm_idx].get_name() not in all_read_names + ): + # Epilogue might directly read the MultiOutput, Locallize MultiOutput to the local Buffer + # if this MultiOutput has not been stored by in-template epilogue + # otherwise, use the cse store cache if it will be stored before used + global_buffers.append(multi_output_buffers[gemm_idx]) + scope.add_local_buffer( + src[gemm_idx], + global_buffers, + ) + else: + scope.add_local_buffer(src[gemm_idx]) + localize_epilogue_nodes.extend( + [L.copy(dst[gemm_idx], src[gemm_idx]).data.data] + ) + reindexers.append(None) + output_names.append(dst[gemm_idx].get_name()) + final_offsets.append( + [sympy.S.Zero] * len(dst[gemm_idx].get_size()) + ) + res = self.store_grouped_gemm_pointwise_nodes( + dst, + localize_epilogue_nodes, + final_offsets, + reindexers, + output_names=output_names, + ) + for gemm_idx in range(gemm_num): + if ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() in all_read_names + ): + # If the MultiOutput is used in the Epilogue, let's remove it from args + multi_output_name = multi_output_buffers[gemm_idx].get_name() + if ( + multi_output_name in self.args.output_buffers + and self.args.output_buffers[multi_output_name] + is not REMOVED + ): + self.remove_buffer(multi_output_name) + return res + else: + if dst[0].get_name() != src[0].get_name(): + copy_list = [] + with LocalBufferContext(self.args) as scope: + for _src, _dst in zip(src, dst): + copy_list.extend([L.copy(_dst, _src).data.data]) + scope.add_local_buffer(_src) + output_names.append(_dst.get_name()) + final_offsets.append([sympy.S.Zero] * len(_dst.get_size())) + reindexers = [None] * len(copy_list) + return self.store_grouped_gemm_pointwise_nodes( + dst, + nodes=copy_list, + offsets=final_offsets, + reindexers=reindexers, + output_names=output_names, + ) + else: + assert all( + _src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst) + ) + assert all( + _src.get_layout() == _dst.get_layout() + for _src, _dst in zip(src, dst) + ) + return "" + + def check_bounds(self, expr, size, lower, upper): + # CppTemplateKernel does not need codegen related operations + return + + +class CppTemplateCaller(ir.ChoiceCaller): + """ + CppTemplateCaller + + This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CppBenchmarkRequest): The benchmark request for the caller. + template_buffer (ir.CppTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[ir.Buffer], + layout: ir.Layout, + make_kernel_render: Callable[ + [ + ir.CppTemplateBuffer, + bool, + Optional[list[ir.IRNode]], + ], + str, + ], + bmreq: CppBenchmarkRequest, + template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 + info_kwargs: Optional[ + dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout, description="") + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict( + self, + ) -> dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]: + return {"backend": "CPP", "op_type": "unknown"} + + def output_node(self) -> ir.TensorBox: + return ir.TensorBox.create( + ir.CppTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + choice=self, + ) + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0316bd3ff8a5ba83397ec4c71d1e7d7f5b865f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_utils.py @@ -0,0 +1,776 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import math +import sys +from collections import namedtuple +from collections.abc import Sequence +from typing import Any, Callable, Optional +from unittest.mock import patch + +import sympy + +import torch +from torch._prims_common import is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.printers import CppPrinter as _CppPrinter +from torch.utils._sympy.symbol import symbol_is_type, SymT +from torch.utils._sympy.value_ranges import ValueRanges + +from .. import ir +from ..dependencies import Dep +from ..loop_body import LoopBody +from ..scheduler import BaseSchedulerNode, SchedulerBuffer +from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs +from ..virtualized import ops, OpsValue, V +from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext + + +DTYPE_TO_CPP = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "at::Half", + torch.int64: "int64_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint64: "uint64_t", + torch.uint32: "uint32_t", + torch.uint16: "uint16_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "at::BFloat16", + torch.complex32: "at::complex", + torch.complex64: "at::complex", + torch.complex128: "at::complex", + torch.float8_e4m3fn: "at::Float8_e4m3fn", + torch.float8_e5m2: "at::Float8_e5m2", + torch.float8_e4m3fnuz: "at::Float8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::Float8_e5m2fnuz", +} + +DTYPE_TO_ATEN = { + torch.float32: "at::kFloat", + torch.float64: "at::kDouble", + torch.float16: "at::kHalf", + torch.int64: "at::kLong", + torch.int32: "at::kInt", + torch.int16: "at::kShort", + torch.int8: "at::kChar", + torch.uint64: "at::kUInt64", + torch.uint32: "at::kUInt32", + torch.uint16: "at::kUInt16", + torch.uint8: "at::kByte", + torch.uint32: "at::kUInt32", + torch.uint64: "at::kUInt64", + torch.bool: "at::kBool", + torch.bfloat16: "at::kBFloat16", + torch.complex32: "at::kComplexHalf", + torch.complex64: "at::kComplexFloat", + torch.complex128: "at::kComplexDouble", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", +} + +DEVICE_TO_ATEN = { + "meta": "at::kMeta", + "cpu": "at::kCPU", + "cuda": "at::kCUDA", + "xpu": "at::kXPU", + "mps": "at::kMPS", +} + +LAYOUT_TO_ATEN = { + torch.strided: "at::kStrided", + torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined] +} + +# matches c10/core/DeviceType.h +DEVICE_TO_INT = {"cpu": 0, "cuda": 1} + +_IS_WINDOWS = sys.platform == "win32" + +INDEX_TYPE = "int64_t" + +GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) + + +def get_promote_dtype(args): + return ( + functools.reduce( + torch.promote_types, # type: ignore[arg-type] + [n.dtype for n in args if isinstance(n, CppCSEVariable)], + ) + if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable)) + else None # not enough info to calculate the promote dtype + ) + + +def promote_args(new_args): + def promote_arg(arg, promote_type): + if ( + isinstance(arg, CppCSEVariable) + and arg.dtype + and promote_type + and arg.dtype != promote_type + ): + arg = ops.to_dtype(arg, promote_type) + arg = arg.value if isinstance(arg, OpsValue) else arg + arg.dtype = promote_type + return arg + + promote_type = get_promote_dtype(new_args) + promote_fn = functools.partial( + promote_arg, + promote_type=promote_type, + ) + if ( + all( + new_arg.dtype is not None + for new_arg in new_args + if isinstance(new_arg, CppCSEVariable) + ) + and promote_type + ): + new_args = list(map(promote_fn, new_args)) + return new_args + + +class CppCSEVariable(CSEVariable): + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(name, bounds, dtype) + self.is_vec = False + self.dependent_itervars = OrderedSet[sympy.Symbol]() + + def __repr__(self) -> str: + return ( + f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, " + f"dependent_itervars: {self.dependent_itervars})" + ) + + def update_on_args(self, name, args, kwargs): + if name == "load": + # args[2] is index + self._set_dependent_itervars(args[2]) + else: + # propagate relevant itervars and is_vec from args + self.dependent_itervars.update( + *[ + arg.dependent_itervars + for arg in args + if isinstance(arg, CppCSEVariable) + ] + ) + if name == "index_expr": + self._set_dependent_itervars(args[0]) + if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): + self.is_vec = True + + def _set_dependent_itervars(self, index: sympy.Expr): + """ + Set the relevant itervars for this variable based on the `index` expression. + This includes the itervars directly used in the `index` as well as relevant itervars + of other cse variables used in the `index`. + """ + for s in index.free_symbols: + if s in V.kernel.itervars: + self.dependent_itervars.add(s) # type: ignore[arg-type] + elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined] + self.dependent_itervars.update( + V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined] + ) + + def depends_on(self, itervar: sympy.Symbol): + return itervar in self.dependent_itervars + + +class CppPrinter(_CppPrinter): + def doprint(self, expr, *, simplify: bool = True, p=True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + +# A function to print, useful for printing sympy symbols. +cexpr = CppPrinter().doprint + + +def cexpr_index(index): + return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" + + +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" + + +def rewrite_index_for_function( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, + global_buf_name: str, +): + # Local buffer at the inner dimensions + snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op + assert snode is not None + local_buf = localize_buffer_handler.global_to_local[global_buf_name] + scheduler_nodes = snode.get_nodes() + _, (group, reduction_group) = max( + scheduler_nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + indices_to_keep = [ + f"x{len(call_ranges) - (idx + 1)}" + for idx in range(len(local_buf.get_layout().size)) + ] + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined] + replacements = {} + for x in sorted_symbols: + if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined] + # Only keep index used by local buffer + replacements[x] = sympy.core.numbers.Zero() + index = sympy_subs(index, replacements) # type: ignore[arg-type] + return index + + +def rewrite_index_for_nodes( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, + global_buf_name: str, +): + used_vars = OrderedSet( + s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX) + ) + index_vars = [] + local_buf = localize_buffer_handler.global_to_local[global_buf_name] + for i in range(len(local_buf.get_size())): + var = sympy_index_symbol_with_prefix(SymT.INDEX, i) + index_vars.append(var if var in used_vars else 0) + index = local_buf.get_layout().make_indexer()(index_vars) + return index + + +class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined] + def __init__( + self, + inner, + global_to_local: dict[str, ir.Buffer], + rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr], + ) -> None: + super().__init__(inner) + self.global_to_local = global_to_local + self.rewrite_index = rewrite_index + + def localize(self, name: str, index: sympy.Expr): + if self.global_to_local and name in self.global_to_local: + assert self.rewrite_index is not None + index = self.rewrite_index(self, index, name) + name = self.global_to_local[name].get_name() + return name, index + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(*self.localize(name, index)) + + def store(self, name, index, value, mode=None): + local_buffer_name, local_buffer_index = self.localize(name, index) + res = self._inner.store(local_buffer_name, local_buffer_index, value, mode) + if ( + self.global_to_local + and name in self.global_to_local + and isinstance(V.kernel, Kernel) + ): + # Remove name of local buffer from Kernel.store_buffer_names + # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store. + V.kernel.store_buffer_names.discard(local_buffer_name) + return res + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(*self.localize(name, index), value) + + +class LocalBufferContext: + """ + This class creates a context that helps to generate code involving Inductor IR with + function local buffers. These buffers are constructed during the codegen process and + are used to store intermediate results such as local accumulators. We do not want to + add them to `V.graph` since they are not global and we do not want to add them as + function arguments either. So we patch the codegen processes under this scope to support + these buffers without exposure to the outside world. + """ + + def __init__(self, kernel_args: KernelArgs) -> None: + self.kernel_args = kernel_args + self.exit_stack = contextlib.ExitStack() + # map local buffer name to local buffer + self.local_buffers: dict[str, ir.Buffer] = {} + # map global buffer name to global buffer + self.global_buffers: dict[str, ir.Buffer] = {} + # map global buffer name to local buffer + self.global_to_local: dict[str, ir.Buffer] = {} + # record the global buffers that are removed by this LocalBufferContext + self.removed_buffers: OrderedSet[str] = OrderedSet() + + def __enter__(self): + self.exit_stack.__enter__() + original_get_dtype = V.graph.get_dtype + + def get_dtype(name): + if name in self.local_buffers: + return self.local_buffers[name].get_dtype() + return original_get_dtype(name) + + self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) + + original_input = self.kernel_args.input + + def input(name): + if name in self.local_buffers: + return name + return original_input(name) + + self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input)) + + original_output = self.kernel_args.output + + def output(name): + if name in self.local_buffers: + return name + return original_output(name) + + self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output)) + + # Set current LocalBufferContext into V + self.exit_stack.enter_context(V.set_local_buffer_context(self)) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.local_buffers.clear() + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def add_local_buffer( + self, local_buffer: ir.Buffer, global_buffers: Optional[list[ir.Buffer]] = None + ): + assert local_buffer.get_name() not in self.local_buffers + self.local_buffers[local_buffer.get_name()] = local_buffer + if global_buffers: + for global_buffer in global_buffers: + global_buffer_name = global_buffer.get_name() + assert ( + global_buffer_name not in self.global_buffers + and global_buffer_name not in self.global_to_local + ) + self.global_buffers[global_buffer_name] = global_buffer + self.global_to_local[global_buffer_name] = local_buffer + if global_buffer_name not in V.graph.removed_buffers: + # Record the global buffers that are removed by this LocalBufferContext + # since which may need to restore. Refer to issue: + # https://github.com/pytorch/pytorch/issues/144186 + self.removed_buffers.add(global_buffer_name) + V.graph.removed_buffers.add(global_buffer_name) + + def localize_function( + self, + fn: Callable[..., Any], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr + ] = rewrite_index_for_function, + ): + def inner(*args, **kwargs): + with V.set_ops_handler( + LocalizeBufferHandler( + V.get_ops_handler(), + global_to_local=self.global_to_local, + rewrite_index=rewrite_index, + ) + ): + return fn(*args, **kwargs) + + return inner + + def localize_nodes( + self, + nodes: list[ir.IRNode], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr + ] = rewrite_index_for_nodes, + ) -> list[ir.IRNode]: + """ + Given `local_buf` and `global_buf` registered in current `LocalBufferContext` + though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf` + for the given `nodes` and returns a new list of IR nodes that work on `local_buf` + instead of `global_buf`, i.e., all the loads and stores are redirected to + `local_buf`. This helps the fused loops to work on smaller-sized local buffers + for better data locality. + + The the data access of `local_buf` is assumed to be contiguous with the + same order as the `global_buf`. + """ + assert len(nodes) > 0 + + def wrap_inner_fn_for_node(node: ir.IRNode): + loops = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(loops, ir.Loops) + new_inner_fn = self.localize_function( + loops.inner_fn, + rewrite_index, + ) + + new_loops = dataclasses.replace(loops, inner_fn=new_inner_fn) + if isinstance(node, ir.ComputedBuffer): + new_node = ir.ComputedBuffer( + name=node.get_name(), layout=node.get_layout(), data=new_loops + ) + else: + new_node = new_loops # type: ignore[assignment] + + return new_node + + return [wrap_inner_fn_for_node(node) for node in nodes] + + +def unify_mask_base_type( + buffer: IndentedBuffer, + vars: tuple[CSEVariable, ...], + dtype=torch.float, +): + """ + Given list of cse variables, + Cast each to new mask base dtype and return casted cse variable. + """ + new_vars = ( + V.kernel.cse.generate( + buffer, + f"{V.kernel._get_mask_cast(var, dtype)}", + ) + for var in vars + ) + return new_vars + + +def may_unify_binary_op_mask_type(a, b): + """ + Given two cse variables, when dtype is bool, unify them to the same mask dtype and return casted cse variable. + """ + if a.dtype == torch.bool: + assert b.dtype == torch.bool + mask_dtype = torch.int32 + return unify_mask_base_type(V.kernel.compute, (a, b), mask_dtype) + return a, b + + +def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32): + assert is_integer_dtype(offset.dtype) + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];" + ) + code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];") + code.writeline(f"{offset}.store(offset);") + code.writeline( + f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )" + ) + with code.indent(): + code.writeline(rand_function) + num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype) + if num_vectors == 1: + code.writeline( + f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);" + ) + else: + code.writeline( + f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);" + ) + code.writeline("()") + return code + + +def get_gemm_template_output_and_compute_dtype(input_dtype): + if input_dtype in [torch.uint8, torch.int8]: + return (torch.int32, torch.int32) + else: + return (torch.float32, torch.float32) + + +def create_epilogue_with_attr(input_buffer, attr, **kwargs): + input_loader = input_buffer.make_loader() + dtype = input_buffer.get_dtype() + if attr == "relu": + + def inner_fn(index): + input = input_loader(index) + zero = ops.constant(0, dtype) + return ops.maximum(input, zero) + + elif attr == "gelu": + assert "algorithm" in kwargs + if kwargs["algorithm"] == "none": + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + half = ops.constant(0.5, torch.float) + one = ops.constant(1.0, torch.float) + const = ops.constant(0.7071067811865476, torch.float) + result = input * half * (ops.erf(input * const) + one) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + else: + assert kwargs["algorithm"] == "tanh" + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + half = ops.constant(0.5, torch.float) + one = ops.constant(1.0, torch.float) + const1 = ops.constant(0.7978845608028654, torch.float) + const2 = ops.constant(0.044715, torch.float) + result = ( + half + * input + * ( + one + + ops.tanh(const1 * (input + const2 * input * input * input)) + ) + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "swish": + + def inner_fn(index): + input = input_loader(index) + result = input * ops.sigmoid(input) + return result + + elif attr == "sigmoid": + + def inner_fn(index): + return ops.sigmoid(input_loader(index)) + + elif attr == "tanh": + + def inner_fn(index): + return ops.tanh(input_loader(index)) + + elif attr == "hardswish" or attr == "hardsigmoid": + + def hardsigmoid_float(input): + zero = ops.constant(0, torch.float) + six = ops.constant(6, torch.float) + three = ops.constant(3, torch.float) + one_over_six = ops.constant(0.16666666666666666, torch.float) + max = ops.maximum(input + three, zero) + min = ops.minimum(max, six) + return min * one_over_six + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + result = hardsigmoid_float(input) + if attr == "hardswish": + result = input * result + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "leaky_relu": + assert "scalars" in kwargs + assert len(kwargs["scalars"]) == 1 + negative_slope = kwargs["scalars"][0] + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + zero = ops.constant(0, torch.float) + result = ops.where( + input > zero, input, input * ops.constant(negative_slope, torch.float) + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "hardtanh": + assert "scalars" in kwargs + assert len(kwargs["scalars"]) == 2 + min_value = kwargs["scalars"][0] + max_value = kwargs["scalars"][1] + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + result = ops.minimum( + ops.maximum(input, ops.constant(min_value, torch.float)), + ops.constant(max_value, torch.float), + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr in ["add", "sub", "mul"]: + assert "other" in kwargs + other = kwargs["other"] + num_input_dims = len(input_buffer.get_size()) + num_other_dims = len(other.get_size()) + dims_diff = num_input_dims - num_other_dims + other_loader = other.make_loader() + + def inner_fn(index): + op = getattr(ops, attr) + if dims_diff != 0: + return op(input_loader(index), other_loader(index[dims_diff:])) + else: + return op(input_loader(index), other_loader(index)) + + elif attr == "bias_add": + assert "other" in kwargs + assert "beta" in kwargs + assert "dtype" in kwargs + beta = kwargs["beta"] + other = kwargs["other"] + dtype = kwargs["dtype"] + bias_loader = other.make_loader() + + def inner_fn(index): + bias = bias_loader(index) + input = input_loader(index) + if beta != 1: + result = ops.constant(beta, torch.float) * bias + input + else: + result = bias + input + return result + + else: + raise ValueError(f"Unsupported epilogue attribute: {attr}") + return ir.Pointwise( + device=input_buffer.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + +def _get_loop_body(fn_list): + if all(isinstance(fn, LoopBody) for fn in fn_list): + loop_bodies = fn_list + else: + if hasattr(fn_list[0], "original_fn"): + # For the case of local buffer, we wrap the fn with localize_function + assert all(hasattr(fn, "original_fn") for fn in fn_list) + assert all( + isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list + ) + loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] + else: + assert all(isinstance(fn, functools.partial) for fn in fn_list) + assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list) + loop_bodies = [fn.args[0]._body for fn in fn_list] + assert loop_bodies is not None + return loop_bodies + + +def _get_dtype_from_loopbodies(loop_bodies): + dtypes = OrderedSet[torch.dtype]() + for loop_body in loop_bodies: + graphs = [loop_body.root_block.graph] + [ + body.graph for body in list(loop_body.subblocks.values()) + ] + for graph in graphs: + for node in graph.nodes: + if node.op != "call_method": + continue + dtypes.add(node.meta[OptimizationContext.key].dtype) + return dtypes + + +def template_fusion_with_epilogues_supported( + template: BaseSchedulerNode, epilogues: list[BaseSchedulerNode] +) -> tuple[bool, bool]: + def _get_indexes_of_template_buf_read( + epilogue_node: ir.Operation, template_buf_names: list[str] + ) -> list[sympy.Expr]: + return [ + read.index + for read in epilogue_node.get_reads() + if read.name in template_buf_names + ] + + def _check_supported_and_same_indexes( + index_of_template_buf_read: Sequence[sympy.Expr], + epilogue_writes: OrderedSet[Dep], + ) -> tuple[bool, bool]: + num_indexes = len(OrderedSet(index_of_template_buf_read)) + + if num_indexes > 1: + same_index = False + supported = False # Different read indexes not supported + elif num_indexes == 0: + same_index = True + supported = True # No reads, automatically supported + elif num_indexes == 1: + iotbr = index_of_template_buf_read[0] + same_index = all(write.index == iotbr for write in epilogue_writes) + # TODO: Add support of fusion when the read of template buffer and the write of epilogue output + # in the epilogue node don't have the same index and change supported to True + supported = same_index + else: + raise AssertionError("Should not reach here") + + return supported, same_index + + def _template_fusion_supported( + template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: list[ir.Operation] + ) -> tuple[bool, bool]: + template_buf_names = [x.get_name() for x in template_outputs] + indexes_of_template_buf_reads = [ + _get_indexes_of_template_buf_read(epilogue_node, template_buf_names) + for epilogue_node in epilogue_nodes + ] + epilogue_nodes_writes = [ + epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes + ] + + results = [ + _check_supported_and_same_indexes(reads, writes) + for reads, writes in zip( + indexes_of_template_buf_reads, epilogue_nodes_writes + ) + ] + supported, same_indexes = zip(*results) + return all(supported), all(same_indexes) + + assert template.is_template() + template_outputs = template.get_outputs() + + epilogue_nodes = [ + n.node + for epilogue in epilogues + for n in epilogue.get_nodes() + if n.node is not None + ] + return _template_fusion_supported(template_outputs, epilogue_nodes) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..08c6a586e98676ee05ef1116af8715fa77a53a83 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -0,0 +1,2747 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import ctypes +import functools +import math +import os +import sys +import textwrap +from itertools import chain, count +from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, Union + +import sympy + +import torch +import torch._higher_order_ops.torchbind +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._ops +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .. import config, ir +from ..utils import _align, DeferredLineBase, LineContext, normalize_name +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .common import get_device_op_overrides, IndentedBuffer, Kernel +from .cpp_utils import cexpr, DEVICE_TO_ATEN, DEVICE_TO_INT, DTYPE_TO_ATEN, DTYPE_TO_CPP +from .wrapper import ( + EnterSubgraphLine, + ExitSubgraphLine, + PythonWrapperCodegen, + SymbolicCallArg, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..graph import GraphLowering + + # At most, the list nesting can go one layer deep. + _OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]] + + +class HasWriteLine(Protocol): + def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ... + + +class CppWrapperCpu(PythonWrapperCodegen): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + """ + + def __init__(self): + if not hasattr(self, "device"): + self.device = "cpu" + # must be initialized prior to calling super().__init__() + self.included_devices: OrderedSet[str] = OrderedSet() + super().__init__() + self.declare = "auto " + self.declare_maybe_reference = "decltype(auto) " + self.ending = ";" + self.comment = "//" + self.none_str = "nullptr" + self.supports_intermediate_hooks = False + self.kernel_callsite_id = count() + self.int_array_id = count() # for int array local variable declarations + self.declared_int_array_vars: OrderedSet[str] = OrderedSet() + self.tmp_tensor_id = count() # for tmp tensor local variable declarations + self.arg_var_id = count() + self.used_cached_devices: OrderedSet[str] = OrderedSet() + self.used_cached_dtypes: OrderedSet[str] = OrderedSet() + self.used_cached_layouts: OrderedSet[str] = OrderedSet() + self.used_cached_memory_formats: OrderedSet[str] = OrderedSet() + self.used_cond_predicate: OrderedSet[str] = OrderedSet() + self.cached_output_id = count() + self.scalar_to_tensor_id = count() + self.custom_op_wrapper_loaded = False + # For GEMM kernels that must be initialized and are resolved at linking. + self.initialized_kernels: dict[str, Kernel] = {} + self.device_codegen = get_device_op_overrides(self.device) + # only need to include each header once + self.include_extra_header = functools.lru_cache(None)( # type: ignore[method-assign] + self._include_extra_header + ) + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperCpu() + + @staticmethod + def _generate_temporary_array_pointer( + c_type: str, elements: Sequence[str], *, force_mutable: bool = False + ) -> str: + """Get a pointer to an array that only exists for the duration of the C++ + statement it's used in.""" + # If the c_type is already a pointer, return a mutable pointer to the array. + # Otherwise, return a const pointer. In the C-shim API, pointer types are only + # const-qualified with respect to the underlying value, not any nested pointers. + # e.g. const double** is possible, but not const double* const*. This means + # that an array containing pointers must _already_ be properly const-qualified + # by the c_type, and not add additional const-ness. + ptr_call = "data()" if force_mutable or c_type.endswith("*") else "cbegin()" + return ( + f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}" + ) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + assert arg_types is not None and len(call_args) == len(arg_types), ( + "Mismatch call_args and arg_types in generate_kernel_call:\n" + f"call_args: {call_args}\n" + f"arg_types: {arg_types}" + ) + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + new_args.append(f"({arg_types[idx]})({arg}.data_ptr())") + else: + # arg is a scalar + new_args.append(arg) + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def write_constant(self, name, hashed): + # include a hash so our code cache gives different constants different files + self.header.writeline(f"// {name} {hashed}") + + @staticmethod + def get_device_include_path(device: str) -> str: + if V.graph.aot_mode: + return f"#include " + return f"#include " + + def add_device_include(self, device: str) -> None: + if device in self.included_devices: + return + + self.included_devices.add(device) + + # Add the default header for this device, plus any C-shim extensions that are + # present. + self.header.splice(self.get_device_include_path(device)) + extend_aoti_c_shim_include = ( + f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h" + ) + extend_aoti_c_shim_path = os.path.join( + os.path.dirname(torch.__file__), + "include", + extend_aoti_c_shim_include, + ) + if os.path.exists(extend_aoti_c_shim_path): + self.header.splice(f"#include <{extend_aoti_c_shim_include}>") + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + if not V.graph.aot_mode: + self.header.splice( + """ + import torch + from torch._inductor.codecache import CppWrapperCodeCache + + cpp_wrapper_src = ( + r''' + """ + ) + + self.add_device_include(self.device) + + if V.graph.aot_mode: + with open( + os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp") + ) as f: + self.header.splice(f.read()) + self.header.splice("\n") + + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if config.profiler_mark_wrapper_call or enable_kernel_profile: + # No C shim for profiling APIs, assuming profiling is a debugging feature which + # does not provide any ABI compatibility promise. + self.header.splice("#include ") + + def _include_extra_header(self, header: str): + # This is needed for cpp to python dtype conversion + self.header.splice(f"#include <{header}>") + + def mark_output_type(self): + # mark output type to unwrap tensor back to python scalar + from ..ir import ShapeAsConstantBuffer + + output_is_tensor = {} + for idx, x in enumerate(V.graph.graph_outputs): + if isinstance(x, ShapeAsConstantBuffer): + output_is_tensor[idx] = False + else: + output_is_tensor[idx] = True + + self.output_is_tensor = output_is_tensor + + def write_prefix(self): + if V.graph.is_const_graph: + # We do not write prefix for constant graph, it will be written by main module. + return + if config.aot_inductor.custom_ops_to_c_shims: + # custom_ops_to_c_shims contains declaration of custom ops with C shim. + # TODO: this could be auto-generated from a passed-in custom op schema + custom_c_shims = list( + chain(*config.aot_inductor.custom_ops_to_c_shims.values()) + ) + declarations = "\n".join( + [f"extern {textwrap.dedent(shim)};" for shim in custom_c_shims] + ) + self.prefix.splice( + f""" + extern "C" {{ + {declarations} + }} + """ + ) + if V.graph.aot_mode: + self.prefix.writeline("namespace torch::aot_inductor {") + + def write_input_output_info( + self, + info_kind: str, + idx: int, + name: str, + ): + self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""") + + def codegen_input_symbol_assignment( + self, + name: str, + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + code = self.prefix + + @functools.cache + def sizeof(name): + self.codegen_input_size_var_decl(code, name) + return f"{name}_size" + + @functools.cache + def strideof(name): + self.codegen_input_stride_var_decl(code, name) + return f"{name}_stride" + + def codegen_symbol( + sym_or_exp: Union[sympy.Symbol, sympy.Expr], + base_name: str, + name_fn: Callable[[str], str], + dim: int, + ): + if isinstance(sym_or_exp, sympy.Symbol): + if sym_or_exp in bound_vars: + return + code.writeline(f"int64_t {sym_or_exp} = {name_fn(base_name)}[{dim}];") + bound_vars.add(sym_or_exp) + elif isinstance(sym_or_exp, sympy.Expr): + undefined_symbols = [ + sym for sym in sym_or_exp.free_symbols if sym not in bound_vars + ] + if len(undefined_symbols) != 1: + # Skip if expression contains no symbols or if multiple + # symbols exists since we assume each base symbol is defined + # by other codegen_symbol calls. + return + + from torch.utils._sympy.solve import try_solve + + free_symbol = undefined_symbols.pop() + base_name = name_fn(base_name) + # Use a size symbol to solve the free symbol + size_symbol = sympy.Symbol(f"{base_name}_{dim}", integer=True) + code.writeline(f"int64_t {size_symbol} = {base_name}[{dim}];") + solution = try_solve(sympy.Eq(sym_or_exp, size_symbol), free_symbol) + if solution is not None: + code.writeline(f"int64_t {free_symbol} = {cexpr(solution[1])};") + bound_vars.add(free_symbol) + else: + raise AssertionError( + str(sympy.Eq(sym_or_exp, size_symbol)) + " is not solvable" + ) + + if isinstance(value, sympy.Expr): + if not isinstance(value, sympy.Symbol) or value in bound_vars: + return + if value.is_integer: + decl = "int64_t" + elif value.is_float: + decl = "double" + else: + raise AssertionError("Unexpected symbol type") + code.writeline(f"{decl} {value} = {name};") + bound_vars.add(value) + elif isinstance(value, ir.TensorBox): + for dim, size in enumerate(value.get_size()): + codegen_symbol(size, name, sizeof, dim) + for dim, stride in enumerate(value.get_stride()): + codegen_symbol(stride, name, strideof, dim) + elif isinstance(value, ir.TorchBindObject): + # torchbind objects are loaded in proxy executor + pass + else: + raise AssertionError(f"Unknown value type: {type(value)}") + + def generate_input_output_runtime_checks(self): + """ + In debug_compile mode, we generate checks to ensure the dtype/shape/stride/device of each + real input/output tensor match ones provided at compile time via sample + input/output. + """ + + def gen_check(handle_kind, idx, name, tensor): + # Wrap AtenTensorHandle with ConstantHandle for cleaner utility function access + self.prefix.writeline( + f"ConstantHandle {name} = ConstantHandle({handle_kind}[{idx}]);" + ) + self.codegen_tensor_dtype_var_decl(self.prefix, name) + expected_dtype_name = DTYPE_TO_ATEN[tensor.dtype] + dtype_str = str(tensor.dtype).split(".")[-1] + self.prefix.splice( + f""" + int32_t {name}_expected_dtype = aoti_torch_dtype_{dtype_str}(); + if ({name}_expected_dtype != {name}_dtype) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched dtype, " + << "expected: " << {name}_expected_dtype << "({expected_dtype_name}), " + << "but got: " << {name}_dtype << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + self.codegen_input_size_var_decl(self.prefix, name) + for dim_idx, d in enumerate(tensor.get_size()): + if isinstance(d, (int, sympy.Integer)): + self.prefix.splice( + f""" + if ({d} != {name}_size[{dim_idx}]) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched dim value at {dim_idx}, " + << "expected: {d}, " << "but got: " << {name}_size[{dim_idx}] + << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + else: + from torch.utils._sympy.value_ranges import bound_sympy + + sym_range = bound_sympy(d, V.graph.sizevars.shape_env.var_to_range) + if not math.isinf(sym_range.lower): + self.prefix.splice( + f""" + if ({name}_size[{dim_idx}] < {sym_range.lower}) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: dim value is too small at {dim_idx}, " + << "expected it to be >= {sym_range.lower}, " << "but got: " + << {name}_size[{dim_idx}] << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + if not math.isinf(sym_range.upper): + # Limit upper bound to max C long long value (2^63 - 1) + max_long_long = ctypes.c_longlong(2**63 - 1).value + upper_bound = min(sym_range.upper, max_long_long) + self.prefix.splice( + f""" + if ({name}_size[{dim_idx}] > {upper_bound}) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, " + << "expected to be <= {upper_bound}, " << "but got: " + << {name}_size[{dim_idx}] << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + self.codegen_input_stride_var_decl(self.prefix, name) + for stride_idx, s in enumerate(tensor.get_stride()): + if not isinstance(s, (int, sympy.Integer)): + continue + self.prefix.splice( + f""" + if ({s} != {name}_stride[{stride_idx}]) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched stride value at {stride_idx}, " + << "expected: {s}, " << "but got: " << {name}_stride[{stride_idx}] + << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + # check input device type + if isinstance(tensor, ir.TensorBox): + tensor_device = tensor.get_device() + if tensor_device is not None: + expected_device_type = DEVICE_TO_INT.get(tensor_device.type) + if expected_device_type is not None: + self.codegen_input_device_type_var_decl(self.prefix, name) + device_type_str = str(tensor_device.type) + self.prefix.splice( + f""" + int32_t {name}_expected_device_type = {expected_device_type}; + if ({name}_expected_device_type != {name}_device_type) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched device type, " + << "expected: " << {name}_expected_device_type << "{expected_device_type}({device_type_str}), " + << "but got: " << {name}_device_type << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + # Create a separate function for each input check to avoid "too big to optimize" error + for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): + self.prefix.splice( + f""" + AOTI_NOINLINE static void check_input_{idx}( + AtenTensorHandle* input_handles + ) {{ + """ + ) + with self.prefix.indent(): + gen_check("input_handles", idx, name, tensor) + self.prefix.writeline("}") + + # force noinline to avoid any potential compilation slowdown due to aggressive + # inline done by the host compiler + self.prefix.splice( + """ + static bool _check_aoti_runtime_check_inputs_env() { + const static char* env_var_value = getenv("AOTI_RUNTIME_CHECK_INPUTS"); + const static bool result = env_var_value != nullptr && env_var_value[0] != '0'; + return result; + } + + AOTI_NOINLINE static void __check_inputs_outputs( + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + if (!_check_aoti_runtime_check_inputs_env()){ + return; + } + """ + ) + with self.prefix.indent(): + for idx in range(len(V.graph.graph_inputs)): + self.prefix.writeline(f"check_input_{idx}(input_handles);") + self.prefix.writeline("}") + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + + assert V.graph.const_wrapper_code is not None + self.prefix.splice(V.graph.const_wrapper_code) + + assert V.graph.const_kernel_code is not None + self.kernel_declarations.splice(V.graph.const_kernel_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + __check_inputs_outputs(input_handles, output_handles); + """ + + self.generate_input_output_runtime_checks() + self.prefix.splice(run_impl_proto) + else: + # cpp entry function for JIT with cpp wrapper + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release release;") + + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] # type: ignore[arg-type] + ) + assert dtype is not None, ( + "Fails to get the dtype of the sympy.Expr" + ) + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + # debug printing for all input args to AOTI model + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.codegen_model_inputs_value_print( + input_args_to_print=[ + input_key + for input_key in V.graph.graph_inputs.keys() + if input_key.startswith("arg") + ] + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + self.prefix.writeline( + f"""[[maybe_unused]] auto {constants_key} = constants_->at({idx});""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"[[maybe_unused]] auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) + + self.codegen_inputs() + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());" + ) + + def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"int32_t {name}_dtype;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));" + ) + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"auto {name}_size = {name}.sizes();") + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"auto {name}_stride = {name}.strides();") + + def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"int32_t {name}_device_type;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));" + ) + + def codegen_model_kernels(self): + self.prefix.writeline("namespace {") + + # Tell compiler we need to link with the non-mangled symbols + for kernel in self.initialized_kernels.values(): + assert hasattr(kernel, "get_signature"), ( + f"{kernel} must have get_signature implemented" + ) + signature = kernel.get_signature() + self.prefix.writeline(f'extern "C" {signature};') + + self.prefix.writeline( + "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {" + ) + self.prefix.writeline(" public:") + declare_kernel = OrderedSet(self.src_to_kernel.values()) - OrderedSet( + self.initialized_kernels.keys() + ) + declare_kernel.update( + entry[0] for entry in self.user_defined_kernel_cache.values() + ) + if V.graph.const_module: + declare_kernel.update( + V.graph.const_module.wrapper_code.src_to_kernel.values() + ) + for kernel in sorted(declare_kernel): + self.prefix.writeline( + maybe_hipify_code_wrapper( + f" {self.device_codegen.cpp_kernel_type()} {kernel}{{nullptr}};" + ) + ) + for name, kernel in self.initialized_kernels.items(): + assert hasattr(kernel, "get_signature"), ( + f"{kernel} must have get_signature implemented" + ) + kernel_ptr = f"(*{name})" + signature = kernel.get_signature().replace(name, kernel_ptr) + self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};") + self.prefix.writeline("};") + self.prefix.writeline("} // namespace\n\n") + + if config.aot_inductor.embed_kernel_binary: + self.prefix.writeline('extern "C" {') + for name in sorted(declare_kernel): + self.prefix.writeline( + f" extern const unsigned char __{name}_start[];" + ) + if torch.xpu.is_available(): + self.prefix.writeline( + f" extern const unsigned char __{name}_end[];" + ) + self.prefix.writeline("}") + + def codegen_model_constructor(self): + """ + // Generated code example + AOTInductorModel::AOTInductorModel() + : AOTInductorModelBase(4, 1) { + inputs_info_[0].name = "input0"; + inputs_info_[0].dtype = "torch.float16"; + ... + constants_info_[0].name = "L__self___weight"; + constants_info_[0].dtype = at::kFloat; + constants_info_[0].offset = 0; + constants_info_[0].data_size = 8192; + constants_info_[0].shape = {64, 32}; + constants_info_[0].stride = {32, 1}; + ... + outputs_info_[0].name = "output0"; + outputs_info_[0].dtype = "torch.float16"; + } + """ + + num_inputs = len(V.graph.graph_inputs) + num_outputs = len(V.graph.graph_outputs) + num_constants = len(V.graph.constants) + include_weights = ( + "true" if config.aot_inductor.package_constants_in_so else "false" + ) + self.prefix.splice( + f""" + AOTInductorModel::AOTInductorModel(std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string& device_str, + std::optional cubin_dir) + : AOTInductorModelBase({num_inputs}, + {num_outputs}, + {num_constants}, + device_str, + std::move(cubin_dir), + {include_weights}) {{ + """ + ) + + with self.prefix.indent(): + for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): + assert not isinstance(inp, sympy.Expr), ( + f"input {name=} cannot be symbolic" + ) + self.write_input_output_info("inputs_info_", idx, name) + + all_cuda = all( + V.graph.get_original_value_of_constant(name).is_cuda + for name in V.graph.constants.keys() + if name not in V.graph.folded_constants + ) + for idx, name in enumerate(V.graph.constants.keys()): + tensor = V.graph.get_original_value_of_constant(name) + assert isinstance(tensor, torch.Tensor) + self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""") + self.prefix.writeline( + f"constants_info_[{idx}].dtype = static_cast({self.codegen_dtype(tensor.dtype)});" + ) + self.prefix.writeline( + f"constants_info_[{idx}].offset = {tensor.storage_offset()};" + ) + + # If constants to serialize contain cpu tensors, we always align data_size it to 64. + # When loading the constants, the valid data will depends on the size + # not the data_size so there won't be correctness issue. + data_size = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};" + ) + + from_folded = "true" if name in V.graph.folded_constants else "false" + self.prefix.writeline( + f"constants_info_[{idx}].from_folded = {from_folded};" + ) + + if name in V.graph.folded_constants: + constant_type_str = "FoldedConstant" + elif name.startswith("_tensor_constant"): + constant_type_str = "TensorConstant" + elif any( + name == normalize_name(parameter_name) + for parameter_name in V.graph.named_parameters + ): + constant_type_str = "Parameter" + elif any( + name == normalize_name(buffer_name) + for buffer_name in V.graph.named_buffers + ): + constant_type_str = "Buffer" + else: + constant_type_str = "Unknown" + self.prefix.writeline( + f"constants_info_[{idx}].type = static_cast(torch::aot_inductor::ConstantType::{constant_type_str});" + ) + + size_str = ", ".join([str(s) for s in tensor.size()]) + self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};") + + stride_str = ", ".join([str(s) for s in tensor.stride()]) + self.prefix.writeline( + f"constants_info_[{idx}].stride = {{{stride_str}}};" + ) + self.prefix.writeline( + f"constants_info_[{idx}].layout = static_cast({self.codegen_layout(tensor.layout)});" + ) + + if tensor.is_mkldnn: + opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md( + tensor + ) + assert opaque_metadata_tensor.dim() == 1, ( + "Expect opaque_metadata_tensor to be 1-D" + ) + + opaque_metadata_list = opaque_metadata_tensor.tolist() + opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list) + self.prefix.writeline( + f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};" + ) + if name in V.graph.dynamo_flat_name_to_original_fqn: + original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get( + name, name + ) + elif name in V.graph.allocated_constant_name: + original_fqn = V.graph.allocated_constant_name[name] + else: + raise AssertionError("original_fqn must be set for constant") + self.prefix.writeline( + f"""constants_info_[{idx}].original_fqn = "{original_fqn}";""" + ) + self.prefix.writeline("update_constants_map(std::move(constants_map));") + self.prefix.writeline("update_constants_array(std::move(constants_array));") + + def escape_string(x): + return ( + x.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\t", "\\t") + ) + + self.prefix.writeline( + f'in_spec_ = R"({config.aot_inductor.serialized_in_spec})";' + ) + self.prefix.writeline( + f'out_spec_ = R"({config.aot_inductor.serialized_out_spec})";' + ) + + for idx, output in enumerate(V.graph.graph_outputs): + assert not isinstance(output, sympy.Expr), ( + f"output {name=} cannot be symbolic" + ) + name = f"output{idx}" + self.write_input_output_info("outputs_info_", idx, name) + + self.prefix.writeline( + "this->kernels_ = std::make_unique();" + ) + + self.prefix.writeline("}") + + def codegen_const_run_driver(self): + """ + // Generated code example + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + std::unordered_map folded_constants_map; + std::vector output_handles; + // build up output_handles over here. + _const_run_impl(output_handles, stream, proxy_executor); + // build up folded_constants_map + return folded_constants_map; + } + """ + + self.prefix.splice( + """ + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + """ + ) + if not config.aot_inductor.use_runtime_constant_folding: + self.prefix.splice( + """ + if (!initialization) { + std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: " + << "aot_inductor.use_runtime_constant_folding=False\\n"; + } + return {}; + } + """ + ) + return + + with self.prefix.indent(): + # This is a mapping to the index of constant folding graph's output + const_index_mapping: list[Optional[tuple[int, str]]] = [None] * len( + V.graph.const_output_index + ) + for idx, (name, _) in enumerate(V.graph.constants.items()): + if name in V.graph.const_output_index: + const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] + assert None not in const_index_mapping, ( + "Not all constant gets mapped for constant folding graph." + ) + + self.prefix.writeline( + f""" + std::unordered_map folded_constants_map; + folded_constants_map.reserve({len(const_index_mapping)}); + std::vector output_handles({len(const_index_mapping)}); + """ + ) + + self.prefix.splice( + """ + // The below assignment of output_handles to constants is not used directly. + // It's only used to memo the correspondence of handle and constants. + """ + ) + + for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f"output_handles[{output_idx}] = constants_->at({const_idx});" + ) + + self.prefix.writeline( + "_const_run_impl(output_handles, stream, proxy_executor);" + ) + + for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];' + ) + self.prefix.writeline("return folded_constants_map;") + + self.prefix.writeline("}") + + def generate(self, is_inference): + with dynamo_timed("CppWrapperCpu.generate", log_pt2_compile_event=True): + self.write_wrapper_decl() + return super().generate(is_inference) + + def finalize_prefix(self): + prior = self.prefix + self.prefix = aot_mode_decls = IndentedBuffer() + if V.graph.aot_mode and not V.graph.is_const_graph: + aot_mode_decls.writeline("namespace torch::aot_inductor {") + self.codegen_model_kernels() + self.codegen_model_constructor() + self.codegen_const_run_driver() + aot_mode_decls.writeline("} // namespace torch::aot_inductor") + aot_mode_decls.writeline("using namespace torch::aot_inductor;") + + self.prefix = cache_decls = IndentedBuffer() + for dtype in self.used_cached_dtypes: + cache_decls.writeline(f"CACHE_TORCH_DTYPE({dtype});") + for device in self.used_cached_devices: + cache_decls.writeline(f"CACHE_TORCH_DEVICE({device});") + for layout in self.used_cached_layouts: + cache_decls.writeline(f"CACHE_TORCH_LAYOUT({layout});") + for memory_format in self.used_cached_memory_formats: + cache_decls.writeline(f"CACHE_TORCH_MEMORY_FORMAT({memory_format});") + + self.prefix.splice(aot_mode_decls) + self.prefix.splice(prior) + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = False, + cpp_definition: Optional[str] = None, + ): + if cpp_definition is not None: + self.header.splice(cpp_definition) + self.kernel_declarations.splice(f"\n{kernel_body}\n") + else: + self.header.splice(f"\n{kernel_body}\n") + + def codegen_scalar_to_tensor(self, output: str): + name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}" + self.wrapper_call.writeline( + f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});" + ) + return name + + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) + + def generate_return(self, output_refs: list[str]): + cst_names = V.graph.constants.keys() + output2idx: dict[str, int] = {} + + # If any output ref represents an rvalue tensor, materialize it to an lvalue + # RAIIAtenTensorHandle first. This prevents situations where the code for the + # rvalue tensor references tensor handles whose contents are modified below. + output_refs = [ + self.create_tmp_raii_handle_var_if_needed(o, self.wrapper_call) + for o in output_refs + ] + + for idx, output in enumerate(output_refs): + if output == "nullptr": + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + if is_constant_buffer: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + + if output not in output2idx: + output2idx[output] = idx + + def generate_before_suffix(self, result): + if not V.graph.is_const_graph: + if V.graph.aot_mode: + result.writeline("} // AOTInductorModel::run_impl") + else: + result.writeline("} // inductor_entry_impl") + + def generate_end(self, result): + """Generates the end of the code block, and any code needed to call it.""" + if V.graph.aot_mode: + if V.graph.is_const_graph: + result.writeline("} // AOTInductorModel::_const_run_impl") + else: + result.writeline("} // namespace torch::aot_inductor\n\n\n") + return + + if config.cpp_wrapper_build_separate: + # Close the wrapper code block, then write any kernel definitions. + result.splice("'''\n)") + if self.kernel_declarations: + result.splice("\nkernel_src = (\nr'''") + result.splice(self.kernel_declarations.getvalue()) + result.splice("'''\n)") + else: + result.splice( + """ + kernel_src = '' + """ + ) + else: + # Merge main code and kernel code + result.splice(self.kernel_declarations.getvalue()) + self.kernel_declarations.clear() + # Close the wrapper code block + result.splice("'''\n)") + + kernel_code = "kernel_src" if config.cpp_wrapper_build_separate else "None" + # Cpp entry function for JIT with cpp wrapper + result.splice( + f""" + inductor_entry = CppWrapperCodeCache.load_pybinding( + argtypes=["std::vector"], + main_code=cpp_wrapper_src, + device_type="{self.device}", + num_outputs={len(V.graph.graph_outputs)}, + kernel_code={kernel_code}, + ) + """ + ) + + wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]" + if V.graph.constants: + # Append constants to the input args for cpp wrapper. + # Python wrapper directly gets the value inside the wrapper call + # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__). + # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly. + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + constants_str = f"[{', '.join(V.graph.constants.keys())}]" + wrapper_body += f""" + constants_tensor = {constants_str} + input_tensors.extend(constants_tensor) + """ + # Convert vector of at::Tensor to vector of AtenTensorHandle. + # If we pass at::Tensor, the compilation will be too slow. + wrapper_body += """ + input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) + """ + # Release the inputs for memory reuse. + wrapper_body += """ + args.clear() + del input_tensors + """ + + # unwrap output tensor back to python scalar + if all(x for x in self.output_is_tensor.values()): + # If no ShapeAsConstantBuffer in the output, directly return the output as tensors + outputs_str = "output_tensors" + else: + outputs = [ + ( + f"output_tensors[{i}]" + if self.output_is_tensor[i] + else f"output_tensors[{i}].item()" + ) + for i in range(len(V.graph.graph_outputs)) + ] + outputs_str = f"[{', '.join(outputs)}]" + wrapper_body += f""" + output_handles = f(input_handles) + output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) + return {outputs_str} + """ + + # Wrap the func to support setting result._boxed_call = True + result.splice( + f""" + def _wrap_func(f): + def g(args): + {wrapper_body} + return g + + call = _wrap_func(inductor_entry) + """ + ) + + @staticmethod + def get_c_shim_func_name(kernel: str, device: str) -> str: + if kernel.startswith("aoti_torch_"): + return kernel + + assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'" + kernel_tokens = kernel.split("::") + kernel_suffix = kernel_tokens[-1] + if kernel_suffix == "call": + kernel_suffix = kernel_tokens[-2] + + shim_fn = f"aoti_torch_{device}_{kernel_suffix}" + return shim_fn + + def generate_c_shim_extern_kernel_call( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + ) -> None: + """debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in + place of args while preserving debug printer output.""" + # We can do this unconditionally, since we cache this call. + self.add_device_include(device) + + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + debug_args if debug_args is not None else args, kernel, None, None, "extern" + ) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + with debug_printer_manager: + shim_fn = self.get_c_shim_func_name(kernel, device) + shim_fn_codes = ( + f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" + ) + if enable_kernel_profile: + shim_fn_codes = textwrap.dedent( + f""" + {{ + RECORD_FUNCTION("{shim_fn}", c10::ArrayRef()); + {shim_fn_codes} + }} + """ + ) + self.writeline(shim_fn_codes) + + def generate_c_shim_extern_kernel_alloc( + self, extern_kernel: ir.ExternKernelAlloc, args: list[str] + ) -> None: + # registered output buffer name + name = extern_kernel.name + output_handle_name = f"{name}_handle" + is_inplace = ( + isinstance(extern_kernel.op_overload, torch._ops.OpOverload) + and torch.Tag.inplace_view in extern_kernel.op_overload.tags + ) + + if not is_inplace: + self.writeline(f"AtenTensorHandle {output_handle_name};") + args = [*args, f"&{output_handle_name}"] + + device = d.type if (d := extern_kernel.get_device()) else self.device + self.generate_c_shim_extern_kernel_call( + extern_kernel.get_kernel_name(), args, device + ) + + if not is_inplace: + self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") + + def _generate_extern_kernel_alloc_helper(self, extern_kernel, args): + if getattr(extern_kernel, "outputs", None): + # ir.ExternKernelAlloc may have outputs if it returns a tuple + self.generate_c_shim_fallback_kernel(extern_kernel, args) + else: + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + + def generate_c_shim_fallback_kernel( + self, fallback_kernel: ir.FallbackKernel, args: list[str] + ) -> None: + output_args = [] + output_raii_handles = [] + output_name_base = fallback_kernel.get_name() + for idx, output in enumerate(fallback_kernel.outputs): + if isinstance(output, ir.MultiOutput): + # TODO: handle integer output (e.g., as in attention) + name = f"{output.get_name()}" + output_handle_name = f"{name}_handle" + if output.indices: + assert output.indices[0][1] == idx, ( + f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" + ) + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_args.append(f"&{output_handle_name}") + output_raii_handles.append( + f"RAIIAtenTensorHandle {name}({output_handle_name});" + ) + elif isinstance(output, int): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"int64_t {output_name} = {output};") + output_args.append(f"&{output_name}") + elif isinstance(output, sympy.Expr): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"auto {output_name} = {cexpr(output)};") + output_args.append(f"&{output_name}") + elif output is None: + output_args.append("nullptr") + else: + raise NotImplementedError(f"unsupported type of {output=}") + args = args + output_args + device = d.type if (d := fallback_kernel.get_device()) else self.device + self.generate_c_shim_extern_kernel_call( + fallback_kernel.cpp_kernel_name, # type: ignore[arg-type] + args, + device, + ) + for raii_handle in output_raii_handles: + self.writeline(raii_handle) + + def _generate_extern_kernel_out_helper( + self, + kernel: str, + out: str, + out_view: Optional[str], + args: list[str], + device: str, + ) -> None: + if out_view: + out_name = f"{out}_as_strided" + self.writeline(f"auto {out_name} = {out_view};") + args.insert(0, out_name) + else: + args.insert(0, out) + + self.generate_c_shim_extern_kernel_call(kernel, args, device) + + def generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + ): + # call the ABI shim function instead of the ATen one + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + inputs_wrapped = [str(x) for x in inputs] + line = f"{cpp_kernel_name}({output}, {','.join(inputs_wrapped)}" + + if python_kernel_name.startswith("aten.scatter_reduce"): + line += f", {','.join(kwargs)}" + else: + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert reduce is None, ( + "Expect reduce to be None for aten.scatter_ with scalar src" + ) + line += ");" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding + # tensor prematurely deallocated, thus the temporary array trick here. + indices_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", indices + ) + args = [ + x, + indices_str, + str(len(indices)), + values, + accumulate, + ] + args.insert(0, x) # set x as the output tensor, this fallback mutates x. + self.writeline(self.wrap_kernel_call(kernel, args)) + + def add_benchmark_harness(self, output): + if V.graph.aot_mode: + return + super().add_benchmark_harness(output) + + def codegen_cpp_sizevar(self, x: sympy.Expr, *, simplify: bool = True) -> str: + return cexpr(V.graph.sizevars.simplify(x) if simplify else x) + + def codegen_sizevar(self, x: sympy.Expr) -> str: + return self.codegen_cpp_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + # in the abi_compatible mode, outputs are returned via arguments + return name + + def codegen_shape_tuple(self, shape: Sequence[sympy.Expr]) -> str: + parts = [*map(self.codegen_sizevar, shape)] + if len(parts) == 0: + return "{}" + if len(parts) == 1: + return f"{{{parts[0]}, }}" + return f"{{{', '.join(parts)}}}" + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline(f"int64_t {sym} = {cexpr(expr)};") + + def _generate_symbolic_call_arg_helper( + self, arg: SymbolicCallArg, graph: GraphLowering + ) -> None: + if (arg.inner, graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((arg.inner, graph)) + self.writeline(f"int64_t {arg.inner} = {cexpr(arg.inner_expr)};") + else: + self.writeline(f"{arg.inner} = {cexpr(arg.inner_expr)};") + + def codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw") + + if len(node.keypath) == 0: + self.writeline(f"auto {node.sym} = {node.sym}_raw;") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): + self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): + # TODO: assert divisibility here + self.writeline( + f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};" + ) + else: + raise AssertionError(f"unrecognized keypath {node.keypath}") + + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.sym)) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout) + or isinstance(buffer, ir.TMADescriptor) + else f"{buffer.get_name()}.reset();" + ) + + def make_free_by_names(self, names_to_del: list[str]): + return " ".join(f"{name}.reset();" for name in names_to_del) + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"auto {new_name} = std::move({old_name}); // reuse" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline( + 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' + ) + + def generate_start_graph(self): + pass + + def generate_end_graph(self): + pass + + def generate_inf_and_nan_checker(self, nodes): + for buf in nodes.get_names(): + # TODO: Add buf name directly into check_inf_and_nan. + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));" + ) + + def codegen_device(self, device): + assert device.type in DEVICE_TO_ATEN, ( + device.type + " not found in DEVICE_TO_ATEN" + ) + device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" + self.used_cached_devices.add(device_str) + return f"cached_torch_device_type_{device_str}, {device.index if device.index else 0}" + + def codegen_dtype(self, dtype): + dtype_str = str(dtype).split(".")[-1] + self.used_cached_dtypes.add(dtype_str) + return f"cached_torch_dtype_{dtype_str}" + + def codegen_layout(self, layout): + layout_str = str(layout).split(".")[-1] + self.used_cached_layouts.add(layout_str) + return f"cached_torch_layout_{layout_str}" + + def codegen_memory_format(self, memory_format): + memory_format_str = str(memory_format).split(".")[-1] + self.used_cached_memory_formats.add(memory_format_str) + return f"cached_torch_memory_format_{memory_format_str}" + + @functools.cache # noqa: B019 + def codegen_int_array_var( + self, + int_array: str, + writeline: Callable[..., None], + known_statically=False, + graph=None, # for per-graph caching + ): + # Used for size/stride declaration + # + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + # As a result, the emitted int array declarations may appear in a later + # position of the generated code, so the second pass codegen should not + # reuse int array declarations generated in the first pass. + # This is why writeline needs to explicitly passed in as a parameter. + var = f"int_array_{next(self.int_array_id)}" + ctype = "int64_t" + if var not in self.declared_int_array_vars: + self.declared_int_array_vars.add(var) + if known_statically: + writeline(f"static constexpr {ctype} {var}[] = {int_array};") + else: + writeline(f"const {ctype} {var}[] = {int_array};") + return var + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + V.graph.get_allocation_size(buffer), + ) + + def make_allocation( + self, name, device, dtype, shape, stride, allocation_shape=None + ): + if allocation_shape is None: + allocation_shape = shape + + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + allocation_size = self.codegen_shape_tuple(allocation_shape) + stride = self.codegen_shape_tuple(orig_stride) + + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + + if allocation_size != size: + allocation_size_array_var = self.codegen_int_array_var( + allocation_size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints( + allocation_shape + ), + graph=self.get_codegened_graph(), + ) + else: + allocation_size_array_var = size_array_var + + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + + handle_name = f"{name}_handle" + args = [ + str(len(shape)), + allocation_size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{handle_name}", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + ) + + if allocation_size != size: + old_handle_name, handle_name = handle_name, f"{name}_handle_restrided" + self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") + args = [ + old_handle_name, + size_array_var, + stride_array_var, + f"&{handle_name}", + ] + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_as_strided({', '.join(args)}));" + ) + + return f"RAIIAtenTensorHandle {name}({handle_name});" + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(stride) + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + name, + cexpr(offset), # bytes not numel + self.codegen_dtype(dtype), + str(len(shape)), + self.codegen_int_array_var( + size, self.wrapper_call.writeline, graph=self.get_codegened_graph() + ), + self.codegen_int_array_var( + stride, self.wrapper_call.writeline, graph=self.get_codegened_graph() + ), + f"&{tmp_name}", + ] + self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" + ) + return f"RAIIAtenTensorHandle({tmp_name})" + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + """Returns a newly-created, temporary RAII tensor handle containing the + reinterpreted tensor data. Callers of this function are responsible for saving + the handle if persistent access is needed.""" + dim = str(len(size)) + original_offset = offset + offset = self.codegen_sizevar(offset) + call_strs = [] + final_tensor_str = None + + def create_reinterpret_call() -> str: + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + self.codegen_shape_tuple(size), + writeline, + known_statically=self.is_statically_known_list_of_ints(size), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + self.codegen_shape_tuple(stride), + writeline, + known_statically=self.is_statically_known_list_of_ints(stride), + graph=self.get_codegened_graph(), + ), + offset, + ] + return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))" + + def create_dtypeview_call(reinterpret_call: str) -> tuple[str, list[str]]: + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] + device_name = data.layout.device.type + dtypeview_function = f"aoti_torch_{device_name}_view_dtype" + tmp_call_strs.append( + f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" + f"({reinterpret_call}, {self.codegen_dtype(dtype)}, &{tmp_AtenTensorHandle}));" + ) + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + def create_new_tensor_handle() -> tuple[str, list[str]]: + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + if ( + size == data.layout.size + and stride == data.layout.stride + and original_offset == data.layout.offset + ): + # pure dtypeview + if dtype is not None and dtype != data.dtype: + final_tensor_str, tmp_call_strs = create_dtypeview_call(data.get_name()) + else: + final_tensor_str, tmp_call_strs = create_new_tensor_handle() + call_strs.extend(tmp_call_strs) + else: + # firstly create reinterpretview + final_tensor_str = create_reinterpret_call() + + if dtype is not None and dtype != data.dtype: + # wrap it with dtypeview + final_tensor_str, tmp_call_strs = create_dtypeview_call( + final_tensor_str + ) + call_strs.extend(tmp_call_strs) + + for line in call_strs: + writeline(line) + + # NB, the return handle here represents a temporary tensor, which will be automatically + # released. + # Here's a sample usage in the cpp wrapper code: + # ``` + # aoti_torch_addmm_out( + # buf1, + # arg1_1, + # RAIIAtenTensorHandle(tmp_tensor_handle_0), + # buf0, + # 1L, + # 1L)); + # ``` + # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. + # This could be problematic when it's used in a different pattern, for example: + # ```` + # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; + # aoti_torch_proxy_executor_call_function(..., tensor_args); + # ```` + # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter + # kernel call. + # + # This is solved by updating the proxy_executor invocation to + # ``` + # aoti_torch_proxy_executor_call_function(..., + # std::array{ + # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 + # }.cbegin() + # ); + # ``` + return final_tensor_str + + def codegen_device_copy(self, src, dst, non_blocking: bool): + """This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to + handle cases where dst is not an AtenTensorHandle.""" + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));" + ) + + def codegen_multi_output(self, node: ir.MultiOutput): + # in the abi_compatible mode, outputs are retrieved by passing + # output pointers, so we skip its codegen here. + pass + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + + for (inner_input, inner_input_val), outer_input in zip( + subgraph.graph.graph_inputs.items(), outer_inputs + ): + if not isinstance(inner_input_val, ir.TensorBox): + continue + + # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional + # input (outer_input) into another at::Tensor to be used as a subgraph input + # (inner_input) in the nested scope. we can't std::move here, as the codegened + # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we + # can't necessarily std::move it back to the origin (x). + self.writeline(f"AtenTensorHandle {inner_input}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);") + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + src = inner_output.codegen_reference() + if not isinstance(inner_output, ir.ShapeAsConstantBuffer): + # in ABI-compatible mode, we need to std::move subgraph output (inner_output) + # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy + # constructor is deleted. + src = f"std::move({src})" + # in case the outer_output carried a value + # before (e.g., in the while_loop codegen) + self.writeline(f"{outer_output}.reset();") + self.writeline(f"{outer_output} = {src};") + + def codegen_invoke_subgraph(self, invoke_subgraph): + raise NotImplementedError( + "codegen invoke_subgraph is not implemented for cpp wrapper" + ) + + def codegen_conditional(self, conditional): + outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands] + outer_outputs = [] + for out in conditional.outputs: + # in ABI-compatible mode, ir.MultiOutput is not codegened, + # hence pre-declare output variables directly and separately + self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") + outer_outputs.append(out.get_name()) + + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # in ABI-compatible mode, we need to use the ABI shim function + # to extract a C++ bool from the underlying scalar bool Tensor + predicate = f"{conditional.predicate.get_name()}_scalar" + if predicate not in self.used_cond_predicate: + self.codegen_tensor_item( + torch.bool, + conditional.predicate.codegen_reference(), + predicate, + ) + self.used_cond_predicate.add(predicate) + else: + # the predicate is not a Tensor: SymBool or Python bool + predicate = conditional.predicate.codegen_reference() + + self.writeline(f"if ({predicate}) {{") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("} else {") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): + # TODO (desertfire) - This function is the old way of supporting + # subgraph codegen by inlining subgraphs in the output code. For python + # wrapper, we have moved to lifting subgraphs as functions, supported by + # PythonWrapperCode `codegen_subgraph` function. We should perhaps + # support lifting of subgraphs as functions for cpp wrapper as well. + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"// subgraph: {subgraph.name}") + self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + finally: + self.pop_codegened_graph() + + def codegen_while_loop(self, while_loop): + is_bool_pred = isinstance( + while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer + ) + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + cond_result_name = f"{name}_cond_result" + if is_bool_pred: + self.writeline(f"bool {cond_result_name};") + else: + self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") + + cond_outer_inputs = [] + for inp, out in zip(outer_carried_inputs, while_loop.outputs): + # in ABI-compatible mode, the carried inputs are codegened + # as buffers outside the while loop and set to the initial + # values. at the end of each while_loop iteration, they + # will be assigned the carried values. + out_name = out.get_name() + self.writeline(f"AtenTensorHandle {out_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") + cond_outer_inputs.append(out_name) + + # additional inputs will be assigned within the while_loop + # iteration directly from the corresponding outer graph buffers + cond_outer_inputs.extend(outer_additional_inputs) + + cond_outer_outputs = [cond_result_name] + body_outer_inputs = list(cond_outer_inputs) + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + + self.writeline("while (1) {") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + self.codegen_subgraph( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + + if is_bool_pred: + cond_result = f"{cond_result_name}" + else: + cond_result = f"{cond_result_name}_scalar" + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) + self.writeline(f"if (!{cond_result}) break;") + + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.codegen_subgraph( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def generate_extern_kernel_args_decl_if_needed( + self, + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + output_args: _OUTPUT_ARGS_TYPE, + raw_outputs: Sequence[ir.Buffer], + ): + """ + Generates declarations for external kernel arguments if needed, based on the provided + operator and its arguments. It processes both input and output arguments, categorizing + them into tensor and integer arguments for further code generation. + """ + schema = None + if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind): + obj = raw_args[0] + method = raw_args[1] + schema = op_overload.schema(obj, method) + else: + assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload) + schema = op_overload._schema + assert schema is not None + arg_types = [x.real_type for x in schema.arguments] + return_types = [x.type for x in schema.returns] + + new_tensor_args = [] + new_int_args = [] + + def fill_args(arg, arg_type): + static_arg_types = ( + torch.FloatType, + torch.BoolType, + torch.StringType, + torch.Type, + torch.DeviceObjType, + ) + inductor_tensor_buffers = ( + ir.Buffer, + ir.ReinterpretView, + ) + + if isinstance(arg_type, torch.TensorType): + assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}" + new_tensor_args.append(f"{arg.codegen_reference()}") + elif isinstance(arg_type, torch.IntType): + # int + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg + new_int_args.append(cexpr(expr)) + elif isinstance(arg_type, torch.NumberType): + # Scalar of type int + assert isinstance(arg, (int, float, bool)) + # Only treat int Scalar as dynamic + if isinstance(arg, int): + new_int_args.append(str(arg)) + elif isinstance(arg, ir.TorchBindObject): + # torchbind objects are loaded in proxy executor + pass + elif isinstance(arg_type, torch.ListType): + assert isinstance(arg, (list, tuple)) + + # List[Tensor] + if isinstance(arg_type.getElementType(), torch.TensorType): + new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg]) + # List[Optional[Tensor]] + elif isinstance( + arg_type.getElementType(), torch.OptionalType + ) and isinstance( + arg_type.getElementType().getElementType(), torch.TensorType + ): + new_tensor_args.extend( + [f"{a.codegen_reference()}" for a in arg if a is not None] + ) + # List[int] + elif isinstance(arg_type.getElementType(), torch.IntType): + new_int_args.extend([str(a) for a in arg]) + # List[SymInt] + elif isinstance(arg_type.getElementType(), torch.SymIntType): + expressions = [ + a.node.expr if isinstance(a, torch.SymInt) else a for a in arg + ] + new_int_args.extend([cexpr(expr) for expr in expressions]) + # List[Scalar] + elif isinstance(arg_type.getElementType(), torch.NumberType): + # Only treat int Scalar as dynamic + is_int_type = [isinstance(a, int) for a in arg] + if any(is_int_type): + assert all(is_int_type), ( + "AOTInductor only supports int scalars of the same type" + ) + new_int_args.extend([str(a) for a in arg]) + else: + assert isinstance( + arg_type.getElementType(), + static_arg_types, # type: ignore[arg-type] + ), ( + f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + ) + else: + assert isinstance( + arg_type, + static_arg_types, # type: ignore[arg-type] + ), ( + f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + ) + + for arg, arg_type in zip(raw_args, arg_types): + if arg is not None: + if isinstance(arg_type, torch.OptionalType): + fill_args(arg, arg_type.getElementType()) + else: + fill_args(arg, arg_type) + + def fill_output_arg( + arg: str, return_type: torch.JitType, is_mutated_output: bool + ) -> None: + if isinstance(return_type, torch.TensorType): + if not is_mutated_output: + self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") + new_tensor_args.append(f"{arg}") + elif isinstance(return_type, torch.SymIntType): + raise NotImplementedError("NYI support for return type: SymInt") + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.SymIntType + ): + raise NotImplementedError("NYI support for return type: List[SymInt]") + else: + raise AssertionError(f"Unsupported return type found: {return_type}") + + # TODO: Only support None and tensor(s) returns for now, SymInt is not implemented yet + for return_type in return_types: + if isinstance( + return_type, (torch.TensorType, torch.NoneType, torch.IntType) + ): + pass + elif isinstance(return_type, torch.OptionalType): + assert isinstance(return_type.getElementType(), torch.TensorType) + elif isinstance(return_type, torch.ListType): + assert isinstance(return_type.getElementType(), torch.TensorType) + else: + raise NotImplementedError( + f"return type {return_type} is not yet supported." + ) + + for output_arg, raw_output_arg in zip(output_args, raw_outputs): # type: ignore[arg-type] + # None output is supported, but Optional return types are not yet supported + if output_arg is None: + continue + elif isinstance(raw_output_arg, int): + new_int_args.append(str(raw_output_arg)) + elif isinstance(output_arg, list): + for out in output_arg: + assert out is not None, out + fill_output_arg( + out, + torch.TensorType.get(), + isinstance(raw_output_arg, ir.MutationOutput), + ) + else: + fill_output_arg( + output_arg, + torch.TensorType.get(), + isinstance(raw_output_arg, ir.MutationOutput), + ) + + return new_tensor_args, new_int_args + + @staticmethod + def _compatible_with_stableivalue(op: torch._ops.OpOverload) -> bool: + """Returns true if op_overload._schema only utilizes types supported by the AOT + C-shim *internal* function to_ivalue. to_ivalue is an implementation detail, so + these types are not guaranteed to be supported long-term. When generating code + for cpp_wrapper mode, we don't have to be forward-compatible, so changing this + function's implementation in future is fine.""" + supported_types = ( + torch.BoolType, + torch.DeviceObjType, + torch.FloatType, + # ScalarTypeType, LayoutType, and MemoryFormatType are seen as IntType + # when queried via torch.JitType.type. + torch.IntType, + torch.TensorType, + ) + + def type_supported(t: torch.JitType) -> bool: + if isinstance(t, torch.OptionalType): + return type_supported(t.getElementType()) + return isinstance(t, supported_types) + + return all( + type_supported(a.type) + for a in chain(op._schema.arguments, op._schema.returns) + ) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + """Generate a call to a kernel not contained in the C-shim. This results in + different code paths for AOT Inductor vs cpp_wrapper Inductor mode.""" + + def extract_output_name( + out: Optional[Union[ir.Buffer, Sequence[ir.Buffer]]], + ) -> Union[Optional[str], _OUTPUT_ARGS_TYPE]: + if out is None: + return None + if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): + return out.get_name() + if isinstance(out, ir.MutationOutput): + mutated_buf_names = out.get_mutation_names() + assert ( + isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1 + ), "Expect only one mutated buffer in MutationOutput" + return mutated_buf_names[0] + if isinstance(out, (list, tuple)): + return [extract_output_name(o) for o in out] # type: ignore[misc] + if isinstance(out, int): + return str(out) + raise AssertionError(f"Unexpected output: {type(out)}") + + if isinstance(op_overload, torch._ops.HigherOrderOperator): + assert isinstance( + op_overload, torch._higher_order_ops.torchbind.CallTorchBind + ), type(op_overload) + assert len(raw_args) > 1 + obj = raw_args[0] + method = raw_args[1] + return_schema = op_overload.schema(obj, method).returns + else: + return_schema = op_overload._schema.returns + + # output_args has the same pytree structure as outputs + if not return_schema: + # kernel does not return a value + output_args: _OUTPUT_ARGS_TYPE = [] + elif isinstance(output_name := extract_output_name(outputs), str): + output_args = [output_name] + else: + # If the schema indicates a return value, we should have a non-None value by + # this point. + assert isinstance(output_name, list), type(output_name) + output_args = output_name + + # In AOT mode, we use a ProxyExecutor to run fallback kernels. + if V.graph.aot_mode: + self.generate_fallback_kernel_with_runtime_lookup_aot( + op_overload, + raw_args, + output_args, + outputs, + ) + return + + assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload) + for output in output_args: + assert output is None or isinstance(output, str), ( + "fallback kernels with runtime lookup currently only support tensor " + "returns, not more complicated types (such as list-of-list-of-tensor)" + ) + + # In non-AOT mode, we use aoti_torch_call_dispatcher if all the inputs and + # outputs of the op can be represented with StableIValue. This avoids the + # overhead of calling back into Python, and covers most remaining fallback ops. + if self._compatible_with_stableivalue(op_overload): + self.generate_fallback_kernel_with_runtime_lookup_nopython( + get_args, + op_overload, + output_args, # type: ignore[arg-type] + outputs, + ) + return + + # Otherwise, we call back into Python, which has some extra runtime overhead, + # but handles situations like list[Tensor] (currently unrepresentable via + # StableIValue). + self.generate_fallback_kernel_with_runtime_lookup_python( + buf_name, + python_kernel_name, + op_overload, + raw_args, + output_args, # type: ignore[arg-type] + outputs, + ) + + def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope): + scoped_lines = IndentedBuffer() + for declaration in declarations_before_scope: + scoped_lines.writeline(declaration) + + scoped_lines.writeline("{") + with scoped_lines.indent(): + scoped_lines.writeline("py::gil_scoped_acquire acquire;") + scoped_lines.writelines(lines_in_scope.split("\n")) + scoped_lines.writelines("}") + return scoped_lines._lines + + def load_custom_op_wrapper(self): + # TODO: need to support control flow + if self.custom_op_wrapper_loaded: + return + + lines = """ +RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache")); +if (!codecache_module) { + throw std::runtime_error("Failed to load torch._inductor.codecache"); +} +custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper"); +if (!custom_op_wrapper) { + throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper"); +}""" + + declarations_before_scope = ["RAIIPyObject custom_op_wrapper;"] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) + + self.custom_op_wrapper_loaded = True + + def generate_float_value(self, val): + assert isinstance(val, float) + if val == float("inf"): + return "std::numeric_limits::infinity()" + elif val == float("-inf"): + return "-std::numeric_limits::infinity()" + elif math.isnan(val): + return "std::numeric_limits::quiet_NaN()" + else: + return f"{val}" + + def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): + def generate_py_arg_inner(lines, raw_arg, arg_type): + def handle_scalar(scalar): + if isinstance(scalar, int): + return f"PyLong_FromLongLong({scalar})" + if isinstance(scalar, float): + return f"PyFloat_FromDouble({self.generate_float_value(scalar)})" + if isinstance(scalar, bool): + return f"PyBool_FromLong({1 if scalar else 0})" + if isinstance(scalar, complex): + real = self.generate_float_value(scalar.real) + imag = self.generate_float_value(scalar.imag) + return f"PyComplex_FromDoubles({real}, {imag})" + if isinstance(scalar, SymTypes): + scalar_var = cexpr(scalar.node.expr) + if isinstance(scalar, torch.SymBool): + return f"PyBool_FromLong({scalar_var})" + if isinstance(scalar, torch.SymFloat): + return f"PyFloat_FromDouble({scalar_var})" + return f"PyLong_FromLongLong({scalar_var})" + raise NotImplementedError( + f"scalar {scalar}, {type(scalar)} cannot be handled by handle_scalar" + ) + + if raw_arg is None: + # Py_None is a singleton, so we have to explicitly incref it here + lines.append("Py_INCREF(Py_None);\n") + return "Py_None" + elif isinstance(arg_type, torch.TensorType): + # In some cases, scalar arguments may be passed in place of tensors. + if not hasattr(raw_arg, "codegen_reference"): + return handle_scalar(raw_arg) + + # Store AtenTensorHandle as void*. All Python args are constructed in a + # nested scope, so this handle will self-destruct after the function + # call. + base_handle = self.create_tmp_raii_handle_var_if_needed( + raw_arg.codegen_reference(), lines + ) + return f"PyCapsule_New(reinterpret_cast({base_handle}.get()), NULL, NULL)" + elif isinstance(arg_type, torch.OptionalType): + return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType()) + elif isinstance(arg_type, torch.IntType): + # int + return f"PyLong_FromLongLong({raw_arg})" + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = ( + raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg + ) + return f"PyLong_FromLongLong({cexpr(expr)})" + elif isinstance(arg_type, torch.FloatType): + return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})" + elif isinstance(arg_type, torch.BoolType): + return f"PyBool_FromLong({1 if raw_arg else 0})" + elif isinstance(arg_type, torch.StringType): + return f'PyUnicode_FromString("{raw_arg}")' + elif isinstance(arg_type, torch.NumberType): + # Union[bool, int, float, complex] + # torch/_prims_common/__init__.py + return handle_scalar(raw_arg) + elif isinstance(raw_arg, torch.device): + device_str, device_index = self.codegen_device(raw_arg).split(", ") + return f"THPDevice_New(c10::Device(static_cast({device_str}), {device_index}))" + elif isinstance(raw_arg, torch.dtype): + return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))" + elif isinstance(raw_arg, torch.layout): + return f"Py_NewRef(torch::getTHPLayout(static_cast({self.codegen_layout(raw_arg)})))" + elif isinstance(raw_arg, torch.memory_format): + return ( + "Py_NewRef(torch::utils::getTHPMemoryFormat(static_cast(" + f"{self.codegen_memory_format(raw_arg)})))" + ) + else: + raise NotImplementedError( + f"arg type {arg_type} is not yet supported by custom_op_wrapper" + ) + + lines = [] + if isinstance(arg_type, torch.ListType): + assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list" + lines.append( + f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" + ) + for i, elem in enumerate(raw_arg): + lines.append( + f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n" + ) + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" + ) + else: + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n" + ) + return "".join(lines) + + def generate_fallback_kernel_with_runtime_lookup_nopython( + self, + get_args: Callable[[], Sequence[str]], + op_overload: torch._ops.OpOverload, + output_args: Sequence[Optional[str]], + raw_outputs: Sequence[ir.Buffer], + ) -> None: + """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can + only be called in cpp_wrapper mode, and assumes that the input is a non-None + OpOverload. + + In the future, we may switch over to directly calling c10::Dispatcher if we need + to support more datatypes.""" + if raw_outputs: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] + if output_arg is not None + and not isinstance(raw_output_arg, ir.MutationOutput) + ] + else: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args # type: ignore[arg-type] + if output_arg is not None + ] + + dispatch_lines = IndentedBuffer() + dispatch_lines.writelines(declarations_before_scope) + dispatch_lines.writeline("{") + + with dispatch_lines.indent(): + tmp_var_number = count() + + def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: + # Strip off any temporary references; we're in an indented context, so + # any saved-off variables will be auto-destroyed. + new_codegen_arg = codegen_arg.removeprefix("&temporary_reference(") + if new_codegen_arg != codegen_arg: + # If we removed temporary_reference, there's a good chance the + # variable ends with get() (which would retrieve an ATenTensorHandle + # from a temporary RAII handle). Strip that off too, since we're + # going to save this in a temporary RAII handle. + if codegen_arg.endswith(".get())"): + codegen_arg = new_codegen_arg.removesuffix(".get())") + else: + codegen_arg = new_codegen_arg.removesuffix(")") + + if isinstance(arg_type, torch.OptionalType): + # If we have a pointer to a variable, strip it off and let + # from handle any internal pointers. + codegen_arg = codegen_arg.removeprefix("&") + + if codegen_arg == "nullptr": + return "from(std::nullopt)" + + var_name = f"tmp_var_{next(tmp_var_number)}" + dispatch_lines.writeline( + f"std::optional {var_name}{{{parse_arg(arg_type.getElementType(), codegen_arg)}}};" + ) + return f"from({var_name})" + + raii_var = self.create_tmp_raii_handle_var_if_needed( + codegen_arg, dispatch_lines + ) + temp_handle = raii_var != codegen_arg + + if isinstance(arg_type, torch.TensorType): + if not temp_handle: + # If the RAII tensor being referenced _isn't_ a temporary, + # scoped to this fallback call, then create a new handle + # referencing it which from can steal. + var_name = f"tmp_var_{next(tmp_var_number)}" + dispatch_lines.writeline(f"AtenTensorHandle {var_name};") + dispatch_lines.writeline( + f"aoti_torch_new_tensor_handle({raii_var}, &{var_name});" + ) + return f"from({var_name})" + # If the RAII tensor _is_ a temporary scoped to this fallback call, + # simply release and steal the handle. + return f"from({raii_var}.release())" + return f"from({codegen_arg})" + + codegen_args = get_args() + ivalue_args = ( + parse_arg(a.type, c) + for a, c in zip(op_overload._schema.arguments, codegen_args) + ) + array_len = max(len(codegen_args), len(output_args)) + dispatch_lines.writeline( + f"std::array dispatch_vars{{{', '.join(ivalue_args)}}};" + ) + dispatch_lines.writeline("AOTI_TORCH_ERROR_CODE_CHECK(") + with dispatch_lines.indent(): + dispatch_lines.writeline( + f'aoti_torch_call_dispatcher("{op_overload._schema.name}", "{op_overload._schema.overload_name}", dispatch_vars.data())' # noqa: B950 + ) + dispatch_lines.writeline(");") + + if len(output_args) == 1 and (output := output_args[0]) is not None: + # result is a single tensor + dispatch_lines.writeline( + f"{output} = to(dispatch_vars[0]);" + ) + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue + dispatch_lines.writeline( + f"{output_arg} = to(dispatch_vars[{idx}]);" + ) + + dispatch_lines.writeline("}") + self.writelines(dispatch_lines.getvalue().splitlines()) + + def generate_fallback_kernel_with_runtime_lookup_python( + self, + buf_name: str, + python_kernel_name: str, + op_overload: torch._ops.OpOverload, + raw_args: Sequence[Any], + output_args: Sequence[Optional[str]], + raw_outputs: Sequence[ir.Buffer], + ) -> None: + """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can + only be called in cpp_wrapper mode, and assumes that the input is a non-None + OpOverload. + + This function calls into Python to dispatch, which allows it to handle datatypes + that cannot be contained in StableIValue, at the cost of some performance.""" + self.load_custom_op_wrapper() + + num_args = len(raw_args) + py_args_var = f"py_args_{next(self.arg_var_id)}" + # First arg is always the python op name + lines = textwrap.dedent( + f""" + RAIIPyObject {py_args_var}(PyTuple_New({num_args + 1})); + if (!{py_args_var}) {{ + throw std::runtime_error("PyTuple_New {py_args_var} failed"); + }} + PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}")); + """ + ) + + for idx, (raw_arg, schema_arg) in enumerate( + zip(raw_args, op_overload._schema.arguments) + ): + lines += self.generate_py_arg( + py_args_var, idx + 1, raw_arg, schema_arg.real_type + ) + + lines += textwrap.dedent( + f""" + // Call the custom op in Python + RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var})); + if (!py_{buf_name}) {{ + if (PyErr_Occurred()) {{ + return; + }} + throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed"); + }} + """ + ) + + if len(output_args) == 1 and (output := output_args[0]) is not None: + # result is a single tensor + lines += f"{output} = reinterpret_cast(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));\n" + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue + lines += f"{output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));\n" # noqa: B950 + + if raw_outputs: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] + if output_arg is not None + and not isinstance(raw_output_arg, ir.MutationOutput) + ] + else: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args # type: ignore[arg-type] + if output_arg is not None + ] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) + + def generate_fallback_kernel_with_runtime_lookup_aot( + self, + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + output_args: _OUTPUT_ARGS_TYPE, + raw_outputs: Sequence[ir.Buffer], + ) -> None: + ( + tensor_call_args, + int_call_args, + ) = self.generate_extern_kernel_args_decl_if_needed( + op_overload, + raw_args, + output_args, + raw_outputs, + ) + # force both temporary arrays to generate mutable data pointers, since the proxy + # executor signature requires that datatype + int_call_str = self._generate_temporary_array_pointer( + "int64_t", int_call_args, force_mutable=True + ) + tensor_call_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", tensor_call_args, force_mutable=True + ) + + extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1 + self.writeline( + f"aoti_torch_proxy_executor_call_function(proxy_executor, " + f"{extern_kernel_node_index}, " + f"{len(int_call_args)}, " + f"{int_call_str}, " + f"{len(tensor_call_args)}, " + f"{tensor_call_str});" + ) + + def generate_reset_kernel_saved_flags(self): + pass + + def generate_save_uncompiled_kernels(self): + pass + + def c_type_for_prim_type(self, val, type_) -> str: + if isinstance(type_, torch.OptionalType): + return f"{self.c_type_for_prim_type(val, type_.getElementType())}*" + elif isinstance(type_, torch.TensorType): + return "AtenTensorHandle" + elif isinstance(type_, (torch.IntType, torch.SymIntType)): + return "int64_t" + elif isinstance( + type_, (torch.BoolType, torch.SymBoolType, torch.EnumType) + ) or repr(type_) in ("Layout", "MemoryFormat", "ScalarType"): + return "int32_t" + elif isinstance(type_, torch.FloatType): + return "double" + elif isinstance(type_, torch.NumberType): + if isinstance(val, bool): + return "int32_t" + elif isinstance(val, (int, float)): + return "double" + elif val is None: + # This could happen when val is an optional value + return "double" + else: + raise AssertionError( + f"Unexpected type in c_type_for_prim_type: {type_=}" + ) + elif isinstance(type_, torch.StringType): + return "const char*" + else: + raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") + + def val_to_arg_str_for_prim_type(self, val, type_) -> str: + # TODO: not using type_ as the first step of refactoring. Will update this later. + if isinstance(val, bool): + return "1" if val else "0" + elif isinstance(val, int): + # uint64_t is long on Linux, but long long on MacOS and Windows + return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L" + elif isinstance(val, complex): + return f"c10::complex{{ {self.generate_float_value(val.real)}, {self.generate_float_value(val.imag)} }}" + elif isinstance(val, str): + return f'"{val}"' + elif isinstance( + val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox) + ): + return val.codegen_reference() + elif isinstance(val, torch.device): + return self.codegen_device(val) + elif isinstance(val, torch.dtype): + return self.codegen_dtype(val) + elif isinstance(val, torch.layout): + return self.codegen_layout(val) + elif isinstance(val, torch.memory_format): + return self.codegen_memory_format(val) + elif isinstance(val, float): + return self.generate_float_value(val) + elif isinstance(val, (list, tuple)): + # FIXME: This happens because type_ is not always properly set to torch.ListType + return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}" + elif isinstance(val, SymTypes): + return cexpr(val.node.expr) + elif isinstance(val, sympy.Expr): + return cexpr(val) + else: + return repr(val) + + def val_to_arg_str(self, val, type_=None) -> str: + if val is None: + # None needs special care. It either represent nullopt or an empty tensor + if type_ is None or isinstance(type_, torch.OptionalType): + if type_ is not None and isinstance( + type_.getElementType(), + ( + torch.DeviceObjType, + torch.ListType, + torch.TupleType, + ), + ): + return "nullptr, 0" + return "nullptr" + + if isinstance(type_, torch.TensorType): + # create an empty tensor, the equivalent of at::Tensor() + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {var_name}({var_name}_handle);") + return var_name + + raise AssertionError("Can not map None to a known data type") + + if isinstance(type_, torch.OptionalType): + element_type = type_.getElementType() + arg_str = self.val_to_arg_str(val, element_type) + # Handle optional iterables as a special case. Utilize the + # temporary_reference function to avoid saving them off and increasing + # memory usage. + if isinstance(element_type, (torch.ListType, torch.TupleType)): + main_value, aux = arg_str.rsplit(", ", maxsplit=1) + return f"&temporary_reference({main_value}), {aux}" + + # Handle optional tensors as a special case, as above. + if isinstance(element_type, torch.TensorType): + base_handle = self.val_to_arg_str(val, element_type) + return f"&temporary_reference({base_handle}.get())" + + var_name = f"var_{next(self.arg_var_id)}" + if isinstance(element_type, torch.DeviceObjType): + main_value, aux = arg_str.rsplit(", ", maxsplit=1) + self.writeline(f"auto {var_name} = {main_value};") + return f"&{var_name}, {aux}" + + self.writeline( + f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {arg_str};" + ) + return f"&{var_name}" + + if isinstance(type_, (torch.ListType, torch.TupleType)): + assert isinstance(val, (list, tuple)), ( + f"{val} does not match with arg type {type_}" + ) + element_type = type_.getElementType() + + if len(val) == 0: + # Zero-size array is not supported in the C or C++ standard, so return a + # nullptr. + return "nullptr, 0" + + result = [self.val_to_arg_str(x, element_type) for x in val] + if isinstance(element_type, torch.TensorType): + result = [f"{t}.get()" for t in result] + + c_type = self.c_type_for_prim_type(val[0], element_type) + # see the comment in self._generate_temporary_array_pointer for an + # explanation of why this c_type gets modified + if isinstance(element_type, torch.OptionalType) and not c_type.startswith( + "const" + ): + c_type = f"const {c_type}" + + # need to pass the array length, because we can't use the std::array member + # function + return ( + f"{self._generate_temporary_array_pointer(c_type, result)}, {len(val)}" + ) + + val_is_scalar = isinstance(val, (bool, complex, float, int, *SymTypes)) + if isinstance(type_, torch.TensorType) and val_is_scalar: + val_str = self.val_to_arg_str_for_prim_type(val, None) + return self.codegen_scalar_to_tensor(val_str) + + return self.val_to_arg_str_for_prim_type(val, type_) + + def create_tmp_raii_handle_var_if_needed( + self, handle: str, writer: Optional[Union[HasWriteLine, list[str]]] = None + ) -> str: + """If the input handle is an rvalue RAII tensor, creates an lvalue variable for + it in writer. Returns a variable name that can be used to access handle.""" + if not handle.startswith( + ( + "borrow_arrayref_tensor_as_tensor(", + "copy_arrayref_tensor_to_tensor(", + "wrap_with_raii_handle_if_needed(", + "RAIIAtenTensorHandle(", + ) + ): + return handle + + tmp_var_name = f"var_{next(self.arg_var_id)}" + call_str = f"auto {tmp_var_name} = {handle};" + + writer = writer if writer is not None else self + if isinstance(writer, list): + writer.append(call_str) + else: + writer.writeline(call_str) + + return tmp_var_name diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py new file mode 100644 index 0000000000000000000000000000000000000000..0d53db7f32c664842fcec791e58f595ffa4985cf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -0,0 +1,878 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._ops + +from .. import config, ir +from ..utils import sympy_product +from ..virtualized import V +from .cpp_utils import DTYPE_TO_CPP +from .cpp_wrapper_cpu import CppWrapperCpu +from .wrapper import ( + BufferLike, + EnterSubgraphLine, + ExitSubgraphLine, + MemoryPlanningLine, + MemoryPlanningState, + PythonWrapperCodegen, +) + + +BufferName = str + +# Default thread stack sizes vary by platform: +# - Linux: 8 MB +# - macOS: 512 KB +# - Windows: 1 MB +# Just pick something comfortably smaller than the smallest for now. +MAX_STACK_ALLOCATION_SIZE = 1024 * 100 + + +class CppWrapperCpuArrayRef(CppWrapperCpu): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + + This class is forked from CppWrapperCpu, with a difference that tensors may be + represented as ArrayRef, see torch/csrc/inductor/aoti_runtime/arrayref_tensor.h + """ + + def __init__(self): + super().__init__() + assert self.device == "cpu", "ArrayRefTensor only supported on CPU!" + self.allow_stack_allocation = config.aot_inductor.allow_stack_allocation + self.stack_allocated_buffers: dict[BufferName, BufferLike] = {} + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperCpuArrayRef() + + @staticmethod + def get_input_cpp_type(input): + assert config.aot_inductor.use_minimal_arrayref_interface + + if isinstance(input, sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input) + assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" + return DTYPE_TO_CPP[dtype] + return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" + + @staticmethod + def get_device_include_path(device: str) -> str: + assert device == "cpu", "ArrayRef only supported on CPU!" + if V.graph.aot_mode: + return "#include " + return "#include " + + def codegen_input_numel_asserts(self): + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + numel = buf.get_numel() + self.prefix.writeline(f"assert_numel({name}, {numel});") + + def generate_extern_kernel_alloc(self, *args, **kwargs): + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_extern_kernel_alloc(*args, **kwargs) + + def generate_extern_kernel_out(self, *args, **kwargs): + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_extern_kernel_out(*args, **kwargs) + + def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_fallback_kernel(node) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + assert not triton, ( + "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU" + ) + assert arg_types is not None and len(call_args) == len(arg_types), ( + "Mismatch call_args and arg_types in generate_kernel_call" + ) + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"auto* {var_name} = get_data_ptr_wrapper({arg});") + new_args.append(f"({arg_types[idx]})({var_name})") + else: + # arg is a scalar + new_args.append(arg) + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if ( + config.aot_inductor.use_minimal_arrayref_interface + and not V.graph.is_const_graph + ): + input_cpp_types = ", ".join( + f"{CppWrapperCpuArrayRef.get_input_cpp_type(x)}" + for x in V.graph.graph_inputs.values() + ) + output_arrayref_types = ", ".join( + f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" + for x in V.graph.graph_outputs + ) + + self.prefix.splice( + f""" + using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; + using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; + """ + ) + + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + + assert V.graph.const_wrapper_code is not None + self.prefix.splice(V.graph.const_wrapper_code) + + assert V.graph.const_kernel_code is not None + self.kernel_declarations.splice(V.graph.const_kernel_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + + self.generate_input_output_runtime_checks() + run_impl_proto += """ + __check_inputs_outputs(input_handles, output_handles); + """ + + if config.aot_inductor.use_minimal_arrayref_interface: + self.prefix.splice( + """ + template <> + AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< + AOTInductorModelInputs, AOTInductorModelOutputs>( + const AOTInductorModelInputs& inputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + self.suffix.splice(run_impl_proto) + self.suffix.splice( + """ + AOTInductorModelInputs inputs; + convert_handles_to_inputs(input_handles, inputs); + auto outputs = run_impl_minimal_arrayref_interface( + inputs, stream, proxy_executor); + // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this + // interface to perform well for a DSO using the minimal arrayref interface, all we need + // to do is provide ThreadLocalCachedTensor for each one! + convert_outputs_to_handles(outputs, output_handles); + } + """ + ) + + self.suffix.splice( + """ + extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( + AOTInductorModelHandle model_handle, + const AOTInductorModelInputs& inputs, + AOTInductorModelOutputs& outputs) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + outputs = model->run_impl_minimal_arrayref_interface( + inputs, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); + }) + } + """ + ) + else: + self.prefix.splice(run_impl_proto) + else: + # cpp entry function for JIT with cpp wrapper + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not config.aot_inductor.use_minimal_arrayref_interface: + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release release;") + + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + if config.aot_inductor.use_minimal_arrayref_interface: + self.prefix.writeline( + f"auto {input_key} = std::get<{idx}>(inputs);" + ) + continue + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] # type: ignore[arg-type] + ) + assert dtype is not None, ( + "Fails to get the dtype of the sympy.Expr" + ) + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + self.prefix.writeline( + f"""auto {constants_key} = constants_->at({idx});""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) + + self.codegen_inputs() + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + if config.aot_inductor.use_minimal_arrayref_interface: + # TODO: input shape checking for regular tensor interface as well? + self.codegen_input_numel_asserts() + else: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());" + ) + + def generate_return(self, output_refs: list[str]): + cst_names = V.graph.constants.keys() + arr_iface = ( + not V.graph.is_const_graph + and config.aot_inductor.use_minimal_arrayref_interface + ) # For brevity. + + def use_thread_local_cached_output_tensor(idx, output): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + cache_type = "Array" if arr_iface else "Tensor" + self.wrapper_call.writeline( + f"thread_local ThreadLocalCachedOutput{cache_type}> " + f"{cached_output_name}({output});" + ) + if arr_iface: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + output_entry = f"std::get<{idx}>(output_arrayref_tensors)" + element_type = f"std::decay_t" + self.wrapper_call.writeline( + f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " + f"output_handles[{idx}]));" + ) + + if arr_iface: + self.wrapper_call.writeline( + "AOTInductorModelOutputs output_arrayref_tensors;" + ) + + output2idx: dict[str, int] = {} + for idx, output in enumerate(output_refs): + if output == "nullptr": + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + self.wrapper_call.writeline( + f"if constexpr ({output_is_tensor_handle_expr}) {{" + ) + with self.wrapper_call.indent(): + if arr_iface: + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + self.wrapper_call.writeline( + f"thread_local RAIIAtenTensorHandle {cached_output_name};" + ) + if is_constant_buffer: + # NOTE(return_constant): In some rare cases where we return + # a constant, we have to return a copy of this constant, + # because (1) constants are not owned by the Model instance + # (2) constants remain the same cross inference runs, + # assuming they are not updated at runtime Basically, we + # cannot release or transfer the ownership of any original + # constant to the user. + self.wrapper_call.writeline( + f"AtenTensorHandle {cached_output_name}_tmp;" + ) + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" + ) + self.wrapper_call.writeline( + f"{cached_output_name} = {cached_output_name}_tmp;" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name} = {output}.release();" + ) + self.wrapper_call.writeline( + f"convert_handle_to_arrayref_tensor({cached_output_name}, " + f"std::get<{idx}>(output_arrayref_tensors));" + ) + else: + if is_constant_buffer: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + self.wrapper_call.writeline("} else {") + with self.wrapper_call.indent(): + use_thread_local_cached_output_tensor(idx, output) + self.wrapper_call.writeline("}") + + if output not in output2idx: + output2idx[output] = idx + if arr_iface: + self.wrapper_call.writeline("return output_arrayref_tensors;") + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + # TODO: integrate memory planning & stack allocation? + self.allow_stack_allocation = False + + def memory_plan_reuse(self): + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + self.allow_stack_allocation = ( + self.allow_stack_allocation is not False + and config.aot_inductor.allow_stack_allocation + and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE + ) + + def can_stack_allocate_buffer(self, buffer): + return ( + self.allow_stack_allocation + and buffer.get_device().type == "cpu" + and self.can_prove_buffer_has_static_shape(buffer) + and ir.is_contiguous_strides_for_shape( + buffer.get_stride(), buffer.get_size() + ) + ) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout) + or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers) + or ( + config.aot_inductor.use_minimal_arrayref_interface + and V.graph.aot_mode + and buffer.get_name() in V.graph.graph_inputs + ) + else f"{buffer.get_name()}.reset();" + ) + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + buffer if self.can_stack_allocate_buffer(buffer) else None, + ) + + def make_allocation( + self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None + ): + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(orig_stride) + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + if buffer_if_can_stack_allocate is not None: + self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate + cpp_type = DTYPE_TO_CPP[dtype] + numel = buffer_if_can_stack_allocate.get_numel() + # Note: we don't zero storage because empty_strided doesn't zero either. + self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") + args = [ + f"{name}_storage", + size_array_var, + stride_array_var, + device_type, + device_idx, + ] + return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" + + args = [ + str(len(shape)), + size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{name}_handle", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + ) + + return f"RAIIAtenTensorHandle {name}({name}_handle);" + + def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + if old_name in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline + ) + if reinterpret_view in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + # The only way to get into this case is via an exact buffer reuse, since all + # other options result in a new tensor handle. + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + return f"{self.declare}{new_name} = {reinterpret_view}{del_line} // reuse" + + def _assert_safe_to_use_borrow_arrayref_tensor_as_tensor(self): + # Borrowing arguments to shim functions is only safe because we know + # that the arguments can't be stack-allocated. Otherwise, to be sure + # we can't return a dangling pointer, we need to either 1) be + # certain that the shim function cannot return an alias of a + # borrowed argument, or 2) be certain that the returned Tensor from + # the shim function cannot escape. + assert self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(), ( + "borrowing arguments to shim functions is unsafe with " + "stack allocation on! (see comment above this assertion)" + ) + + def is_safe_to_use_borrow_arrayref_tensor_as_tensor(self): + return not self.allow_stack_allocation and not self.stack_allocated_buffers + + def generate_c_shim_extern_kernel_call( + self, kernel: str, args: list[str], device: str, **_ + ) -> None: + # In the abi_compatible mode, we call fallback aten ops through a C shim layer + # Setting self.allow_stack_allocation to False because the exchange between + # ArrayRefTensor and at::Tensor is still fragile. + self.allow_stack_allocation = False + + wrapped_args = [] + for arg in args: + # We only really *need* borrow_arrayref_tensor_as_tensor for + # ArrayRefTensors. The code flowing into here uses `0` for nullptr, which + # borrow_arrayref_tensor_as_tensor would blindly coerce to int, so just + # avoid wrapping integers. Name matching is to find tensor is hacky, but + # fixing all the ArrayRefTensor issues is not a priority for now. + if isinstance(arg, str) and arg.startswith( + ("buf", "arg", "wrap_with_raii_handle_if_needed") + ): + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + arg = f"borrow_arrayref_tensor_as_tensor({arg})" + wrapped_args.append(arg) + + super().generate_c_shim_extern_kernel_call( + kernel, wrapped_args, device, debug_args=args + ) + + def generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + ): + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + + # call the ABI shim function instead of the ATen one + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + inputs_wrapped = [ + (f"borrow_arrayref_tensor_as_tensor({x})" if isinstance(x, str) else str(x)) + for x in inputs + ] + line = f"{cpp_kernel_name}(borrow_arrayref_tensor_as_tensor({output}), {','.join(inputs_wrapped)}" + + if python_kernel_name.startswith("aten.scatter_reduce"): + line += f", {','.join(kwargs)}" + else: + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert reduce is None, ( + "Expect reduce to be None for aten.scatter_ with scalar src" + ) + line += ");" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding + # tensor prematurely deallocated, thus the temporary array trick here. + indices_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", + [f"borrow_arrayref_tensor_as_tensor({i})" for i in indices], + ) + args = [ + f"borrow_arrayref_tensor_as_tensor({x})", + indices_str, + str(len(indices)), + f"borrow_arrayref_tensor_as_tensor({values})", + accumulate, + ] + args.insert( + 0, f"borrow_arrayref_tensor_as_tensor({x})" + ) # set x as the output tensor, this fallback mutates x. + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + super().generate_fallback_kernel_with_runtime_lookup( + buf_name, python_kernel_name, get_args, op_overload, raw_args, outputs + ) + + def codegen_device_copy(self, src, dst, non_blocking: bool): + # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, + # while stack-allocation results in ArrayRefTensor + # so disable stack allocation here + self.allow_stack_allocation = False + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));" + ) + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + """Returns a newly-created, temporary RAII tensor handle containing the + reinterpreted tensor data. Callers of this function are responsible for saving + the handle if persistent access is needed.""" + dim = str(len(size)) + + def create_reinterpret_call() -> str: + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + self.codegen_shape_tuple(size), + writeline, + known_statically=self.is_statically_known_list_of_ints(size), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + self.codegen_shape_tuple(stride), + writeline, + known_statically=self.is_statically_known_list_of_ints(stride), + graph=self.get_codegened_graph(), + ), + offset, + ] + return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))" + + def create_new_tensor_handle() -> tuple[str, list[str]]: + # Calling reset() on ArrayRefTensor does nothing, since the array is + # const-allocated on the stack. Thus, it's safe to return a reference to + # the original array. + if (name := data.get_name()) in self.stack_allocated_buffers: + return name, [] + + tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + if ( + size == data.layout.size + and stride == data.layout.stride + and offset == data.layout.offset + and (dtype is None or dtype == data.dtype) + ): + final_tensor_str, call_strs = create_new_tensor_handle() + for line in call_strs: + writeline(line) + return final_tensor_str + + return super().codegen_reinterpret_view( + data, size, stride, offset, writeline, dtype + ) + + def val_to_arg_str(self, val, type_=None) -> str: + if ( + val is not None + and isinstance(type_, torch.OptionalType) + and isinstance(type_.getElementType(), torch.TensorType) + ): + # Handle optional tensors as a special case, as in the parent class. + base_handle = self.val_to_arg_str(val, torch.TensorType) + if config.aot_inductor.use_minimal_arrayref_interface: + if self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(): + base_handle = f"borrow_arrayref_tensor_as_tensor({base_handle})" + else: + base_handle = f"copy_arrayref_tensor_to_tensor({base_handle})" + return f"&temporary_reference({base_handle}.get())" + + return super().val_to_arg_str(val, type_) + + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + + # We know that item_ doesn't alias the input, so borrowing should be safe. + tensor = f"borrow_arrayref_tensor_as_tensor({tensor})" + + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + + # We know that item_ doesn't alias the input, so borrowing should be safe. + tensor = f"borrow_arrayref_tensor_as_tensor({tensor})" + + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..430511ce4ebf08a6dfcb39c74e6db4a7841e6486 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -0,0 +1,717 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import re +from itertools import count, zip_longest +from typing import Any, Optional, Union +from typing_extensions import Self + +import sympy + +import torch +from torch import dtype as torch_dtype +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name +from torch._inductor.runtime.runtime_utils import dynamo_timed + +from .. import config +from ..codecache import CudaKernelParamCache +from ..ir import ( + GraphPartitionSignature, + TensorBox, + TMADescriptorExperimental, + TMADescriptorStable, +) +from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .common import get_device_op_overrides, TritonScratchWorkspace +from .cpp_utils import cexpr +from .cpp_wrapper_cpu import CppWrapperCpu +from .multi_kernel import MultiKernelCall +from .triton_utils import should_unwrap_unspec_arg +from .wrapper import PythonWrapperCodegen, SymbolicCallArg + + +_cpp_string_literal_escapes = { + "\\": "\\\\", + '"': '\\"', + "\n": "\\n", + "\t": "\\t", + "\r": "\\r", +} +_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]') + + +def cpp_string_literal(s: str) -> str: + escaped = _cpp_string_literal_pattern.sub( + lambda match: _cpp_string_literal_escapes[match.group(0)], s + ) + return f'"{escaped}"' + + +@dataclasses.dataclass +class DeferredTritonCallWrapper: + """ + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred generating the final wrapper around + the triton kernel until right before the prefix is written. + """ + + wrapper_name: str + kernel_name: str + kernel_name_to_body: dict[str, str] + arg_types: list[Any] + + def generate(self, wrapper: CppWrapperGpu): + """ + Generate the GPU kernel definition, as well as load and launch code. + """ + prefix = wrapper.prefix + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + params = CudaKernelParamCache.get(self.kernel_name) + assert params, f"CudaKernelParamCache not populated for {self.kernel_name}" + def_args = params["def_args"] + arg_types = self.arg_types + inductor_meta = params["inductor_meta"] + + if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types): + # extra_launcher_args should already be in def_args + assert len(def_args) == len(arg_types) - len( + inductor_meta["extra_launcher_args"] + ) + arg_types = arg_types + [SymbolicCallArg] * len( + inductor_meta["extra_launcher_args"] + ) + + if not V.graph.aot_mode: + prefix.writeline( + maybe_hipify_code_wrapper( + f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;" + ) + ) + kernel_var_name = self.kernel_name + else: + kernel_var_name = f"kernels_.{self.kernel_name}" + + # tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types + template_types = [ + f"typename {name}_type_" + for name, arg_type in zip(def_args, arg_types) + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)) + ] + if V.graph.aot_mode: + template_types.append("typename kernels_type_") + if template_types: + prefix.writeline(f"template <{', '.join(template_types)}>") + prefix.writeline(f"static inline void {self.wrapper_name}(") + with prefix.indent(): + assert len(def_args) == len(arg_types), (def_args, arg_types) + for name, arg_type in zip(def_args, arg_types): + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)): + prefix.writeline(f"const {name}_type_& {name},") + elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)): + prefix.writeline(f"int64_t {name},") + elif arg_type is float: + prefix.writeline(f"float {name},") + elif arg_type is bool: + prefix.writeline(f"bool {name},") + else: + raise ValueError(f"Unexpected arg type {arg_type}") + prefix.writeline("int32_t device_idx_,") + prefix.writeline( + maybe_hipify_code_wrapper( + f"{wrapper.device_codegen.cpp_stream_type()} stream_," + ) + ) + if V.graph.aot_mode: + prefix.writeline("kernels_type_& kernels_,") + prefix.writeline( + "const std::optional& cubin_dir_ = std::nullopt" + ) + prefix.writeline("){") + with prefix.indent(): + if V.graph.aot_mode: + # Emit the original Triton kernel for debugging purposes + prefix.writeline("/*") + prefix.splice(self.kernel_name_to_body[self.kernel_name]) + prefix.writeline("*/") + self.generate_grid(prefix, inductor_meta, params) + self.generate_load_kernel(prefix, kernel_var_name, params) + self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params) + prefix.writeline("}") + + if not config.aot_inductor.embed_kernel_binary: + # Ensure the cubin file is included in the package + V.graph.wrapper_code.additional_files.append( + params[get_cpp_wrapper_cubin_path_name()] + ) + + def generate_grid( + self, + prefix: IndentedBuffer, + inductor_meta: dict[str, Any], + params: dict[str, Any], + ): + from ..runtime.triton_heuristics import GridExpr + + grid = GridExpr.from_meta(inductor_meta, params["config"], mode="cpp") + for line in grid.prefix: + prefix.writeline(line) + prefix.splice( + f"""\ + uint32_t grid_0 = {grid.x_grid}; + uint32_t grid_1 = {grid.y_grid}; + uint32_t grid_2 = {grid.z_grid}; + """ + ) + prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;") + + def generate_load_kernel(self, prefix, kernel_var_name, params): + prefix.writeline(f"if ({kernel_var_name} == nullptr) {{") + with prefix.indent(): + embed_kernel_args = [f"__{params['inductor_meta']['kernel_name']}_start"] + if torch.xpu.is_available(): + # XPU needs the end address of the kernel to calculate the size of the kernel binary. + embed_kernel_args.append( + f"__{params['inductor_meta']['kernel_name']}_end" + ) + + load_kernel_args = ( + [ + *embed_kernel_args, + cpp_string_literal(params["mangled_name"]), + str(params["shared_mem"]), + ] + if V.graph.aot_mode and config.aot_inductor.embed_kernel_binary + else [ + cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]), + cpp_string_literal(params["mangled_name"]), + str(params["shared_mem"]), + "cubin_dir_", + ] + ) + prefix.writeline( + f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); " + ) + prefix.writeline("}") + + def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): + triton_meta = params["triton_meta"] + assert len(self.arg_types) == len(params["def_args"]), ( + self.arg_types, + params["def_args"], + ) + arg_type_loookup = dict(zip(params["def_args"], self.arg_types)) + # difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants + call_args = [ + name for name in params["call_args"] if name not in triton_meta["constants"] + ] + arg_types = [arg_type_loookup[name] for name in call_args] + arg_signatures = [triton_meta["signature"][name] for name in call_args] + call_args_str = wrapper.generate_args_decl( + prefix, + call_args, + arg_types, + arg_signatures, + workspace_size=params.get("global_scratch") or 0, + ) + prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") + launch_kernel_args = [ + kernel_var_name, + "grid_0", + "grid_1", + "grid_2", + str(params["num_warps"]), + str(params["shared_mem"]), + "kernel_args_", + "stream_", + ] + prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});") + + +class CppWrapperGpu(CppWrapperCpu): + """ + Generates cpp wrapper for running on GPU and calls CUDA kernels + """ + + def __init__(self) -> None: + self.device = get_gpu_type() + self.device_codegen = get_device_op_overrides(self.device) + super().__init__() + self.grid_id = count() + self._kernel_name_to_body: dict[str, str] = {} + self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {} + self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT" + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperGpu() + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + super().write_header() + self.header.splice( + maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) + ) + + @cache_on_self + def write_tma_descriptor_helpers_once(self): + self.header.splice(self.device_codegen.tma_descriptor_helpers()) + + def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str: + name = f"stream{device_idx}" + self.writeline( + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_stream_type()} {name};" + ) + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({device_idx}, (void**)&{name}));" + ) + return name + + def get_autotuning_input_name(self, idx): + return f"{self.autotune_input_prefix}_{idx}" + + def codegen_inputs(self): + # See Note: [Input Alignment handling in Inductor] + # + # JIT Inductor does not guard on input alignment. It relies on copy_misaligned_inputs to + # copy misaligned inputs to aligned buffers. For AOTInductor, we need to do the same in cpp. + + if config.is_fbcode(): + # TODO: This is added because FC. Remove this once the newly added shim symbols, + # e.g. aoti_torch_clone_preserve_strides, have landed + return super().codegen_inputs() + + if V.graph.aot_mode and V.graph.inputs_to_check: + for idx in V.graph.inputs_to_check: + input_name = V.graph.graph_input_names[idx] + assert input_name in V.graph.graph_inputs, ( + f"{input_name} not found in graph inputs" + ) + value = V.graph.graph_inputs[input_name] + assert isinstance(value, TensorBox), ( + f"{input_name} is expected to be tensor but found as {type(value)}" + ) + warn_msg = ( + f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, " + "but it is not aligned at run time. Copying to an aligned tensor " + "to guarantee correctness, but expect a performance hit." + ) + self.prefix.splice( + f""" + if ((long({input_name}.data_ptr()) & ({GPU_ALIGN_BYTES} -1)) != 0) {{ + AOTI_TORCH_WARN("{warn_msg}"); + AtenTensorHandle {input_name}_aligned; + aoti_torch_clone_preserve_strides({input_name}, &{input_name}_aligned); + {input_name} = std::move(RAIIAtenTensorHandle({input_name}_aligned)); + }} + """ + ) + + super().codegen_inputs() + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + if gpu: + self._kernel_name_to_body[kernel_name] = kernel_body + if config.triton.autotune_at_compile_time: + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen._define_kernel_helper( + self, kernel_name, kernel_body, metadata, gpu, cpp_definition + ) + else: + return CppWrapperCpu._define_kernel_helper( + self, kernel_name, kernel_body, metadata, gpu, cpp_definition + ) + + def generate(self, is_inference): + with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True): + return super().generate(is_inference) + + def finalize_prefix(self): + """Define the triton kernels now that autotuning is finished""" + old_prefix = self.prefix # new content should go at start of prefix + + # Generating triton kernel callers can modify the prefix (cached dtypes), + # so do this before running finalize_prefix(), but put the generated code + # after the finalize_prefix() code. + self.prefix = IndentedBuffer() + for kernel in self._triton_call_wrappers.values(): + self.prefix.writeline("\n") + kernel.generate(self) + triton_prefix = self.prefix + + self.prefix = IndentedBuffer() + super().finalize_prefix() + + self.prefix.splice(triton_prefix) + + self.prefix.writeline("\n") + self.prefix.splice(old_prefix) + + def generate_tma_descriptor(self, desc): + self.write_tma_descriptor_helpers_once() + + if isinstance(desc, TMADescriptorExperimental): + self._generate_experimental_tma_descriptor(desc) + else: + assert isinstance(desc, TMADescriptorStable) + self._generate_stable_tma_descriptor(desc) + + def _generate_experimental_tma_descriptor(self, desc): + # generate data pointer for the source tensor + source = self.generate_args_decl( + code=self, + call_args=[self.val_to_arg_str(desc.tensor)], + arg_types=[desc.tensor.get_dtype()], + arg_signatures=[None], + # these args are passed to initNDTMADescriptor, which is NOT a triton kernel + is_triton_kernel=False, + ) + + desc_name = desc.name + self.writeline(f"alignas(64) CUtensorMap {desc_name};") + + # `source` is in the form of `&var_x`, where `var_x` is the data pointer + # (CUdeviceptr); we dereference `source` and cast to `void*` to pass to + # the data pointer of the source tensor to the helper function + # `init{1,2}DTMADescriptor` + ptr = f"reinterpret_cast(*({source}))" + dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.dims) + block_dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.block_dims) + element_size = self.val_to_arg_str(desc.element_size) + fn = f"init{desc.rank}DTMADescriptor" + args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}" + self.writeline(f"{fn}({args});") + + def _generate_stable_tma_descriptor(self, desc): + source = self.generate_args_decl( + code=self, + call_args=[self.val_to_arg_str(desc.tensor)], + arg_types=[desc.tensor.get_dtype()], + arg_signatures=[None], + # these args are passed to initNDTMADescriptor, which is NOT a triton kernel + is_triton_kernel=False, + ) + + desc_name = desc.name + # Pack the relevant information into a StableTMADescriptor struct. + # See [Note: AOTI TMA Stable handling] for more details. + self.writeline(f"alignas(64) StableTMADescriptor {desc_name};") + + def fill_array(name, values): + for i, val in enumerate(values): + self.writeline(f"{name}[{i}] = {val};") + + ptr = f"reinterpret_cast(*({source}))" + rank = len(desc.tensor.get_size()) + + fill_array(f"{desc_name}.block_shape", desc.block_shape) + fill_array(f"{desc_name}.global_shape", desc.tensor.get_size()) + fill_array(f"{desc_name}.strides", desc.tensor.get_stride()) + + element_size = self.val_to_arg_str(desc.tensor.get_dtype().itemsize) + fn = "initTMADescriptor" + args = ", ".join( + str(x) + for x in [ + f"&{desc_name}.m", + ptr, + element_size, + rank, + f"{desc_name}.block_shape", + f"{desc_name}.global_shape", + f"{desc_name}.strides", + ] + ) + self.writeline(f"{fn}({args});") + + def generate_args_decl( + self, + code: Union[IndentedBuffer, Self], + call_args, + arg_types, + arg_signatures, + is_triton_kernel=True, + workspace_size=0, + ): + """ + Generates any declarations of args to pass into a kernel call, and then returns the arg names. + + In more detail: + * declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;` + * returns: a string with the list of args, e.g. "var_0, var_1" + + call_args: list of call arguments + arg_types: list of argument types + arg_signatures: list with signatures of all the args + is_triton_kernel: whether these are passed into a triton kernel or not. In particular, + calls to triton kernels will have an additional global scratch space + arg injected at the front of the arg list. + """ + new_args: list[str] = [] + + # Add more cases for other types as needed + signature2dtype = { + "i32": "int32_t", + "i64": "int64_t", + "fp32": "float", + } + + def signature_is_tma_desc(sig): + if not sig: + return False + if sig == "nvTmaDesc": + return True + if sig.startswith("tensordesc<"): + return True + return False + + def process_tma_stable_arg(arg, arg_type, arg_signature, var_name): + # [Note: AOTI TMA Stable handling] + # For most args, a single arg passed to the python triton interface + # maps to a single arg in the cubin interface. However, for host-side + # TMA descriptors, a single python arg turns into 1 + 2 * N args in the + # cubin interface (where N is the rank). + # + # To do this: at TMA codegen time (for aoti), we generate a struct + # (StableTMADescriptor) containing the necessary information; and then + # when we call the function (i.e. here), we unpack the struct members. + code.writeline(f"auto {var_name} = {cexpr(arg)};") + + result = [] + result.append(f"&{var_name}.m") + + # from https://github.com/triton-lang/triton/blob/16961b79bdac1b774b42d44e52fd55a266ec2866/third_party/nvidia/backend/driver.py#L111 # noqa: B950 + match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", arg_signature) + assert match is not None + shape = match.group(2) + ndim = shape.count(",") + 1 + + for i in range(ndim): + result.append(f"&{var_name}.block_shape[{i}]") + + for i in range(ndim): + result.append(f"&{var_name}.strides[{i}]") + + return result + + def process_args(arg, arg_type, arg_signature=None): + var_name = f"var_{next(self.arg_var_id)}" + # ignore tma descriptors, as host-side TMA descriptors need + # to be passed to the compiled Triton kernel by value + if isinstance(arg_type, UnwrapUnspecArg) and not signature_is_tma_desc( + arg_signature + ): + self.codegen_tensor_item( + arg_type.dtype, + arg, + var_name, + indented_buffer=code, + ) + new_args.append(f"&{var_name}") + elif isinstance(arg_type, torch_dtype) and not signature_is_tma_desc( + arg_signature + ): + device_ptr_type = self.device_codegen.cpp_device_ptr() + code.writeline( + maybe_hipify_code_wrapper( + f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" + ) + ) + new_args.append(f"&{var_name}") + elif arg_type in (sympy.Integer, int): + code.writeline(f"int {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + elif arg_type in (sympy.Float, float): + code.writeline(f"float {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + # For symbolic call arguments, examine the arg signatures from triton meta + # to explicitly cast to the right type + # Reason: `auto` can infer unexpected type against kernel input signature. + elif ( + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype.keys() + ): + code.writeline( + f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" + ) + new_args.append(f"&{var_name}") + elif arg_signature and arg_signature.startswith("tensordesc<"): + new_args.extend( + process_tma_stable_arg(arg, arg_type, arg_signature, var_name) + ) + else: + code.writeline(f"auto {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + + for arg, arg_type, arg_signature in zip_longest( + call_args, arg_types, arg_signatures + ): + process_args(arg, arg_type, arg_signature) + + if ( + is_triton_kernel + and ( + global_scratch := self.device_codegen.cpp_global_scratch( + next(self.arg_var_id), + workspace=TritonScratchWorkspace( + size=workspace_size, + generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)), + ), + ) + ) + is not None + ): + global_scratch_def, global_scratch_var = global_scratch + code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def]) + new_args.append(f"&{global_scratch_var}") + + return ", ".join(new_args) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Override the default value of argument 'gpu' to True here. + generate_kernel_call can still be called with gpu=False because of + a mix of cpu kernels and gpu kernels. + """ + device = device or V.graph.get_current_device_or_throw() + if device.type == "cpu": + # Even in CppWrapperGpu, we may see cpp kernels + return CppWrapperCpu._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + ) + + if ( + triton + and config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + original_fxnode_name=original_fxnode_name, + ) + + stream = ( + "stream" + if V.graph.aot_mode + else self.write_get_raw_stream(device.index, graph_name) + ) + + if triton: + call_args, arg_types = self.prepare_triton_wrapper_args( + call_args, arg_types + ) + wrapper_name = f"call_{kernel_name}" + if wrapper_name not in self._triton_call_wrappers: + self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper( + wrapper_name, + kernel_name, + self._kernel_name_to_body, + arg_types, + ) + device_idx = "this->device_idx_" if V.graph.aot_mode else str(device.index) + call_args.append(device_idx) + call_args.append(stream) + if V.graph.aot_mode: + call_args.append("kernels") + call_args.append("this->cubin_dir_") + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args[: len(arg_types)], kernel_name, arg_types, None + ) + with debug_printer_manager: + self.writeline(f"{wrapper_name}({', '.join(call_args)});") + else: + casted = [] + for arg_type, arg in zip(arg_types, call_args): + new_arg = arg + if arg_type.endswith("*") and arg != "nullptr": + new_arg = f"{arg}.data_ptr()" + casted.append(f"({arg_type}){cexpr(new_arg)}") + call_args_str = ", ".join(casted) + self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") + + @staticmethod + def prepare_triton_wrapper_args( + call_args: list[Any], arg_types: list[Any] + ) -> tuple[list[Any], list[Any]]: + assert len(call_args) == len(arg_types), (call_args, arg_types) + new_args = [] + new_args_types = [] + for arg, arg_type in zip(call_args, arg_types): + if isinstance(arg, str): + if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + arg_type = UnwrapUnspecArg(dtype=arg_type) + new_args.append(arg) + elif isinstance(arg, bool): + new_args.append(str(arg).lower()) + elif isinstance(arg, (int, float, SymbolicCallArg)): + new_args.append(str(arg)) + else: + new_args.append(cexpr(V.graph.sizevars.simplify(arg))) + new_args_types.append(arg_type) + return new_args, new_args_types + + def make_zero_buffer(self, name): + return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));" + + +@dataclasses.dataclass +class UnwrapUnspecArg: + """Marker that we need to call .item() on the tensor""" + + dtype: torch_dtype diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d8a1eb7578aa338056b4292991fa99764ee5ce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -0,0 +1,99 @@ +from typing import Any, Optional + +import sympy + +import torch + +from ..ir import GraphPartitionSignature +from ..virtualized import V +from .cpp_wrapper_gpu import CppWrapperGpu +from .wrapper import PythonWrapperCodegen + + +class CppWrapperMps(CppWrapperGpu): + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ) -> "CppWrapperMps": + return CppWrapperMps() + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args: list[str], + arg_types: Optional[list[type]] = None, + **kwargs: dict[str, Any], + ) -> None: + """ + Generates MPS kernel call code. It should look something like: + ``` + auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel"); + auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get()); + mps_lib_0_func->runCommandBlock([&] { + mps_lib_0_func->startEncoding(); + aoti_torch_mps_set_arg(mps_lib_0_func_handle, 0, buf0); + aoti_torch_mps_set_arg(mps_lib_0_func_handle, 1, arg0_1); + ... + mps_lib_0_func->dispatch(9); + }); + ``` + """ + assert arg_types is not None + + new_args = [] + for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): + if isinstance(arg_type, torch.dtype): + new_args.append( + f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n" + ) + elif arg_type in (int, sympy.core.symbol.Symbol): + new_args.append( + f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n" + ) + else: + raise NotImplementedError( + f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}" + ) + + threads, group_size = call_args[-2], call_args[-1] + if threads is None: + raise NotImplementedError("No threads or group_size provided") + elif group_size is None: + new_args.append(f"{kernel_name}->dispatch({threads});\n") + else: + new_args.append(f"{kernel_name}->dispatch({threads}, {group_size});\n") + + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args[:-2], + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def wrap_kernel_call(self, name: str, call_args: list[str]) -> str: + lib_name = name[: -len("_func")] + calling_args = " ".join(call_args) + return f""" + auto {name} = {lib_name}.getKernelFunction("generated_kernel"); + auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get()); + {name}->runCommandBlock([&] {{ + {name}->startEncoding(); + {calling_args} + }}); + """ + + @staticmethod + def get_device_include_path(device: str) -> str: + assert V.graph.aot_mode + return ( + "#include \n" + "#include " + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..1ffafa74dd68775852bd6bfda3f66d34aa12abde --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from textwrap import dedent + +from .common import DeviceOpOverrides, register_device_op_overrides + + +class CpuDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return dedent( + """ + def get_raw_stream(_): + return 0 + """ + ) + + def set_device(self, device_idx: int) -> str: + return "pass" + + def synchronize(self) -> str: + return "pass" + + def device_guard(self, device_idx: int) -> str: + return "pass" + + +register_device_op_overrides("cpu", CpuDeviceOpOverrides()) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..889961f2523b2b86defdc90683c554ecb6f0a72c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -0,0 +1,293 @@ +# mypy: allow-untyped-defs +import hashlib +import logging +from collections.abc import Sequence +from typing import cast + +from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + MockCutlassHandler, +) +from torch._inductor.utils import Placeholder +from torch.utils._ordered_set import OrderedSet + +from ...._dynamo.utils import counters +from ... import config +from ...codecache import code_hash, get_path +from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + SchedulerNode, + WhyNoFuse, +) +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import BackendFeature, IndentedBuffer + + +log = logging.getLogger(__name__) + + +class WhyNoFuseNames(WhyNoFuse): + def __init__(self, name1: str, name2: str) -> None: + self.name1 = name1 + self.name2 = name2 + + +class CUDACPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for CUDA C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and CUDA C++ specific template code generation. + """ + + @classmethod + def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: + return OrderedSet() + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, CUDATemplateBuffer + ) + + def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode): + assert node1.node, "node1.node should not be None" + return self._can_fuse_epilogue_impl( + cast(CUDATemplateBuffer, node1.node), + [], + node2, # type: ignore[arg-type] + ) + elif self.is_cuda_cpp_fused_template(node1) and isinstance( + node2, BaseSchedulerNode + ): + assert node1.node, "node1.node should not be None" + assert node2.node, "node2.node should not be None" + fnode1 = cast(FusedSchedulerNode, node1) + return self._can_fuse_epilogue_impl( + fnode1.get_template_node(), # type: ignore[arg-type] + self._unwrap_epilogue_nodes(fnode1), + node2, # type: ignore[arg-type] + ) + + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + + # use the original src_code as the key + kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + # no EVT kernel, use the original kernel name + kernel_name = f"cutlass_{kernel_hash}" + else: + kernel_name = f"cutlass_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CUDA template, possibly with fused epilogues + """ + counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cuda_cpp_template(template_node), ( + "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node) + epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc] + assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), ( + "Epilogue nodes must all be instances of ir.ComputedBuffer" + ) + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes) + + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + + # typically there is a codegen pass which runs after mark_run + # for this kernel we've already generated the C++ code, but we still + # need to let the kernel know about loads/stores that occur in the fused + # kernel for memory planning to properly optimize allocations + ctb.emulate_store_fn() + for node in epilogue_ir_nodes: + with V.set_ops_handler(MockCutlassHandler(V.get_ops_handler())): + assert isinstance( + node, ComputedBuffer + ) # Not sure why we need to do this again + node.get_store_function()(CutlassEVTCodegen.get_index_vars(node)) + + with V.set_kernel_handler(kernel): + src_code = render() + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule) + + # debug printing values of intermediate tensors + _, call_args, arg_signatures, _ = kernel.args.python_argdefs() + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_signatures, kernel + ) + with debug_printer_manager: + kernel.call_kernel(kernel_name, ctb) + + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() + + @staticmethod + def _unwrap_epilogue_nodes( + fused_node: FusedSchedulerNode, + ) -> list[BaseSchedulerNode]: + nodes = fused_node.get_nodes() + template_node = fused_node.get_template_node() + assert all(n.node is not None for n in nodes), ( + "All epilogue nodes should have an IRNode" + ) + return cast( + list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node] + ) + + def _can_fuse_epilogue_impl( + self, + cuda_template_buffer: CUDATemplateBuffer, + existing_epilogue_nodes: list[BaseSchedulerNode], + node_to_fuse: BaseSchedulerNode, + ) -> bool: + """ + Check if the given node can be fused with the epilogue. At the moment, Kernels + support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes. + + Args: + cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer + existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes. + node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue. + Returns: + - bool: True if the given node can be fused with the epilogue, False otherwise. + + """ + why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name()) + + scheduler_nodes_to_fuse = node_to_fuse.get_nodes() + + assert isinstance(cuda_template_buffer, CUDATemplateBuffer) + + # Checks on constituent nodes + for s_node in scheduler_nodes_to_fuse: + node = s_node.node + + if not isinstance(node, ComputedBuffer): + why(f"{node} is not a ComputedBuffer") + return False + elif not isinstance(node.data, Pointwise): + why(f"{node} is not a Pointwise op") + return False + elif not node.get_computed_buffer_name(): # type: ignore[attr-defined] + why(f"{node} does not have a computed buffer name") + return False + + name = node.get_computed_buffer_name() # type: ignore[attr-defined] + # dtype can differ, and strides can differ as long as they are broadcastable + if node.get_size() != cuda_template_buffer.get_size(): + why( + f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \ +size: {cuda_template_buffer.get_size()}" + ) + return False + + assert len( + existing_epilogue_nodes + ) or cuda_template_buffer.get_name() in OrderedSet( + [rd.name for rd in node_to_fuse.read_writes.reads] + ), "First epilogue node must read from cuda template buffer" + + if node_to_fuse.has_aliasing_or_mutation(): + why(f"{node_to_fuse.get_name()} has aliasing or mutation") + return False + elif node_to_fuse.is_reduction(): + why( + f"{node_to_fuse.get_name()} is a reduction which is not yet supported by EVT" + ) + return False + elif ( + not config.cuda.cutlass_epilogue_fusion_enabled + or not config.epilogue_fusion + ): + why("cutlass epilogue fusion is not enabled") + return False + elif not cuda_template_buffer.supports_epilogue_fusion: + why("epilogue fusion is only supported for TMA-enabled gemm ops") + return False + + try: + from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + ) + + CutlassEVTCodegen.ir_to_evt_python_code( + cuda_template_buffer.get_name(), + existing_epilogue_nodes + list(node_to_fuse.get_nodes()), + OrderedSet(), + ) + + except NotImplementedError as e: + not_implemented_op = str(e) + if not_implemented_op.startswith("_op_"): + not_implemented_op = not_implemented_op[4:] + why( + f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \ +likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950 + ) + return False + else: # Likely due to unsupported dtype. + why( + f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \ +Reason: {not_implemented_op}" # noqa: G004, B950 + ) + return False + + return True diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py new file mode 100644 index 0000000000000000000000000000000000000000..a11462fc8a0b8c46f16c88cec8b92fcde8683e88 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_env.py @@ -0,0 +1,45 @@ +import functools +import logging +import shutil +from typing import Optional + +import torch +from torch._inductor.utils import clear_on_fresh_cache + +from ... import config + + +log = logging.getLogger(__name__) + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_cuda_arch() -> Optional[str]: + try: + cuda_arch = config.cuda.arch + if cuda_arch is None: + # Get Compute Capability of the first Visible device + major, minor = torch.cuda.get_device_capability(0) + return str(major * 10 + minor) + return str(cuda_arch) + except Exception as e: + log.error("Error getting cuda arch: %s", e) + return None + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_cuda_version() -> Optional[str]: + try: + cuda_version = config.cuda.version + if cuda_version is None: + cuda_version = torch.version.cuda + return cuda_version + except Exception as e: + log.error("Error getting cuda version: %s", e) + return None + + +@functools.cache +def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool: + return nvcc_path is not None and shutil.which(nvcc_path) is not None diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..1091e188343e42b5531fa45cc1a938e0b7a1b8c4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -0,0 +1,674 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union + +from sympy import Expr, symbols + +import torch._inductor.config as config +from torch import dtype as torch_dtype +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder +from torch.utils._sympy.value_ranges import ValueRanges + +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE + + +if TYPE_CHECKING: + from .cuda_template import ArgInfo + +from ...autotune_process import CUDABenchmarkRequest +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + IRNode, + Layout, + PrimitiveInfoType, + TensorBox, +) +from ...utils import sympy_product +from ...virtualized import V +from ..common import ( + CSEVariable, + IndentedBuffer, + Kernel, + OpOverrides, + WorkspaceArg, + WorkspaceZeroMode, +) +from ..cpp_utils import CppPrinter, DTYPE_TO_CPP + + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] +ValidLayoutAttrs = Literal["size", "stride"] + + +@dataclass(frozen=True) +class LayoutArg: + node: IRNode + symbol: ValidLayoutSymbols + attr: ValidLayoutAttrs + dim: int + + def matches(self, node, attr, dim) -> bool: + return self.node == node and self.attr == attr and self.dim == dim + + +class CUDAKernel(Kernel): + """ + Baseclass for CUDA / Cutlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list) + self.size_args: list[Union[Expr, int]] = [] + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + + def find_symbol( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[str]: + arg = self.find_layout_arg(node, attr, dim) + return arg.symbol if arg else None + + def find_layout_arg( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[LayoutArg]: + matches = [ + arg + for arg in itertools.chain.from_iterable(self.layout_args.values()) + if arg.matches(node, attr, dim) + ] + if len(matches) >= 1: + # Verify all matches have the same node, attribute, and dimension + # And if they come from the same node, whichever symbol we use is fine. + # if in runtime the logic changes, this would trigger guard + first_match = matches[0] + if not all( + match.node == first_match.node + and match.attr == first_match.attr + and match.dim == first_match.dim + for match in matches + ): + raise AssertionError("All matching layout args should be identical") + return first_match + return None + + def add_layout_arg( + self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int + ): + arg = LayoutArg(node, symbol, attr, dim) + self.layout_args[symbol].append(arg) + + def init_layout_args(self) -> None: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + x_mdim = _normalize_idx(-2, len(X.get_size())) + x_kdim = _normalize_idx(-1, len(X.get_size())) + w_kdim = _normalize_idx(-2, len(W.get_size())) + w_ndim = _normalize_idx(-1, len(W.get_size())) + y_mdim = _normalize_idx(-2, len(Y.get_size())) + y_ndim = _normalize_idx(-1, len(Y.get_size())) + self.add_layout_arg("M", X, "size", x_mdim) + self.add_layout_arg("K", X, "size", x_kdim) + self.add_layout_arg("K", W, "size", w_kdim) + self.add_layout_arg("N", W, "size", w_ndim) + self.add_layout_arg("M", Y, "size", y_mdim) + self.add_layout_arg("N", Y, "size", y_ndim) + if len(X.get_size()) > 2: + self.add_layout_arg("B", X, "size", 0) + + lda_dim = self.find_ld_idx(X) + ldb_dim = self.find_ld_idx(W) + ldc_dim = self.find_ld_idx(Bias) if Bias else None + ldd_dim = self.find_ld_idx(Y) + self.add_layout_arg("lda", X, "stride", lda_dim) + self.add_layout_arg("ldb", W, "stride", ldb_dim) + if Bias is not None and ldc_dim is not None: + self.add_layout_arg("ldc", Bias, "stride", ldc_dim) + self.add_layout_arg("ldd", Y, "stride", ldd_dim) + + def get_layout_args(self) -> tuple[Union[Expr, int], ...]: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + mdim = _normalize_idx(-2, len(X.get_size())) + ndim = _normalize_idx(-1, len(W.get_size())) + kdim = _normalize_idx(-1, len(X.get_size())) + + def get_ld(node) -> Union[Expr, int]: + dim = self.find_ld_idx(node) + return node.get_stride()[dim] + + M = X.get_size()[mdim] + N = W.get_size()[ndim] + K = X.get_size()[kdim] + B = X.get_size()[0] if len(X.get_size()) > 2 else 1 + LDA = get_ld(X) + LDB = get_ld(W) + LDC = get_ld(Bias) if Bias else 0 + LDD = get_ld(Y) + return (M, N, K, B, LDA, LDB, LDC, LDD) + + def get_dynamic_shape_args(self) -> list[Union[Expr, int]]: + return [*self.get_layout_args(), *self.size_args] + + @staticmethod + def find_ld_idx(node: IRNode) -> int: + strides = node.get_stride() + # Handle 1D tensor case + if V.graph.sizevars.statically_known_equals(strides[-1], 1): + return _normalize_idx(-2, len(strides)) + + assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2] + return _normalize_idx(-1, len(strides)) + + +class CUDATemplateKernel(CUDAKernel): + """ + Template kernels defined by CUDA / Cutlass in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the CUDATemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def get_signature(self) -> str: + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + additional_size_args: Additional size arguments for epilogue inputs + """ + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + free_symbols: OrderedSet[Expr] = OrderedSet() + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + if name not in ( + "X", + "W", + "Bias", + "Y", + ): # we handle these symbolic shapes explicitly + for expr in itertools.chain(node.get_size(), node.get_stride()): + if isinstance(expr, Expr): + for s in expr.free_symbols: + free_symbols.add(s) # type: ignore[arg-type] + + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) + + self.init_layout_args() + size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] + size_vars.extend(str(s) for s in free_symbols) + self.size_args.extend(free_symbols) + size_args = [f"const int {s}" for s in size_vars] + + runtime_arg_decls = ",".join( + [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + ) + if runtime_arg_decls: + runtime_arg_decls += ", " + + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "CUDATemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + arg_types: list[Any] + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # We always originally initialize name with "KERNEL_NAME". So, we + # we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace(str(Placeholder.KERNEL_NAME), name) + _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + dynamic_shape_args = self.get_dynamic_shape_args() + call_args.extend(dynamic_shape_args) # type: ignore[arg-type] + for arg in self.runtime_arg_values: + call_args.append(arg) + arg_types.extend("int" for _ in dynamic_shape_args) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + elif isinstance(arg_types[i], torch_dtype): + call_args[i] = ( + call_args[i] + if V.graph.cpp_wrapper + else f"c_void_p({call_args[i]}.data_ptr())" + ) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + # workspace_size is here. + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + workspace = str(ws.outer_name) + call_args.append( + workspace + if V.graph.cpp_wrapper + else f"c_void_p({workspace}.data_ptr())" + ) + else: + ws = None + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + + wrapper.generate_kernel_call( + name, + call_args, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default_dtype + from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate + + return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] + + def max_valid_index(self, node: IRNode, default=-1): + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default + max_valid_offset = 0 + for i in range(len(node.get_size())): + max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] + return max_valid_offset + + def offset(self, node: IRNode) -> str: + """ + Generates code which represents offset of a given node. + """ + + if node is None: + return "0" + return str(node.get_layout().offset) # type: ignore[union-attr] + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + sizes = [ + self.find_symbol(node, "size", dim=i) or node.get_size()[i] + for i in range(start_index, end_index + 1) + ] + if len(sizes) == 0: + return str(default_value) + + sizes = [symbols(v) if isinstance(v, str) else v for v in sizes] + val = sympy_product(sizes) + return val + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + if V.graph.sizevars.statically_known_leq(stride, 1): + return str(stride) + return self.find_symbol(node, "stride", dim=index) or str(stride) + + def batch_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the batch stride of an arg. + Returns 0 if batch dim is not present. + + This method assumes that batch stride is the largest stride. + """ + + if node is None: + return str(default_value) + + if len(node.get_size()) < 3: + return str(default_value) + + batch_stride = node.get_stride()[0] + if V.graph.sizevars.statically_known_leq(batch_stride, 1): + return str(batch_stride) + + return "{}*{}".format( + self.find_symbol(node, "size", dim=1) or node.get_size()[1], + self.find_symbol(node, "size", dim=2) or node.get_size()[2], + ) + + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the row or column stride of an arg. + This is required by some CUTLASS 2.X APIs. + If the node is in row_major, it returns stride[-2]. + If the node is in column_major, it returns stride[-1]. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None or len(node.get_stride()) < 2: + return str(default_value) + + stride0 = node.get_stride()[-1] + stride1 = node.get_stride()[-2] + if stride0 == 1: + return cexpr(self.rename_indexing(stride1)) + elif stride1 == 1: + return cexpr(self.rename_indexing(stride0)) + else: + raise RuntimeError( + f"At least 1 stride should be 1. Strides: {node.get_stride()=}" + ) + + def load(self, name: str, index: Expr, mode: Any = None) -> CSEVariable: + """ + Mock load function for memory planning to optimize allocations properly. + """ + return self.create_cse_var(name, bounds=ValueRanges.unknown()) + + def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None: + """ + Mock store function for memory planning to optimize allocations properly. + """ + self.store_buffer_names.add(name) + + +class CUDATemplateCaller(ChoiceCaller): + """ + CUDATemplateCaller + + This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CUDABenchmarkRequest): The benchmark request for the caller. + template_buffer (CUDATemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[ + [CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]], + tuple[CUDATemplateKernel, functools.partial[str]], + ], + bmreq: CUDABenchmarkRequest, + supports_epilogue_fusion: bool, + template: "CUDATemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + description: str, + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.supports_epilogue_fusion = supports_epilogue_fusion + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def __str__(self) -> str: + return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"cuda_template_kernels.{self.name}" + + def kernel_hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + return "-".join( + [ + self.category, + self.bmreq.hash_key, + str(self.info_dict().get("swizzle")), + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + return { + "backend": "CUDA", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch), + "tile_shape": str(op.tile_description.tile_shape), + "epilogue_schedule": str(op.epilogue_schedule), + "kernel_schedule": str(op.kernel_schedule), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + "instruction_shape": str( + op.tile_description.math_instruction.instruction_shape + ), + "swizzle": str(self.info_kwargs["swizzle"]), + } + else: + return {"backend": "CUDA", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + self.bmreq.update_workspace_size() + return TensorBox.create( + CUDATemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + supports_epilogue_fusion=self.supports_epilogue_fusion, + template=self.template, + ) + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed67b0daa49f65394b55a6ba87cdc9b1958931b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cuda_template.py @@ -0,0 +1,318 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import itertools +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING +from typing_extensions import override +from unittest.mock import patch + +import sympy + +import torch +from torch._inductor.utils import Placeholder +from torch._logging import getArtifactLogger + +from ...autotune_process import CUDABenchmarkRequest, TensorMeta +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE + + +if TYPE_CHECKING: + from ...scheduler import BaseSchedulerNode # noqa: TC004 +else: + BaseSchedulerNode = Any + +GemmOperation = Any + +autotuning_log = getArtifactLogger(__name__, "autotuning") + + +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + + Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the CUDATemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + @staticmethod + def supports_epilogue_fusion(op: GemmOperation) -> bool: + return False + + def generate( # type: ignore[override] + self, + description, + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller + may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + kernel_name = str(Placeholder.KERNEL_NAME) + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + CUDATemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + autotuning_log.debug("Generated Code:\n%s", code) + autotuning_log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) + size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) + extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + + kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] + kernel_name = f"cutlass_{kernel_hash}" + code = code.replace(self.name, kernel_name) + + # create the BenchmarkRequest + bmreq = CUDABenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + # kwargs has "op" argument in case of CUTLASSGemmTemplate + op = kwargs["op"] + if not op: + supports_epilogue_fusion = False + else: + # epilogue fusion is only supported for TMA kernels + supports_epilogue_fusion = self.supports_epilogue_fusion(op) + + def make_kernel_render( + template_node: CUDATemplateBuffer, + epilogue_nodes: Optional[list[BaseSchedulerNode]] = None, + ) -> tuple[CUDATemplateKernel, functools.partial[str]]: + assert supports_epilogue_fusion or not epilogue_nodes, ( + "epilogue fusion is not supported for this kernel" + ) + kernel = CUDATemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return CUDATemplateCaller( + kernel_name, + "cutlass_gemm", + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + supports_epilogue_fusion, + self, + kwargs, + description, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] + + +class CUTLASSTemplate(CUDATemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cute/tensor.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/reference/device/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace cute; + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in ("1", "1L"): + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + torch.float8_e4m3fn: "cutlass::float_e4m3_t", + } + + _DTYPE_TO_CUTLASS_SPARSE_META = { + torch.int32: "uint32_t", + torch.int16: "uint16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" + + def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return ( + f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})" + ) + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("swizzle", "const uint8_t")] + + @override + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..7afdd654ea74ad11bcf03d097a6140bebe1998c4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -0,0 +1,105 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import json +import logging +import os +import time +from typing import Any, Optional + +import torch._inductor.config as config +from torch._inductor.codecache import cutlass_key +from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version +from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer +from torch._inductor.runtime.cache_dir_utils import cache_dir +from torch._inductor.utils import clear_on_fresh_cache + + +log = logging.getLogger(__name__) + + +CONFIG_PREFIX: str = "configs" + + +def get_config_request_key( + arch: str, + cuda_version: str, + instantiation_level: str, +) -> str: + """ + Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level. + """ + hash_target = "-".join( + [ + cutlass_key().hex(), + arch, + cuda_version, + instantiation_level, + ] + ) + return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8] + + +def _generate_config_filename(request_key: str) -> str: + """ + Generate a filename for the full ops. + """ + return f"{CONFIG_PREFIX}_{request_key}.json" + + +@clear_on_fresh_cache +@functools.cache +def maybe_fetch_ops() -> Optional[list[Any]]: + """ + Fetch ops from databases. + """ + if config.force_disable_caches: + return None + + # setup + arch: str = get_cuda_arch() + # get_cuda_version might return "12.4.0" or "12.4" + # but we want to use "12.4" + version: str = ".".join(get_cuda_version().split(".")[:2]) + instantiation_level: str = config.cuda.cutlass_instantiation_level + + # filename and filepath + request_key: str = get_config_request_key(arch, version, instantiation_level) + filename: str = _generate_config_filename(request_key) + filepath: str = os.path.join(cache_dir(), filename) + + # try fetch + serialized_ops: Optional[list[str]] = None + start_time = time.time() + if os.path.isfile(filepath): + # locally + try: + with open(filepath) as f: + serialized_ops = json.load(f) + + assert isinstance(serialized_ops, list), ( + f"Expected serialized ops is a list, got {type(serialized_ops)}" + ) + except Exception as e: + log.warning( + "Failed to load CUTLASS config %s from local cache: %s", + filename, + e, + ) + serialized_ops = None + elif config.is_fbcode(): + from torch._inductor.fb.cutlass_remote_cache import ( + maybe_fetch_cutlass_configs_from_remote, + ) + + # from remote + serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath) + + if serialized_ops is None: + return None + + # deserialize + serializer = get_cutlass_operation_serializer() + full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr] + log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time) + return full_ops diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..becbf1f2c5528e00d02955edccf1aeab48d3ea85 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -0,0 +1,240 @@ +from typing import Any, Callable, Union + +from sympy import Expr + +from torch._inductor.ir import ( + ComputedBuffer, + InputBuffer, + is_contiguous_strides_for_shape, +) +from torch.utils._ordered_set import OrderedSet + +from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass + + +EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace +Buffer = Union[ComputedBuffer, InputBuffer] +CutlassTupleType = Any # cutlass.backend.c_types.tuple_factory_..TupleType +CutlassVisitorType = Any # cutlass.backend.c_types.visitor_factory..VisitorType +CutlassArgType = ( + Any # Can be a CutlassTupleType, CutlassVisitorType, EmptyByte, or ctype.c_void_p +) + + +if try_import_cutlass(): + import ast + import ctypes + import textwrap + from typing import Union + + from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found] + EmptyByte, + ) + from cutlass.backend.epilogue import ( # type: ignore[import-untyped, import-not-found] + dtype2ctype, + ) + from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found] + EpilogueFunctorVisitor, + ) + from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found] + FusionCallbacks, + ) + from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found] + CollectiveEpilogue, + ) + from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found] + PythonASTFrontend, + ) + from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found] + Tensor as CutlassTensor, + ) + from cutlass_library import ( + DataType, + EpilogueScheduleType, + LayoutType, + TileDescription, + ) + + from torch._inductor.codegen.cuda import cuda_env + from torch._inductor.utils import IndentedBuffer + + _CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated] + + def create_example_tensors( + var_name_to_buffer_name: dict[str, str], + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + ) -> dict[str, CutlassTensor]: + def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor: + shape = buffer.get_layout().size + stride = buffer.get_layout().stride + shape = tuple(size_hint_fn(x) for x in shape) + stride = tuple(size_hint_fn(x) for x in stride) + + is_row_major = is_contiguous_strides_for_shape(stride, shape) + is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1]) + + if not is_row_major and not is_column_major: + raise RuntimeError( + f"Cannot create example tensor for {buffer.get_name()} with \ +non-contiguous layout, received stride: {stride} and shape: {shape}" + ) + + return CutlassTensor( + shape=shape, + layout_tag=LayoutType.RowMajor + if is_row_major + else LayoutType.ColumnMajor, + element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype), + ) + + return { + key: cutlass_tensor_from_buffer(name_to_buffer[name]) + for key, name in var_name_to_buffer_name.items() + } + + def trace( + fn_src: str, + example_tensors: dict[str, CutlassTensor], + accum_type: DataType, + output_type: DataType, + tile_description: TileDescription, + epilogue_schedule: EpilogueScheduleType, + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + **kwargs: dict[str, Any], + ) -> tuple[str, str, str]: + cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type] + assert cuda_arch >= 90, "Only SM90+ is supported for EVT" + epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs) + visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor) + fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False) + collective_epilogue = CollectiveEpilogue( + tile_description, + epilogue_schedule, + accum_type, + output_type, + fusion_callbacks, + ) + evt_name, evt_code = collective_epilogue.emit() + evt_args = _render_argument_type(epilogue_functor, name_to_buffer, size_hint_fn) + return evt_name, evt_args, evt_code + + # Based off of + # https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117 + # This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function + # The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval + def _trace( + fn_src: str, example_tensors: dict[str, CutlassTensor], cc: int, **kwargs: Any + ) -> EpilogueFunctor: + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc: int, **kwargs: Any): + self.source = textwrap.dedent(fn_src) + super().__init__(cc, **kwargs) + + def parse(self, example_inputs: dict[str, CutlassTensor]) -> None: + self.example_inputs = example_inputs + self.ast = ast.parse(self.source) + self.visit(self.ast) + + cc = int(cuda_env.get_cuda_arch()) + epilogue_functor = EpilogueFunctor(cc=cc, **kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + + def _render_argument_type( + epilogue_functor: EpilogueFunctor, + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + ) -> str: + epilogue_thread_type = epilogue_functor.epilogue_thread_type + + # Fragile, but this is the only way to guarantee t is expected type because t is a local class + def is_nested_visitor_type(t: type) -> bool: + return ( + ".".join([t.__module__, t.__qualname__]) + == "cutlass.backend.c_types.visitor_factory..VisitorType" + ) + + buffer = IndentedBuffer() + with buffer.set_tabwidth(2): + + def render_argument_type(name: str, t: CutlassArgType) -> None: + if issubclass(t, ctypes.c_byte): + buffer.writeline(f"{{}}, /* {name} */") + else: + fields = [ + ( + fname, + _get_arg_from_node(ty, name_to_buffer[name], size_hint_fn), + ) + for fname, ty in t._fields_ + ] + field_strs = [ + f"/* {fname} */ {str(field)}" for fname, field in fields + ] + buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */") + + def render_thread_type(name: str, t: CutlassArgType) -> None: + if is_nested_visitor_type(t): + buffer.writeline(f"{{ /* {name} */") + with buffer.indent(): + for name, inner_t in t._fields_: + render_thread_type(name, inner_t) + buffer.writeline("},") + else: + render_argument_type(name, t) + + # unroll the recursion once to address special case formatting + # namely, no ending comma and no indentation for the outermost thread type + buffer.writeline("{ /* thread */") + with buffer.indent(3): + if is_nested_visitor_type(epilogue_thread_type): + with buffer.indent(): + for name, inner_t in epilogue_thread_type._fields_: + render_thread_type(name, inner_t) + else: + render_argument_type("thread", epilogue_thread_type) + buffer.writeline("}") + + return buffer.getvalue() + + def _get_arg_from_node( + arg_ty: type, node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int] + ) -> str: + from ..cuda_template import CUTLASSTemplate + + # Today, arguments are either a pointer to the + # node's memory, a stride tuple, the datatype + # Once again, need to check for local class type for stride tuple + if ( + str(arg_ty) + == ".TupleType'>" + ): + DEFAULT_STRIDE_LEN = 3 + assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN + stride = [size_hint_fn(x) for x in node.get_layout().stride] + for _ in range(DEFAULT_STRIDE_LEN - len(stride)): + stride.append(0) + + def render_stride(x: int) -> str: + # Handle EBO for 0 and 1 + if x == 0: + return "_0{}" + elif x == 1: + return "_1{}" + else: + return str(x) + + return f"{{{', '.join([render_stride(x) for x in stride])}}}" + + elif issubclass(arg_ty, ctypes.c_void_p): + return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {node.get_name()}" + elif ( + arg_ty in _CUTLASS_C_DTYPES + ): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently + return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)" + elif issubclass(arg_ty, EmptyByte): + return "{}" + + raise NotImplementedError(f"Unsupported arg type: {arg_ty}") diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ab413ccf66e818cf01756ce4a078f40803ce37 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -0,0 +1,411 @@ +# mypy: ignore-errors +from ..cutlass_utils import try_import_cutlass + + +# copied / modified from original at +# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658 + +if try_import_cutlass(): + import enum + + from cutlass_library.gemm_operation import * # noqa: F401, F403 + from cutlass_library.library import * # noqa: F401, F403 + + _LOGGER = logging.getLogger(__name__) + + class EmitGemmUniversal3xInstanceWithEVT: + """Responsible for emitting a CUTLASS 3.x template definition""" + + def __init__(self, operation_suffix="", evt_name=None): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + ] + self.builtin_epilogue_functor_template = """${epilogue_functor}< + ${element_d}, + ${element_epilogue}, + ${element_c}, + ${element_epilogue} + >""" + + self.evt_name = evt_name + self.gemm_template = """ +using ${operation_name}_epilogue = +typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class_epi}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${epi_tile_mn}, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule}, + ${epilogue_functor} +>::CollectiveOp; + +${mixed_dtype_prepare_code} + +using ${operation_name}_mainloop = +typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class_main}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${stages}, + ${kernel_schedule} +>::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + ${problem_shape}, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + +// Define named type +struct ${operation_name} : +public ${operation_name}_base { }; + + """ + + # + def instance_template(self): + return """ +${compile_guard_start} +{ + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); +} +${compile_guard_end} + """ + + def emit_block_scale_epilogue_functor(self, operation): + block_scaled_template = """ + ${epilogue_functor}< + ${epi_vs}, + ${element_d}, + ${element_accumulator}, + ${element_sfd}, + ${layout_sfd}, + ${element_c}, + ${element_scalar} + > + """ + block_scaled_values = { + "epi_vs": str(operation.ScaleFactorVectorSize), + "element_d": str(DataTypeTag[operation.D.element]), + "element_sfd": str(DataTypeTag[operation.ScaleFactorD.element]), + "layout_sfd": LayoutTag[operation.ScaleFactorD.layout], + "epilogue_functor": EpilogueFunctor3xTag[ + EpilogueFunctor3x.LinearCombinationBlockScaleFactor + ], + "element_accumulator": str(DataTypeTag[operation.accumulator_type()]), + "element_scalar": str(DataTypeTag[operation.accumulator_type()]), + "element_c": str(DataTypeTag[operation.C.element]), + } + return SubstituteTemplate(block_scaled_template, block_scaled_values) + + @staticmethod + def pointerize_if_grouped(operation, layout): + return layout if not is_grouped(operation.gemm_kind) else layout + "* " + + @staticmethod + def problem_shape(operation): + gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = ( + "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" + ) + + return ( + gemm_shape_type + if not is_grouped(operation.gemm_kind) + else grouped_gemm_shape_type + ) + + def emit(self, operation): + """Given a gem operation, emits a template definition of the operation""" + + opcode_class_main = operation.tile_description.math_instruction.opcode_class + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + instruction_shape = ( + operation.tile_description.math_instruction.instruction_shape + ) + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + + tile_shape_m, tile_shape_n, tile_shape_k = tile_shape + + # account for static/dynamic cluster shapes + cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0] + cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1] + + # Shape passed to epilogue builder + is_sm100_kernel = operation.arch == 100 + if is_sm100_kernel: + cta_m_per_mma_instruction = ( + 2 if "2sm" in operation.procedural_name() else 1 + ) + if cluster_m <= 0: + cta_m = cta_m // cta_m_per_mma_instruction + + if opcode_class_main in [ + OpcodeClass.TensorOp, + OpcodeClass.BlockScaledTensorOp, + ]: + tile_shape_m = instruction_shape[0] + tile_shape_n = instruction_shape[1] + + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<\ +{str(operation.tile_description.stages)}>" + else: + stage_count_string = ( + f"cutlass::gemm::collective::StageCountAutoCarveout(\ +sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>" + ) + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + + ( + instance_layout_A, + instance_layout_B, + instance_layout_C, + instance_layout_D, + ) = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + operation.D.layout, + ) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctor3xTag[ + operation.epilogue_functor + ], + } + epilogue_functor = SubstituteTemplate( + self.builtin_epilogue_functor_template, values + ) + + if ( + is_block_scaled(operation.gemm_kind) + and operation.ScaleFactorD.element != DataType.void + ): + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + + if ( + is_block_scaled(operation.gemm_kind) + and operation.ScaleFactorD.element != DataType.void + ): + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + # + # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, + # e.g. cute::tuple, Transform : cute::identity / cute::conjugate. + element_a = ( + DataTypeTag[operation.A.element] + if not operation.is_complex() + else f"cute::tuple<{str(DataTypeTag[operation.A.element])},\ +{str(ComplexTransformTag3x[operation.A.complex_transform])}>" + ) + element_b = ( + DataTypeTag[operation.B.element] + if not operation.is_complex() + else f"cute::tuple<{str(DataTypeTag[operation.B.element])},\ +{str(ComplexTransformTag3x[operation.B.complex_transform])}>" + ) + epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + + if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + is_no_smem_epilogue = operation.epilogue_schedule in [ + EpilogueScheduleType.NoSmemWarpSpecialized1Sm, + EpilogueScheduleType.NoSmemWarpSpecialized2Sm, + ] + grouped = is_grouped(operation.gemm_kind) + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule( + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped + ): + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[ + to_grouped_schedule( + EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped + ) + ] + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule( + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped + ): + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[ + to_grouped_schedule( + EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped + ) + ] + element_a = f"cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>" + element_b = f"cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>" + + operation_name_str = operation.procedural_name() + layout_a_str = LayoutTag[instance_layout_A] + layout_b_str = LayoutTag[instance_layout_B] + mixed_dtype_prepare_code = "" + if operation.mixed_input_mode is not None: + A_dtype = operation.A.element + B_dtype = operation.B.element + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + narrow_tag = DataTypeTag[narrow_dtype] + wide_tag = DataTypeTag[wide_dtype] + scale_tag = DataTypeTag[wide_dtype] + zero_tag = DataTypeTag[wide_dtype] + + do_shuffle = False + value_shuffle_str = "" + if narrow_dtype_bits == 4 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, \ +cute::Stride>" + do_shuffle = True + if narrow_dtype_bits == 8 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, \ +cute::Stride>" + do_shuffle = True + do_shuffle = operation.mixed_input_shuffle and do_shuffle + + if do_shuffle: + if is_A_dtype_narrow: + stride_narrow_str = ( + f"cutlass::detail::TagToStrideA_t<{layout_a_str}>" + ) + layout_a_str = f"{operation_name_str}_LayoutNarrowReordered" + else: + stride_narrow_str = ( + f"cutlass::detail::TagToStrideB_t<{layout_b_str}>" + ) + layout_b_str = f"{operation_name_str}_LayoutNarrowReordered" + # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and + # layout_{a, b}_str are to prevent errors in Windows platform unity build + mixed_dtype_prepare_code = f""" + using {operation_name_str}_StrideNarrow = {stride_narrow_str}; + using {operation_name_str}_ValueShuffle = {value_shuffle_str}; + static constexpr int {operation_name_str}_NumShuffleAtoms = 1; + using {operation_name_str}_MmaAtomShape = \ +cute::Layout>>; + using {operation_name_str}_LayoutAtomQuant = \ +decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, \ +{operation_name_str}_ValueShuffle>()); + using {operation_name_str}_LayoutNarrowReordered = \ +decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, \ +cute::Layout, {operation_name_str}_StrideNarrow>{{}})); + """ + + mixed_input_modes_to_element = { + MixedInputMode.ConvertOnly: narrow_tag, + MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>", + MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>", + } + narrow_element = mixed_input_modes_to_element.get( + operation.mixed_input_mode, narrow_tag + ) + + if narrow_dtype == DataType.s4 and ( + wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2 + ): + narrow_element = ( + f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>" + ) + + if is_A_dtype_narrow: + element_a = narrow_element + else: + element_b = narrow_element + + if self.evt_name: + epilogue_functor = self.evt_name + + values = { + "operation_name": operation_name_str, + "operation_suffix": self.operation_suffix, + "problem_shape": self.problem_shape(operation), + "element_a": element_a, + "layout_a": self.pointerize_if_grouped(operation, layout_a_str), + "element_b": element_b, + "layout_b": self.pointerize_if_grouped(operation, layout_b_str), + "element_c": DataTypeTag[operation.C.element], + "layout_c": self.pointerize_if_grouped( + operation, LayoutTag[instance_layout_C] + ), + "element_d": DataTypeTag[operation.D.element], + "layout_d": self.pointerize_if_grouped( + operation, LayoutTag[instance_layout_D] + ), + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class_main": OpcodeClassTag[opcode_class_main], + "opcode_class_epi": OpcodeClassTag[opcode_class_epi], + "arch": f"cutlass::arch::Sm{operation.arch}", + "tile_shape_m": str(tile_shape_m), + "tile_shape_n": str(tile_shape_n), + "tile_shape_k": str(tile_shape_k), + "cluster_shape_m": "cute::_" + + str(operation.tile_description.cluster_shape[0]) + if operation.tile_description.cluster_shape[0] > 0 + else "int", + "cluster_shape_n": "cute::_" + + str(operation.tile_description.cluster_shape[1]) + if operation.tile_description.cluster_shape[1] > 0 + else "int", + "cluster_shape_k": "cute::_" + + str(operation.tile_description.cluster_shape[2]) + if operation.tile_description.cluster_shape[2] > 0 + else "int", + "instruction_shape_m": str(instruction_shape[0]), + "instruction_shape_n": str(instruction_shape[1]), + "instruction_shape_k": str(instruction_shape[2]), + "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), + "epilogue_schedule": str(epilogue_schedule_type), + "epi_tile_mn": epi_tile_mn, + "epilogue_functor": epilogue_functor, + "stages": stage_count_string, + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), + "mixed_dtype_prepare_code": mixed_dtype_prepare_code, + } + + return SubstituteTemplate(self.gemm_template, values) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_presets.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_presets.py new file mode 100644 index 0000000000000000000000000000000000000000..f0888e2ef29bb905102be7ed272df6015e38af6c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_presets.py @@ -0,0 +1,239 @@ +import functools +from collections import defaultdict + +import torch +from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch + + +@functools.cache +def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]: + """ + Generate cutlass presets for the given CUDA arch. + """ + presets: dict[int, dict[str, list[str]]] = {} + + if not torch._C._has_cuda: + return presets + + presets[0] = defaultdict(list) + arch = get_cuda_arch() + if arch == "90": + preset = presets[0] + preset["0"] = [ + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_64x256x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + ] + preset["1111"] = [ + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + ] + preset["2222"] = [ + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + ] + preset["3333"] = [ + r"cutlass3x_sm90_tensorop_s64x48x16gemm_.*_64x48x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_4x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_4x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + ] + preset["4444"] = [ + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x8x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_64x192x64_4x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + ] + preset["5555"] = [ + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x32x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_128x32x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x256_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x8x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_128x192x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x8x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x256_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + ] + + return presets diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5e6031b19cd52e869fd1de853cbf21d348d2fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py @@ -0,0 +1,322 @@ +import itertools +from collections.abc import Generator, Iterable, Iterator, Sequence +from contextlib import contextmanager +from os import linesep +from typing import Any, Optional + +import sympy + +import torch +import torch._inductor.virtualized as virtualized +from torch._inductor.ir import ComputedBuffer, Pointwise +from torch._inductor.ops_handler import DefaultHandler, WrapperHandler +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import DelayReplaceLine, IndentedBuffer, OrderedSet +from torch._inductor.virtualized import OpsValue + +from ...virtualized import V + + +_ACCUMULATOR_ARG_NAME = "accum" + + +def scaled_mm_evt( + scale_A_name: str, scale_B_name: str, bias_name: Optional[str], output_name: str +) -> tuple[list[str], dict[str, Any], str]: + evt_read_names = [scale_A_name, scale_B_name] + var_name_to_buffer_name = {n: n for n in [scale_A_name, scale_B_name]} + var_name_to_buffer_name["D"] = output_name + var_name_to_buffer_name[_ACCUMULATOR_ARG_NAME] = output_name + expr = f"accum * {scale_A_name} * {scale_B_name}{linesep}" + if bias_name: + expr = f"({expr}) + {bias_name}" + evt_read_names.append(bias_name) + var_name_to_buffer_name[bias_name] = bias_name + + evt_py_code = f"def fn(accum, {','.join(evt_read_names)}):{linesep}\ + D = {expr}{linesep}\ + return D{linesep}" + + return evt_read_names, var_name_to_buffer_name, evt_py_code + + +class CutlassEVTOpsMixIn: + @staticmethod + def _infix_bin_op(op: str, a: str, b: str) -> str: + return f"{a} {op} {b}" + + @staticmethod + def _prefix_bin_op(op: str, a: str, b: str) -> str: + return f"{op}({a}, {b})" + + @staticmethod + def _prefix_un_op(op: str, a: str) -> str: + return f"{op}({a})" + + @staticmethod + def to_dtype( + x: str, + dtype: Any, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = False, + ) -> str: + return x + + @staticmethod + def constant(value: Any, dtype: Any) -> str: + raise NotImplementedError + + @staticmethod + def mul(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("*", x0, x1) + + @staticmethod + def truediv(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("/", x0, x1) + + @staticmethod + def ge(x0: str, x1: str) -> str: + raise NotImplementedError + + @staticmethod + def add(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("+", x0, x1) + + @staticmethod + def relu(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("relu", x0) + + @staticmethod + def sigmoid(x0: str) -> str: + raise NotImplementedError("sigmoid is not supported in CUTLASS python evt") + + @staticmethod + def sub(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("-", x0, x1) + + @staticmethod + def tanh(x0: str) -> str: + raise NotImplementedError("tanh is not supported in CUTLASS python evt") + + +class MockCutlassHandler(CutlassEVTOpsMixIn, WrapperHandler): + """Passthrough handler for cutlass ops, used for running epilogue nodes for memory planning""" + + +class _AssignmentFormatter(DefaultHandler): + def __init__(self, parent_handler: "CutlassEVTCodegen"): + self.parent_handler = parent_handler + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # Handle op dispatch here + if hasattr(self.parent_handler, name): + fn = getattr(self.parent_handler, name) + line = fn(*args, **kwargs) + if name in ("load", "store"): + return OpsValue(line) + else: + var = self.parent_handler._tmp_var() + line = DelayReplaceLine( + var, + lambda: "D" + if var == self.parent_handler.last_stored_var_name + else var, + f"{var} = {line}", + ) + self.parent_handler.body.writeline(line) + return OpsValue(var) + else: + raise NotImplementedError(name) + + +class CutlassEVTCodegen(CutlassEVTOpsMixIn): + """ + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTCodegen.ir_to_evt_python_code(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + """ + + def __init__(self, accumulator_node_name: str, removed_buffers: OrderedSet[str]): + """ + + Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. + Use the CutlassEVTCodegen.ir_to_evt_python_code static method. + + Args: + accumulator_node_name: The name of the accumulator node which should contain + the Matmul result before fusion according to the IR graph. + epilogue_nodes: The list of scheduler nodes to be fused into the epilogue + """ + self.accumulator_node_name: str = accumulator_node_name # + self.body: IndentedBuffer = IndentedBuffer(1) # The body buffer for codegen + self.var_counter: Iterator[int] = itertools.count() + self.store_name_to_value: dict[str, OpsValue] = ( + dict() + ) # Aliases for subexpression functors + self.reads: OrderedSet[str] = OrderedSet([]) + # Used for creating example tensors + self.var_name_to_buffer_name: dict[str, str] = { + _ACCUMULATOR_ARG_NAME: accumulator_node_name + } + self.removed_buffers: OrderedSet[str] = removed_buffers + self.cur_node: Optional[ComputedBuffer] = None + self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants.keys(): + self.name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + self.is_D_assigned = False + self.D_var_name = None + + if accumulator_node_name not in removed_buffers: + # cannot return accumulator directly, so alias it + var = self._tmp_var() + self.body.writeline(f"{var} = {_ACCUMULATOR_ARG_NAME}") + self.store(accumulator_node_name, value=OpsValue(var)) + + @staticmethod + def ir_to_evt_python_code( + cuda_template_node_name: str, + epilogue_nodes: list[BaseSchedulerNode], + removed_buffers: OrderedSet[str], + ) -> tuple[list[str], list[str], dict[str, Any], str]: + codegen = CutlassEVTCodegen(cuda_template_node_name, removed_buffers) + handler = _AssignmentFormatter(codegen) + + with virtualized.V.set_ops_handler(handler): + for s_node in epilogue_nodes: + node = s_node.node + assert isinstance(node, ComputedBuffer) + with codegen.set_cur_node(node): + index_vars = CutlassEVTCodegen.get_index_vars(node) + node.get_store_function()(index_vars) + + codegen.finalize() + + return ( + codegen.get_reads(), + codegen.get_writes(), + codegen.get_renames(), + codegen.get_value(), + ) + + def get_value(self) -> str: + return linesep.join( + [ + self._render_input_signature(), + self.body.getvalue(), + self._render_return_statement(), + ] + ) + + def finalize(self) -> None: + # Rename the last store to D + # no other code references this store + # to workaround https://github.com/NVIDIA/cutlass/issues/2288 + # Note: the delayed line will automatically rewrite the last assignment to + # be to D + buffer_name = self.var_name_to_buffer_name[self.last_stored_var_name] + self.var_name_to_buffer_name.pop(self.last_stored_var_name) + self.var_name_to_buffer_name["D"] = buffer_name + self.store_name_to_value[buffer_name] = OpsValue("D") + + @contextmanager + def set_cur_node(self, node: ComputedBuffer) -> Generator[None, Any, Any]: + prev_node = self.cur_node + try: + self.cur_node = node + yield + finally: + self.cur_node = prev_node + + def get_renames(self) -> dict[str, str]: + return dict(self.var_name_to_buffer_name) + + def get_reads(self) -> list[str]: + return list(self.reads.difference(self.store_name_to_value.keys())) + + def get_writes(self) -> list[str]: + return list(self.store_name_to_value.keys()) + + def load(self, name: str, index: Any) -> str: + self._check_indexing(name, index) + if name in self.store_name_to_value: + return self.store_name_to_value[name].value + elif name == self.accumulator_node_name: + return _ACCUMULATOR_ARG_NAME + else: + self.reads.add(name) + self.var_name_to_buffer_name[name] = name + return name + + def store( + self, name: Any, index: Any = None, value: Any = None, mode: Any = None + ) -> None: + if name not in self.removed_buffers: + if index: + self._check_indexing(name, index) + assert value.value != _ACCUMULATOR_ARG_NAME, ( + "Cannot store accumulator arg name" + ) + self.var_name_to_buffer_name[value.value] = name + self.store_name_to_value[name] = value + self.last_stored_var_name = value.value + return None + + def _get_cur_node(self) -> ComputedBuffer: + assert self.cur_node + return self.cur_node + + @staticmethod + def get_index_vars(node: ComputedBuffer) -> Sequence[sympy.Expr]: + data = node.data + # TODO mlazos: relax this, cutlass supports reductions and other ops + assert isinstance(data, Pointwise) + return data._index(data.ranges) + + def _get_current_index_vars(self) -> Sequence[sympy.Expr]: + return self.get_index_vars(self._get_cur_node()) + + def _check_indexing(self, name: str, index: sympy.Expr) -> None: + # We only support indexing that matches the layout today because + # CUTLASS doesn't support arbitrary indexing + buffer_name = ( + self.accumulator_node_name if name == _ACCUMULATOR_ARG_NAME else name + ) + buffer = self.name_to_buffer[buffer_name] + index_strides = V.graph.sizevars.stride_vars( + index, self._get_current_index_vars() + ) + stride = buffer.get_layout().stride + if not self._stride_compatible(stride, index_strides): + raise NotImplementedError( + f"Unsupported indexing for {name} with index {index}, index strides {index_strides}, and layout stride {stride}" + ) + + def _stride_compatible( + self, left: Iterable[sympy.Expr], right: Iterable[sympy.Expr] + ) -> bool: + return all( + sympy.Eq(l, r) or sympy.Eq(l, 0) or sympy.Eq(r, 0) + for l, r in (zip(left, right)) + ) + + def _render_input_signature(self) -> str: + arguments = ", ".join( + [_ACCUMULATOR_ARG_NAME] + + [name for name in self.reads if name != self.accumulator_node_name] + ) + return f"def fn({arguments}):" + + def _render_return_statement(self) -> str: + return_vars = OrderedSet( + op_v.value for op_v in self.store_name_to_value.values() + ) + assert "D" in return_vars + return f"return {', '.join(return_vars)}" + + def _tmp_var(self) -> str: + return f"tmp_{next(self.var_counter)}" diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a340fd59c175b740c64abd1d5e980c9ad7a36138 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -0,0 +1,497 @@ +# mypy: allow-untyped-defs +import atexit +import functools +import logging +import os +import shutil +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import sympy + +import torch +from torch._inductor.utils import clear_on_fresh_cache + +from ... import config +from ...ir import Layout +from ...runtime.runtime_utils import cache_dir +from ...virtualized import V +from ..cpp_utils import DTYPE_TO_CPP +from .cuda_env import get_cuda_arch, get_cuda_version + + +log = logging.getLogger(__name__) + +CUTLASS_OPERATION_KIND: str = "gemm" + + +@atexit.register +def move_cutlass_compiled_cache() -> None: + """Move CUTLASS compiled cache file to the cache directory if it exists.""" + if "cutlass" not in sys.modules: + return + + import cutlass # type: ignore[import-not-found] + + if not os.path.exists(cutlass.CACHE_FILE): + return + + try: + filename = os.path.basename(cutlass.CACHE_FILE) + shutil.move(cutlass.CACHE_FILE, os.path.join(cache_dir(), filename)) + log.debug("Moved CUTLASS compiled cache file to %s", cache_dir()) + except OSError as e: + log.warning("Failed to move CUTLASS compiled cache file: %s", str(e)) + + +def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str: + for cutlass_module in cutlass_modules: + content = content.replace( + f"from {cutlass_module} import ", + f"from cutlass_library.{cutlass_module} import ", + ) + return content + + +@functools.cache +def try_import_cutlass() -> bool: + """ + We want to support three ways of passing in CUTLASS: + 1. fbcode, handled by the internal build system. + 2. pip install nvidia-cutlass, which provides the cutlass_library package + and the header files in the cutlass_library/source directory. + 3. User specifies cutlass_dir. The default is ../third_party/cutlass/, + which is the directory when developers build from source. + """ + if config.is_fbcode(): + try: + import cutlass # type: ignore[import-not-found] + import cutlass_library # type: ignore[import-not-found] + except ImportError as e: + log.warning( + "Failed to import CUTLASS packages in fbcode: %s, ignoring the CUTLASS backend.", + str(e), + ) + return False + + return True + + try: + import cutlass # type: ignore[import-not-found] # noqa: F811 + import cutlass_library # type: ignore[import-not-found] # noqa: F811 + + cutlass_minor_vesion = int(cutlass.__version__.split(".")[1]) + if cutlass_minor_vesion < 7: + log.warning("CUTLASS version < 3.7 is not recommended.") + + log.debug( + "Found cutlass_library in python search path, overriding config.cuda.cutlass_dir" + ) + cutlass_library_dir = os.path.dirname(cutlass_library.__file__) + assert os.path.isdir(cutlass_library_dir), ( + f"{cutlass_library_dir} is not a directory" + ) + config.cuda.cutlass_dir = os.path.abspath( + os.path.join( + cutlass_library_dir, + "source", + ) + ) + + return True + except ModuleNotFoundError: + log.debug( + "cutlass_library not found in sys.path, trying to import from config.cuda.cutlass_dir" + ) + + # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. + # This is a temporary hack to avoid CUTLASS module naming conflicts. + # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. + + # TODO(mlazos): epilogue visitor tree currently lives in python/cutlass, + # but will be moved to python/cutlass_library in the future (later 2025) + def path_join(path0, path1): + return os.path.abspath(os.path.join(path0, path1)) + + # contains both cutlass and cutlass_library + # we need cutlass for eVT + cutlass_python_path = path_join(config.cuda.cutlass_dir, "python") + torch_root = os.path.abspath(os.path.dirname(torch.__file__)) + mock_src_path = os.path.join( + torch_root, + "_inductor", + "codegen", + "cuda", + "cutlass_lib_extensions", + "cutlass_mock_imports", + ) + + cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library") + cutlass_src_path = path_join(cutlass_python_path, "cutlass") + pycute_src_path = path_join(cutlass_python_path, "pycute") + + tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass")) + + dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library") + dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass") + dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute") + + # mock modules to import cutlass + mock_modules = ["cuda", "scipy", "pydot"] + + if os.path.isdir(cutlass_python_path): + if tmp_cutlass_full_path not in sys.path: + + def link_and_append(dst_link, src_path, parent_dir): + if os.path.exists(dst_link): + assert os.path.islink(dst_link), ( + f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + ) + assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( + src_path, + ), f"Symlink at {dst_link} does not point to {src_path}" + else: + os.makedirs(parent_dir, exist_ok=True) + os.symlink(src_path, dst_link) + + if parent_dir not in sys.path: + sys.path.append(parent_dir) + + link_and_append( + dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path + ) + link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path) + link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path) + + for module in mock_modules: + link_and_append( + path_join(tmp_cutlass_full_path, module), # dst_link + path_join(mock_src_path, module), # src_path + tmp_cutlass_full_path, # parent + ) + + try: + import cutlass # noqa: F401 + import cutlass_library.generator # noqa: F401 + import cutlass_library.library # noqa: F401 + import cutlass_library.manifest # noqa: F401 + import pycute # type: ignore[import-not-found] # noqa: F401 + + return True + except ImportError as e: + log.debug( + "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", + str(e), + ) + else: + log.debug( + "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", + cutlass_python_path, + ) + return False + + +@functools.lru_cache(8) +def _normalize_cuda_arch(arch: str) -> str: + if int(arch) >= 100: + log.warning( + "Detected CUDA architecture >= 100: %s. We will generate operations with " + "GenerateSM100 (if available) and GenerateSM90. Please file an " + "issue for any problems and feedback. ", + arch, + ) + + if int(arch) >= 100: + return "100" + elif int(arch) >= 90: + return "90" + elif int(arch) >= 80: + return "80" + elif int(arch) >= 75: + return "75" + elif int(arch) >= 70: + return "70" + else: + raise NotImplementedError(f"Unsupported cuda arch: {arch}") + + +@dataclass +class CUTLASSArgs: + """ + CUTLASS args used to initialize a CUTLASS Manifest. + """ + + architectures: Optional[str] = None + cuda_version: Optional[str] = None + instantiation_level: Optional[str] = None + operations: Optional[str] = None + + build_dir = "" + curr_build_dir = "" + generator_target = "" + kernels = "all" + ignore_kernels = "" + exclude_kernels = "" + # TODO: these three look dead? + kernel_filter_file: None = None + selected_kernel_list: None = None + interface_dir: None = None + filter_by_cc = True + disable_full_archs_compilation = False + + def __post_init__(self): + if self.architectures is None or self.cuda_version is None: + raise RuntimeError( + f"{self.architectures=} or {self.cuda_version=} is None!" + ) + self.architectures = _normalize_cuda_arch(self.architectures) + + +@clear_on_fresh_cache +@functools.cache +def _gen_ops_cached(arch, version) -> dict[Any, Any]: + # Note: Cache needs to be specific for cuda architecture and version + + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library.generator as cutlass_generator + import cutlass_library.manifest as cutlass_manifest + + if arch is None or version is None: + log.error( + "Cannot detect cuda arch %s or cuda version %s. " + "Will discard all cutlass ops. " + "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.", + arch, + version, + ) + return {} + arch = _normalize_cuda_arch(arch) + instantiation_level: str = config.cuda.cutlass_instantiation_level + args = CUTLASSArgs( + architectures=arch, + cuda_version=version, + instantiation_level=instantiation_level, + operations=CUTLASS_OPERATION_KIND, + ) + manifest = cutlass_manifest.Manifest(args) + + start_time = time.time() + if arch == "100": + if hasattr(cutlass_generator, "GenerateSM100"): + cutlass_generator.GenerateSM100(manifest, args.cuda_version) + cutlass_generator.GenerateSM90(manifest, args.cuda_version) + else: + try: + func = getattr(cutlass_generator, "GenerateSM" + arch) + func(manifest, args.cuda_version) + except AttributeError as e: + raise NotImplementedError( + "Arch " + arch + " is not supported by current cutlass lib." + ) from e + + log.info( + "CUTLASS library generated a dict of %d operation kinds in %.2f seconds", + len(manifest.operations), + time.time() - start_time, + ) + return manifest.operations + + +def gen_ops() -> dict[Any, Any]: + """ + Generates all supported CUTLASS operations. + """ + arch = get_cuda_arch() + version = get_cuda_version() + return _gen_ops_cached(arch, version) + + +DTYPE_TO_CUTLASS_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "__half", + torch.bfloat16: "__nv_bfloat16", + torch.float8_e4m3fn: "__nv_fp8_e4m3", +} + + +@functools.lru_cache(32) +def torch_dtype_to_cutlass_type( + torch_dtype: torch.dtype, +) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library # type: ignore[import] + + if torch_dtype == torch.float: + return cutlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_library.library.DataType.bf16 + else: + raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") + + +@functools.lru_cache(32) +def dtype_match( + torch_dtype: Optional[torch.dtype], + cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 +) -> bool: + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library + + if torch_dtype == torch.float: + return ( + cutlass_dtype == cutlass_library.library.DataType.f32 + or cutlass_dtype == cutlass_library.library.DataType.tf32 + ) + elif torch_dtype == torch.half: + return cutlass_dtype == cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_dtype == cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.int8: + return cutlass_dtype == cutlass_library.library.DataType.s8 + elif torch_dtype == torch.uint8: + return cutlass_dtype == cutlass_library.library.DataType.u8 + elif torch_dtype == torch.int32: + return cutlass_dtype == cutlass_library.library.DataType.s32 + elif torch_dtype == torch.float8_e4m3fn: + return cutlass_dtype == cutlass_library.library.DataType.e4m3 + else: + return False + + +def get_accumulator_dtype( + input_torch_dtypes: list[torch.dtype], +) -> Optional[torch.dtype]: + """ + Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. + """ + + if len(input_torch_dtypes) != 2: + return None + + torch_dtype = None + if input_torch_dtypes[0] == input_torch_dtypes[1]: + torch_dtype = input_torch_dtypes[0] + else: + size0 = torch.tensor([], dtype=input_torch_dtypes[0]).element_size() + size1 = torch.tensor([], dtype=input_torch_dtypes[1]).element_size() + if size0 > size1: + dtype0, dtype1 = input_torch_dtypes + else: + dtype1, dtype0 = input_torch_dtypes + if dtype0 in [torch.half, torch.bfloat16] and dtype1 in [ + torch.int8, + torch.uint8, + ]: + torch_dtype = dtype0 + + if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn): + return torch.float + if torch_dtype == torch.int8: + return torch.int32 + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}") + + +@functools.lru_cache(32) +def get_alignments(torch_dtype: torch.dtype) -> list[int]: + """ + Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. + CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment. + """ + + if torch_dtype in (torch.half, torch.bfloat16): + return [8, 4, 2, 1] + elif torch_dtype == torch.float: + return [4, 2, 1] + elif torch_dtype in (torch.uint8, torch.int8, torch.float8_e4m3fn): + return [16, 8, 4, 2] + elif torch_dtype == torch.int32: + return [4, 2, 1] + else: + raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") + + +def get_max_alignment(inductor_layout: Layout) -> int: + """ + Returns the max alignment (in terms of number of elements) for a given Inductor Layout. + """ + + dtype = inductor_layout.dtype + size = inductor_layout.size + offset = inductor_layout.offset + + def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + def a_factor_of(x, alignment): + if is_static_int(x) and is_static_int(alignment): + return x % alignment == 0 + rem = sympy.Mod(x, alignment) + return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0)) + + try: + contiguous_dim = inductor_layout.stride.index(1) + except ValueError: + # No dim with stride 1 found, return 1 + return 1 + alignments = get_alignments(dtype) + for alignment in alignments: + if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of( + offset, alignment + ): + continue + if all( + (dim == contiguous_dim) + or a_factor_of(inductor_layout.stride[dim], alignment) + for dim in range(len(size)) + ): + return alignment + return 1 + + +class CUDACompileSourceCapturingContext: + # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation. + # Can be used to capture the sourcecode passed to CUDACodeCache.compile + + def __init__(self): + self.sources = [] + self._compile_patch = None + + def __enter__(self, *args, **kwargs): + import unittest.mock as mock + + import torch._inductor.codecache + + _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile + + def my_compile(source_code, dst_file_ext): + self.sources.append(source_code) + return _compile_method_orig(source_code, dst_file_ext) + + self._compile_patch = mock.patch( + "torch._inductor.codecache.CUDACodeCache.compile", my_compile + ) + self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr] + return self + + def __exit__(self, *args, **kwargs): + self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr] + + +def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path): + # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run + # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled. + from torch._inductor.codecache import cuda_compile_command + + extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"] + compile_command = cuda_compile_command( + [str(srcpath)], str(exepath), "exe", extra_args=extra_args + ) + return compile_command diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba0677422944abeebda6fca44e78c4826f074c4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +from typing import Optional + +import torch + +from ...utils import triton_version_uses_attrs_dict +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) + + +class CUDADeviceOpOverrides(DeviceOpOverrides): + """ + CUDA-specific codegen functions, see DeviceOpOverrides for details + """ + + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _cuda_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.cuda.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.cuda.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.cuda._DeviceGuard({device_idx})" + + def cpp_device_guard(self) -> str: + return "at::cuda::CUDAGuard" + + def cpp_aoti_device_guard(self) -> str: + return "AOTICudaGuard" + + def cpp_stream_guard(self) -> str: + return "at::cuda::CUDAStreamGuard" + + def cpp_aoti_stream_guard(self) -> str: + return "AOTICudaStreamGuard" + + def cpp_getStreamFromExternal(self) -> str: + return "at::cuda::getStreamFromExternal" + + def kernel_header(self) -> str: + source_codes = """ + #include + #include + #include + """ + return source_codes + + def kernel_driver(self) -> str: + source_codes = """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + CUresult code_get_error = cuGetErrorString(code, &msg); \\ + if (code_get_error != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string("invalid error code!")); \\ + } \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline CUfunction loadKernel(const void* start, const std::string &funcName, uint32_t sharedMemBytes) { + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoadData(&mod, start)); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + if torch.version.hip is not None: + # Adjusting the warp size to GPU supported wavefront size on AMD GPU + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + source_codes = source_codes.replace( + "32*numWarps", str(prop.warp_size) + "*numWarps" + ) + return source_codes + + def tma_descriptor_helpers(self) -> str: + """ + CUDA helper functions for initializing TMA Descriptors on host side + """ + if torch.version.hip is not None: + raise RuntimeError("Host-side TMA descriptors not supported on HIP.") + + # helper functions for initializing 1D and 2D TMA descriptors in C++. borrowed from the Triton code here: + # Old APIs (fill(1|2)DTMADescriptor): + # https://github.com/triton-lang/triton/blob/6af4f88591c85de079d8a36a4d7dba67918e2b39/third_party/nvidia/backend/driver.c#L283 + # New APIs (fillTMADescriptor): + # https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/driver.c#L283 + return """ + #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + [[maybe_unused]] static void init1DTMADescriptor( + CUtensorMap* m, + void* globalAddress, + uint64_t dim, + uint32_t blockDim, + uint32_t elementSize) { + uint64_t dims[1] = {dim}; + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t tensorDims[1] = {blockDim}; + uint32_t elementStrides[1] = {1}; + + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elementSize must be 1, 2, or 4"); + } + + if (elementSize * blockDim < 32) { + throw std::runtime_error("block size too small"); + } + + int rank = 1; + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + [[maybe_unused]] static void init2DTMADescriptor( + CUtensorMap* m, + void* globalAddress, + uint64_t dim1, + uint64_t dim0, + uint32_t blockDim1, + uint32_t blockDim0, + uint32_t elementSize) { + uint64_t dims[2] = {dim0, dim1}; + uint32_t tensorDims[2] = {blockDim0, blockDim1}; + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[2] = {1, 1}; + + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elementSize must be 1, 2, or 4"); + } + + int rank = 2; + + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + [[maybe_unused]] static void initTMADescriptor( + CUtensorMap* m, + void* globalAddress, + int elemSize, + int rank, + uint32_t* blockSize, + uint64_t* shape, + uint64_t* stride + ) { + uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; + uint32_t blockSizeInt[5]; + uint64_t shapeInt[5]; + uint64_t stridesLL[5]; + + // Reorder blockSize (reverse the order) + for (int i = 0; i < rank; ++i) { + blockSizeInt[rank - i - 1] = blockSize[i]; + } + + // Reorder shape (reverse the order) + for (int i = 0; i < rank; ++i) { + shapeInt[rank - i - 1] = shape[i]; + } + + // Reorder and calculate strides + for (int i = 0; i + 1 < rank; ++i) { + stridesLL[rank - i - 2] = elemSize * stride[i]; + } + stridesLL[rank - 1] = + shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]); + + CUtensorMapDataType type; + // In Triton this is computed ahead of time; but for simplicity + // in the PyTorch version we copied this code from the old + // TMA API handling (i.e. init2DTMADescriptor) + switch (elemSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elemSize must be 1, 2, or 4"); + } + + // Calculate the size of the most contiguous dimension in bytes + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elemSize * blockSizeInt[0]; + if (rank == 1) { + // rank 1 should not be swizzled + swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + } else if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, + shapeInt, stridesLL, blockSizeInt, elementStrides, + CU_TENSOR_MAP_INTERLEAVE_NONE, (CUtensorMapSwizzle)swizzle, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + struct StableTMADescriptor { + CUtensorMap m; + uint32_t block_shape[5]; + uint64_t global_shape[5]; + uint64_t strides[5]; + }; + #endif + """ + + def cpp_stream_type(self) -> str: + return "cudaStream_t" + + def aoti_get_stream(self) -> str: + return "aoti_torch_get_current_cuda_stream" + + def cpp_kernel_type(self) -> str: + return "CUfunction" + + def cpp_device_ptr(self) -> str: + return "CUdeviceptr" + + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: + if triton_version_uses_attrs_dict(): + var_name = f"global_scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" + + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name + return None + + +register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ce96192a56d3a6cd3761ce2c530e0b8f438b2f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/gemm_template.py @@ -0,0 +1,1905 @@ +# mypy: allow-untyped-defs +import copy +import enum +import functools +import logging +import re +import time +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.select_algorithm import create_inputs_key +from torch._inductor.utils import clear_on_fresh_cache + +from ... import ir +from ...config import cuda as inductor_cuda_config +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + FixedLayout, + IRNode, + Layout, + ReinterpretView, +) +from ...utils import is_dynamic, Placeholder +from ...virtualized import V +from ..common import IndentedBuffer +from . import cutlass_utils +from .cuda_kernel import CUDATemplateKernel +from .cuda_template import CUTLASSTemplate +from .cutlass_presets import gen_cutlass_presets +from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt +from .cutlass_utils import torch_dtype_to_cutlass_type + + +GemmOperation = Any + +log = logging.getLogger(__name__) + +# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below. +GEMM_TEMPLATE_CUTLASS_3X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{epilogue_visitor_tree}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} + +// configuration name: {{op_conf_name}} +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments, used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X = r""" + // Initialize GemmUniversal3xInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A + { + {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, + {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, + {{template.cute_int(kernel.batch_stride(X), "batch_stride_x")}} + }, // StrideA dA + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B + { + {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, + {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, + {{template.cute_int(kernel.batch_stride(W), "batch_stride_w")}} + }, // StrideB dB + }, // MainloopArguments mainloop + {{epilogue_arguments}}, + hw_info + }; + arguments.scheduler.max_swizzle_size = swizzle; +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied, +# used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + // see https://tinyurl.com/4rk89z48 + { + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C + { + {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, + {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, + {{template.cute_int(kernel.batch_stride(Bias), "batch_stride_bias")}} + }, // StrideC dC + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D + { + {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, + {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, + {{template.cute_int(kernel.batch_stride(Y), "batch_stride_y")}} + }, // StrideD dD + }, // EpilogueArguments epilogue +""" + +# Jinja template for GEMM Kernel, used by the CUTLASS2xGemmTemplate class below. +GEMM_TEMPLATE_CUTLASS_2X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + int B = {{kernel.size(Y, 0, -3, default_value=1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(instance_type, argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Meta, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + +# Jinja template for Cutlass 2.x GEMM Kernel arguments, used by the CUTLASS2xGemmTemplate class below. +GEMM_ARGS_CUTLASS_2X = r""" + int64_t batch_stride_x = {{kernel.stride(X, -3)}}; + int64_t row_stride_x = {{kernel.row_or_column_stride(X)}}; + int64_t batch_stride_w = {{kernel.stride(W, -3)}}; + int64_t row_stride_w = {{kernel.row_or_column_stride(W)}}; + int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}}; + int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}}; + int64_t batch_stride_y = {{kernel.stride(Y, -3)}}; + int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}}; + // Initialize GemmUniversalInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + {{split_k if split_k > 1 else 'B'}}, // int batch_count + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D + batch_stride_x, // int64_t batch_stride_A + batch_stride_w, // int64_t batch_stride_B + batch_stride_bias, // int64_t batch_stride_C + batch_stride_y, // int64_t batch_stride_D + row_stride_x, // typename LayoutA::Stride::LongIndex lda + row_stride_w, // typename LayoutB::Stride::LongIndex ldb + row_stride_bias, // typename LayoutC::Stride::LongIndex ldc + row_stride_y, // typename LayoutC::Stride::LongIndex ldd + }; +""" + +GEMM_ARGS_SPARSE_CUTLASS_2X = r""" + using TensorRefA = cutlass::TensorRef<{{instance_type}}::ElementA, + {{instance_type}}::LayoutA>; + using TensorRefB = cutlass::TensorRef<{{instance_type}}::ElementB, + {{instance_type}}::LayoutB>; + using TensorRefC = cutlass::TensorRef<{{instance_type}}::ElementC, + {{instance_type}}::LayoutC>; + using TensorRefE = cutlass::TensorRef<{{instance_type}}::ElementE, + {{instance_type}}::LayoutE>; + // Note that "X" and "W" names may be misleading here. Namely, for + // sparse GEMM, the first argument is always sparse, while typically + // weight matrix, implied by name "W" will be sparse in + // applications. Thus, just remember that here: "X" refers to first + // argument, that is sparse, and "W" to second, that is dense. + TensorRefA X_ref({{template.cutlass_type_cast(X, kernel.ptr(X))}}, {{kernel.row_or_column_stride(X)}}); + TensorRefB W_ref({{template.cutlass_type_cast(W, kernel.ptr(W))}}, {{kernel.row_or_column_stride(W)}}); + TensorRefC Y_ref({{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, {{kernel.row_or_column_stride(Y)}}); + TensorRefE Meta_ref({{template.cutlass_sparse_meta_type_cast(Meta, kernel.ptr(Meta))}}, + TensorRefE::Layout::packed({ {{kernel.size(Meta, 0)}}, {{kernel.size(Meta, 1)}} })); + // Initialize GemmSparse arguments. + arguments = { + { + static_cast(M), + static_cast(N), + static_cast(2 * K), + }, // GemmCoord problem_size + X_ref, // TensorRef ref_A + W_ref, // TensorRef ref_B + Y_ref, // TensorRef ref_C + Y_ref, // TensorRef ref_D + Meta_ref, // TensorRef ref_E + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue, + }; +""" + +# Additional includes which are necessary if the standalone test / debug runner is generated as well +GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES = r""" +#ifdef GENERATE_STANDALONE_RUNNER +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include +#endif +""" + +# Jinja template for the standalone runner that may be generated as part of the code. +GEMM_STANDALONE_RUNNER_TEMPLATE = r""" +#ifdef GENERATE_STANDALONE_RUNNER +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed, float max=1.0, float min=-1.0) { + if (block.size()<=0) return false; + Element scope_max(static_cast(max)), scope_min(static_cast(min)); + cutlass::reference::device::BlockFillRandomUniform( + (Element*)block.get(), block.size(), seed, scope_max, scope_min); + + return true; +} + +{% if Meta is defined and Meta is not none %} +template +bool initialize_block_meta( + cutlass::DeviceAllocation& block, + uint64_t seed) { + if (block.size()<=0) return false; + cutlass::reference::device::BlockFillRandomSparseMeta( + (Element*)block.get(), block.size(), seed, {{instance_type}}::kMetaSizeInBits); + return true; +} +{% endif %} + +extern "C" int run_standalone(uint64_t seed, int repetitions) { + std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; + size_t workspace_size = 0; + size_t* workspace_size_ptr = &workspace_size; + + int M = {{kernel.get_layout_args()[0]}}; + int N = {{kernel.get_layout_args()[1]}}; + int K = {{kernel.get_layout_args()[2]}}; + int B = {{kernel.get_layout_args()[3]}}; + int lda = {{kernel.get_layout_args()[4]}}; + int ldb = {{kernel.get_layout_args()[5]}}; + int ldc = {{kernel.get_layout_args()[6]}}; + int ldd = {{kernel.get_layout_args()[7]}}; + uint8_t swizzle = {{kernel.runtime_arg_values[0]}}; + + using ElementA = {{kernel.cutlass_dtype(X)}}; + using ElementB = {{kernel.cutlass_dtype(W)}}; + using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void + using ElementD = {{kernel.cutlass_dtype(Y)}}; + {% if Meta is defined and Meta is not none %} + using ElementE = {{kernel.cutlass_dtype(Meta)}}; + {% endif %} + + cutlass::DeviceAllocation X_data({{kernel.max_valid_index(X)+1}}); + initialize_block(X_data, seed++); + cutlass::DeviceAllocation W_data({{kernel.max_valid_index(W)+1}}); + initialize_block(W_data, seed++); + cutlass::DeviceAllocation Bias_data({{kernel.max_valid_index(Bias)+1}}); + initialize_block(Bias_data, seed++); + cutlass::DeviceAllocation Y_data({{kernel.max_valid_index(Y)+1}}); + {% if Meta is defined and Meta is not none %} + cutlass::DeviceAllocation Meta_data({{kernel.max_valid_index(Meta)+1}}); + initialize_block_meta(Meta_data, seed++); + {% endif %} + + cutlass::DeviceAllocation workspace_data; + // Call once with workspace_size_ptr set to get workspace size + + std::cout << "Calling once to get workspace size" << std::endl; + {{test_call_statement}}; + // Allocate workspace if necessary + if (workspace_size > 0) { + workspace_data.reset(workspace_size); + std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; + } + std::cout << "Calling Kernel as {{test_call_statement}};" << std::endl; + workspace_size_ptr = nullptr; + for (int i=0; i None: + """ + Args: + input_nodes (List[Buffer]): List of input nodes of the GEMM kernel. + layout (Layout): Layout type of the resulting output node. + alpha (float): The scaling factor for the product of the inputs in the GEMM operation. + beta (float): The scaling factor applied to the output matrix. + input_reorder (Optional[List[int]]): Specifies the reordering of the input nodes. If not provided, + no reordering is performed. Defaults to None. + """ + super().__init__( + str(Placeholder.KERNEL_NAME), input_nodes, layout, input_reorder + ) + self.alpha = alpha + self.beta = beta + self.use_fast_accum = use_fast_accum + assert 2 <= len(input_nodes) <= 5 + assert self._are_inputs_layout_compatible( + [node.get_layout() for node in input_nodes] + ) + + self.cache_key: str = create_inputs_key(self.input_nodes) + + @staticmethod + @abstractmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + **extra_kwargs, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError + + @staticmethod + @abstractmethod + def _has_tma_epilogue(self) -> bool: + raise NotImplementedError + + @abstractmethod + def _get_template(self) -> str: + raise NotImplementedError + + @abstractmethod + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + raise NotImplementedError + + @abstractmethod + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + raise NotImplementedError + + @abstractmethod + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + raise NotImplementedError + + @abstractmethod + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + raise NotImplementedError + + @abstractmethod + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + raise NotImplementedError + + def _add_cutlass_gemm_choices( + self, + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + **extra_kwargs, + ) -> None: + """ + Adds Cutlass GEMM configurations choices to the auto-tuning list. + + This function mutates the passed list of choices by appending the choices for Cutlass GEMM configs to it. + + Args: + choices (list): The list to which choices are appended. + layout (ir.Layout): The layout configuration. + input_nodes (list): The list of input nodes. + alpha (float,int): Scaling factor, defaults to 1. + beta (float,int): Offset, defaults to 0. + input_reorder (list, optional): Order of the inputs, defaults to None. + **extra_kwargs: Additional keyword arguments. + + """ + + ops = self.gen_ops() + for name, op in ops: + for swizzle in inductor_cuda_config.cutlass_max_profiling_swizzle_options: + description = f"{name} swizzle={swizzle}" + self.maybe_append_choice( + choices, description=description, op=op, swizzle=swizzle + ) + + if len(ops) == 0: + input_layouts = [node.get_layout() for node in input_nodes] + input_strides = [node.get_stride() for node in input_nodes] + output_layout = layout + warning_msg = f"No suitable Cutlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950 + log.warning(warning_msg) + log.debug( + "Added %d Cutlass gemm configs.", + len(ops), + ) + + def header(self) -> IndentedBuffer: + """ + Returns a buffer containing CUDA C++ code for the header section of the CUTLASS GEMM template. + This section primarily includes the necessary header files. + + Returns: + IndentedBuffer: An instance of IndentedBuffer that contains the generated CUDA C++ header code. + """ + res = super().header() + res.splice( + """ + #include "cutlass/gemm/gemm.h" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/gemm/device/gemm_sparse.h" + #include "cutlass/gemm/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/epilogue/thread/activation.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" + """ + ) + if inductor_cuda_config.generate_test_runner and not is_dynamic( + *self.input_nodes, self.output_node + ): + res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES) + return res + + @staticmethod + def cutlass_layout(torch_layout: ir.Layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + """ + Converts an ir.Layout instance into the corresponding cutlass_library.LayoutType enum value + (RowMajor, ColumnMajor, or None if no matching value is found ). + + Args: + torch_layout (ir.Layout): The layout that needs to be looked up. + + Returns: + cutlass_lib.LayoutType: The converted layout corresponding to the `torch_layout` or None if no matching + value is found. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return cutlass_lib.LayoutType.RowMajor + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-2], 1): + return cutlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def flip_cutlass_layout( + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821 + """Helper method: Flips a given cutlass layout (cutlass_lib.LayoutType) from RowMajor + to ColumnMajor or vice versa""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if cutlass_layout == cutlass_lib.LayoutType.RowMajor: + return cutlass_lib.LayoutType.ColumnMajor + else: + return cutlass_lib.LayoutType.RowMajor + + @staticmethod + @functools.lru_cache(32) + def layout_match( + torch_layout: ir.Layout, + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """Helper Method: Determines whether a given torch layout matches a given Cutlass layout""" + return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout + + @staticmethod + def set_alignment(torch_layout, op_element) -> bool: + """ + Helper method to update the alignment of a given CUTLASS GEMM op operand's element. + + This method modifies the alignment of the given Cutlass GEMM op operand's element to match the + layout of the corresponding ir.Buffer node. + + Args: + torch_layout: The layout of the corresponding ir.Buffer node. + op_element: The Cutlass GEMM op operand's element whose alignment is to be updated. + + Returns: + bool: True if the alignment was successfully updated, False otherwise. + """ + alignment = cutlass_utils.get_max_alignment(torch_layout) + cuda_arch = cutlass_utils.get_cuda_arch() + if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment: + return False + else: + op_element.alignment = alignment + return True + + @staticmethod + def should_swap_XW( + bias: IRNode, + ) -> bool: + """ + Helper method to determine whether we should do an explicit transpose by switching the order of the + matmul operands. This might be necessary when we can't otherwise arrive at the right memory + layout for the given Bias operand. + + Note: This method is a workaround for CUDA Errors that seemingly non-deterministically + occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term. + it might make sense to check on newer Cutlass releases whether it makes sense to keep + returning True in certain cases or whether it becomes unnecessary. + """ + # If bias is row major, swap all M and N dimensions + if ( + bias is not None + and len(bias.get_stride()) >= 2 + and bias.get_stride()[-1] in (0, 1) + ): + log.debug("GEMM Layout swapped X and W -> explicit transpose") + return True + return False + + @staticmethod + def swap_XW( + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Swap operands X and W (aka operans A and B) of the GEMM operation. This + requires transposing the operands, which is done by swapping the strides. + Note that we don't change the apparent external layout, just the operand layout. + this is intentional. + """ + new_op = copy.deepcopy(op) + new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout) + new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout) + new_op.A, new_op.B = new_op.B, new_op.A + new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout) + new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout) + return new_op + + def fix_op_layout( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + X: Buffer, + W: Buffer, + Bias: Optional[Buffer], + Y: Union[Buffer, ReinterpretView], + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + # This is a workaround to deal with cases where the input layouts have changed + # between autotuning and rendering. This happens if the inputs layout + # are FlexibleLayout instances. In this case, we need to update the + # op's input layouts. It is a hack, because now the op + # we benchmarked is not the same as the op we render, + # but there is no simple way to fix this in the autotuner, since that would + # potentially disable other optimizations. + a_layout = X.get_layout() + b_layout = W.get_layout() + c_layout = Bias.get_layout() if Bias is not None else None + + d_layout = copy.deepcopy(Y.get_layout()) + match_list = [ + CUTLASSGemmTemplate.layout_match(buf.get_layout(), op_layout) + for buf, op_layout in zip( + (X, W, Bias, Y), + (op.A.layout, op.B.layout, op.C.layout, op.D.layout), + ) + if buf is not None + ] + all_match = all(match_list) + if all_match: + return op + log.warning( + f"Cutlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004, B950 + ) + new_op = copy.deepcopy(op) + + if a_layout is not None: + new_op.A.layout = CUTLASSGemmTemplate.cutlass_layout(a_layout) + if b_layout is not None: + new_op.B.layout = CUTLASSGemmTemplate.cutlass_layout(b_layout) + if c_layout is not None: + new_op.C.layout = CUTLASSGemmTemplate.cutlass_layout(c_layout) + new_op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(c_layout.dtype) + if d_layout is not None: + new_op.D.layout = CUTLASSGemmTemplate.cutlass_layout(d_layout) + return new_op + + def _dtype_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """ + Checking dtypes of A, B, acc, D here. + + Empirically speaking, CUTLASS2x ops have same dtype for C and D. + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + + accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + cutlass_utils.dtype_match(X.get_dtype(), op.A.element) + and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + and cutlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.D.element + ) + and cutlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return False + + return True + + def filter_op( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Helper method: + + Determines whether a given Cutlass GEMM op definition is suitable for the current + input / output of the operation that this template is supposed to implement. + + Takes memory layout, dtype and support for EVT operations into account, + and filters potentially problematic ops. + + Returns None if the op is not suitable, otherwise returns the op to be used, which might + have been mutated. + """ + + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib # type: ignore[import] + + # Skip simt kernels + if ( + op.tile_description.math_instruction.opcode_class + == cutlass_lib.OpcodeClass.Simt + ): + return None + + if op.gemm_kind not in self._get_supported_ops(): + return None + + X = self.input_nodes[0] + W = self.input_nodes[1] + + # Filter ops according to the shape match. + if not self._shape_match(op): + return None + + # Filter ops by dtypes. + if not self._dtype_match(op): + return None + + # Filter ops by input layouts. + if not ( + self.layout_match(X.get_layout(), op.A.layout) + and self.layout_match(W.get_layout(), op.B.layout) + ): + return None + + # Filter ops by alignment. + if not self._alignment_match(op): + log.debug( + "Skipping due to alignment mismatch. op: %s", op.configuration_name() + ) + return None + + # Update op. + op = copy.deepcopy(op) + + # Set output layout. + op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) + + # Filter ops by alignments and set alignments. + status = ( + self.set_alignment(X.get_layout(), op.A) + and self.set_alignment(W.get_layout(), op.B) + and self.set_alignment(self.output_node.get_layout(), op.D) + ) + if not status: + log.debug( + "Skipping due to alignment setting failure. op: %s", + op.configuration_name(), + ) + return None + + if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op): + return None + + # Set epilogue. + # TODO: update epilogue functor according to epilogues. + op.element_epilogue = op.accumulator_type() + + if self.use_fast_accum is not None: + is_op_fast_accum = "fastaccum" in op.configuration_name() + if self.use_fast_accum ^ is_op_fast_accum: + return None + + # Set bias layout and alignment. + status = self._set_bias_layout_and_alignment(op) + if not status: + log.debug( + "Skipping due to bias layout and alignment setting failure. op: %s", + op.configuration_name(), + ) + return None + + # Apply regex filters at the end when configuration name doesn't change anymore + if ( + inductor_cuda_config.cutlass_op_allowlist_regex + or inductor_cuda_config.cutlass_presets + ): + patterns = [] + if inductor_cuda_config.cutlass_op_allowlist_regex: + patterns.append(inductor_cuda_config.cutlass_op_allowlist_regex) + if inductor_cuda_config.cutlass_presets: + presets = gen_cutlass_presets() + preset_nums = [ + int(x) for x in inductor_cuda_config.cutlass_presets.split(",") + ] + for preset_num in preset_nums: + preset = presets.get(preset_num, {}).get( + inductor_cuda_config.cutlass_instantiation_level, [] + ) + + patterns.extend(preset) + + pattern = "|".join(patterns) + if pattern and not re.search(pattern, op.configuration_name()): + return None + if inductor_cuda_config.cutlass_op_denylist_regex is not None: + if re.search( + inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name() + ): + return None + + return op + + def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821 + """ + Creates a list of Cutlass GemmOperation instances that match the operation this template is designed to represent. + The matching is carried out with respect to the input and output specifications of the operation. + + No function arguments. + + Returns: + List[Tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation) + tuples that are compatible with the operation requirements of this template. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + + if self.cache_key in self.filtered_ops_cache: + log.debug("Using cached ops for %s", self.cache_key) + return self.filtered_ops_cache[self.cache_key] + + maybe_ops = maybe_fetch_ops() + if maybe_ops is None: + log.debug("Cannot fetch ops from cache, generating ops from scratch") + full_ops = cutlass_utils.gen_ops() + ops = pytree.tree_flatten(full_ops)[0] + else: + log.debug("Using cached ops from cache") + ops = maybe_ops + + res: dict[str, cutlass_gemm_op.GemmOperation] = {} + start_time = time.time() + for op in ops: + # if changed, need to also change CUTLASS_OPERATION_KIND + assert isinstance(op, cutlass_gemm_op.GemmOperation) + filter_res = self.filter_op(op) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + log.info( + "Got cutlass configs: total number of ops: %d. Filtering took %.2f seconds", + len(res), + time.time() - start_time, + ) + sorted_res = sorted(res.items()) + ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs] + if len(self.filtered_ops_cache) < 50: + self.filtered_ops_cache[self.cache_key] = ret_res + else: + log.debug("Not caching ops since filtered_ops_cache has reached size 50.") + return ret_res + + def gemm_mode(self) -> str: + """ + Returns a Cutlass GEMM mode string for the current operation, dependent on whether this op implements + a batched GEMM or a simple GEMM without batch dimension. + + Returns: + str: A string indicating the Cutlass GEMM mode. If the output node has more than two dimensions, + "cutlass::gemm::GemmUniversalMode::kBatched" is returned, otherwise + "cutlass::gemm::GemmUniversalMode::kGemm" is returned. + """ + sizes = self.output_node.get_size() + if len(sizes) > 2: + return "cutlass::gemm::GemmUniversalMode::kBatched" + else: + return "cutlass::gemm::GemmUniversalMode::kGemm" + + def render( # type: ignore[override] + self, + kernel: CUDATemplateKernel, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[CUDATemplateBuffer] = None, + epilogue_nodes: Optional[list[BaseSchedulerNode]] = None, + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement, + including potentially fused epilogues. + + Args: + kernel (CUDATemplateKernel): The kernel to be rendered. + op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the + input and output definitions as well as a possible epilogue. Defaults to None. + **kwargs: Additional keyword arguments. Currently unused. + + Returns: + str: Cutlass based CUDA C++ code fragment as a string, to be used by the current + CUDATemplateKernel or autotuning code. + + Note: + All inputs and their corresponding buffer addresses and names take precedence over previously + passed inputs to the template at construction time. However, they should be layout compatible. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + assert isinstance(op, cutlass_gemm_op.GemmOperation), ( + "op argument is required and has to be an instance of GemmOperation" + ) + + if epilogue_nodes and not self._has_tma_epilogue(op): + raise NotImplementedError( + "Non-TMA epilogue visitor tree is not supported in Cutlass." + ) + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + for input_node in self.input_nodes: + if not isinstance(X.layout, FixedLayout): + input_node.freeze_layout() + + Y = self.output_node + if template_buffer_node is not None: + Y = template_buffer_node + + Bias, extra_inputs, extra_names = self._get_extra_inputs_and_names(op) + + # Define Kernel call signature + # Important: This step also populates Kernel name to node mapping data structures, + # which are required further below ( for example by the template renderer ) + inputs = [X, W, Bias, *extra_inputs] + names = ["X", "W", "Bias", *extra_names] + ["Y"] + names_str = ",".join(names) + if self.input_reorder is not None: + input_reorder = self.input_reorder + else: + input_reorder = None + + # The layouts might have changed between autotuning and this call if they were FlexibleLayout + # we need to adapt, which might lead to suboptimal performance. + op = self.fix_op_layout(op, X, W, Bias, Y) + + # to make op mutable without affecting others + op = copy.deepcopy(op) + is_scaled_mm = len(self.input_nodes) in (4, 5) + if Bias is not None and not is_scaled_mm: + assert Bias.get_dtype() == X.get_dtype() + # This might have been set to void during filtering, when the assumption was still that there's no C + # operand + op.C.element = op.A.element + + assert op.C.element == op.D.element, ( + f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}" + ) + + argument_template, epilogue_template = self._get_template_args(op) + should_swap_xw: bool = False + if Bias is not None and self._has_tma_epilogue(op): + if ( + op.epilogue_schedule + != cutlass_lib.EpilogueScheduleType.EpilogueTransposed + and self.should_swap_XW(Bias) + ): + # TMA epilogue requires bias vector in column major to get best perf. + op = self.swap_XW(op) + should_swap_xw = True + + if epilogue_nodes or is_scaled_mm: + if epilogue_nodes: + ( + input_names, + output_names, + var_name_to_buffer_name, + evt_py_code, + ) = CutlassEVTCodegen.ir_to_evt_python_code( + Y.get_name(), epilogue_nodes, V.kernel.removed_buffers + ) + + D_output_name = var_name_to_buffer_name["D"] + name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants.keys(): + name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + D_output_buffer = name_to_buffer[D_output_name] + Y = D_output_buffer # type: ignore[assignment] + # Interestingly, I don't think the rest of the layout matters here since we + # use the properties of the Y buffer to fill in D's properties in the epilogue + # args. This is needed though because it defines types expected in the epilogue args. + op.D.element = cutlass_utils.torch_dtype_to_cutlass_type( + D_output_buffer.get_dtype() + ) + + assert output_names, "There should be at least one write" + + epilogue_inputs = [name_to_buffer[name] for name in input_names] + outputs = [name_to_buffer[name] for name in output_names] + else: # Scaled MM, we read the two scale matrices (and optional bias) and write a single output + bias = None if len(self.input_nodes) < 5 else self.input_nodes[4] + bias_name = bias.get_name() if bias else None + + ( + evt_read_names, + var_name_to_buffer_name, + evt_py_code, + ) = scaled_mm_evt( + self.input_nodes[2].get_name(), # scale_A + self.input_nodes[3].get_name(), # scale_B + bias_name, + Y.get_name(), + ) + + input_names = list(evt_read_names) + output_names = [] # We only need Y + epilogue_inputs = [self.input_nodes[2], self.input_nodes[3]] + if bias: + epilogue_inputs.append(bias) + outputs = [] + + acc_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()] + ) + assert acc_dtype, "Could not determine accumulator dtype" + + evt_name, evt_args, evt_code = self._render_evt( + op, + evt_py_code, + var_name_to_buffer_name, + Y.get_dtype(), + acc_dtype, + ) + + inputs = [ + X, + W, + Bias, + *epilogue_inputs, # type: ignore[list-item] + Y, + *extra_inputs, + ] + names_str = ",".join( + ["X", "W", "Bias", *input_names, "Y", *output_names, *extra_names] + ) + else: + evt_name = None + outputs = [Y] + evt_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" + evt_code = "" + + kernel_call_signature = kernel.def_kernel( + inputs=inputs, # type: ignore[arg-type] + outputs=outputs, # type: ignore[arg-type] + names_str=names_str, + input_reorder=input_reorder, + ) + + test_call_statement = self.test_call_statement(kernel, inputs, names_str) + + instance_definition, instance_type = self._define_gemm_instance(op, evt_name) + + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + kernel_call_signature=kernel_call_signature, + Bias=Bias, + epilogue_template=epilogue_template, + argument_template=argument_template, + should_swap_xw=should_swap_xw, + template=self, + kernel=kernel, + instance_definition=instance_definition, + instance_type=instance_type, + input_reorder=self.input_reorder, + epilogue_args=evt_args, + test_call_statement=test_call_statement, + op_conf_name=op.configuration_name(), + epilogue_visitor_tree=evt_code, + ) + options.update(dict(zip(extra_names, extra_inputs))) + res = self._template_from_string(self._get_template()).render(**options) + if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): + test_runner_code = self._template_from_string( + GEMM_STANDALONE_RUNNER_TEMPLATE + ).render(**options) + res += "\n\n" + test_runner_code + + # splice to remove trailing spaces in each line + buf = IndentedBuffer() + buf.splice(res) + return buf.getvalue() + + def test_call_statement( + self, + kernel, + input_nodes, + names_str: str = "", + ) -> str: + """ + Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone + test runner that might also be generated along with the rest of the code, if the corresponding config is + enabled. + + Returns a C++ statement that calls the GEMM operation with the correct arguments. + """ + _, __, arg_types = kernel.args.cpp_argdefs(cutlass_utils.DTYPE_TO_CUTLASS_TYPE) + arg_names = [name.strip() for name in names_str.strip().split(",")] + arg_names = self._update_arg_names_for_test_call_statement( + arg_names, input_nodes + ) + arguments = [ + f"(({arg_type}){arg_name}_data.get())" + for arg_type, arg_name in zip(arg_types, arg_names) + ] + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 + + def _render_evt( + self, + op: GemmOperation, + evt_py_code: str, + buffer_renames: dict[str, str], + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str]: # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError("_render_evt in CUTLASSGemmTemplate not implemented") + + +class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): + """ + CUTLASS 3x GEMM Template, which is used to generate CUTLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + ): + super().__init__( + input_nodes, layout, alpha, beta, input_reorder, use_fast_accum + ) + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + **extra_kwargs, + ) -> None: + template = CUTLASS3xGemmTemplate( + input_nodes, + layout, + alpha, + beta, + input_reorder, + use_fast_accum, + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + @functools.lru_cache(1) + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal3x] + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_3X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE) + + @staticmethod + def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined] + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined,arg-type] # noqa: F821 + ) -> bool: # type: ignore[name-defined] + """Helper method: Determine whether a given Cutlass GEMM op has a TMA Epilogue""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + result = False + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1] + result = epilogue_schedule_str.lower().startswith("tma") + return result + + @staticmethod + def supports_epilogue_fusion(op: GemmOperation) -> bool: + return CUTLASS3xGemmTemplate._has_tma_epilogue(op) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for General Matrix Multiply (GEMM). + + This function checks compatibility of A, B, and possibly C operand layouts for + a General Matrix Multiply (GEMM) operation, expressed as 'alpha * matmul(A, B) + beta * C'. + It verifies requirements such as matching data types, minimum rank, and suitability + for broadcasting, as defined by PyTorch operations like `torch.matmul`, `torch.aten.mm`, + `addmm`, `bmm`, `baddbmm`, etc. + + Args: + layouts (List[Layout]): List containing 2 or 3 Layout objects representing + the input matrices A, B, and possibly C. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert 2 <= len(layouts) <= 5 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) < 1: + return False + if len(B_layout.size) < 1: + return False + A_size = list(V.graph.sizevars.size_hints(A_layout.size)) + B_size = list(V.graph.sizevars.size_hints(B_layout.size)) + if len(A_size) < 2: + A_size.insert(0, 1) + if len(B_size) < 2: + A_size.insert(1, 1) + # Are batch dims broadcastable? + while len(A_size) < len(B_size): + A_size.insert(0, 1) + while len(B_size) < len(A_size): + B_size.insert(0, 1) + K = max(A_size[-1], B_size[-2]) + M = A_size[-2] + N = B_size[-1] + if K != A_size[-1] and A_size[-1] != 1: + return False + if K != B_size[-2] and B_size[-1] != 1: + return False + # check batch dim broadcastable + for i in range(len(A_size) - 2): + if A_size[i] != B_size[i] and A_size[i] != 1 and B_size[i] != 1: + return False + if len(layouts) == 3: + C_layout = layouts[2] + C_size = [V.graph.sizevars.size_hint(i) for i in C_layout.size] + while len(C_size) < len(A_size): + C_size.insert(0, 1) + # check batch dims + for i in range(len(A_size) - 2): + bd = max(A_size[i], B_size[i]) + if bd != C_size[i] and C_size[i] != 1: + return False + if len(C_size) > len(A_size): + # This may happen if the last elements of C are contiguous and + # their multiplied size equals the last dim size of B + if M != C_size[len(A_size) - 2] and C_size[len(A_size) - 2] != 1: + return False + remaining_size = 1 + for i in range(len(A_size) - 1, len(C_size)): + remaining_size *= C_size[i] + if N != remaining_size and remaining_size != 1: + return False + return True + assert len(C_size) == len(A_size) + if M != C_size[-2] and C_size[-2] != 1: + return False + if N != C_size[-1] and C_size[-1] != 1: + return False + return True + + def _render_evt( + self, + op: GemmOperation, + evt_py_code: str, + var_name_to_buffer_name: dict[str, str], + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str]: # type: ignore[name-defined] # noqa: F821 + from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace + + name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + + for name in V.graph.constants.keys(): + name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + + # handle the fake output buffer during lowering + name_to_buffer[self.output_node.get_name()] = self.output_node # type: ignore[assignment] + + acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) + output_dtype = torch_dtype_to_cutlass_type(output_dtype) + examples = create_example_tensors( + var_name_to_buffer_name, + name_to_buffer, # type: ignore[arg-type] + V.graph.sizevars.size_hint, + ) + evt_name, evt_args, evt_code = trace( + evt_py_code, + examples, + acc_dtype, + output_dtype, + op.tile_description, # type: ignore[attr-defined] + op.epilogue_schedule, # type: ignore[attr-defined] + {k: name_to_buffer[v] for k, v in var_name_to_buffer_name.items()}, # type: ignore[arg-type,misc] + V.graph.sizevars.size_hint, + ) + + return ( + evt_name, + evt_args, + evt_code, + ) + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + has_bias = len(self.input_nodes) == 3 and self.input_nodes[2] is not None + if has_bias: + Bias = self.input_nodes[2] + # bias dtype + op.C.element = cutlass_utils.torch_dtype_to_cutlass_type( + Bias.get_layout().dtype + ) + + # Bias layout + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + op.C.layout = bias_layout + + # Bias alignment + status = self.set_alignment(Bias.get_layout(), op.C) + if not status: + return False + else: + op.C.element = cutlass_lib.DataType.void + return True + + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions + + emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg] + + if not hasattr(op, "epilogue_functor") or not isinstance( + op.epilogue_functor, enum.Enum + ): + op = copy.deepcopy(op) + op.epilogue_functor = cutlass_lib.EpilogueFunctor.LinearCombination + + op_def = emitter.emit(op) + pattern = re.compile(r"\s*struct\s(.*?)\s:") + decl = [line for line in op_def.split("\n") if "struct " in line][-1] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n" + op_type = f"{op_type}_device_type" + + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + Bias = self.input_nodes[2] if len(self.input_nodes) == 3 else None + inputs: list[Optional[Buffer]] = [] + names: list[str] = [] + return (Bias, inputs, names) + + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + if input_nodes[2] is None: + del arg_names[2] + else: + # Reorder them as Bias, A, B + if self.input_reorder is not None: + arg_names[0 : len(self.input_reorder)] = [ + arg_names[i] for i in self.input_reorder + ] + return arg_names + + def render_gemm_arguments( + self, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = dict( + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + assert epilogue_template is not None + + if should_swap_xw: + # Swap + def clone_with_transposed_stride(node: IRNode) -> IRNode: + old_layout = node.get_layout() + new_stride = list(old_layout.stride) # type: ignore[union-attr] + new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2] + assert old_layout.device is not None + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(old_layout.size), # type: ignore[union-attr] + new_stride, + old_layout.offset, # type: ignore[union-attr] + ) + return Buffer(name=node.get_name(), layout=new_layout) + + new_X = clone_with_transposed_stride(X) + new_W = clone_with_transposed_stride(W) + new_Bias = clone_with_transposed_stride(Bias) + new_Y = clone_with_transposed_stride(Y) + options["X"], options["W"], options["Bias"], options["Y"] = ( + new_W, + new_X, + new_Bias, + new_Y, + ) + options["M"], options["N"] = "N", "M" + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments + + +class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate): + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ): + super().__init__(input_nodes, layout, alpha, beta, input_reorder) + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = False, + **extra_kwargs, + ) -> None: + template = CUTLASS2xGemmTemplate( + input_nodes, layout, alpha, beta, input_reorder + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal, cutlass_lib.GemmKind.Sparse] + + @staticmethod + def _has_tma_epilogue(self) -> bool: + return False + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_2X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return (GEMM_ARGS_SPARSE_CUTLASS_2X, None) + + return (GEMM_ARGS_CUTLASS_2X, None) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for set of operations supported by this class. + + Args: + layouts (List[Layout]): List containing Layout objects representing + the input matrices. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert len(layouts) == 2 or len(layouts) == 3 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) != 2: + return False + if len(B_layout.size) != 2: + return False + A_size = [int(i) for i in A_layout.size] + B_size = [int(i) for i in B_layout.size] + K = max(A_size[1], B_size[0]) + return (K == A_size[1] or K == 2 * A_size[1]) and K == B_size[0] + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + X, W = self.input_nodes[0], self.input_nodes[1] + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return X.get_size()[1] * 2 == W.get_size()[0] + + return X.get_size()[1] == W.get_size()[0] + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + return True + + # SparseGemm in CUTLASS has specific alignment check that for + # small k could make some of the choices throw kMisalignedOperand + # CUTLASS error when run, see: + # https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/include/cutlass/gemm/kernel/sparse_gemm.h#L198-L200 # noqa: B950 + # So, let's skip these choices if that would be the case. + X = self.input_nodes[0] + return (X.get_size()[1] * 2) % op.tile_description.tile_shape[2] == 0 + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + op.C.layout = op.D.layout + return True + + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + if bias_layout != op.D.layout: + # For cutlass2, bias and output layout must match + return False + if not self.set_alignment(Bias.get_layout(), op.C): + return False + else: + op.C.layout = op.D.layout + return True + + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + emitter = cutlass_gemm_op.EmitSparseGemmInstance() + else: + emitter = cutlass_gemm_op.EmitGemmInstance() + op_def = emitter.emit(op) + op_def = op_def.replace( + "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal" + ) + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + op_def = op_def.replace("false,", "") + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = op_def.split("\n")[2] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + Bias = None + Meta = self.input_nodes[2] + else: + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + Meta = None + inputs = [Meta] + names = ["Meta"] + return (Bias, inputs, names) + + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + if input_nodes[3] is None: + del arg_names[3] + if input_nodes[2] is None: + del arg_names[2] + return arg_names + + def render_gemm_arguments( + self, + instance_type: str, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Meta: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + instance_type (str): GEMM instance type. + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Meta (IRNode): The meta tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = dict( + instance_type=instance_type, + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + Meta=Meta, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + + if epilogue_template is None: + arguments = self._template_from_string(argument_template).render( + split_k=1, **options + ) + return arguments + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..82fe188c09e843fd9fbbe6d76e440fe21ee94409 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda/serialization.py @@ -0,0 +1,466 @@ +# mypy: allow-untyped-defs +import enum +import functools +import json +from enum import Enum +from typing import Optional + +from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass + + +class CUTLASSOperationSerializer: + """Serializes and deserializes CUTLASS GEMM operations to/from JSON. + + Handles GemmOperation objects and their nested components (TileDescription, TensorDescription). + """ + + # not used, but keeping in case we want to generalize the serializer + _SUPPORTED_CLASSES: list[str] = [ + "GemmOperation", + "GemmKind", + "TileDescription", + "TensorDescription", + "DataType", + "EpilogueFunctor", + "EpilogueFunctor3x", + "SwizzlingFunctor", + "KernelScheduleType", + "EpilogueScheduleType", + "TileSchedulerType", + ] + + @classmethod + def serialize(cls, operation: "GemmOperation"): # type: ignore[name-defined] # noqa: F821 + """Serialize a GEMM operation to JSON string. + + Args: + operation: GemmOperation object + indent: JSON indentation spaces + + Returns: + str: JSON representation of the operation + """ + assert operation.__class__.__qualname__ == "GemmOperation", ( + "Only GemmOperation objects are supported via the main API" + ) + return json.dumps(cls._gemm_operation_to_json(operation)) + + @classmethod + def deserialize(cls, json_str: str) -> "GemmOperation": # type: ignore[name-defined] # noqa: F821 + """Deserialize JSON string to a GEMM operation. + + Args: + json_str: JSON string of a GEMM operation + + Returns: + GemmOperation: Reconstructed operation + """ + json_dict = json.loads(json_str) + return cls._json_to_gemm_operation(json_dict) + + @classmethod + def _gemm_operation_to_json(cls, operation): + """Convert GemmOperation to JSON-serializable dict. + + Args: + operation: GemmOperation object + + Returns: + dict: Dictionary representation + """ + from cutlass_library.library import TensorDescription + + # Create the main dictionary with required and optional parameters + result = { + # Required parameters + "gemm_kind": cls._enum_to_json(operation.gemm_kind), + "arch": operation.arch, + "tile_description": cls._tile_description_to_json( + operation.tile_description + ), + "A": cls._tensor_description_to_json(operation.A), + "B": cls._tensor_description_to_json(operation.B), + "C": cls._tensor_description_to_json(operation.C), + "element_epilogue": cls._enum_to_json(operation.element_epilogue), + # Optional parameters + "epilogue_functor": cls._enum_to_json(operation.epilogue_functor), + "swizzling_functor": cls._enum_to_json(operation.swizzling_functor), + "D": cls._tensor_description_to_json(operation.D) if operation.D else None, + "kernel_schedule": cls._enum_to_json(operation.kernel_schedule), + "epilogue_schedule": cls._enum_to_json(operation.epilogue_schedule), + "tile_scheduler": cls._enum_to_json(operation.tile_scheduler), + } + + # Process optional attributes + optional_attrs = [ + "mixed_input_mode", + "mixed_input_shuffle", + "ScaleFactorA", + "ScaleFactorB", + "ScaleFactorD", + "ScaleFactorMVecSize", + "ScaleFactorNVecSize", + "ScaleFactorKVecSize", + "ScaleFactorVectorSize", + "is_3x", + ] + + for attr in optional_attrs: + if not hasattr(operation, attr): + continue + + value = getattr(operation, attr) + + if isinstance(value, TensorDescription): + result[attr] = cls._tensor_description_to_json(value) + elif isinstance(value, Enum): + result[attr] = cls._enum_to_json(value) + else: + result[attr] = value + + return result + + @classmethod + def _json_to_gemm_operation(cls, json_dict): + """Convert JSON dict to GemmOperation object. + + Args: + json_dict: Dictionary representation + + Returns: + GemmOperation: Reconstructed object + """ + from cutlass_library import DataType + from cutlass_library.gemm_operation import GemmKind, GemmOperation + from cutlass_library.library import ( + EpilogueFunctor, + EpilogueFunctor3x, + EpilogueScheduleType, + KernelScheduleType, + MixedInputMode, + SwizzlingFunctor, + TileSchedulerType, + ) + + # Extract constructor parameters from the JSON dictionary + gemm_kind = cls._json_to_enum(json_dict["gemm_kind"], GemmKind) + arch = json_dict["arch"] + tile_description = cls._json_to_tile_description(json_dict["tile_description"]) + A = cls._json_to_tensor_description(json_dict.get("A")) + B = cls._json_to_tensor_description(json_dict.get("B")) + C = cls._json_to_tensor_description(json_dict.get("C")) + element_epilogue = cls._json_to_enum(json_dict["element_epilogue"], DataType) + + # Get optional parameters with defaults + epilogue_functor = cls._json_to_enum( + json_dict.get("epilogue_functor"), + EpilogueFunctor3x if json_dict.get("is_3x") else EpilogueFunctor, + ) + swizzling_functor = cls._json_to_enum( + json_dict.get("swizzling_functor"), SwizzlingFunctor + ) + D = cls._json_to_tensor_description(json_dict.get("D")) + kernel_schedule = cls._json_to_enum( + json_dict.get("kernel_schedule"), KernelScheduleType + ) + epilogue_schedule = cls._json_to_enum( + json_dict.get("epilogue_schedule"), EpilogueScheduleType + ) + tile_scheduler = cls._json_to_enum( + json_dict.get("tile_scheduler"), TileSchedulerType + ) + + mixed_input_mode = cls._json_to_enum( + json_dict.get("mixed_input_mode"), MixedInputMode + ) + mixed_input_shuffle = json_dict.get("mixed_input_shuffle", False) + + # Scale factors + ScaleFactorA = cls._json_to_enum(json_dict.get("ScaleFactorA"), DataType) + ScaleFactorB = cls._json_to_enum(json_dict.get("ScaleFactorB"), DataType) + + ScaleFactorD = None + if "ScaleFactorD" in json_dict and "ScaleFactorVectorSize" in json_dict: + ScaleFactorD = { + "tensor": cls._json_to_tensor_description( + json_dict.get("ScaleFactorD") + ), + "vector_size": json_dict.get("ScaleFactorVectorSize"), + } + + ScaleFactorMVecSize = json_dict.get("ScaleFactorMVecSize") + ScaleFactorNVecSize = json_dict.get("ScaleFactorNVecSize") + ScaleFactorKVecSize = json_dict.get("ScaleFactorKVecSize") + + # Create the GemmOperation with the extracted parameters + operation = GemmOperation( + gemm_kind=gemm_kind, + arch=arch, + tile_description=tile_description, + A=A, + B=B, + C=C, + element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor, + D=D, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + mixed_input_mode=mixed_input_mode, + mixed_input_shuffle=mixed_input_shuffle, + ScaleFactorA=ScaleFactorA, + ScaleFactorB=ScaleFactorB, + ScaleFactorD=ScaleFactorD, + ScaleFactorMVecSize=ScaleFactorMVecSize, + ScaleFactorNVecSize=ScaleFactorNVecSize, + ScaleFactorKVecSize=ScaleFactorKVecSize, + ) + + return operation + + @classmethod + def _tile_description_to_json(cls, tile_desc): + """ + Convert TileDescription to JSON dict. + + Args: + tile_desc: TileDescription object + + Returns: + dict: Dictionary representation + """ + if tile_desc is None: + return None + + # Create a dictionary for math_instruction if it exists + math_instruction_dict = None + if ( + hasattr(tile_desc, "math_instruction") + and tile_desc.math_instruction is not None + ): + math_instruction = tile_desc.math_instruction + math_instruction_dict = { + "instruction_shape": math_instruction.instruction_shape, + "element_a": cls._enum_to_json(math_instruction.element_a), + "element_b": cls._enum_to_json(math_instruction.element_b), + "element_accumulator": cls._enum_to_json( + math_instruction.element_accumulator + ), + "opcode_class": cls._enum_to_json(math_instruction.opcode_class), + "math_operation": cls._enum_to_json(math_instruction.math_operation), + } + + # Add element_scale_factor if it exists + if ( + hasattr(math_instruction, "element_scale_factor") + and math_instruction.element_scale_factor is not None + ): + math_instruction_dict["element_scale_factor"] = cls._enum_to_json( + math_instruction.element_scale_factor + ) + + # Create the main dictionary with field names matching TileDescription constructor parameters + result = { + "threadblock_shape": tile_desc.threadblock_shape, + "stages": tile_desc.stages, + "warp_count": tile_desc.warp_count, + "math_instruction": math_instruction_dict, + "min_compute": tile_desc.minimum_compute_capability, # Store as min_compute for constructor + "max_compute": tile_desc.maximum_compute_capability, # Store as max_compute for constructor + "cluster_shape": tile_desc.cluster_shape, + "explicit_vector_sizes": tile_desc.explicit_vector_sizes, + } + + # Add tile_shape if it exists and differs from threadblock_shape + if ( + hasattr(tile_desc, "tile_shape") + and tile_desc.tile_shape != tile_desc.threadblock_shape + ): + result["tile_shape"] = tile_desc.tile_shape + + return result + + @classmethod + def _json_to_tile_description(cls, json_dict): + """ + Convert JSON dict to TileDescription object. + + Args: + json_dict: Dictionary representation + + Returns: + TileDescription: Reconstructed object + """ + if json_dict is None: + return None + + from cutlass_library import DataType + from cutlass_library.library import ( + MathInstruction, + MathOperation, + OpcodeClass, + TileDescription, + ) + + # First, reconstruct the math_instruction if it exists + math_instruction_obj = None + if ( + "math_instruction" in json_dict + and json_dict["math_instruction"] is not None + ): + mi_dict = json_dict["math_instruction"] + + # Convert string enum names back to enum values + element_a = cls._json_to_enum(mi_dict["element_a"], DataType) + element_b = cls._json_to_enum(mi_dict["element_b"], DataType) + element_acc = cls._json_to_enum(mi_dict["element_accumulator"], DataType) + + # Get the opcode_class enum + opcode_class = cls._json_to_enum(mi_dict["opcode_class"], OpcodeClass) + + # Get the math_operation enum + math_op = cls._json_to_enum(mi_dict["math_operation"], MathOperation) + + # Create the MathInstruction object + math_instruction_obj = MathInstruction( + instruction_shape=mi_dict["instruction_shape"], + element_a=element_a, + element_b=element_b, + element_accumulator=element_acc, + opcode_class=opcode_class, + math_operation=math_op, + ) + + # Add element_scale_factor if it exists + if ( + "element_scale_factor" in mi_dict + and mi_dict["element_scale_factor"] is not None + ): + math_instruction_obj.element_scale_factor = cls._json_to_enum( + mi_dict["element_scale_factor"], DataType + ) + + # Get compute capability values, checking both naming conventions + min_compute = json_dict.get( + "min_compute", json_dict.get("minimum_compute_capability") + ) + max_compute = json_dict.get( + "max_compute", json_dict.get("maximum_compute_capability") + ) + + # Get cluster shape with default value + cluster_shape = json_dict.get("cluster_shape", [1, 1, 1]) + + # Create the TileDescription object + tile_desc = TileDescription( + threadblock_shape=json_dict["threadblock_shape"], + stages=json_dict["stages"], + warp_count=json_dict["warp_count"], + math_instruction=math_instruction_obj, + min_compute=min_compute, + max_compute=max_compute, + cluster_shape=cluster_shape, + explicit_vector_sizes=json_dict.get("explicit_vector_sizes"), + ) + + # Set tile_shape if it exists and differs from threadblock_shape + if ( + "tile_shape" in json_dict + and json_dict["tile_shape"] != json_dict["threadblock_shape"] + ): + tile_desc.tile_shape = json_dict["tile_shape"] + + return tile_desc + + @classmethod + def _tensor_description_to_json(cls, tensor_desc): + """Convert TensorDescription to JSON dict. + + Args: + tensor_desc: TensorDescription object + + Returns: + dict: Dictionary representation + """ + if tensor_desc is None: + return None + + return { + "element": cls._enum_to_json(tensor_desc.element), + "layout": cls._enum_to_json(tensor_desc.layout), + "alignment": tensor_desc.alignment, + "complex_transform": cls._enum_to_json(tensor_desc.complex_transform), + } + + @classmethod + def _json_to_tensor_description(cls, tensor_json): + """Convert JSON dict to TensorDescription object. + + Args: + tensor_json: Dictionary representation + + Returns: + TensorDescription: Reconstructed object + """ + if tensor_json is None: + return None + + from cutlass_library import DataType + from cutlass_library.library import ( + ComplexTransform, + LayoutType, + TensorDescription, + ) + + element = cls._json_to_enum(tensor_json["element"], DataType) + layout = cls._json_to_enum(tensor_json["layout"], LayoutType) + alignment = tensor_json["alignment"] + complex_transform = cls._json_to_enum( + tensor_json["complex_transform"], ComplexTransform + ) + + return TensorDescription(element, layout, alignment, complex_transform) + + @classmethod + def _enum_to_json(cls, enum_value): + """Convert enum value to JSON dict. + + Args: + enum_value: Enum value + + Returns: + dict: Dictionary representation + """ + if enum_value is None: + return None + + assert isinstance(enum_value, enum.Enum) + return { + "type": enum_value.__class__.__name__, + "name": enum_value.name, + } + + @classmethod + def _json_to_enum(cls, json_dict, enum_class): + """Convert JSON dict to enum value. + + Format: {name: "EnumName", value: 1} + + Args: + json_dict: Dictionary representation + enum_class: Target enum class + + Returns: + Reconstructed enum value + """ + if json_dict is None or json_dict.get("name", "None") == "None": + return None + return enum_class[json_dict["name"]] + + +@functools.lru_cache(1) +def get_cutlass_operation_serializer() -> Optional[CUTLASSOperationSerializer]: + if not try_import_cutlass(): + return None + return CUTLASSOperationSerializer() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..9dacb0fdaad3919a896d911999727314c73229c8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -0,0 +1,136 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, Optional, TYPE_CHECKING, Union + +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from .cuda.cuda_cpp_scheduling import CUDACPPScheduling +from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling +from .triton import TritonScheduling + + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing_extensions import TypeAlias + + from sympy import Expr + + import torch + from torch.utils._ordered_set import OrderedSet + + from .common import BackendFeature + + _IntLike: TypeAlias = Union[int, Expr] + + +class CUDACombinedScheduling(BaseScheduling): + """ + Scheduler for CUDA Kernels, which delegates calls as appropriate + to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + self._triton_scheduling = TritonScheduling(scheduler) + self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) + self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) + + def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: + return self._triton_scheduling.get_backend_features(device) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): + return self._cuda_cpp_scheduling + if self._rocm_cpp_scheduling.is_rocm_cpp_template(node): + return self._rocm_cpp_scheduling + return self._triton_scheduling + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + elif self._cuda_cpp_scheduling.is_cuda_cpp_template( + node1 + ) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2): + return False + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + for node in (node1, node2): + if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): + return self._cuda_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[_IntLike]] + ) -> tuple[tuple[_IntLike, ...], ...]: + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: + if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): + assert not prologue_nodes + return self._cuda_cpp_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node): + assert not epilogue_nodes + assert not prologue_nodes + return self._rocm_cpp_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: + return self._triton_scheduling.codegen_node(node) + + def codegen_sync(self) -> None: + return self._triton_scheduling.codegen_sync() + + def flush(self) -> None: + return self._triton_scheduling.flush() + + def codegen_combo_kernel(self, *args: Any, **kwargs: Any) -> None: + return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs) + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + return self._triton_scheduling.benchmark_fused_nodes(nodes) + + def benchmark_codegened_module(self, module): + return self._triton_scheduling.benchmark_codegened_module(module) + + def generate_kernel_code_from_nodes( + self, nodes: Sequence[Any], benchmark_kernel: bool = False + ) -> str: + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel + ) + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> tuple[float, float, list[Optional[str]]]: + return self._triton_scheduling.benchmark_combo_kernel(node_list) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d4292c0d24097b193180e9072cc1a8263fe89cc6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/debug_utils.py @@ -0,0 +1,284 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +import os +from enum import Enum +from typing import Callable, Optional + +import torch +from torch import dtype as torch_dtype + +from .. import config +from ..virtualized import V +from .multi_kernel import MultiKernel + + +log = logging.getLogger(__name__) + + +def _print_debugging_tensor_value_info(msg, arg): + # helper for printing debugging stats for intermediate tensor values + # at jit inductor level codegen + max_numel_to_print = 64 + print(msg) + if not isinstance(arg, torch.Tensor): + print("Value: ", arg) + return + numel = arg.float().numel() + # print the debug printing stats + if numel <= max_numel_to_print: + print(arg) + print("Number of elements: ", numel) + print("Size: ", arg.float().size()) + print("Dtype: ", arg.float().mean().item()) + print("Mean: ", arg.float().mean().item()) + print("Min: ", arg.float().min().item()) + print("Max: ", arg.float().max().item()) + print("Std: ", arg.float().std().item()) + + +# AOTI debug printing related configs +class IntermediateValueDebuggingLevel(Enum): + # OFF: No intermediate tensor value debug info will be printed or saved. + OFF = "0" + # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed. + SAVE_ONLY = "1" + # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed. + PRINT_ONLY = "2" + # LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed. + # This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc. + PRINT_KERNEL_NAMES_ONLY = "3" + + +class DebugPrinterManager: + def __init__( + self, + debug_printer_level, + use_array_ref: bool, + writeline: Optional[Callable[..., None]] = None, + args_to_print_or_save: Optional[list[str]] = None, + kernel_name: str = "", + kernel=None, + arg_signatures: Optional[list[type]] = None, + kernel_type=None, + ): + self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level) + self.use_array_ref = use_array_ref + if args_to_print_or_save is None: + args_to_print_or_save = [] + self.args_to_print_or_save = args_to_print_or_save + self.kernel_name = kernel_name + self.arg_signatures: Optional[list[type]] = None + self.kernel = kernel + self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names() + self.kernel_type = None + + def __enter__(self): + self._perform_debug_print_or_save_helper( + self.args_to_print_or_save, + self.kernel_name, + before_launch=True, + arg_signatures=self.arg_signatures, + ) + + def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures): + self._perform_debug_print_or_save_helper( + args_to_print_or_save, + kernel_name, + before_launch=False, + arg_signatures=arg_signatures, + ) + + def _perform_debug_print_or_save_helper( + self, + args_to_print_or_save, + kernel_name, + before_launch, + arg_signatures: Optional[list[type]] = None, + ): + if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF: + return + if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY: + # by default save all the tensor values before launch + self.codegen_intermediate_tensor_value_save( + self.args_to_print_or_save, + self.kernel_name, + before_launch, + arg_signatures=self.arg_signatures, + ) + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # by default print all the tensor values before launch + self.codegen_intermediate_tensor_value_print( + self.args_to_print_or_save, + self.kernel_name, + before_launch, + arg_signatures=self.arg_signatures, + ) + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + # Print all kernel names to the console only + self.codegen_intermediate_tensor_value_print( + [], + self.kernel_name, + before_launch, + ) + + @functools.lru_cache # noqa: B019 + def _get_debug_filtered_kernel_names(self) -> list[str]: + if config.aot_inductor.filtered_kernel_names is None: + return [] + return [ + x.strip() + for x in config.aot_inductor.filtered_kernel_names.lower().split(",") + ] + + def set_printer_args( + self, + args_to_print_or_save: list[str], + kernel_name: str, + arg_signatures: Optional[list[type]], + kernel, + kernel_type=None, + ): + # Note: MultiKernel debug printing is not supported for now + if isinstance(kernel, MultiKernel): + log.info( + "MultiKernel type is not supported in AOTI debug printer tool yet." + ) + self.debug_printer_level = IntermediateValueDebuggingLevel.OFF + + self.kernel_type = kernel_type + # Note: if the kernel type is an extern kernel (or cpp kernel), we do a special handling to + # get the list of args_to_print_or_save + # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls + if kernel_type == "extern": + args_to_print_or_save_extern = [ + arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg")) + ] + self.args_to_print_or_save = args_to_print_or_save_extern + elif kernel_type == "cpp": + self.args_to_print_or_save = [ + ( + f"copy_arrayref_tensor_to_tensor({arg})" + if self.use_array_ref + else arg + ) + for arg in args_to_print_or_save + if arg.startswith(("buf", "arg")) + ] + else: + self.args_to_print_or_save = args_to_print_or_save + self.kernel_name = kernel_name + self.arg_signatures = arg_signatures + self.kernel = kernel + + def codegen_model_inputs_value_print(self, input_args_to_print: list[str]) -> None: + if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY: + return + for arg in input_args_to_print: + if V.graph.cpp_wrapper: + V.graph.wrapper_code.prefix.writeline( + f'aoti_torch_print_tensor_handle({arg}, "aoti_model_inputs - {arg}");' + ) + + def codegen_intermediate_tensor_value_save( + self, + args_to_save, + kernel_name, + before_launch=True, + arg_signatures: Optional[list[type]] = None, + ) -> None: + for i, arg in enumerate(args_to_save): + if arg_signatures is not None and not isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + continue + launch_prefix = "before_launch" if before_launch else "after_launch" + if V.graph.cpp_wrapper: + V.graph.wrapper_code.writeline( + f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' + ) + else: + cwd = os.getcwd() + saved_dir = cwd + "/tmp/jit_inductor/" + if not os.path.exists(saved_dir): + log.info( + "Creating directory to save inductor intermediate tensor values." + ) + os.makedirs(saved_dir) + # Save the model to the directory + saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt" + log.info( + "Saved intermediate tensor %s for %s to %s", + arg, + kernel_name, + saved_path, + ) + line = f"torch.save({arg}, '{saved_path}')" + V.graph.wrapper_code.writeline(line) + + def codegen_intermediate_tensor_value_print( + self, + args_to_print, + kernel_name, + before_launch=True, + arg_signatures: Optional[list[type]] = None, + ) -> None: + launch_prefix = "before_launch" if before_launch else "after_launch" + + # if the debug printing level is PRINT_KERNEL_NAMES_ONLY + # we only print the kernel name to the console + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + if V.graph.cpp_wrapper: + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix}: {kernel_name} ]\\n");' + ) + return + + if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY: + return + for i, arg in enumerate(args_to_print): + # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY, + # check if filtered kernel name list is provided + if ( + len(self.filtered_kernel_names_to_print) > 0 + and kernel_name.lower() not in self.filtered_kernel_names_to_print + ): + continue + if V.graph.cpp_wrapper: + if arg_signatures is not None and isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + V.graph.wrapper_code.writeline( + f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' + ) + elif arg_signatures is not None and isinstance( + arg_signatures[i], + ( + type(torch._inductor.codegen.wrapper.SymbolicCallArg), + type(int), + type(float), + type(bool), + ), + ): + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\\\n");' + ) + else: + if arg_signatures is None and self.kernel_type == "cpp" or "extern": + V.graph.wrapper_code.writeline( + f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' + ) + else: + V.graph.wrapper_code.writeline( + f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})' + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py new file mode 100644 index 0000000000000000000000000000000000000000..1749db7576edbde36145c24ffdc7ceb6c585c83a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/halide.py @@ -0,0 +1,1699 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import functools +import itertools +import logging +import re +from collections import defaultdict +from math import inf +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +import torch._logging + +from ..._prims_common import is_integer_dtype +from ...utils._ordered_set import OrderedSet +from ...utils._sympy.functions import FloorDiv, ModularIndexing +from ...utils._sympy.symbol import symbol_is_type, SymT +from ...utils._sympy.value_ranges import ValueRanges +from .. import config, ir +from ..codecache import HalideCodeCache +from ..ir import get_reduction_combine_fn +from ..metrics import is_metric_table_enabled, log_kernel_metadata +from ..ops_handler import AddParenHandler +from ..runtime.hints import HalideInputSpec, HalideMeta +from ..utils import ( + get_bounds_index_expr, + get_kernel_metadata, + parallel_num_threads, + sympy_index_symbol, + sympy_subs, +) +from ..virtualized import _ops as ops, V +from .common import ( + BackendFeature, + CSEVariable, + DeferredLine, + IndentedBuffer, + KernelArgType, + OpOverrides, + PythonPrinter, + SizeArg, + TensorArg, +) +from .cpp import DTYPE_TO_CPP +from .cpp_utils import cexpr +from .simd import constant_repr, SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..ops_handler import ReductionType, StoreMode + +log = logging.getLogger(__name__) + + +def halide_constant(val): + if isinstance(val, int) and not (-2147483648 <= val <= 2147483647): + info = torch.iinfo(torch.int64) + if val == info.min: + return "hl.Int(64).min()" + if val == info.max: + return "hl.Int(64).max()" + return f"hl.i64({val!r})" + if isinstance(val, float): + return f"hl.f64({constant_repr(val)})" + return repr(val) + + +class Unsupported(RuntimeError): + def __init__(self, thing) -> None: + super().__init__(f"halide backend does not support: {thing}") + + +class HalidePrinter(PythonPrinter): + @staticmethod + def cast_index(expr): + return f"hl.cast({V.kernel.index_dtype}, {expr})" + + @staticmethod + def cast_float(expr): + return f"hl.cast(hl.Float(32), {expr})" + + def _print_Float(self, expr): + return f"hl.f32({expr})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"hl.f32({self._print(expr.args[0])})" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.floor({self._print(expr.args[0])})") + + _print_FloorToInt = _print_floor + + def _print_Trunc(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.trunc({self._print(expr.args[0])})") + + _print_TruncToInt = _print_Trunc + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.ceil({self._print(expr.args[0])})") + + def _helper_sqrt(self, expr): + return f"hl.sqrt({self.cast_float(self._print(expr))})" + + def _print_Where(self, expr): + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"hl.select({c}, {p}, {q})" + + def _print_Min(self, expr): + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Min(*expr.args[:mid])) + b = self._print(sympy.Min(*expr.args[mid:])) + return f"hl.min({a}, {b})" + + def _print_Max(self, expr): + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Max(*expr.args[:mid])) + b = self._print(sympy.Max(*expr.args[mid:])) + + return f"hl.max({a}, {b})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.abs({self._print(expr.args[0])})") + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"hl.cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"hl.cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"hl.acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"hl.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"hl.sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"hl.asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"hl.tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"hl.tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"hl.atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr): + raise NotImplementedError("log2") + + def _print_FloorDiv(self, expr): + if expr.is_integer: + return super()._print_FloorDiv(expr) + + x, div = expr.args + x = self.cast_float(self.doprint(x)) + div = self.cast_float(self.doprint(div)) + return self.cast_index(f"hl.floor({x} / {div})") + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.round({self._print(expr.args[0])})") + + _print_RoundToInt = _print_Round + + def _print_IntTrueDiv(self, expr): + a, b = expr.args + # force a cast to float + return f"({a}) / ({b}+hl.f32(0))" + + def _print_RoundDecimal(self, expr): + val, n = expr.args + val = self._print(val) + n = int(n) + return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))" + + +texpr = HalidePrinter().doprint +pexpr = PythonPrinter().doprint + + +_halide_type = { + torch.bool: "hl.Bool()", + torch.bfloat16: "hl.BFloat(16)", + torch.float16: "hl.Float(16)", + torch.float32: "hl.Float(32)", + torch.float64: "hl.Float(64)", + torch.int8: "hl.Int(8)", + torch.int16: "hl.Int(16)", + torch.int32: "hl.Int(32)", + torch.int64: "hl.Int(64)", + torch.uint8: "hl.UInt(8)", + torch.uint16: "hl.UInt(16)", + torch.uint32: "hl.UInt(32)", + torch.uint64: "hl.UInt(64)", +} + + +def halide_type(dtype): + return _halide_type[dtype] + + +def halide_acc_type(dtype): + if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64: + dtype = torch.int32 + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + return halide_type(dtype) + + +class HalideOverrides(OpOverrides): + @staticmethod + def to_dtype( + x, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ): + if dtype == torch.bool: + return f"({x} != 0)" + return f"hl.cast({halide_type(dtype)}, {x})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + if src_dtype in (torch.float16, torch.bfloat16): + x = f"hl.cast({halide_type(src_dtype)}, {x})" # body compute is upcast to fp32 + line = f"hl.reinterpret({halide_type(dtype)}, {x})" + if dtype in (torch.float16, torch.bfloat16): + line = f"hl.cast(hl.Float(32), {line})" + return line + + @classmethod + def constant(cls, value, dtype): + return cls.to_dtype(halide_constant(value), dtype) + + @staticmethod + def abs(x): + return f"hl.abs({x})" + + @staticmethod + def exp(x): + if not hasattr(x, "name"): + return f"hl.exp({x})" + return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})" + + @staticmethod + def sqrt(x): + return f"hl.sqrt({x})" + + @staticmethod + def minimum(a, b): + # return f"hl.min({a}, {b})" <== handles nan wrong + if not hasattr(a, "name"): + return f"hl.min({a}, {b})" + b = f"hl.cast({a.name}.type(), {b})" + return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})" + + @staticmethod + def maximum(a, b): + # return f"hl.max({a}, {b})" <== handles nan wrong + if not hasattr(a, "name"): + return f"hl.max({a}, {b})" + b = f"hl.cast({a.name}.type(), {b})" + return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})" + + @staticmethod + def where(a, b, c): + if hasattr(b, "name"): + c = f"hl.cast({b.name}.type(), {c})" + return f"hl.select({a}, {b}, {c})" + + @staticmethod + def cos(x): + return f"hl.cos({x})" + + @staticmethod + def sin(x): + return f"hl.sin({x})" + + @staticmethod + def lgamma(x): + raise Unsupported("lgamma") + + @staticmethod + def erf(x): + return f"hl.erf({x})" + + @staticmethod + def cosh(x): + return f"hl.cosh({x})" + + @staticmethod + def sinh(x): + return f"hl.sinh({x})" + + @staticmethod + def acos(x): + return f"hl.acos({x})" + + @staticmethod + def acosh(x): + return f"hl.acosh({x})" + + @staticmethod + def asin(x): + return f"hl.asin({x})" + + @staticmethod + def asinh(x): + return f"hl.asinh({x})" + + @staticmethod + def atan2(x, y): + return f"hl.atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"hl.atan({x})" + + @staticmethod + def atanh(x): + return f"hl.atanh({x})" + + @staticmethod + def copysign(x, y): + raise Unsupported("copysign") + + @staticmethod + def erfinv(x): + raise Unsupported("erfinv") + + @staticmethod + def hypot(x, y): + return f"hl.hypot({x}, {y})" + + @staticmethod + def nextafter(x, y): + raise Unsupported("nextafter") + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + return f"halide_helpers.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + return f"halide_helpers.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}" + + @staticmethod + def rsqrt(x): + # return f"hl.fast_inverse_sqrt({x})" <== accuracy issues + return f"1./hl.sqrt({x})" + + @staticmethod + def tan(x): + return f"hl.tan({x})" + + @staticmethod + def tanh(x): + return f"hl.tanh({x})" + + @staticmethod + def signbit(x): + return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0" + + @staticmethod + def fmod(a, b): + # TODO(jansel): find a better way to do this, builtin % has wrong sign + return f"{a} - hl.trunc({a}/{b})*{b}" + + @staticmethod + def pow(a, b): + return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy + + @staticmethod + def log(x): + return f"hl.log({x})" # hl.fast_log fails accuracy + + @staticmethod + def log2(x): + raise NotImplementedError("log2") + + @staticmethod + def isinf(x): + # workaround https://github.com/halide/Halide/issues/8309 + return f"hl.is_inf(hl.cast(hl.Float(32), {x}))" + + @staticmethod + def isnan(x): + # workaround https://github.com/halide/Halide/issues/8309 + return f"hl.is_nan(hl.cast(hl.Float(32), {x}))" + + @staticmethod + def round(x): + return f"hl.round({x})" + + @staticmethod + def floor(x): + return f"hl.floor({x})" + + @staticmethod + def int_truediv(a, b): + return f"({a}) / ({b} + hl.f32(0))" + + @staticmethod + def floordiv(a, b): + # TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work + return ( + f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" + ) + + @classmethod + def sign(cls, x): + left = ops.to_dtype(ops.lt("0", x), torch.int8) + right = ops.to_dtype(ops.lt(x, "0"), torch.int8) + sub = ops.sub(left, right) + return f"hl.cast({x.name}.type(), {sub})" + + @staticmethod + def trunc(x): + return f"hl.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # this causes crashes with floating point exception, see test_div_zero_dim_cpu + # return f"hl.div_round_to_zero({a}, {b})" + return ( + f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" + ) + + @staticmethod + def ceil(x): + return f"hl.ceil({x})" + + @staticmethod + def relu(x): + return f"hl.max({x}, 0)" + + @classmethod + def index_expr(cls, expr, dtype): + index = V.kernel.prepare_indexing(expr) + var = V.kernel.genfunc( + V.kernel.index_to_str(index), + V.kernel.used_dims_from_index(index), + bounds=get_bounds_index_expr(expr), + ) + if dtype not in (torch.int32, torch.int64): + return ops.to_dtype(var, dtype) + return var + + @classmethod + def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True): + # TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow + index_var = ops.to_dtype(index_var, torch.int32) + index_var = ops.halide_clamp(index_var, size, check) + index_var.indirect_indexing_size = size + return sympy_index_symbol(str(index_var)) + + @classmethod + def halide_clamp(cls, value, size, check): + end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1) + if not isinstance(size, (int, sympy.Integer)): + end = f"hl.cast({value.name}.type(), {end})" + # Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692 + # return f"hl.unsafe_promise_clamped({value}, 0, {end})" + return f"hl.clamp({value}, 0, {end})" + + @staticmethod + def masked(mask, body, other): + with V.kernel.mask_loads(mask, other) as new_mask: + result = body() + + if result.bounds.is_bool: + other = bool(other) + + # Take dtype from result to prevent accidental promotion + other = V.kernel.genfunc( + f"hl.cast({result.name}.type(), {halide_constant(other)})", + [], + bounds=ValueRanges.wrap(other), + ) + # TODO(jansel): look into removing the where in the same places triton does + return ops.where(new_mask, result, other) + + @staticmethod + def frexp(x): + raise NotImplementedError("frexp") + + +HalideOverrides._initialize_pointwise_overrides("halide") + + +class HalideCSEVariable(CSEVariable): + undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") + + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(name, bounds, dtype) + self.used_dims: Optional[list[sympy.Symbol]] = None + + def update_on_args(self, name, args, kwargs): + used = OrderedSet(self.used_dims or ()) + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, HalideCSEVariable): + assert arg.used_dims is not None, (name, arg, args) + used.update(arg.used_dims) + self.used_dims = V.kernel.sort_used_dims(used) + + def index_str(self, dims): + if len(dims) == 0: + return f"{self.name}[()]" + # Reversed since Halide is column major + return f"{self.name}[{', '.join(map(str, dims))}]" + + def __str__(self) -> str: + if self.used_dims is None: + # This will get recomputed and replaced in codegen_kernel() + return f"{self.name}[?]" + return self.index_str(self.used_dims) + + def subs_str(self, replacements): + assert self.used_dims is not None and all( + isinstance(x, sympy.Expr) for x in self.used_dims + ) + return self.index_str([replacements.get(n, n) for n in self.used_dims]) + + +@dataclasses.dataclass +class DimensionInfo: + expr: Optional[sympy.Expr] + size: sympy.Expr + stride: sympy.Expr + + def __init__(self, expr, size, stride) -> None: + super().__init__() + if V.graph.sizevars.statically_known_lt(stride, 0): + stride = -stride + expr = -expr + self.expr = expr + self.size = size + self.stride = stride + + def index_str(self, replacements=None, zero_vars=False): + assert self.expr is not None + expr = self.expr + if zero_vars and expr == 0: + return "hl.Var()" + if replacements: + replacements = {**replacements} + for sym in expr.free_symbols: + if symbol_is_type(sym, SymT.TMP): + assert isinstance(sym, sympy.Symbol) + var = V.kernel.lookup_cse_var(sym.name) + assert isinstance(var, HalideCSEVariable) + replacements[sym] = sympy_index_symbol(var.subs_str(replacements)) + expr = sympy_subs(expr, replacements) + return V.kernel.index_to_str(expr) + + +def eq(left, right): + if V.graph.sizevars.statically_known_equals(left, right): + return True + try: + a = V.graph.sizevars.size_hint(left) + b = V.graph.sizevars.size_hint(right) + except TypeError: # unbacked symints + return False + if a == b: + V.graph.sizevars.guard_equals(left, right) + return a == b + + +def lt(left, right): + if V.graph.sizevars.statically_known_lt(left, right): + return True + try: + a = V.graph.sizevars.size_hint(left) + b = V.graph.sizevars.size_hint(right) + except TypeError: # unbacked symints + gcd = sympy.gcd(left, right) + if gcd == left: + return left != right + return False + if a < b: + V.graph.sizevars.guard_lt(left, right) + return a < b + + +class HalideKernel(SIMDKernel): + overrides = HalideOverrides # type: ignore[assignment] + kexpr: Callable[[sympy.Expr], str] = texpr + + def __init__( + self, + tiling: dict[str, sympy.Expr], + **kwargs, + ) -> None: + super().__init__(tiling, **kwargs) + # For halide, we just write directly to the body + self.compute = self.body + self.loads = self.body + self.stores = self.body + self.indexing_code_dom = IndentedBuffer() + self.needs_dom_indexing = self.inside_reduction + self.has_reduction = self.inside_reduction + self.buffer_dimensions: dict[str, list[DimensionInfo]] = {} + self.buffer_offsets: dict[str, sympy.Expr] = {} + # {h0: size1, h1: size2, ...} + self.halide_vars: dict[sympy.Symbol, sympy.Expr] = {} + # {x0: h0, x1: h1+10*h2, ...} + self.index_replacements: dict[sympy.Expr, sympy.Expr] = {} + # {h1: hr1, ...} + self.reduction_renames: dict[sympy.Symbol, sympy.Symbol] = {} + # {"i": {h0: hi0}, "o": ...} + self.dom_renames: dict[str, dict[sympy.Symbol, sympy.Symbol]] = {} + # {"in_ptr0": ["in_ptr0_view0"], ...} + self.buffer_aliases: dict[str, list[str]] = defaultdict(list) + self.has_indirect_indexing = False + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return halide_type(dtype) + + def create_cse_var(self, name, bounds=None, dtype=None): + self.body.writeline(f"{name} = hl.Func({name!r})") + return HalideCSEVariable(name, bounds, dtype) + + def finalize_indexing(self, indices: Sequence[sympy.Expr]): + """ + Hook called right before codegen with every index that will be + used in the fused kernel. + + This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing + scheme that avoids using divide and modulus. Instead of xindex/yindex/rindex + we base indexing on a larger number of vars whose product combines to those. + + This function populates self.halide_vars, self.index_replacements, and self.reduction_renames + """ + assert not ( + self.index_replacements or self.halide_vars or self.reduction_renames + ) + size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type] + indices = dict.fromkeys(map(super().prepare_indexing, indices)) + all_used_symbols = OrderedSet[Any]() + sym_to_node = { + n.symbol(): n + for n in itertools.chain.from_iterable( + [tree.nodes.values() for tree in self.range_trees] + ) + } + + def simplify(expr): + return sympy.simplify( + V.graph.sizevars.remove_precomputed_replacements(expr) + ) + + def visit_modular_indexing(base, divisor, modulus): + if base in sym_to_node: + node = sym_to_node[base] + all_used_symbols.add( + node.root.lookup( + node.divisor * divisor, + V.graph.sizevars.evaluate_min( + modulus, FloorDiv(node.length, divisor) + ), + ).symbol() + ) + + def visit_floor_div(base, divisor): + if base in sym_to_node: + node = sym_to_node[base] + all_used_symbols.add( + node.root.lookup( + node.divisor * divisor, + FloorDiv(node.length, divisor), + ).symbol() + ) + + # first figure out all_used_symbols to do dead symbol elimination + for index in indices: + if index.has(ModularIndexing): + index.replace( + ModularIndexing( + sympy.Wild("base"), + sympy.Wild("divisor"), + sympy.Wild("modulus"), + ), + visit_modular_indexing, + ) + if index.has(FloorDiv): + index.replace( + FloorDiv( + sympy.Wild("base"), + sympy.Wild("divisor"), + ), + visit_floor_div, + ) + all_used_symbols.update(super().prepare_indexing(index).free_symbols) + + self.has_indirect_indexing = any( + symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols + ) + + had_fallback = False + for tree in reversed(self.range_trees): + nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols] + nodes.sort(key=lambda n: size_hint(n.divisor)) + if not nodes: + nodes.append(tree.lookup(1, tree.numel)) + handled_count = 0 + divisor = sympy.S.One + added_sym_size = [] + # decide on a minimal set of symbols and put them in self.halide_vars + while handled_count < len(nodes) and not eq(tree.numel, divisor): + sizes_to_add = [ + simplify(n.length) for n in nodes if eq(n.divisor, divisor) + ] + handled_count += len(sizes_to_add) + assert sizes_to_add, nodes + end = divisor * functools.reduce( + V.graph.sizevars.evaluate_max, sizes_to_add + ) + sizes_to_add.extend( + [ + simplify(n.divisor / divisor) + for n in nodes + if lt(divisor, n.divisor) and lt(n.divisor, end) + ] + ) + while sizes_to_add: + next_size = functools.reduce(sympy.gcd, sizes_to_add) + if eq(next_size, 1): + # sizes share no common factors, e.g [2, 21, 42, 441, 889056] + # TODO(jansel): we should just prevent fusion in cases that hit this + next_size = simplify(tree.numel / divisor) + assert not eq(next_size, 1) + sizes_to_add = [] + handled_count = len(nodes) + had_fallback = True + sym = sympy_index_symbol(f"h{len(self.halide_vars)}") + if tree.is_reduction: + self.reduction_renames[sym] = sympy_index_symbol( + f"hr{len(self.halide_vars)}" + ) + self.halide_vars[sym] = next_size + added_sym_size.append((sym, next_size)) + divisor *= next_size + new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)] + handled_count += len(new_sizes) + prior_len = len(sizes_to_add) + sizes_to_add = [ + sympy.simplify(s / next_size) + for s in sizes_to_add + if not eq(s, next_size) + ] + assert len(sizes_to_add) < prior_len or prior_len == 0 + sizes_to_add.extend(new_sizes) + + # create a mapping to the new set of symbols in self.index_replacements + for node in nodes: + try: + idx = 0 + divisor = 1 + while not eq(node.divisor, divisor): + sym, size = added_sym_size[idx] + idx += 1 + divisor *= size + length = 1 + expr = sympy.S.Zero + while not eq(node.length, length): + sym, size = added_sym_size[idx] + idx += 1 + expr += length * sym + length *= size + self.index_replacements[node.symbol()] = expr + except IndexError: + assert had_fallback + full_index = sympy.S.Zero + stride = sympy.S.One + for sym, size in added_sym_size: + full_index += stride * sym + stride *= size + self.index_replacements[node.symbol()] = ( + V.graph.sizevars.simplify_with_ranges( + ModularIndexing(full_index, node.divisor, node.length), + self.halide_vars, # type: ignore[arg-type] + ) + ) + + # codegen the variable definitions + for sym in self.halide_vars: + self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})") + if self.reduction_renames: + self.codegen_rdom( + "rdom", + {rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()}, + ) + + def setup_dom_indexing(self): + """RDom based indexing uses explicit iteration ranges for Func updates""" + prefix = "i" if self.inside_reduction else "o" + if prefix in self.dom_renames: + return self.dom_renames[prefix] + + renames = {} + for var in self.halide_vars.keys(): + if not self.inside_reduction and var in self.reduction_renames: + continue + m = re.match(r"^h(\d+)$", var.name) + assert m + renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}") + + self.codegen_rdom( + f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()} + ) + + self.dom_renames[prefix] = renames + return renames + + def codegen_rdom(self, name, vars): + rsizes = [ + f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})" + for size in vars.values() + ] + self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])") + for i, rsym in enumerate(vars.keys()): + self.indexing_code.writeline(f"{rsym} = {name}[{i}]") + + def prepare_indexing( + self, + index: sympy.Expr, + ): + index = super().prepare_indexing(index) + index = sympy_subs(index, self.index_replacements) + return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars) # type: ignore[arg-type] + + def sym_size(self, sym): + """The size of an index symbol""" + if symbol_is_type(sym, SymT.TMP): + return self.lookup_cse_var(sym.name).indirect_indexing_size + return self.halide_vars[sym] + + def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool): + """Convert address-based indexing into dimensions using self.halide_vars""" + symbols = [] + for sym in sorted(index.free_symbols, key=lambda x: x.name): # type: ignore[attr-defined] + if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)): + symbols.append(sym) + else: + assert symbol_is_type( + sym, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + ), + ), sym + + # group the expression by variables used + offset = sympy.S.Zero + split_expr = dict.fromkeys(symbols, sympy.S.Zero) + split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = [] + index = sympy.expand(self.rename_indexing(index)) + for part in index.args if isinstance(index, sympy.Add) else [index]: + part_vars = [v for v in part.free_symbols if v in split_expr] + if len(part_vars) == 0: + offset += part + elif len(part_vars) == 1: + split_expr[part_vars[0]] += part + else: + new_split_failed = [] + for i in range(len(split_failed)): + assert split_failed[i] is not None + other_vars, other_part = split_failed[i] + if OrderedSet(other_vars) & OrderedSet(part_vars): + part_vars.extend([v for v in other_vars if v not in part_vars]) + part += other_part + else: + new_split_failed.append((other_vars, other_part)) + split_failed = [*new_split_failed, (part_vars, part)] + + def expr_to_dimension(expr, syms): + expr = sympy.factor(expr) + if len(syms) == 1: + stride_wild = sympy.Wild("wild", exclude=symbols) + m = expr.match(stride_wild * syms[0]) + if m: + return DimensionInfo( + syms[0], self.sym_size(syms[0]), m[stride_wild] + ) + assert not is_store, expr + length = sympy.simplify( + sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1 + ) + stride = sympy.S.One + if isinstance(expr, sympy.Mul): + for term in expr.args: + if isinstance(term, sympy.Integer): + stride *= term + expr = sympy.simplify(expr / term) + length = sympy.simplify(sympy.ceiling(length / term)) + return DimensionInfo(expr, length, stride) + + # try to turn each group into a strided access + dims = [] + for syms, expr in split_failed: + for v in syms: + expr += split_expr.pop(v) + dims.append(expr_to_dimension(expr, syms)) + for sym, expr in split_expr.items(): + dims.append(expr_to_dimension(expr, [sym])) + dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf)) # type: ignore[arg-type] + + if not dims: # scalar load/store + if self.has_indirect_indexing: + # workaround https://github.com/halide/Halide/issues/8338 + dims.append(DimensionInfo(sympy.S.Zero, 1, 1)) + elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1): + # Halide assumes dimension 0 is stride == 1, so add a dummy dimension + dims.insert( + 0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1) + ) + + if dims and not is_store: + if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq( + offset, self.buffer_offsets[var] + ): + # reuse the existing offset to avoid needing an input alias + self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var]) + offset = self.buffer_offsets[var] + elif V.graph.sizevars.statically_known_gt( + offset, 0 + ): # TODO(jansel): negative offsets + # roll the offset into the dimensions for cleaner indexing + self.apply_offset_to_dimension(dims, offset) + offset = 0 + + orig_var = var + for i in itertools.count(): + if self.install_dims(var, dims, offset, is_store): + return var, dims + assert not is_store + var = f"{orig_var}_view{i}" + if var not in self.buffer_aliases[orig_var]: + self.buffer_aliases[orig_var].append(var) + + def install_dims(self, var, dims, offset, is_store): + """Try to set self.buffer_dimensions[var], return True on success""" + if var not in self.buffer_dimensions: + self.buffer_dimensions[var] = dims + self.buffer_offsets[var] = offset + return True + if self.buffer_offsets[var] != offset or len( + self.buffer_dimensions[var] + ) != len(dims): + return False + if is_store: + return self.buffer_dimensions[var] == dims + for old, new in zip(self.buffer_dimensions[var], dims): + if old.stride != new.stride: + return False + if old.size != new.size or old.expr != new.expr: + old.size = V.graph.sizevars.evaluate_max(old.size, new.size) + old.expr = None + return True + + def apply_offset_to_dimension(self, dims, offset): + if offset == 0: + return + for i in reversed(range(len(dims))): + if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq( + offset, dims[i].stride + ): + part = FloorDiv(offset, dims[i].stride) + offset -= part * dims[i].stride + dims[i].expr += part + assert offset == 0 + + def used_dims_from_index(self, index: sympy.Expr): + """Detect which range trees are used to populate HalideCSEVariable.used_dims""" + used_dims = OrderedSet[sympy.Symbol]() + for sym in index.free_symbols: + assert isinstance(sym, sympy.Symbol) + if symbol_is_type(sym, SymT.TMP): + # indirect indexing + cse_var = self.lookup_cse_var(sym.name) + assert ( + isinstance(cse_var, HalideCSEVariable) + and cse_var.used_dims is not None + ) + used_dims.update(cse_var.used_dims) + elif symbol_is_type(sym, SymT.HALIDE): + used_dims.add(sym) + elif symbol_is_type( + sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX) + ): + pass + else: + raise NotImplementedError(f"unhandled symbol {sym}") + return self.sort_used_dims(used_dims) + + def sort_used_dims(self, used_dims): + assert all(isinstance(x, sympy.Expr) for x in used_dims) + ordered = [ + sym + for sym in itertools.chain( + self.halide_vars, self.reduction_renames.values() + ) + if sym in used_dims + ] + assert len(ordered) == len(used_dims) + return ordered + + def make_index_str(self, dims, replacements=None, zero_vars=False): + index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims) + if len(dims) == 0: + index_str = "()" + elif len(dims) == 1: + # workaround for https://github.com/halide/Halide/issues/8299 + index_str = f"{index_str}," + return index_str + + def load(self, name: str, index: sympy.Expr): + """Codegen a load from an InputBuffer""" + var = self.args.input(name) + index = self.prepare_indexing(index) + var, dims = self.indexing_to_dimensions(var, index, False) + line = f"{var}[{self.make_index_str(dims)}]" + dtype = V.graph.get_dtype(name) + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + line = f"hl.cast(hl.Float(32), {line})" + + if self._load_mask: + assert ( + isinstance(self._load_mask, HalideCSEVariable) + and self._load_mask.used_dims is not None + ) + used_dims = OrderedSet( + (*self.used_dims_from_index(index), *self._load_mask.used_dims) + ) + result = self.newfunc(self.sort_used_dims(used_dims)) + if result.used_dims: + self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])") + self.body.writeline(f"{result.name}_mask.where({self._load_mask})") + other = self.kexpr(self._load_other or 0) # type: ignore[arg-type] + self.body.writeline( + f"{result} = hl.cast({halide_type(dtype)}, {other})" + ) + self.body.writeline( + f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)" + ) + else: + # scalar case + self.body.writeline( + f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))" + ) + return result + else: + return self.genfunc(line, self.used_dims_from_index(index)) + + def lookup_cse_var(self, name: str): + return self.cse.varname_map[re.sub(r"\[.*", "", name)] + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + """Codegen a store to an OutputBuffer""" + assert isinstance(value, HalideCSEVariable) + var = self.args.output(name) + index = self.prepare_indexing(index) + var, dims = self.indexing_to_dimensions(var, index, True) + if self.is_indirect_indexing(index) or mode is not None: + replacements = self.setup_dom_indexing() + index_str = self.make_index_str(dims, replacements) + value_str = value.subs_str(replacements) + undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()" + self.body.writeline( + DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())") + ) + else: + index_str = self.make_index_str(dims, zero_vars=True) + value_str = str(value) + + dtype = V.graph.get_dtype(name) + if mode is None: + line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})" + elif mode == "atomic_add": + line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})" + else: + raise NotImplementedError(f"store mode={mode}") + self.body.writeline(DeferredLine(name, line)) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + """Codegen a reduction operation""" + assert self.inside_reduction + assert not self._load_mask + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + if isinstance(value, tuple): + assert reduction_type == "welford_combine" + self.cse.reduction_cache[cache_key] = result_tuple = ( + self.welford_combine_impl(*value) + ) + return result_tuple + + assert isinstance(value, HalideCSEVariable) and value.used_dims is not None + reduction_vars = OrderedSet(self.reduction_renames) + result_var = self.newfunc( + [v for v in value.used_dims if v not in reduction_vars] + ) + if reduction_vars - OrderedSet(value.used_dims): + value = self.genfunc( + f"{value}", + self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))), + ) + value_str = value.subs_str(self.reduction_renames) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + acc_type = halide_acc_type(dtype) + + if reduction_type in ("argmax", "argmin"): + index = f"{result_var.name}_{reduction_type}" + self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})") + # turn the N-D argmax index into a 1-D one + parts = [] + stride = 1 + for i, sym in enumerate(self.reduction_renames): + parts.append(f"{index}[{i}]") + if stride != 1: + parts[-1] += f"*{stride}" + stride *= self.halide_vars[sym] + self.body.writeline(f"{result_var} = {' + '.join(parts)}") + elif reduction_type == "welford_reduce": + # TODO(jansel): implement welford_reduce without fallback + result_var = self.welford_reduce_fallback(dtype, value) + else: + combine_fn = get_reduction_combine_fn(reduction_type, acc_type) + with V.set_ops_handler(AddParenHandler(HalideOverrides())): + combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type] + default_str = f"hl.cast({acc_type}, {halide_constant(default)})" + self.body.writeline(f"{result_var} = {default_str}") + self.body.writeline(f"{result_var} = {combine_str}") + + self.cse.reduction_cache[cache_key] = result_var + return result_var + + def welford_combine_impl(self, mean, m2, weight): + assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None + assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None + assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None + used_dims = OrderedSet( + (*mean.used_dims, *m2.used_dims, *weight.used_dims) or self.halide_vars + ) + used_dims -= OrderedSet(self.reduction_renames) + result_var = self.newfunc(self.sort_used_dims(used_dims)) + default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)] + pfx = result_var.name + self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])") + self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]") + self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]") + self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]") + self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}") + self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}") + self.body.writeline( + f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}" + ) + self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1") + self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2") + self.body.writeline( + f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)" + ) + update = [ + f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w", + f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w", + f"{pfx}_new_weight", + ] + self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])") + + unpacked = [] + for i in range(3): + unpacked.append(self.newfunc(result_var.used_dims)) + self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]") + return tuple(unpacked) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values_orig: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + assert self.inside_reduction + assert len(dtypes) == len(values_orig) + values: list[HalideCSEVariable] = [] + all_used_dims = OrderedSet[sympy.Symbol]() + + for value in values_orig: + assert isinstance(value, HalideCSEVariable) and value.used_dims is not None + if OrderedSet(value.used_dims) & OrderedSet(self.reduction_renames): + values.append(value) + else: + values.append( + self.genfunc( + f"{value}", [*value.used_dims, [*self.reduction_renames][:1]] + ) + ) + all_used_dims.update(value.used_dims) + result_var = self.newfunc(self.sort_used_dims(all_used_dims)) + assert result_var.used_dims and OrderedSet(result_var.used_dims) & OrderedSet( + self.reduction_renames + ) + initial = [ + f"hl.cast({halide_acc_type(dtype)}, {value})" + for dtype, value in zip(dtypes, values) + ] + + length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel)) + scan_dom = f"{result_var.name}_rdom" + scan = f"{scan_dom}.x" + self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])") + + assert len(self.reduction_renames) == 1, ( + "multi-dimensional scan not implemented" + ) + (scan_var,) = [*self.reduction_renames] # type: ignore[misc] + scan_renames_cur = {scan_var: sympy_index_symbol(scan)} + scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1} + + if len(values) == 1: + + def maybe_tuple(x): + return x[0] + + read_left = [result_var.subs_str(scan_renames_pri)] + read_right = [result_var.subs_str(scan_renames_cur)] + else: + + def maybe_tuple(x): + return f"hl.Tuple([{', '.join(x)}])" + + read_left = [ + result_var.subs_str(scan_renames_pri) + f"[{i}]" + for i in range(len(values)) + ] + read_right = [ + result_var.subs_str(scan_renames_cur) + f"[{i}]" + for i in range(len(values)) + ] + + self.body.writeline(f"{result_var} = {maybe_tuple(initial)}") + + # Disable CSE for update fn + with V.set_ops_handler(AddParenHandler(HalideOverrides())): + combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type] + self.body.writeline( + f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}" + ) + + if len(values) == 1: + return (result_var,) + + unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values] + for i, v in enumerate(unpack_vars): + self.body.writeline(f"{v} = {result_var}[{i}]") + return tuple(unpack_vars) + + def genfunc( + self, line, used_dims, *, bounds=ValueRanges.unknown() + ) -> HalideCSEVariable: + var = self.cse.generate(self.body, line, bounds=bounds) + assert isinstance(var, HalideCSEVariable) + var.used_dims = used_dims + return var + + def newfunc(self, used_dims) -> HalideCSEVariable: + var = self.cse.newvar() + assert isinstance(var, HalideCSEVariable) + var.used_dims = used_dims + return var + + def halide_buffer_numel(self, name: str): + """ + We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch + supports. If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while + PyTorch's numel excludes them. + """ + return V.graph.get_buffer(name).get_layout().storage_size() + + def halide_argdefs(self): + """ + Halide requires scalar inputs before outputs, so need to reorder args. + """ + + def arg_order(arg_tuple): + _call_str, arg = arg_tuple + if isinstance(arg, SizeArg): + return 1 # this would normally be at the end, move it to middle + elif "out_ptr" in arg.name: + return 2 + else: + assert "in_ptr" in arg.name + return 0 + + result: list[tuple[Optional[str], KernelArgType]] = [] + _, a, b, _ = self.args.python_argdefs() + for call_str, arg in sorted(zip(a, b), key=arg_order): + result.append((call_str, arg)) + if isinstance(arg, TensorArg): + assert arg.offset == 0 and arg.alias_of is None + result.extend( + ( + None, + TensorArg( + alias, + arg.buffer, + arg.dtype, + arg.offset, + alias_of=arg.name, + ), + ) + for alias in self.buffer_aliases.get(arg.name, ()) + ) + return result + + def halide_kernel_meta(self) -> HalideMeta: + """Compute metadata required by codecache.py""" + argtypes = [] + for _, arg in self.halide_argdefs(): + if isinstance(arg, SizeArg): + shape = None + stride = None + offset = None + dtype = "long" + else: + shape = [ + cexpr(self.rename_indexing(x.size)) + for x in self.buffer_dimensions[arg.name] + ] + stride = [ + cexpr(self.rename_indexing(x.stride)) + for x in self.buffer_dimensions[arg.name] + ] + assert len(shape) == len(stride) + offset = cexpr(self.buffer_offsets[arg.name]) + dtype = f"{DTYPE_TO_CPP[arg.dtype]}*" + argtypes.append( + HalideInputSpec( + dtype, + arg.name, + shape=shape, + stride=stride, + offset=offset, + alias_of=arg.alias_of, + ) + ) + + current_device = V.graph.get_current_device_or_throw() + if current_device.type == "cpu": + target = [config.halide.cpu_target] + scheduler = config.halide.scheduler_cpu + scheduler_flags = { + "parallelism": parallel_num_threads(), + } + cuda_device = None + else: + assert current_device.type == "cuda", "only cpu/cuda supported" + assert current_device.index <= 0, "only default device supported" + target = [config.halide.gpu_target] + scheduler = config.halide.scheduler_cuda + capability = torch.cuda.get_device_properties(current_device) + if "cuda_capability" not in target[0]: + for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]: + if capability.major >= major and capability.minor >= minor: + target.append(f"cuda_capability_{major}{minor}") + break + target.append("user_context") + scheduler_flags = { + "parallelism": capability.multi_processor_count, + # TODO(jansel): explore other flags, see: + # grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp + } + cuda_device = max(0, current_device.index) + + # strict_float is requires for correctness + target.append("strict_float") + + # without this we will initialize cuda once per kernel and hit errors + target.append("no_runtime") + + if not config.halide.asserts: + target.append("no_asserts") + + if config.halide.debug: + target.append("debug") + + if "64" in self.index_dtype: + # TODO(jansel): it is unclear if this does anything, since input sizes are still int32 + target.append("large_buffers") + + return HalideMeta( + argtypes, + target="-".join(target), + scheduler=scheduler, + scheduler_flags=scheduler_flags, # type: ignore[arg-type] + cuda_device=cuda_device, + ) + + def codegen_kernel(self, name=None): + """Called at the end to generate a final kernel string""" + if self.args.inplace_buffers: + raise Unsupported("inplace_buffers") + meta = self.halide_kernel_meta() # ensure needed args are added early + code = IndentedBuffer() + code.splice( + """ + import halide as hl + from torch._inductor.runtime import halide_helpers + from math import inf, nan + + @hl.generator(name="kernel") + class Kernel: + """, + strip=True, + ) + code.do_indent() + for _, arg in self.halide_argdefs(): + if isinstance(arg, SizeArg): + code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})") + else: + assert arg.buffer, arg + argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer" + argtype = halide_type(arg.dtype) + ndim = len(self.buffer_dimensions[arg.name]) + code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})") + code.splice( + """ + def generate(g): + """ + ) + code.do_indent() + for _, arg in self.halide_argdefs(): + code.writeline(f"{arg.name} = g.{arg.name}") + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.indexing_code) + + def update_index(m): + var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)]) + assert var.used_dims is not None, var + return str(var) + + for line in self.body._lines: + if isinstance(line, str): + # fill in missing indices + line = HalideCSEVariable.undefined_re.sub(update_index, line) + code.writeline(line) + code.writeline("") + code.writeline("assert g.using_autoscheduler()") + + for _, arg in self.halide_argdefs(): + # fallback=1 below because halide requires buffers to be at least as large as the estimates + # This causes crashes if our estimate is greater than the vector length + # https://github.com/halide/Halide/issues/3103 + if isinstance(arg, SizeArg): + hint = V.graph.sizevars.size_hint(arg.expr, fallback=1) + code.writeline(f"{arg.name}.set_estimate({hint})") + else: + dims = self.buffer_dimensions[arg.name] + range_hints = [] + for i, dim in enumerate(dims): + hint = self._autoscheduler_workarounds( + V.graph.sizevars.size_hint(dim.size, fallback=1), dims + ) + range_hints.append(f"hl.Range(0, {hint})") + if "out" not in arg.name: + code.writeline(f"{arg.name}.dim({i}).set_min(0)") + try: + code.writeline( + f"{arg.name}.dim({i}).set_stride({int(dim.stride)})" + ) + except TypeError: + pass # not integer + try: + code.writeline( + f"{arg.name}.dim({i}).set_extent({int(dim.size)})" + ) + except TypeError: + pass # not integer + code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])") + + code.do_unindent(2) + code.splice( + """ + if __name__ == "__main__": + hl.main() + """.rstrip(), + ) + if meta.scheduler: + code.splice( + f""" + else: + hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r}) + target = hl.Target({meta.target!r}) + autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r}) + with hl.GeneratorContext(target, autoscheduler): + gen = Kernel() + pipeline = gen._build_pipeline() + # gen.compile_to_callable() does not run the autoscheduler + pipeline.apply_autoscheduler(target, autoscheduler) + kernel = pipeline.compile_to_callable([ + gen._get_input_parameter(a.name)._to_argument() + for a in gen._get_arginfos() + if a.dir == hl.ArgInfoDirection.Input + ], target) + """, + strip=True, + ) + else: + code.splice( + f""" + else: + with hl.GeneratorContext(hl.Target({meta.target!r})): + kernel = Kernel().compile_to_callable() + """, + strip=True, + ) + return code.getvalue() + + @staticmethod + def _autoscheduler_workarounds(n, dims): + if ( + len(dims) == 1 + and config.halide.scheduler_cuda == "Anderson2021" + and V.graph.get_current_device_or_throw().type == "cuda" + ): + # workaround https://github.com/halide/Halide/issues/8246 + n = max(2, n) + return n + + def call_kernel(self, name: str, node=None): + """Codegen a call to this kernel""" + wrapper = V.graph.wrapper_code + call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None] + current_device = V.graph.get_current_device_or_throw() + if current_device.type == "cuda": + stream_name = wrapper.write_get_raw_stream( + current_device.index, V.graph.name + ) + call_args.append(stream_name) + wrapper.generate_kernel_call( + name, + call_args, + device=current_device, + triton=False, + ) + + def generate_assert(self, check): + return False # TODO(jansel): support asserts + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + pass # TODO(jansel): support asserts + + +class HalideScheduling(SIMDScheduling): + kernel_type = HalideKernel # type: ignore[arg-type,assignment] + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + result = OrderedSet( + [ + BackendFeature.TUPLE_REDUCTION, + BackendFeature.PREFER_STORE_LOOP_ORDER, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + if config.halide.scan_kernels: + result.add(BackendFeature.SCAN) + return result + + def define_kernel(self, src_code, node_schedule, kernel): + """Codegen kernel definition to go in output wrapper code""" + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}" + wrapper.src_to_kernel[src_code] = kernel_name + wrapper.add_import_once( + "from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec" + ) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline( + f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''" + ) + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + if is_metric_table_enabled("kernel_metadata"): + log_kernel_metadata(kernel_name, "", src_code) + + return kernel_name diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..8efec7eeca9f8b752593d00c3b6c0cdda461559b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/memory_planning.py @@ -0,0 +1,775 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import dataclasses +import itertools +import pprint +from typing import Any, Optional, Protocol, TYPE_CHECKING + +import sympy + +import torch +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer +from ..virtualized import V +from .wrapper import ( + AllocateLine, + BufferLike, + FreeIfNotReusedLine, + MemoryPlanningLine, + NullLine, + ReuseLine, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +@dataclasses.dataclass +class LiveRange: + """ + A range where a given tensor is live. Begin and end are both counters + representing points in the program of grouped memory operations. + Begin is inclusive, end is exclusive. + + Invariant: begin <= end + """ + + begin: float # int | +/-inf + end: float # int | +/-inf + + def contains(self, other: LiveRange): + """Is other entirely within self""" + return self.begin <= other.begin and other.end <= self.end + + def join(self, other: LiveRange): + """Combine two ranges using a union operation""" + return LiveRange(min(self.begin, other.begin), max(self.end, other.end)) + + def __len__(self): + return self.end - self.begin + + +class LiveRanges: + """ + A collection of LiveRange regions, allowing for non-contiguous + live regions. + + Invariant: LiveRanges.ranges is in sorted order and non-overlapping + """ + + def __init__(self, ranges: Iterable[LiveRange]): + ranges = [*sorted(ranges, key=lambda x: x.begin)] + self.ranges = ranges[:1] + for r in ranges[1:]: + assert self.ranges[-1].begin <= r.begin + if self.ranges[-1].end >= r.begin: + self.ranges[-1] = LiveRange.join(self.ranges[-1], r) + else: + self.ranges.append(r) + + def overlaps(self, other: LiveRanges): + """Check if any pair of ranges in self and other overlap""" + left = collections.deque(self.ranges) + right = collections.deque(other.ranges) + while left and right: + if left[0].begin > right[0].begin: + left, right = right, left + assert left[0].begin <= right[0].begin + if left[0].end > right[0].begin: + return True + left.popleft() + return False + + @property + def begin(self): + return self.ranges[0].begin + + @property + def end(self): + return self.ranges[-1].end + + def __repr__(self): + return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])" + + +class AllocationTreeNode: + """ + Abstract base class for nodes in allocation pool. + """ + + def allocate(self, block: Allocation, is_last: bool) -> bool: + """ + Try to assign block to a memory location in this bool. Return True if + an assignment was made. + """ + return False + + def get_live_ranges(self) -> LiveRanges: + """Aggregate LiveRanges for all objects below this in tree""" + raise NotImplementedError + + def get_size_hint(self) -> int: + """Number of bytes used for example inputs""" + raise NotImplementedError + + def get_symbolic_size(self) -> sympy.Expr: + """Number of bytes needed at runtime""" + raise NotImplementedError + + def finalize(self, pool, offset) -> AllocationTreeNode: + """Called after all allocations have been made""" + return self + + def is_empty(self): + return False + + +@dataclasses.dataclass +class Allocation(AllocationTreeNode): + """ + Represents memory allocated to a given node in the allocation pool. + """ + + node: BufferLike + live_range: LiveRange + size_hint: int + symbolic_size: sympy.Expr + allocated: bool = False + pool: Optional[AllocationPool] = None + offset: Optional[sympy.Expr] = None + + @property + def device(self): + return self.node.get_device() + + def get_live_ranges(self): + return LiveRanges([self.live_range]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return self.symbolic_size + + def mark_allocated(self): + assert not self.allocated + self.allocated = True + + def finalize(self, pool, offset): + assert self.pool is None and self.offset is None + self.pool = pool + self.offset = offset + return self + + def codegen_alloc_from_pool(self, wrapper): + assert self.pool + node = self.node + shape = tuple(node.get_size()) + stride = tuple(node.get_stride()) + return wrapper.codegen_alloc_from_pool( + self.pool.name, self.offset, node.get_dtype(), shape, stride + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"node={self.node.get_name()}, " + f"live_range={self.live_range}, " + f"size_hint={self.size_hint}, " + f"symbolic_size={self.symbolic_size}, " + f"pool={self.pool.name if self.pool else None}, " + f"offset={self.offset})" + ) + + +@dataclasses.dataclass +class Empty(AllocationTreeNode): + """ + Placeholder to represent empty space in the allocation pool. + Only exists to get the size_hint correct in parent nodes. + """ + + size_hint: int + + def get_live_ranges(self): + return LiveRanges([]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return 0 + + def is_empty(self): + return True + + +class MemorySplitProtocol(Protocol): + get_live_ranges: CachedMethod[[], LiveRanges] + get_size_hint: CachedMethod[[], int] + get_symbolic_size: CachedMethod[[], sympy.Expr] + + def _allocate(self, block: Allocation, is_last: bool) -> bool: ... + + +class ClearCacheOnAllocateMixin(MemorySplitProtocol): + """ + Helper to assist in caching get_live_ranges, get_size_hint, and + get_symbolic_size. + """ + + def allocate(self, block: Allocation, is_last: bool): + is_allocated = self._allocate(block, is_last) + if is_allocated: + self.clear_cache() + return is_allocated + + def clear_cache(self): + self.get_live_ranges.clear_cache(self) + self.get_size_hint.clear_cache(self) + self.get_symbolic_size.clear_cache(self) + + +@dataclasses.dataclass +class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains a list of allocations not overlapping in LiveRanges. + + Invariant: no pair (a,b) in self.allocations will have: + a.get_live_ranges().overlaps(b.get_live_ranges()) + """ + + allocations: list[AllocationTreeNode] + + def _allocate(self, block: Allocation, is_last: bool): + slot_size = self.get_size_hint() + block_size = block.get_size_hint() + if not is_last and block_size > slot_size: + return False # doesn't fit + + block_live = block.get_live_ranges() + overlapping = [ + s for s in self.allocations if s.get_live_ranges().overlaps(block_live) + ] + if len(overlapping) > 1: + # TODO(jansel): we could try harder here by merging overlapping in space + return False + elif len(overlapping) == 1: + return overlapping[0].allocate(block, is_last) + else: + block.mark_allocated() + + if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty): + self.allocations.pop() + + if slot_size == block_size: + # perfect fit + self.allocations.append(block) + elif slot_size > block_size: + self.allocations.append( + SpatialSplit.create(block, slot_size - block_size) + ) + else: # grow this allocation + assert is_last + self.allocations = [ + *( + SpatialSplit.create(a, block_size - slot_size) + for a in self.allocations + ), + block, + ] + return True + + @cache_on_self + def get_live_ranges(self) -> LiveRanges: + return LiveRanges( + itertools.chain.from_iterable( + x.get_live_ranges().ranges for x in self.allocations + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + if not self.allocations: + return 0 + return max(x.get_size_hint() for x in self.allocations) + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + if not self.allocations: + return 0 # type: ignore[return-value] + return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) + + def is_empty(self): + return len(self.allocations) == 1 and self.allocations[0].is_empty() + + def finalize(self, pool, offset): + self.allocations = [block.finalize(pool, offset) for block in self.allocations] + self.clear_cache() + if len(self.allocations) == 1: + return self.allocations[0] + return self + + +@dataclasses.dataclass +class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains two allocations, left and right, that do not overlap in space. + Right will be allocated immediately after left in memory. + """ + + left: TemporalSplit + right: TemporalSplit + + @staticmethod + def create(left, extra_space): + assert isinstance(left, AllocationTreeNode) + assert isinstance(extra_space, int) and extra_space >= 1 + return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)])) + + def _allocate(self, block: Allocation, is_last: bool): + return self.left.allocate(block, False) or self.right.allocate(block, is_last) + + @cache_on_self + def get_live_ranges(self): + return LiveRanges( + itertools.chain( + self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + return _align(self.left.get_size_hint()) + self.right.get_size_hint() + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size() + + def finalize(self, pool, offset): + self.left = self.left.finalize(pool, offset) + self.right = self.right.finalize( + pool, offset + align(self.left.get_symbolic_size()) + ) + self.clear_cache() + if self.right.is_empty(): + return self.left + return self + + +@dataclasses.dataclass +class AllocationPool: + """ + Represents a pool of allocations that will be generated by a single + call to torch.empty. + """ + + device: torch.device + root: TemporalSplit + can_expand: bool = True + restrict_live_range: Optional[LiveRange] = None + name: Optional[str] = None + names_to_del: list[str] = dataclasses.field(default_factory=list) + creation_cache: dict[str, str] = dataclasses.field(default_factory=dict) + + def allocate(self, block: Allocation, is_last: bool): + if self.restrict_live_range and not self.restrict_live_range.contains( + block.live_range + ): + return False + + is_last = self.can_expand and is_last + if self.root.allocate(block, is_last): + return True + + if is_last: + return self.allocate_at_end(block) + + return False + + def allocate_at_end(self, block): + block.mark_allocated() + self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) + return True + + def finalize(self, name): + assert not self.name + self.name = name + self.names_to_del.append(name) + self.root.finalize(self, 0) + + def codegen_create(self, wrapper, code: IndentedBuffer): + assert self.name + nbytes = self.root.get_symbolic_size() + for block in self.root.allocations: + if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): + # optimization: fuse first allocation and pool creation + node = block.node + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=node.get_dtype(), + shape=tuple(node.get_size()), + stride=tuple(node.get_stride()), + ) + ) + self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name + return + else: + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=torch.uint8, + shape=(nbytes,), + stride=(1,), + ) + ) + + def codegen_destroy(self, wrapper, code: IndentedBuffer): + code.writeline(wrapper.make_free_by_names(self.names_to_del)) + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +@dataclasses.dataclass +class AllocationPools: + """ + Collection of many AllocationPool objects grouped by device. + """ + + device_to_pools: dict[torch.device, list[AllocationPool]] = dataclasses.field( + default_factory=dict + ) + + def get_pools(self, block): + if block.device not in self.device_to_pools: + self.device_to_pools[block.device] = [] + return self.device_to_pools[block.device] + + def allocate(self, block: Allocation): + pools = self.get_pools(block) + + for pool in pools: + if pool.allocate(block, is_last=pool is pools[-1]): + return + + # everything is full, make a new pool + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool != "none", + ) + ) + block.mark_allocated() + + def allocate_output(self, block: Allocation): + """Outputs get different pools so memory gets freed properly""" + pools = self.get_pools(block) + if pools and config.memory_pool in ("outputs", "combined"): + pools[-1].allocate_at_end(block) + else: + # create a new pool + block.mark_allocated() + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool == "combined", + ) + ) + + def finalize(self): + """Called at the end of allocation process""" + for i, pool in enumerate( + itertools.chain.from_iterable(self.device_to_pools.values()) + ): + pool.finalize(f"pool{i}") + + def pprint(self): + for pool in itertools.chain.from_iterable(self.device_to_pools.values()): + print() + print(pool.name) + print(pool.root.get_live_ranges()) + pprint.pprint(pool.root) + + +class BufferGroup: + """ + Due to inplace reuse an allocated buffer can have many names. + This tracks these collections of buffers sharing underlying memory. + """ + + def __init__(self, node: BufferLike): + self.node = node + self.names = [node.get_name()] + self.is_output = False + self.allocation: Optional[Allocation] = None + self.live_range = LiveRange(float("inf"), -float("inf")) + + def update_usage(self, timestep: int): + """Expand self.live_range to include timestep""" + self.live_range = LiveRange( + min(timestep, self.live_range.begin), + max(timestep, self.live_range.end), + ) + + def sym_nbytes(self): + return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize + + def make_allocation(self): + assert not self.allocation, "multiple allocations" + assert isinstance(self.live_range.begin, int), "live ranges not computed" + nbytes = self.sym_nbytes() + # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have + # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored. + size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64) + self.allocation = Allocation( + self.node, + self.live_range, + size_hint=size_hint, + symbolic_size=nbytes, + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, " + f"live_range={self.live_range}" + ) + + +@dataclasses.dataclass +class PoolMemoryPlanningLine(MemoryPlanningLine): + """Abstract base class for {Alloc,Dealloc}FromPoolLine""" + + group: BufferGroup + timestep: Optional[int] = None + + @property + def node(self): + return self.group.node + + +@dataclasses.dataclass +class AllocFromPoolLine(PoolMemoryPlanningLine): + """Similar to AllocationLine, but takes memory from a pool""" + + is_first_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + allocation = self.group.allocation + assert allocation and allocation.pool + pool = allocation.pool + name = self.node.get_name() + + if self.is_first_pool_usage: + pool.codegen_create(self.wrapper, code) + + pool.names_to_del.extend(self.group.names) + alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) + if alloc_from_pool in pool.creation_cache: + code.writeline( + self.wrapper.make_tensor_alias( + name, pool.creation_cache[alloc_from_pool], "alloc" + ) + ) + else: + pool.creation_cache[alloc_from_pool] = name + code.writeline( + f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}" + ) + + +@dataclasses.dataclass +class DeallocFromPoolLine(PoolMemoryPlanningLine): + """Similar to FreeIfNotReusedLine, but takes memory from a pool""" + + is_last_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + if self.is_last_pool_usage: + assert self.group.allocation and self.group.allocation.pool + self.group.allocation.pool.codegen_destroy(self.wrapper, code) + + +@dataclasses.dataclass +class MemoryPlanner: + """ + Coordination object to run memory planning passes during wrapper + codegen. + """ + + wrapper: Any + pools: AllocationPools = dataclasses.field(default_factory=AllocationPools) + buffer_groups: Optional[list[BufferGroup]] = None + + def plan(self, lines: list[Any]) -> list[Any]: + """Call all the memory planning passes in sequence""" + lines = [*lines] + self.drop_removed_buffers(lines) + self.convert_to_pool_lines(lines) + self.compute_live_ranges(lines) + self.allocate_groups() + self.mark_first_last_usage(lines) + return lines + + def drop_removed_buffers(self, lines): + """ + Replace any memory planning lines in V.graph.removed_buffers with NullLine + """ + # drop any removed buffers + for i, line in enumerate(lines): + if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)): + if line.node.get_name() in V.graph.removed_buffers: + lines[i] = NullLine(self.wrapper) + + def compute_buffer_groups(self, lines): + """ + Populates self.buffer_groups with BufferGroup objects that join + allocations with common storage (due to inplace reuse) into a + single object. + """ + name_to_group = {} + for line in lines: + if isinstance(line, AllocateLine): + name = line.node.get_name() + assert name not in name_to_group + name_to_group[name] = BufferGroup(line.node) + elif isinstance(line, ReuseLine): + old_name = line.node.get_name() + new_name = line.reused_as.get_name() + assert new_name not in name_to_group + # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc + if old_name in name_to_group: + name_to_group[old_name].names.append(new_name) + name_to_group[new_name] = name_to_group[old_name] + + outputs = OrderedSet(V.graph.get_output_names()) + unique_groups = [*{id(g): g for g in name_to_group.values()}.values()] + for group in unique_groups: + group.is_output = any(x in outputs for x in group.names) + + assert self.buffer_groups is None + self.buffer_groups = unique_groups + return name_to_group + + def convert_to_pool_lines(self, lines): + """ + Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their + pool-based counterparts. + """ + name_to_group = self.compute_buffer_groups(lines) + for i, line in enumerate(lines): + if isinstance(line, AllocateLine): + if line.node.get_name() in name_to_group: + lines[i] = AllocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, FreeIfNotReusedLine): + assert not line.is_reused + if line.node.get_name() in name_to_group: + lines[i] = DeallocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, ReuseLine): + if line.node.get_name() in name_to_group: + line.delete_old = False + + def compute_live_ranges(self, lines): + """Populate every BufferGroup.live_ranges field based on first/last usage""" + timestep = 0 + worklist = collections.deque(lines) + while worklist: + if isinstance(worklist[0], MemoryPlanningLine): + timestep += 1 + while worklist and isinstance(worklist[0], MemoryPlanningLine): + line = worklist.popleft() + if isinstance(line, PoolMemoryPlanningLine): + line.group.update_usage(timestep) + line.timestep = timestep + else: + worklist.popleft() + + timestep += 1 + assert self.buffer_groups is not None + for group in self.buffer_groups: + if group.is_output: + group.update_usage(timestep) + + def allocate_groups(self): + """ + Assign every allocation to a specific location in a specific AllocationPool. + """ + assert config.memory_pool in ("none", "intermediates", "outputs", "combined") + assert self.buffer_groups is not None + + for group in self.buffer_groups: + group.make_allocation() + + outputs: list[Allocation] = [] + intermediates: list[Allocation] = [] + for group in self.buffer_groups: + assert group.allocation + if group.is_output and config.memory_pool != "combined": + outputs.append(group.allocation) + else: + intermediates.append(group.allocation) + + for block in sorted( + outputs, + key=lambda x: ( + x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate_output(block) + + for block in sorted( + intermediates, + key=lambda x: ( + -x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate(block) + + self.pools.finalize() + + def mark_first_last_usage(self, lines): + """ + Populate the AllocFromPoolLine.is_first_pool_usage and + DeallocFromPoolLine.is_last_pool_usage fields so that pools + are created/destroyed. + """ + seen = OrderedSet[AllocationPool]() + for line in lines: + if isinstance(line, AllocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_first_pool_usage = True + seen.add(pool) + + seen = OrderedSet[AllocationPool]() + for line in reversed(lines): + if isinstance(line, DeallocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_last_pool_usage = ( + pool.root.get_live_ranges().end <= line.timestep + ) + seen.add(pool) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py new file mode 100644 index 0000000000000000000000000000000000000000..d0cb221dd4abc7cd9cfc2dc97769c90307328cb0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps.py @@ -0,0 +1,988 @@ +# This is not a feature-complete compiler backend +# Just an early prototype that shows that one can compile elementwise ops into a Metal shader +from __future__ import annotations + +import functools +import itertools +import logging +import math +from pathlib import Path +from typing import Any, Optional, TYPE_CHECKING + +import sympy +from sympy.printing.precedence import PRECEDENCE + +import torch +from torch.utils._cpp_embed_headers import _embed_headers +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.printers import CppPrinter, ExprPrinter as ExprPrinter_ +from torch.utils._sympy.value_ranges import ValueRanges + +from ..utils import ceildiv, get_bounds_index_expr, get_kernel_metadata +from ..virtualized import ops, OpsWrapper, V +from .common import ( + CSEVariable, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + IndentedBuffer, + OpOverrides, + PythonPrinter, +) +from .simd import IterationRangesEntry, SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from typing import Union + + from ..ops_handler import ReductionType, StoreMode + from ..scheduler import Scheduler, SchedulerNode + from .common import OpVarT + +log = logging.getLogger(__name__) + +DTYPE_TO_METAL = { + torch.bool: "bool", + torch.int8: "char", + torch.int16: "short", + torch.int32: "int", + torch.int64: "long", + torch.uint8: "uchar", + torch.float: "float", + torch.half: "half", + torch.bfloat16: "bfloat", +} + + +def value_to_metal(val: Union[float, int, bool, str, CSEVariable]) -> str: + if isinstance(val, float): + if val == torch.inf: + return "HUGE_VALF" + elif val == -torch.inf: + return "-HUGE_VALF" + elif val != val: # Only float that not equal to self is nan + return "NAN" + return str(val) + elif isinstance(val, bool): + return "true" if val else "false" + return str(val) + + +class MetalExprPrinter(ExprPrinter_): + """Converts sympy expression to Metal code snippet""" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = expr.args + x = self.doprint(x) + div = self.doprint(div) + if expr.is_integer: + return f"c10::metal::floor_divide({x}, {div})" + return f"metal::floor({x}) / ({div})" + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = expr.args + x = self.doprint(x) + if div != 1: + div = self.doprint(div) + if expr.is_integer: + x = f"({x}) / ({div})" + else: + x = f"metal::floor({x}) / ({div})" + mod = self.doprint(mod) + return f"({x}) % ({mod})" + + def _print_Min(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise RuntimeError("metal::min only supported for 2 args") + a, b = map(self._print, expr.args) + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::min({typecast_a}, {typecast_b})" + + def _print_Max(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise RuntimeError("metal::max only supported for 2 args") + a, b = map(self._print, expr.args) + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::max({typecast_a}, {typecast_b})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"metal::abs({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"static_cast(metal::rint({self._print(expr.args[0])}))" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"static_cast(metal::rint(1e{ndigits} * {number_str}) * 1e{-ndigits})" + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**23 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + x, y = map(self.doprint, expr.args) + return f"metal::pow(static_cast({x}), static_cast({y}))" + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast({x})" + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast(metal::floor(static_cast({x})))" + + _print_floor = _print_FloorToInt + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast(metal::trunc({x}))" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"metal::log2({x})" + + +class MetalOverrides(OpOverrides): + """Implements Metal-specific overrides for ops. Base class emits Python-friendly overrides.""" + + @staticmethod + def to_dtype( + x: CSEVariable, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> str: + if dtype == torch.double: + log.warning( + "float64 cast requested, probably from tensorify_python_scalars" + ) + return f"static_cast({x})" + return f"static_cast<{DTYPE_TO_METAL[dtype]}>({x})" + + @staticmethod + def to_dtype_bitcast( + x: CSEVariable, dtype: torch.dtype, src_dtype: torch.dtype + ) -> str: + return f"as_type<{DTYPE_TO_METAL[dtype]}>(static_cast<{DTYPE_TO_METAL[src_dtype]}>({x}))" + + @staticmethod + def constant(val: Union[bool, float, int], dtype: torch.dtype) -> str: + return value_to_metal(val) + + @staticmethod + def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> str: + idx_str = V.kernel.index_to_str(V.kernel.prepare_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) + + @staticmethod + def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str: + # TODO: Type annotation for other is wrong, it's often float or int + with V.kernel.mask_loads(mask, other) as new_mask: + result = body() + + if result.bounds.is_bool: + other = bool(other) # type: ignore[assignment] + + return ops.where(new_mask, result, other) + + @staticmethod + def where(a: OpVarT, b: OpVarT, c: OpVarT) -> str: + return f"{a} ? {b} : {value_to_metal(c)}" + + @staticmethod + def remainder(a: OpVarT, b: OpVarT) -> str: + return f"c10::metal::remainder({a}, {b})" + + @staticmethod + def maximum(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"c10::metal::max({typecast_a}, {typecast_b})" + + @staticmethod + def minimum(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"c10::metal::min({typecast_a}, {typecast_b})" + + @staticmethod + def logical_or(a: CSEVariable, b: CSEVariable) -> str: + return f"{a} || {b}" + + @staticmethod + def logical_and(a: CSEVariable, b: CSEVariable) -> str: + return f"{a} && {b}" + + @staticmethod + def isnan(x: CSEVariable) -> str: + return f"metal::isnan({x})" + + @staticmethod + def isinf(x: CSEVariable) -> str: + return f"metal::isinf({x})" + + @staticmethod + def log(x: CSEVariable) -> str: + return f"metal::log({x})" + + @staticmethod + def exp(x: CSEVariable) -> str: + return f"metal::exp({x})" + + @staticmethod + def abs(x: CSEVariable) -> str: + return f"metal::abs({x})" + + @staticmethod + def signbit(x: CSEVariable) -> str: + return f"metal::signbit({x})" + + @staticmethod + def sin(x: CSEVariable) -> str: + return f"metal::precise::sin({x})" + + @staticmethod + def sinc(x: CSEVariable) -> str: + return f"c10::metal::sinc({x})" + + @staticmethod + def cos(x: CSEVariable) -> str: + return f"metal::precise::cos({x})" + + @staticmethod + def tan(x: CSEVariable) -> str: + return f"metal::tan({x})" + + @staticmethod + def asin(x: CSEVariable) -> str: + return f"metal::asin({x})" + + @staticmethod + def acos(x: CSEVariable) -> str: + return f"metal::acos({x})" + + @staticmethod + def atan(x: CSEVariable) -> str: + return f"metal::atan({x})" + + @staticmethod + def atan2(x: CSEVariable, y: CSEVariable) -> str: + return f"::metal::atan2({x}, {y})" + + @staticmethod + def sqrt(x: CSEVariable) -> str: + return f"metal::sqrt({x})" + + @staticmethod + def neg(x: CSEVariable) -> str: + # TODO: Does it rely on undefined behavior? + # If so, add special logic for unsigned types + return f"static_cast(-{x})" + + @staticmethod + def rsqrt(x: CSEVariable) -> str: + return f"metal::rsqrt({x})" + + @staticmethod + def tanh(x: CSEVariable) -> str: + return f"metal::tanh({x})" + + @staticmethod + def atanh(x: CSEVariable) -> str: + return f"metal::atanh({x})" + + @staticmethod + def floordiv(a: CSEVariable, b: CSEVariable) -> str: + # a and b must be of integer type + return f"c10::metal::floor_divide({a}, {b})" + + @staticmethod + def floor(x: CSEVariable) -> str: + return f"metal::floor({x})" + + @staticmethod + def sign(x: CSEVariable) -> str: + return f"metal::sign({x})" + + @staticmethod + def fmod(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::fmod({typecast_a}, {typecast_b})" + + @staticmethod + def trunc(x: CSEVariable) -> str: + return f"metal::trunc({x})" + + @staticmethod + def truncdiv(a: CSEVariable, b: CSEVariable) -> str: + quot = f"{a} / {b}" + if (a.dtype is not None and a.dtype.is_floating_point) or ( + b.dtype is not None and b.dtype.is_floating_point + ): + return f"metal::trunc({quot})" + return quot + + @staticmethod + def ceil(x: CSEVariable) -> str: + return f"metal::ceil({x})" + + @staticmethod + def rand(seed: CSEVariable, offset: CSEVariable) -> str: + V.kernel.headers.add("random") + return f"c10::metal::rand({seed}, {offset})" + + @staticmethod + def randn(seed: CSEVariable, offset: CSEVariable) -> str: + V.kernel.headers.add("random") + return f"c10::metal::randn({seed}, {offset})" + + @staticmethod + def randint64( + seed: CSEVariable, offset: CSEVariable, low: CSEVariable, high: CSEVariable + ) -> str: + V.kernel.headers.add("random") + return f"c10::metal::randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def round(x: CSEVariable) -> str: + return f"metal::round({x})" + + @staticmethod + def pow(a: CSEVariable, b: CSEVariable) -> str: + cast_a = f"static_cast({a})" + cast_b = f"static_cast({b})" + return f"metal::pow({cast_a}, {cast_b})" + + def _special_unary(self, a: CSEVariable, name: str) -> str: + V.kernel.headers.add("special_math") + return f"c10::metal::{name}({a})" + + def _special_binary(self, a: CSEVariable, b: CSEVariable, name: str) -> str: + V.kernel.headers.add("special_math") + return f"c10::metal::{name}({a}, {b})" + + @classmethod + def _initialize_special_ops(cls) -> None: + # Unary special ops + for name in [ + "erf", + "erfinv", + "i0", + "i0e", + "i1", + "i1e", + "digamma", + "spherical_bessel_j0", + ]: + setattr(cls, name, functools.partialmethod(cls._special_unary, name=name)) + + cls.lgamma = functools.partialmethod(cls._special_unary, name="log_gamma") # type: ignore[assignment] + + # Unary special ops with forward in method name + for name in [ + "bessel_j0", + "bessel_j1", + "bessel_y0", + "bessel_y1", + "modified_bessel_i0", + "modified_bessel_i1", + "modified_bessel_k0", + "modified_bessel_k1", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", + ]: + setattr( + cls, + name, + functools.partialmethod(cls._special_unary, name=name + "_forward"), + ) + + # Binary special ops + for name in [ + "polygamma", + "zeta", + ]: + setattr(cls, name, functools.partialmethod(cls._special_binary, name=name)) + + # Binary special ops with forward in method name + for name in [ + "chebyshev_polynomial_t", + "chebyshev_polynomial_u", + "chebyshev_polynomial_v", + "chebyshev_polynomial_w", + "hermite_polynomial_h", + "hermite_polynomial_he", + ]: + setattr( + cls, + name, + functools.partialmethod(cls._special_binary, name=name + "_forward"), + ) + + +MetalOverrides._initialize_pointwise_overrides("mps") +MetalOverrides._initialize_special_ops() + + +class MetalKernel(SIMDKernel): + """Implement Metal codegen based on the SIMDKernel abstraction""" + + overrides = MetalOverrides # type: ignore[assignment] + suffix = ";" + newvar_prefix = "auto " + max_threadgroup_size = 1024 + simd_group_size = 32 + pexpr = PythonPrinter().doprint + cexpr = CppPrinter().doprint + sexpr = MetalExprPrinter().doprint + kexpr = sexpr + headers: OrderedSet[str] = OrderedSet(["utils"]) + multistage_reduction_entry: list[IterationRangesEntry] = [] + + def __init__( + self, + tiling: dict[str, sympy.Expr], + **kwargs: Any, + ) -> None: + super().__init__(tiling, **kwargs) + self.acc_var_ids = itertools.count() + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return DTYPE_TO_METAL[dtype] + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + """Codegen a load from an InputBuffer""" + var = self.args.input(name) + index = self.prepare_indexing(index) + dtype = V.graph.get_dtype(name) + line = f"{var}[{self.index_to_str(index)}]" + if dtype in [torch.float16, torch.bfloat16]: + # TODO(NS): Figure out the right balance between optype casts + # op_math_t for half-precision floats should be float32 + # Otherwise it can lead to a correctness issues with eager + line = f"static_cast({line})" + dtype = torch.float32 + return self.cse.generate(self.loads, line, dtype=dtype) + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + var = self.args.output(name) + index = self.prepare_indexing(index) + dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + cast_val = f"static_cast<{dtype_str}>({value})" + if mode is None: + line = f"{var}[{self.index_to_str(index)}] = {cast_val};" + elif mode == "atomic_add": + self.headers.add("atomic") + atomic_type = f"c10::metal::AtomicType<{dtype_str}>" + cast_var = f"reinterpret_cast({var})" + line = f"{atomic_type}::atomic_add({cast_var}, {self.index_to_str(index)}, {cast_val});" + else: + raise RuntimeError(f"Unimplemented store mode {mode}") + if self.inside_reduction: + self.compute.writeline(DeferredLine(name, line)) + else: + self.stores.writeline(DeferredLine(name, line)) + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + var = self.args.output(name) + index = self.prepare_indexing(index) + dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + reduction_dim = next(t for t in self.range_trees if t.is_reduction) + # Only one thread in the reduction group needs to store the results + line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});" + line = f"if ({reduction_dim.name} == 0) {line}" + self.stores.writeline(DeferredLine(name, line)) + + def _new_idxvar( + self, + dtype: Union[str | torch.dtype], + elem_count: Optional[int] = None, + default_value: Optional[Any] = None, + is_threadgroup: bool = True, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + ) -> CSEVariable: + if isinstance(dtype, torch.dtype): + dtype = self.dtype_to_str(dtype) + var_name = f"tmp_acc_{next(self.acc_var_ids)}" + var = V.kernel.create_cse_var(var_name, bounds, dtype) + var_def = "threadgroup " if is_threadgroup else "" + var_def += f"{dtype} {var_name}" + if elem_count: + var_def += f"[{elem_count}]" + if default_value is not None: + assert not is_threadgroup, "Thread group var can not have default value" + var_def += f" = {default_value}" + self.indexing_code.writeline(var_def + self.suffix) + return var + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + "Caching wrapper around _reduction_nocache" + cache_key = (src_dtype, reduction_type, value) + # Return cached reduction + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + result = self._reduction_nocache(dtype, src_dtype, reduction_type, value) + self.cse.reduction_cache[cache_key] = result # type: ignore[assignment] + return result + + def _reduction_nocache( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + """Codegen a reduction operation. + Only sum and prod operations are somewhat reasonable optimized""" + assert self.inside_reduction + assert not self._load_mask + + def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: + # Uwraps vec3 dtype into individual components + return OpsWrapper._unwrap( + [CSEVariable(f"{res3}.{t}", res3.bounds, res3.dtype) for t in "xyz"] + ) + + # Establish reduction buffer size and index expression + reduction_idx = "" + acc_buf_size = 1 + for rd in self.range_trees: + if not rd.is_reduction: + continue + if reduction_idx: + reduction_idx += " + " + reduction_idx += f"{rd.name} * {acc_buf_size}" + acc_buf_size *= rd.numel + acc_buf_size = min(acc_buf_size, self.max_threadgroup_size) + + if reduction_type == "any": + acc = self._new_idxvar(dtype) + self.indexing_code.writeline(f"{acc} = false;") + self.indexing_code.writeline( + "threadgroup_barrier(metal::mem_flags::mem_threadgroup);" + ) + self.compute.splice( + f""" + if ({value}) {{ + {acc} = true; + }} + """ + ) + self.stores.writeline( + "threadgroup_barrier(metal::mem_flags::mem_threadgroup);" + ) + return acc + + self.headers.add("reduction_utils") + + if reduction_type in ["prod", "sum"]: + acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype] + acc_buf = self._new_idxvar( + acc_dtype, ceildiv(acc_buf_size, self.simd_group_size) + ) + if not self.multistage_reduction_entry: + val = value + else: + default_val, reduction_op = ( + (0, "+") if reduction_type == "sum" else (1, "*") + ) + val = self._new_idxvar( + acc_dtype, default_value=default_val, is_threadgroup=False + ) + self.compute.splice(f"{val} {reduction_op}= {value};") + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})", + dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], + ) + if reduction_type in ["max", "min", "argmin", "argmax"]: + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + src_metal_type = DTYPE_TO_METAL[src_dtype] + if not self.multistage_reduction_entry: + self.compute.splice( + f"{acc_thread_var} = static_cast<{src_metal_type}>({value});" + ) + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=dtype, + ) + lim_fn = "lowest" if reduction_type.endswith("max") else "max" + self.indexing_code.writeline( + f"{acc_thread_var} = ::metal::numeric_limits<{src_metal_type}>::{lim_fn}();" + ) + if reduction_type.startswith("arg"): + idx_var = next( + t for t in self.range_tree_nodes.values() if t.is_reduction + ) + idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size) + cmp_op = ">" if reduction_type == "argmax" else "<" + idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]" + self.indexing_code.splice(f"{idx_thread_var} = -1;") + self.compute.splice(f""" + if ({value} {cmp_op} {acc_thread_var}) {{ + {acc_thread_var} = {value}; + {idx_thread_var} = {idx_var.name}; + }} + """) + return self.cse.generate( + self.stores, + f"{idx_acc_buf}[c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})]", + dtype=dtype, + ) + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::{reduction_type}({acc_thread_var}, {value});" + ) + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=dtype, + ) + if reduction_type == "welford_reduce": + if not self.multistage_reduction_entry: + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") + wf_res = self.cse.generate( + self.compute, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + acc_buf = self._new_idxvar("float3", acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, float3({value}, 0.0, 1.0));" + ) + wf_res = self.cse.generate( + self.stores, + f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + if reduction_type == "welford_combine": + assert isinstance(value, tuple), "Input to welford combine must be tuple" + acc_buf = self._new_idxvar("float3", acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + inp_value = f"float3({value[0]}, {value[1]}, {value[2]})" + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + if self.multistage_reduction_entry: + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, {inp_value});" + ) + else: + self.compute.writeline(f"{acc_thread_var} = {inp_value};") + wf_res = self.cse.generate( + self.stores if self.multistage_reduction_entry else self.compute, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + raise NotImplementedError(reduction_type) + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: + index_expr = self.rename_indexing(entry.expr) + index_str = self.sexpr(index_expr) # type: ignore[misc] + + if not entry.is_reduction or entry.root.numel <= self.max_threadgroup_size: + self.indexing_code.writeline( + f"{self.index_dtype} {entry.name} = {index_str};" + ) + return + self.multistage_reduction_entry.append(entry) + # When reducing the tensor whose size exceeds max threadgroup size + # loop over extra indices per reduction thread and perform part of the operation + # using values in the shared memory + loop_size = ( + entry.root.numel + self.max_threadgroup_size - 1 + ) // self.max_threadgroup_size + self.body.writeline( + f"for(auto {entry.name}_cnt = 0; {entry.name}_cnt < {loop_size}; ++{entry.name}_cnt) {{" + ) + with self.body.indent(): + self.body.writeline( + f"{self.index_dtype} {entry.name} = {loop_size} * {index_str} + {entry.name}_cnt;" + ) + # Check that reduction is performed only within tensor boundary + if loop_size * self.max_threadgroup_size != entry.root.numel: + self.body.writeline(f"if ({entry.name} >= {entry.root.numel}) break;") + + def codegen_body(self) -> None: + """ + Concat output code from index_code, loads, compute, stores, + suffix into self.body. + + For pointwise kernels, this is called just once at the end. + + For reduction kernels, this generates a loop over the reduction + axis. + """ + if self.multistage_reduction_entry: + with self.body.indent(): + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.writeline("}" * len(self.multistage_reduction_entry)) + # Invalidate variables instantiated inside loop + # But results of reduction alive. Reduction cache values can be + # either CSEVariable or tuple of CSEVariables, in which case all + # variables in the tuple must be preserved + self.cse.invalidate( + OrderedSet( + v + for item in self.cse.reduction_cache.values() + for v in (item if isinstance(item, tuple) else (item,)) + ) + ) + # And loop codegen + while self.multistage_reduction_entry: + self.multistage_reduction_entry.pop().cache_clear() + else: + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.loads.clear() + self.compute.clear() + self.stores.clear() + + def codegen_kernel(self, name: Optional[str] = None) -> str: + """Called at the end to generate a final kernel string""" + self.codegen_body() + code = IndentedBuffer() + + if V.graph.cpp_wrapper: + code.writeline('(R"MTL(') + else: + code.writeline("compile_mps_shader('''") + + idx_vars = self.active_range_trees() + with code.indent(): + if not V.graph.cpp_wrapper: + for header in self.headers: + code.writeline(f"#include ") + else: + headers = [ + f"#include " for header in self.headers + ] + header_contents = _embed_headers( + headers, + [Path(__file__).parent.parent.parent / "include"], + OrderedSet(), # type: ignore[arg-type] + ) + code.writeline(header_contents) + + if self.inside_reduction: + total_reduction_size = math.prod( + t.numel for t in self.range_trees if t.is_reduction + ) + threadgroup_size = min(total_reduction_size, self.max_threadgroup_size) + code.writeline( + f"[[max_total_threads_per_threadgroup({threadgroup_size})]]" + ) + code.writeline("kernel void generated_kernel(") + with code.indent(): + for outer, inner in self.args.output_buffers.items(): + if outer in self.removed_buffers: + continue + dtype_str = self.dtype_to_str(V.graph.get_dtype(outer)) + code.writeline(f"device {dtype_str}* {inner},") + for outer, inner in self.args.input_buffers.items(): + dtype = V.graph.get_dtype(outer) + # MPS does not support float64, but scalar inputs are fine + if dtype == torch.float64: + outer_buf = V.graph.try_get_buffer(outer) + if outer_buf is None or outer_buf.get_size() != []: + raise RuntimeError("float64 is not supported by MPS") + dtype_str = "float" + else: + dtype_str = self.dtype_to_str(dtype) + code.writeline(f"constant {dtype_str}* {inner},") + for outer, inner in self.args.sizevars.items(): + code.writeline(f"constant long& {inner},") + assert len(idx_vars) < 4, "Up to 3 index variables are supported" + thread_pos_dtype = ( + f"uint{len(idx_vars)}" if len(idx_vars) > 1 else "uint" + ) + thread_pos_var_name = ( + idx_vars[0].name if len(idx_vars) == 1 else "thread_pos" + ) + thread_pos_suffix = "," if self.inside_reduction else "" + code.writeline( + f"{thread_pos_dtype} {thread_pos_var_name} [[thread_position_in_grid]]{thread_pos_suffix}" + ) + if self.inside_reduction: + code.writeline( + f"{thread_pos_dtype} group_pos [[thread_position_in_threadgroup]]" + ) + code.writeline(") {") + with code.indent(): + if len(idx_vars) > 1: + for idx, var in enumerate(idx_vars): + code.writeline( + f"auto {var.name} = thread_pos.{chr(120 + idx)};" + ) + code.splice(self.indexing_code) + code.splice(self.body) + code.writeline("}") + + if V.graph.cpp_wrapper: + code.writeline(')MTL");') + else: + code.writeline("''')") + + return code.getvalue() + + def call_kernel(self, name: str, node: Any = None) -> None: + """Codegen a call to this kernel""" + wrapper = V.graph.wrapper_code + # Make sure sizevars has been computed + for v in self.args.sizevars.keys(): + wrapper.ensure_size_computed(v) + + _, call_args, _, arg_types = self.args.python_argdefs() + arg_name_to_type = { + str(call_arg): arg_type for call_arg, arg_type in zip(call_args, arg_types) + } + + args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()] + args = [arg for arg in args if arg not in self.removed_buffers] + args += [str(v) for v in self.args.sizevars.keys()] + + arg_types = [arg_name_to_type[arg] for arg in args] + expr_printer = self.cexpr if V.graph.cpp_wrapper else self.pexpr + + def format_threads(threads: list[str], kwarg: str) -> str: + if V.graph.cpp_wrapper: + threads = [f"static_cast({t})" for t in threads] + return f"{{{', '.join(threads)}}}" + else: + return f"{kwarg}=[{', '.join(threads)}]" + + # For reduction kernels, limit the maximum size over reduction dimensions to + # a maximum threadgroup size + if len(self.active_range_trees()) > 0: + threads = [ + expr_printer( + sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc] + if v.is_reduction + else v.numel + ) + for v in self.active_range_trees() + ] + + args.append(format_threads(threads, "threads")) + arg_types.append(list) + else: + if V.graph.cpp_wrapper: + raise RuntimeError("We should always have threads?") + + if self.inside_reduction: + threads = [ + expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc] + if v.is_reduction + else "1" + for v in self.active_range_trees() + ] + args.append(format_threads(threads, "group_size")) + arg_types.append(list) + else: + if V.graph.cpp_wrapper: + # Add a None so that we always have a group_size in the + # arguments. We won't use it if the value is None. + args += [None] # type: ignore[list-item] + arg_types.append(None) + + wrapper.generate_kernel_call( + name, + args, + device=torch.device("cpu"), # TODO: Fix me, MPS does not expose streams now + triton=False, + arg_types=arg_types, + ) + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + if not (lower or upper): + return + # TODO(malfet): support asserts + # See https://github.com/pytorch/pytorch/issues/144634 + expr_str = self.index_to_str(expr) + lower_expr = f"{expr_str} < 0" if lower else "" + # TODO(malfet): Is upper bound inclusive or exclusive? + upper_expr = f"{expr_str} > {self.index_to_str(size)}" if upper else "" + if lower and upper: + line = f"if (({lower_expr}) && ({upper_expr})) return" + else: + line = f"if ({lower_expr}{upper_expr}) return" + self.cse.generate(self.compute, line, assignment=False) + + +class MetalScheduling(SIMDScheduling): + kernel_type = MetalKernel # type: ignore[assignment] + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + wrapper = V.graph.wrapper_code + if wrapper is not None: + if not V.graph.cpp_wrapper: + wrapper.header.splice( + "from torch._inductor.runtime.runtime_utils import compile_mps_shader" + ) + + def define_kernel( + self, src_code: str, node_schedule: list[SchedulerNode], kernel: MetalKernel + ) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + # TODO: Merge multiple kernels into a single library + # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling + mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" + + if V.graph.cpp_wrapper: + src_code = ( + f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}" + + src_code + ) + kernel_name = f"{mps_lib_name}_func" + else: + kernel_name = f"{mps_lib_name}.generated_kernel" + + wrapper.src_to_kernel[src_code] = kernel_name + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False) + + return kernel_name diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4ddb163ef4f9957e1a64a4ab25ec865e8206b5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from .common import DeviceOpOverrides, register_device_op_overrides + + +class MPSDeviceOpOverrides(DeviceOpOverrides): + def device_guard(self, device_idx: int) -> str: + assert device_idx == 0 + return "torch._ops.contextlib.nullcontext()" + + def set_device(self, device_idx: int) -> str: + assert device_idx == 0 + return "pass # MPS set device" + + def kernel_driver(self) -> str: + return """ + #include + """ + + def cpp_kernel_type(self) -> str: + return "MTLFunction_t" + + +register_device_op_overrides("mps", MPSDeviceOpOverrides()) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..7178cf7cc195d20455c788465ce55341f65420ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/multi_kernel.py @@ -0,0 +1,379 @@ +# mypy: allow-untyped-defs +import functools +import logging +import os +import pathlib + +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..codecache import code_hash, CodeCacheFuture, get_path, write_atomic +from ..runtime.benchmarking import benchmarker +from ..utils import cache_on_self, IndentedBuffer +from ..virtualized import V +from .common import TensorArg, WorkspaceArg + + +log = logging.getLogger(__name__) + + +class MultiKernelState: + """ + Maintain state of multi-kernel compilation so we don't define duplicated + multi-kernel for the same set of sub-kernels. + + V.graph.wrapper_code has a reference to MultiKernelState instance. + """ + + def __init__(self): + self.subkernel_to_kernel_name = {} + self.kernel_defs = IndentedBuffer() + + def define_kernel(self, kernels): + """ + Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}". + This has some minor issue. + + E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca , + there are 2 flavors of non-persistent reduction: + https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4 + and + https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd + + The only different is cache eviction policy. + + We should name the multi-kernel differently in these 2 cases. + """ + kernel_names = tuple(k.kernel_name for k in kernels) + if kernel_names in self.subkernel_to_kernel_name: + return self.subkernel_to_kernel_name[kernel_names] + + # name the multi kernel based on the first kernel + multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}" + self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name + + if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time: + # we should not generate any python code for multi-kernel during + # the second pass of cpp-wrapper. + return multi_kernel_name + + buf = self.kernel_defs + buf.writeline("") + buf.writeline( + f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [" + ) + with buf.indent(): + for name in kernel_names: + buf.writeline(f"{name},") + buf.writeline("])") + + if config.triton.autotune_at_compile_time: + V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = ( + multi_kernel_name + ) + + return multi_kernel_name + + +class MultiKernel: + """ + This class maintains the compile time state for multi kernels. + + Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. + The generated definition for the multi-kernel will looks like: + ``` + multi_kernel_kernel1 = MultiKernelCall( + [kernel1, kernel2], multi_kernel_definition_code + ) + ``` + + Here is a concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 + """ + + def __init__(self, kernels): + assert len(kernels) >= 2 + + self.kernels = kernels + self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel( + kernels + ) + + # need this since some code in inductor check if the kernel object has an args + # attribute to decide if it's a non-null kernel. + self.args = object() + + @staticmethod + def _merge_workspace_args(left: list[WorkspaceArg], right: list[WorkspaceArg]): + if left == right: + return left + result = {x.inner_name: x for x in left} + for arg in right: + if arg.inner_name in result: + result[arg.inner_name] = WorkspaceArg.maximum( + result[arg.inner_name], arg + ) + else: + result[arg.inner_name] = arg + return [*result.values()] + + @staticmethod + def merge_workspaces_inplace(kernels): + if len(kernels) < 2: + return + # All kernels must share the same workspace + workspace_args = functools.reduce( + MultiKernel._merge_workspace_args, + [kernel.args.workspace_args for kernel in kernels], + ) + for kernel in kernels: + kernel.args.workspace_args = workspace_args + return workspace_args + + def call_kernel(self, kernel_name): + """ + Collect the union of arguments from all subkernels as the arguments + for the multi-kernel. + """ + assert kernel_name == self.kernel_name + V.graph.wrapper_code.write_triton_header_once() + _, call_args, _, arg_types = self.kernels[0].args.python_argdefs() + for kernel in self.kernels[1:]: + _, other_call_args, _, other_arg_types = kernel.args.python_argdefs() + assert call_args == other_call_args, (call_args, other_call_args) + assert arg_types == other_arg_types + + if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time: + # for the second pass of cpp-wrapper codegen, we should call + # the fast kernel directly + kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + + # numels for all subkernels should be the same. Use kernels[0] here + self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types) + + for ws in self.kernels[0].args.workspace_args: + V.graph.wrapper_code.generate_workspace_allocation(ws) + + V.graph.wrapper_code.generate_kernel_call( + kernel_name, + call_args, + arg_types=arg_types, + ) + + for ws in reversed(self.kernels[0].args.workspace_args): + V.graph.wrapper_code.generate_workspace_deallocation(ws) + + def codegen_nan_check(self): + wrapper = V.graph.wrapper_code + seen: OrderedSet[str] = OrderedSet() + for k in self.kernels: + _, call_args, precompile_args, _ = k.args.python_argdefs() + for arg, precompile_arg in zip(call_args, precompile_args): + if arg in seen: + continue + seen.add(arg) + if isinstance(precompile_arg, TensorArg): + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + @property + def removed_buffers(self): + return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels]) + + @property + def inplaced_to_remove(self): + return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels]) + + @property + @cache_on_self + def inplace_update_buffers(self): + """ + Make sure all kernels have the same inplace update mappings. + """ + for k in self.kernels[1:]: + assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers + return self.kernels[0].inplace_update_buffers + + def warn_mix_layout(self, kernel_name: str): + pass + + +class MultiKernelCall: + """ + This class is called at run time to actually run the kernel + """ + + def __init__(self, multi_kernel_name, kernels): + assert len(kernels) >= 2 + self._kernels = kernels + self.multi_kernel_name = multi_kernel_name + + self.disable_cache = os.environ.get( + "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE" + ) == "1" or is_metric_table_enabled("persistent_red_perf") + + self.picked_kernel = None + if config.triton.multi_kernel > 1: + # manually force a subkernel to ease perf testing + picked_by_config = config.triton.multi_kernel - 2 + assert picked_by_config < len(self._kernels) + self.picked_kernel = picked_by_config + elif not self.disable_cache: + self.load_cache() + + self._recorded = False + + def cache_file_path(self): + key = code_hash( + ",".join( + [ + f"{k.fn.cache_key}{k.size_hints!r}{k.triton_meta!r}" + for k in self.kernels + ] + ) + ) + _, _, path = get_path(key, "picked_kernel") + return pathlib.Path(path) + + def load_cache(self): + assert self.picked_kernel is None + path = self.cache_file_path() + if path.exists(): + with path.open() as fd: + self.picked_kernel = int(fd.read()) + assert self.picked_kernel >= 0 and self.picked_kernel < len( + self._kernels + ) + log.debug( + "Load picked kernel %d from cache file %s", self.picked_kernel, path + ) + + def store_cache(self): + assert self.picked_kernel is not None + path = self.cache_file_path() + path.parent.mkdir(parents=True, exist_ok=True) + + write_atomic(path, str(self.picked_kernel)) + log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path) + + @property + def kernels(self): + """ + Read results from future. + + This should be called after parallel compilation is done. + In case you call this before compilation is done, + it may slow down the parallel compilation. + """ + for i, kernel in enumerate(self._kernels): + if isinstance(kernel, CodeCacheFuture): + self._kernels[i] = kernel.result() + + return self._kernels + + def benchmark_sub_kernels(self, *args, **kwargs): + """ + Benchmark all the sub kernels and return the execution time + (in milliseconds) for each of time. + + Unit test may mock this method to force a specific kernel to + be picked. + """ + + def wrap_fn(kernel): + def inner(): + args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs) + return kernel.run(*args_clone, **kwargs_clone) + + return inner + + return [ + benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40) + for kernel in self.kernels + ] + + # record_choice and lookup_choice are helper functions for cpp-wrapper + # codegen. The first pass use record_choice to keep the choice and + # the second pass do lookup by calling lookup_choice. + # + # An alternative that reused the multi-kernel cache does not work well + # since during codegen of the second pass, it's very hard to know the + # path for the cache file. Also reading the cache file need do some IO + # which can be slower. + @staticmethod + def record_choice(multi_kernel_name: str, picked_kernel_name: str): + """ + Record the multi-kernel choice for cpp-wrapper after autotuning + + We should do nothing if this function is not called during codegen. + """ + from torch._inductor.graph import GraphLowering + + if not isinstance(V.graph, GraphLowering): + return + + if not V.graph.record_multi_kernel_choice: + return + + V.graph.multi_kernel_to_choice[multi_kernel_name] = picked_kernel_name + + @staticmethod + def lookup_choice(multi_kernel_name: str) -> str: + # this should always been done during cpp-wrapper codegen + assert ( + V.graph.record_multi_kernel_choice + and multi_kernel_name in V.graph.multi_kernel_to_choice + ) + # there should be no miss + return V.graph.multi_kernel_to_choice[multi_kernel_name] + + def run(self, *args, **kwargs): + if self.picked_kernel is None: + timings = self.benchmark_sub_kernels(*args, **kwargs) + self.picked_kernel = timings.index(min(timings)) + k0 = self.kernels[0] + log.debug( + "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s", + self.picked_kernel, + [k.inductor_meta.get("kernel_name") for k in self.kernels], + k0.size_hints, + k0.inductor_meta.get("reduction_hint"), + timings, + ) + get_metric_table("persistent_red_perf").add_row( + functools.partial(self._metrics_table_row, timings) + ) + if not self.disable_cache: + self.store_cache() + + if not self._recorded: + self._recorded = True + picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get( + "kernel_name" + ) + assert picked_kernel_name is not None + self.record_choice(self.multi_kernel_name, picked_kernel_name) + self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign] + self.run(*args, **kwargs) + + def _metrics_table_row(self, timings): + def get_kernel_path(k): + return k.fn.fn.__code__.co_filename + + k0 = self.kernels[0] + row = { + "size_hints": k0.size_hints, + "reduction_hint": k0.inductor_meta.get("reduction_hint"), + } + max_kernels = 4 + assert len(timings) <= max_kernels + for i in range(max_kernels): + if i < len(self.kernels): + row[f"kernel{i}_path"] = get_kernel_path(self.kernels[i]) + row[f"kernel{i}_latency"] = timings[i] + else: + row[f"kernel{i}_path"] = "" + row[f"kernel{i}_latency"] = "" + return row diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b9d6382f588ad6847d8e73b2f24c8b9eadbafb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -0,0 +1,608 @@ +# mypy: allow-untyped-defs +import copy +import logging +import random + +from torch._inductor.virtualized import V + + +try: + import ck4inductor # type: ignore[import] +except ImportError: + ck4inductor = None + +if ck4inductor is not None: + from ck4inductor.grouped_conv_fwd.gen_instances import ( # type: ignore[import] + gen_conv_ops_library, + ) + from ck4inductor.grouped_conv_fwd.op import ( # type: ignore[import] # noqa: TCH002 + CKGroupedConvFwdOp, + ) +else: + + def gen_conv_ops_library(): + return [] + + +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def torch_layout_to_ck_layouts(torch_layout): + # logically, torch tensors are always NCHW, + # and channels-last memory layout is visible in the strides + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + # when input or output is NCHW + # NB: torch.conv2d result is always NCHW + return ["NGCHW", "GKCYX", "NGKHW"] + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + # when input or output or weight is channels-last + return ["NHWGC", "GKYXC", "NHWGK"] + else: + return None + + +def torch_layout_to_ck_input_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGCHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGC" + else: + return None + + +def torch_layout_to_ck_weight_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "GKCYX" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "GKYXC" + else: + return None + + +def torch_layout_to_ck_output_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGKHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGK" + else: + return None + + +class CKGroupedConvFwdTemplate(CKTemplate): + conv_template = r""" + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto conv = {{instance_type}} {}; + auto invoker = conv.MakeInvoker(); + + using ck::index_t; + + constexpr index_t NumDTensor = {{n_d_tensors}}; + constexpr index_t NDimSpatial = {{n_dim_spatial}}; + const std::vector FilterSize = { FilterSize_0, FilterSize_1 }; + const std::vector InputSize = { InputSize_0, InputSize_1 }; + const std::vector ConvolutionStrides = { ConvolutionStrides_0, ConvolutionStrides_1 }; + const std::vector Dilations = { Dilations_0, Dilations_1 }; + const std::vector LeftPads = { LeftPads_0, LeftPads_1 }; + const std::vector RightPads = { RightPads_0, RightPads_1 }; + + + auto conv_param = ck::utils::conv::ConvParam { + NDimSpatial, + GroupCount, + NBatch, + NOutChannels, + NInChannels, + FilterSize, + InputSize, + ConvolutionStrides, + Dilations, + LeftPads, + RightPads, + }; + + using InLayout = ck::tensor_layout::convolution::{{input_layout}}; + using WeiLayout = ck::tensor_layout::convolution::{{weight_layout}}; + using OutLayout = ck::tensor_layout::convolution::{{output_layout}}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + const void* p_a = input; + const void* p_b = weight; + const std::array p_ds; + void* p_e = output; + std::array a_g_n_c_wis_lengths; + std::array a_g_n_c_wis_strides; + std::array b_g_k_c_xs_lengths; + std::array b_g_k_c_xs_strides; + std::array, NumDTensor> ds_g_n_k_wos_lengths; + std::array, NumDTensor> ds_g_n_k_wos_strides; + std::array e_g_n_k_wos_lengths; + std::array e_g_n_k_wos_strides; + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + const auto a_element_op = PassThrough {}; + const auto b_element_op = PassThrough {}; + const auto cde_element_op = PassThrough {}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto argument = conv.MakeArgument( + p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op + ); + if (!conv.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for conv instance " << conv.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = conv.GetWorkSpaceSize(&argument); + return 0; + } + + if (p_a == nullptr) { + std::cerr << "p_a is nullptr" << std::endl; + return -1; + } + if (p_b == nullptr) { + std::cerr << "p_b is nullptr" << std::endl; + return -1; + } + if (p_e == nullptr) { + std::cerr << "p_e is nullptr" << std::endl; + return -1; + } + + // when debugging, do time kernel to serialize launches + auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + + if (workspace != nullptr) { + conv.SetWorkSpacePointer(&argument, workspace, stream_config); + } + + // run the kernel + float elapsed_time = invoker.Run(argument, stream_config); + return 0; + } // kernel definition + } // extern C + + #ifdef GENERATE_CK_STANDALONE_RUNNER + int main(int argc, char** argv) { + (void) argc; + (void) argv; + return 0; + } + #endif // GENERATE_CK_STANDALONE_RUNNER +""" + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK conv globals + + using NWC = ck::tensor_layout::convolution::NWC; + using NHWC = ck::tensor_layout::convolution::NHWC; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + using KXC = ck::tensor_layout::convolution::KXC; + using KYXC = ck::tensor_layout::convolution::KYXC; + using KZYXC = ck::tensor_layout::convolution::KZYXC; + + using NWK = ck::tensor_layout::convolution::NWK; + using NHWK = ck::tensor_layout::convolution::NHWK; + using NDHWK = ck::tensor_layout::convolution::NDHWK; + + using GNWC = ck::tensor_layout::convolution::GNWC; + using GNHWC = ck::tensor_layout::convolution::GNHWC; + using GNDHWC = ck::tensor_layout::convolution::GNDHWC; + + using GKXC = ck::tensor_layout::convolution::GKXC; + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + + using GKCX = ck::tensor_layout::convolution::GKCX; + using GKCYX = ck::tensor_layout::convolution::GKCYX; + using GKCZYX = ck::tensor_layout::convolution::GKCZYX; + + using GNWK = ck::tensor_layout::convolution::GNWK; + using GNHWK = ck::tensor_layout::convolution::GNHWK; + using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + + using NGKW = ck::tensor_layout::convolution::NGKW; + using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKDHW = ck::tensor_layout::convolution::NGKDHW; + + using NWGC = ck::tensor_layout::convolution::NWGC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + + using KXGC = ck::tensor_layout::convolution::KXGC; + using KYXGC = ck::tensor_layout::convolution::KYXGC; + using KZYXGC = ck::tensor_layout::convolution::KZYXGC; + + using NWGK = ck::tensor_layout::convolution::NWGK; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using NGCW = ck::tensor_layout::convolution::NGCW; + using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCDHW = ck::tensor_layout::convolution::NGCDHW; + + using G_K = ck::tensor_layout::convolution::G_K; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + using ConvolutionForwardSpecialization = ck::tensor_operation::device::ConvolutionForwardSpecialization; + + namespace ck { + namespace utils { + namespace conv { + + ConvParam::ConvParam(ck::index_t n_dim, + ck::index_t group_count, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(static_cast(n_dim)), + G_(static_cast(group_count)), + N_(static_cast(n_batch)), + K_(static_cast(n_out_channels)), + C_(static_cast(n_in_channels)), + filter_spatial_lengths_(num_dim_spatial_), + input_spatial_lengths_(num_dim_spatial_), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(num_dim_spatial_), + conv_filter_dilations_(num_dim_spatial_), + input_left_pads_(num_dim_spatial_), + input_right_pads_(num_dim_spatial_) + { + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + filter_spatial_lengths_[i] = static_cast(filters_len[i]); + input_spatial_lengths_[i] = static_cast(input_len[i]); + conv_filter_strides_[i] = static_cast(strides[i]); + conv_filter_dilations_[i] = static_cast(dilations[i]); + input_left_pads_[i] = static_cast(left_pads[i]); + input_right_pads_[i] = static_cast(right_pads[i]); + + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } + } + + } // namespace conv + } // namespace utils + } // namespace ck + + const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } + const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } + void HostTensorDescriptor::CalculateStrides() { + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum( + mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); + } + """ + ) + return res + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK conv headers + + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" + #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + + #include "ck/library/utility/convolution_parameter.hpp" + #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + """ + ) + return res + + @staticmethod + def add_ck_conv_choices( + choices, + layout, + input_nodes, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + template = CKGroupedConvFwdTemplate( + input_nodes, + layout, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=n_spatial_dimensions, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op, + ) + + def __init__( + self, + input_nodes, + layout, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + super().__init__( + "ck_conv_template", + input_nodes, + layout, + ) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.n_spatial_dimensions = n_spatial_dimensions + + def filter_op(self, op: "CKGroupedConvFwdOp"): # type: ignore[name-defined] + metas = [ + T.get_layout() + for T in [*self.input_nodes, self.output_node] + if T is not None + ] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.e_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_input_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_weight_layout(W_meta): + return None + if op.e_layout != torch_layout_to_ck_output_layout(Y_meta): + return None + # disable the instance if number of spatial dimensions doesn't match + if op.n_dim_spatial != self.n_spatial_dimensions: + return None + # disable 1x1 and odd-channels conv specializations for now + if "Default" not in op.conv_forward_specialization: + return None + return op + + def gen_ops(self): + unfiltered_instances = gen_conv_ops_library() + + filtered_instances = list( + filter(lambda op: self.filter_op(op), unfiltered_instances) + ) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_max_profiling_configs), + ) + if config.rocm.ck_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + def emit_ck_instance(self, op: "CKGroupedConvFwdOp") -> tuple[str, str]: # type: ignore[name-defined] + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + template_params.append(arg) + else: + if field_value is not None: + template_params.append(f"/* {field_name} */ {field_value}") + return self._template_from_string(template_definition).render( + operation_name=op.name(), + template_params=(",\n" + 12 * " ").join(template_params), + ), self._template_from_string(template_type).render(operation_name=op.name()) + + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGroupedConvFwdOp", # type: ignore[name-defined] + **kwargs, + ) -> str: + template_buffer_node = kwargs.get("template_buffer_node", None) + if template_buffer_node is not None: + self.output_node = template_buffer_node + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = self.input_nodes[2] if 3 == len(self.input_nodes) else None + + op = copy.deepcopy(op) + + instance_definition, instance_type = self.emit_ck_instance(op) + + size_arg_strs = [ + "GroupCount", + "NBatch", + "NOutChannels", + "NInChannels", + "FilterSize_0", + "FilterSize_1", + "InputSize_0", + "InputSize_1", + "ConvolutionStrides_0", + "ConvolutionStrides_1", + "Dilations_0", + "Dilations_1", + "LeftPads_0", + "LeftPads_1", + "RightPads_0", + "RightPads_1", + ] + + return self._template_from_string(self.conv_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + instance_type=instance_type, + kernel_definition=kernel.def_kernel( + inputs=[X, W, Bias] if Bias is not None else [X, W], + outputs=[Y], + names_str="input, weight, bias, output" + if Bias is not None + else "input, weight, output", + size_args=[f"int32_t {arg}" for arg in size_arg_strs], + ), + n_d_tensors=1 if Bias is not None else 0, + n_dim_spatial=self.n_spatial_dimensions, + input_layout=op.a_layout, + weight_layout=op.b_layout, + output_layout=op.e_layout, + ) + + def size_args(self): + x, w = self.input_nodes[0], self.input_nodes[1] + y = self.output_node + + group_count = self.groups + n_batch = x.shape[0] # type: ignore[index] + n_out_channels = y.shape[1] # type: ignore[index] + n_in_channels = x.shape[1] # type: ignore[index] + + filter_size_0, filter_size_1 = w.shape[2:4] # type: ignore[index] + input_size_0, input_size_1 = x.shape[2:4] # type: ignore[index] + convolution_strides_0, convolution_strides_1 = self.stride + dilations_0, dilations_1 = self.dilation + left_pads_0, left_pads_1 = self.padding + right_pads_0, right_pads_1 = self.padding + + return ( + group_count, + n_batch, + n_out_channels, + n_in_channels, + filter_size_0, + filter_size_1, + input_size_0, + input_size_1, + convolution_strides_0, + convolution_strides_1, + dilations_0, + dilations_1, + left_pads_0, + left_pads_1, + right_pads_0, + right_pads_1, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f137aa60c0951277f5ae59eddf60524474ca15 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_template.py @@ -0,0 +1,108 @@ +from typing import Any +from typing_extensions import override + +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + +from .rocm_template import ArgInfo + + +class CKTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", + torch.float8_e5m2fnuz: "BF8", + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK headers + + #ifdef DEBUG_LOG + #define DEBUG_LOG_TMP DEBUG_LOG + #undef DEBUG_LOG + #else + #define DEBUG_LOG_TMP 0 + #endif + #include "ck/ck.hpp" + #undef DEBUG_LOG + #define DEBUG_LOG DEBUG_LOG_TMP + + #include "ck/utility/data_type.hpp" + #include "ck/library/utility/check_err.hpp" + #include "ck/library/utility/device_memory.hpp" + #include "ck/library/utility/fill.hpp" + #include "ck/library/utility/host_tensor.hpp" + #include "ck/library/utility/host_tensor_generator.hpp" + #include "ck/library/utility/literals.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK globals + + template + using S = ck::Sequence; + + template + using Tuple = ck::Tuple; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + using Scale = ck::tensor_operation::element_wise::Scale; + using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply; + + // see "composable_kernel/include/ck/utility/data_type.hpp" + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + using F16 = ck::half_t; + using F32 = float; + // using F64 = double; + using BF16 = ck::bhalf_t; + // using I32 = int32_t; + // using I8 = int8_t; + // using I4 = ck::int4_t; + + #if DEBUG_LOG + static constexpr auto kDEBUG_LOG = 1; + #else + static constexpr auto kDEBUG_LOG = 0; + #endif + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + @override + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py new file mode 100644 index 0000000000000000000000000000000000000000..65ac33444bdb485a8dc2cd143b6396de3fb10f34 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py @@ -0,0 +1,56 @@ +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + + +class CKTileTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", + torch.float8_e5m2fnuz: "BF8", + } + + ck_dtype_to_size = { + "FP16": 2, + "BF16": 2, + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK headers + #include "ck_tile/core.hpp" + + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using F8 = ck_tile::fp8_t; + using BF8 = ck_tile::bf8_t; + using F16 = ck_tile::half_t; + using F32 = float; + using BF16 = ck_tile::bfloat16_t; + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..5862534ce6cc19a8f31ef5836b2bc502224e9b11 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py @@ -0,0 +1,967 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import functools +import logging +import random +from dataclasses import asdict, dataclass +from typing import Any + +import torch +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_tile_template import CKTileTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.codegen.rocm.rocm_template import ArgInfo +from torch._inductor.ir import Buffer, Layout +from torch.utils._ordered_set import OrderedSet + +from ...utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def is_static_int(number): + import sympy + + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +@dataclass +class CKTileGemmOperation: + layout_a: str + layout_b: str + layout_c: str + + datatype_a: str + datatype_b: str + datatype_c: str + + tile_m: int + tile_n: int + tile_k: int + + warp_m: int + warp_n: int + warp_k: int + + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + m_is_padded: str + n_is_padded: str + k_is_padded: str + + pipeline: str + scheduler: str + epilogue: str + + def layout_repr(self): + return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}" + + def dtype_repr(self): + return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}" + + def tile_sizes(self): + return "_".join( + [ + f"{self.tile_m}{self.tile_n}{self.tile_k}", + f"{self.warp_m}{self.warp_n}{self.warp_k}", + f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}", + ] + ) + + def name(self): + return "ck_tile_gemm_universal_" + "_".join( + [ + f"{self.layout_repr()}", + f"{self.dtype_repr()}", + f"{self.tile_sizes()}", + f"{self.pipeline}", + f"{self.scheduler}", + f"{self.epilogue}", + ] + ) + + def dict_items(self): + return asdict(self).items() + + +@functools.cache +def ops(): + """ + Generate the supported instance dataclasses + """ + import itertools + + compute_v3_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV3", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + compute_v4_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV4", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [ + (256, 256, 32) + ] # half the tile size since it has double buffering + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + mem_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="Mem", + scheduler=scheduler, + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for scheduler in ["Intrawave", "Interwave"] + for epilogue in ["Default", "CShuffle"] + ] + + return list( + itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances) + ) + + +class CKTileGemmTemplate(CKTileTemplate): + """ + This class is used for rendering CK-Tile Universal GEMM kernels + """ + + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + + using {{instance_namespace}}::BaseGemmPipeline; + using {{instance_namespace}}::TilePartitioner; + + constexpr auto TileK = {{instance_namespace}}::TileK; + constexpr auto kPrefetchStages = BaseGemmPipeline::PrefetchStages; + + auto kargs = ck_tile::GemmKernelArgs { + X, + W, + Y, + M, + N, + K, + LDA, + LDB, + LDC, + kBatch + }; + + if (workspace_size) { + *workspace_size = 0; + return 0; + } + + // run the kernel + const auto dispatch = [&](const auto has_hot_loop_, const auto tail_number_) constexpr { + using Kernel = {{instance_namespace}}::Kernel; + + if (!Kernel::IsSupportedArgument(kargs)) { + // we do our best to statically avoid this case in `filter_op` + throw std::runtime_error("invalid argument"); + } + auto stream_config = ck_tile::stream_config{stream}; + auto grid_size = Kernel::GridSize(M, N, kBatch); + constexpr auto block_size = Kernel::BlockSize(); + constexpr auto lds_bytes = 0; + constexpr auto kBlockPerCU = 1; + auto gemm = ck_tile::make_kernel(Kernel{}, grid_size, block_size, lds_bytes, kargs); + float elapsed_time = ck_tile::launch_kernel(stream_config, gemm); + }; + + const ck_tile::index_t k_grain = kBatch * TileK; + const ck_tile::index_t K_split = (K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + {{rendered_dispatch}} + + return 0; + } // kernel definition + } // extern C + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + ) -> None: + super().__init__( + "ck_tile_gemm_template", + input_nodes=input_nodes, + layout=layout, + ) + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK GEMM header(s) + + #include "ck_tile/ops/gemm.hpp" + #include "ck_tile/ops/epilogue.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + template + void dispatch_memory_pipeline_hot_loop(const ck_tile::TailNumber tail_num, Dispatcher dispatch) + { + if(tail_num == ck_tile::TailNumber::One) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + """ + ) + return res + + def check_dtypes(self, op: "CKTileGemmOperation"): + X_dtype, W_dtype, out_dtype = [ + T.get_layout().dtype for T in [*self.input_nodes, self.output_node] + ] + if op.datatype_a != self._TORCH_DTYPE_TO_CK[X_dtype]: + return False + if op.datatype_b != self._TORCH_DTYPE_TO_CK[W_dtype]: + return False + if op.datatype_c != self._TORCH_DTYPE_TO_CK[out_dtype]: + return False + return True + + def check_layouts(self, op: "CKTileGemmOperation"): + X_layout, W_layout, out_layout = [ + torch_layout_to_ck_layout(T.get_layout()) + for T in [*self.input_nodes, self.output_node] + ] + if op.layout_a != X_layout: + return False + if op.layout_b != W_layout: + return False + if op.layout_c != out_layout: + return False + return True + + def get_gemm_problem_size(self): + X_size, W_size = [T.get_layout().size for T in [*self.input_nodes]] + + M, K = X_size + _, N = W_size + + return M, N, K + + def check_block_tiles(self, op: "CKTileGemmOperation"): + """ + The contiguous dimension of a tensor must be divisible by the block tile size + This helper function enforces it for the inputs and the output. + """ + M, N, K = self.get_gemm_problem_size() + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + if op.layout_a == "Row": + # handle in kBatch check + return True + elif op.layout_a == "Col": + if not check(M, op.tile_m, op.m_is_padded): + return False + else: + raise AssertionError(f"Invalid layout {op.layout_a=}") + + if op.layout_b == "Row": + if not check(N, op.tile_n, op.n_is_padded): + return False + elif op.layout_b == "Col": + # handle in kBatch check + return True + else: + raise AssertionError(f"Invalid {op.layout_b=}") + + if op.layout_c == "Row": + if not check(N, op.tile_n, op.n_is_padded): + return False + elif op.layout_c == "Col": + if not check(M, op.tile_m, op.m_is_padded): + return False + else: + raise AssertionError(f"Invalid layout {op.layout_c=}") + + return True + + def check_alignments(self, op: "CKTileGemmOperation"): + """ + The contiguous dimension of a tensor must be divisible by the vector load size. + """ + M, N, K = self.get_gemm_problem_size() + + def max_alignment(contiguous_elements_per_tile, elements_per_thread, ck_dtype): + for vector_load_bytes in (16, 8, 4, 2, 1): + alignment = vector_load_bytes // self.ck_dtype_to_size[ck_dtype] + if ( + alignment > 0 + and contiguous_elements_per_tile % alignment == 0 + and elements_per_thread % alignment == 0 + ): + return alignment + + threads_per_block = ( + op.warp_m * op.warp_n * op.warp_k * self.gfx9_threads_per_warp + ) + a_elements_per_thread = op.tile_m * op.tile_k / threads_per_block + b_elements_per_thread = op.tile_n * op.tile_k / threads_per_block + + if op.layout_a == "Row": + # K is contiguous tensor dimension + a_max_vector_size = max_alignment( + op.tile_k, a_elements_per_thread, op.datatype_a + ) + if is_static_int(K) and K % a_max_vector_size != 0: + return False + elif op.layout_a == "Col": + # M is contiguous tensor dimension + a_max_vector_size = max_alignment( + op.tile_m, a_elements_per_thread, op.datatype_a + ) + if is_static_int(M) and M % a_max_vector_size != 0: + return False + else: + raise AssertionError(f"Invalid layout {op.layout_a=}") + + if op.layout_b == "Row": + # N is contiguous tensor dimension + b_max_vector_size = max_alignment( + op.tile_n, b_elements_per_thread, op.datatype_b + ) + if is_static_int(N) and N % b_max_vector_size != 0: + return False + elif op.layout_b == "Col": + # K is contiguous tensor dimension + b_max_vector_size = max_alignment( + op.tile_k, b_elements_per_thread, op.datatype_b + ) + if is_static_int(K) and K % b_max_vector_size != 0: + return False + else: + raise AssertionError(f"Invalid layout {op.layout_b=}") + + # the `default` epilogue writes C to memory by 1 tensor element + # (divisibility check not necessary) + # the `cshuffle` epilogue writes C to memory by 16 bytes + # (so the contiguous C dimension size must be divisible by the number of tensor elements in 16 bytes) + if op.epilogue == "CShuffle": + if ( + op.layout_c == "Row" + and is_static_int(N) + and N % (16 / self.ck_dtype_to_size[op.datatype_c]) != 0 + ): + return False + + return True + + def check_warp_tiles(self, op: "CKTileGemmOperation"): + if op.tile_m % (op.warp_m * op.warp_tile_m) != 0: + return False + if op.tile_n % (op.warp_n * op.warp_tile_n) != 0: + return False + if op.tile_k % (op.warp_k * op.warp_tile_k) != 0: + return False + return True + + def check_block_tile_size(self, op: "CKTileGemmOperation"): + # assuming LDS size is 64KB + if op.pipeline == "CompV4": + max_block_tile_size = 2**15 + else: + max_block_tile_size = 2**16 + + block_tile_size = ( + self.ck_dtype_to_size[op.datatype_a] * op.tile_m * op.tile_k + + self.ck_dtype_to_size[op.datatype_b] * op.tile_n * op.tile_k + ) + if block_tile_size > max_block_tile_size: + return False + return True + + def filter_op(self, op: "CKTileGemmOperation"): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + if not self.check_dtypes(op): + return None + if not self.check_layouts(op): + return None + if not self.check_block_tiles(op): + return None + if not self.check_alignments(op): + return None + + return op + + def emit_ck_instance(self, op: "CKTileGemmOperation"): + """ + This method is used to generate code which defines the type alias for the generated kernel class + """ + template_definition = r""" + // Gemm operator {{operation_name}} + + namespace {{operation_name}} { + // block tile + constexpr int32_t TileM = {{tile_m}}; + constexpr int32_t TileN = {{tile_n}}; + constexpr int32_t TileK = {{tile_k}}; + // warps per block + constexpr int32_t WarpM = {{warp_m}}; + constexpr int32_t WarpN = {{warp_n}}; + constexpr int32_t WarpK = {{warp_k}}; + // xdl tile + constexpr int32_t WarpTileM = {{warp_tile_m}}; + constexpr int32_t WarpTileN = {{warp_tile_n}}; + constexpr int32_t WarpTileK = {{warp_tile_k}}; + + constexpr bool kPadM = {{m_is_padded}}; + constexpr bool kPadN = {{n_is_padded}}; + constexpr bool kPadK = {{k_is_padded}}; + + using ALayout = {{layout_a}}; + using BLayout = {{layout_b}}; + using CLayout = {{layout_c}}; + + using ADataType = {{datatype_a}}; + using BDataType = {{datatype_b}}; + using CDataType = {{datatype_c}}; + using AccDataType = F32; + + constexpr bool permuteA = false; + constexpr bool permuteB = false; + constexpr bool DoubleSmemBuffer = {{has_double_smem_buffer}}; + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + constexpr ck_tile::index_t TilePartitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + permuteA, + permuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = + ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + {{rendered_scheduler}} + + template + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + {{rendered_pipeline}} + + {{rendered_epilogue}} + + template + using Kernel = ck_tile::GemmKernel, GemmEpilogue>; + } + +""" + + def render_epilogue(epilogue_type): + if epilogue_type == "Default": + return r""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem; + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; + """ + elif epilogue_type == "CShuffle": + return r""" + constexpr auto kMemoryOperation = ck_tile::memory_operation_enum::set; + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue; + """ + else: + raise AssertionError("Epilogue must be set") + + def render_pipeline(pipeline_type): + return rf""" + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCr{pipeline_type}; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCr{pipeline_type}>; + """ + + def render_scheduler(scheduler_type): + return rf""" + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::{scheduler_type}; + """ + + rendered_definition = self._template_from_string(template_definition).render( + operation_name=op.name(), + **asdict(op), + rendered_scheduler=render_scheduler(op.scheduler), + rendered_pipeline=render_pipeline(op.pipeline), + rendered_epilogue=render_epilogue(op.epilogue), + has_double_smem_buffer=("true" if op.pipeline == "CompV4" else "false"), + ) + return rendered_definition + + def render( # type: ignore[override] + self, kernel: ROCmTemplateKernel, op: "CKTileGemmOperation", **kwargs + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes", None) + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node", None) + if template_buffer_node is not None: + self.output_node = template_buffer_node + assert 2 == len(self.input_nodes) + X, W = self.input_nodes + Y = self.output_node + + instance_definition = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} +*/ +""" + + def render_dispatch(pipeline_type, op_name): + switch_tailnum_template = r""" + switch (tail_num) { + {% for tail_num in valid_tailnums %} + case ck_tile::TailNumber::{{tail_num}}: + dispatch({{has_hot_loop}}, + ck_tile::integral_constant{}); + break; + {% endfor %} + default: + std::ostringstream err; + err << "Unsupported dispatch: " + << "Pipeline: " << "{{pipeline}}" + << "Prefetch stages: " << kPrefetchStages + << "Tail num: " << tail_num; + throw std::runtime_error(err.str()); + } // switch tail_num + """ + dispatch_template = r""" + if (has_hot_loop) { + {{rendered_with_hot_loop}} + } + else { // has_hot_loop == false + {{rendered_without_hot_loop}} + } // if has_hot_loop + """ + if pipeline_type == "CompV3": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + elif pipeline_type == "Mem": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop="dispatch_memory_pipeline_hot_loop(tail_num, dispatch);", + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + elif pipeline_type == "CompV4": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Two", "Three"), + pipeline=pipeline_type, + ), + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + else: + raise AssertionError(f"Pipeline {pipeline_type} is not supported") + + return self._template_from_string(self.gemm_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, Y", + size_args=[ + f"int32_t {arg}" for arg in ["M", "N", "K", "LDA", "LDB", "LDC"] + ], + ), + instance_namespace=op.name(), + version_comment=version_comment, + rendered_dispatch=render_dispatch(op.pipeline, op.name()), + ) + + def gen_ops(self): + """ + Creates a list of `CKTileGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + instances = ops() + if not instances: + raise AssertionError( + "No Composable Kernel Universal GEMM instances found. " + "Please check if the library is installed." + ) + filtered_instances = list(filter(self.filter_op, instances)) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_tile_max_profiling_configs), + ) + if config.rocm.ck_tile_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after sample: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_choices( + choices, + layout, + input_nodes, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKTileGemmTemplate( + input_nodes, + layout, + ) + ops = template.gen_ops() + for op in ops: + for k_batch in template.k_batch_choices(op): + template.maybe_append_choice( + choices, + op=op, + kBatch=k_batch, + ) + + def k_batch_choices(self, op: "CKTileGemmOperation") -> tuple[int, ...]: + """ + Returns a list of k_batch choices for the template. + """ + default_choices = (1, 2, 4, 8, 16, 32) + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + _, _, K, _, _, _ = self.size_args() + if op.layout_a == "Row" or op.layout_b == "Col": + choices = tuple( + filter( + lambda k_batch: check(K, op.tile_k * k_batch, op.k_is_padded), + default_choices, + ) + ) + else: + choices = default_choices + + if op.epilogue == "Default": + choices = (1,) + + return choices + + def size_args(self): + """ + Sizes and strides to be used for the kernel call + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + Y = self.output_node + + M = X.get_size()[0] + K = X.get_size()[1] + N = W.get_size()[1] + LDA = X.get_stride()[0 if X.get_stride()[1] == 1 else 1] + LDB = W.get_stride()[0 if W.get_stride()[1] == 1 else 1] + LDC = Y.get_stride()[0 if Y.get_stride()[1] == 1 else 1] + + return M, N, K, LDA, LDB, LDC + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + # maybe_append_choice kwarg for k_batch must match the name of the argument + arg_names = OrderedSet([arg.name for arg in self.get_runtime_arg_info()]) + if not arg_names.issubset(kwargs): + raise ValueError( + "Missing runtime arguments: " + ", ".join(arg_names - kwargs.keys()) + ) + return [kwargs[k] for k in arg_names] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0f75b919bbf2806f990f77013f29c9322161c9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -0,0 +1,1016 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import copy +import logging +import math +import random +from collections import namedtuple +from typing import Optional + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.codegen.cpp_utils import DTYPE_TO_CPP +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.compile_command import rocm_compile_command +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.ir import Buffer, Layout +from torch._inductor.runtime.runtime_utils import next_power_of_2 + +from ...utils import IndentedBuffer, is_dynamic, try_import_ck_lib + + +_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib() + + +log = logging.getLogger(__name__) + +# lightweight collection of information about a single op +InductorROCmOp = namedtuple("InductorROCmOp", ["op", "kBatch"]) + +padding_lookup = { + "M": { + "GemmSpecialization::MPadding": True, + "GemmSpecialization::MNPadding": True, + "GemmSpecialization::MKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, + "N": { + "GemmSpecialization::NPadding": True, + "GemmSpecialization::MNPadding": True, + "GemmSpecialization::NKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, + "K": { + "GemmSpecialization::KPadding": True, + "GemmSpecialization::MKPadding": True, + "GemmSpecialization::NKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, +} + + +def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +class CKGemmTemplate(CKTemplate): + # the JINJA template for rendering CK Universal GEMMs + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto gemm = {{instance_type}} {}; + auto invoker = gemm.MakeInvoker(); + {% if is_batched %} + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{ds_names}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + B, + LDA, + LDB, + std::array{ {{ds_strides}} }, + LDC, + M * K, // batch_stride_A + N * K, // batch_stride_B + std::array{ {{ds_batch_strides}} }, + M * N, // batch_stride_C + {{a_elementwise_op}}, + {{b_elementwise_op}}, + {{epilogue}} // c_elementwise_op + ); + {% else %} + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{ds_names}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + LDA, + LDB, + std::array{ {{ds_strides}} }, + LDC, + kBatch, // kBatch + {{a_elementwise_op}}, + {{b_elementwise_op}}, + {{epilogue}} // c_elementwise_op + ); + {% endif %} + if (!gemm.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for gemm instance " << gemm.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = gemm.GetWorkSpaceSize(&argument); + return 0; + } + // run the kernel + #ifdef GENERATE_CK_STANDALONE_RUNNER + const auto stream_config = StreamConfig{ + stream, + /* time kernel */ 1, + /* log level */ 1, + /* n_cold_iter */ 100, + /* n_hot_iter */ 100, + /* flush_l2_cache */ 1, + /* rotate_count */ 5}; + #else + const auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + #endif + + const float elapsed_time = invoker.Run(argument, stream_config); + + #ifdef GENERATE_CK_STANDALONE_RUNNER + std::cout << "elapsed time: " << elapsed_time << " ms" << std::endl; + #else + (void)elapsed_time; + #endif + return 0; + } // kernel definition + } // extern C + """ + + standalone_runner_template = r""" + #ifdef GENERATE_CK_STANDALONE_RUNNER + // standalone runner for the generated CK GEMM kernel + + {{inline_utils}} + + extern "C" { + int run_main(int argc, char** argv) { + {% if is_batched %} + const int32_t B = {{B}}; + {% endif %} + const int32_t M = {{M}}; + const int32_t N = {{N}}; + const int32_t K = {{K}}; + const int32_t LDA = {{LDA}}; + const int32_t LDB = {{LDB}}; + const int32_t LDC = {{LDC}}; + const int32_t LDD = {{LDD}}; + const int32_t kBatch = {{kBatch}}; + + using AElementType = {{a_ck_dtype}}; + using BElementType = {{b_ck_dtype}}; + using CElementType = {{c_ck_dtype}}; + {% if has_bias %} + using BiasElementType = {{bias_ck_dtype}}; + {% endif %} + {% if has_scale %} + using ScaleAElementType = {{scale_a_ck_dtype}}; + using ScaleBElementType = {{scale_b_ck_dtype}}; + {% endif %} + + using AArgType = {{a_torch_dtype}}; + using BArgType = {{b_torch_dtype}}; + using CArgType = {{c_torch_dtype}}; + {% if has_bias %} + using BiasArgType = {{bias_torch_dtype}}; + {% endif %} + {% if has_scale %} + using ScaleAArgType = {{scale_a_torch_dtype}}; + using ScaleBArgType = {{scale_b_torch_dtype}}; + {% endif %} + + using ALayout = {{a_layout}}; + using BLayout = {{b_layout}}; + using CLayout = {{c_layout}}; + {% if has_bias %} + using BiasLayout = {{bias_layout}}; + {% endif %} + + {% if is_batched %} + using strides_t = std::array; + auto get_strides = [](int32_t batch_stride, int32_t leading_dimension, auto layout) constexpr -> strides_t { + if constexpr (std::is_same_v) { + return {batch_stride, leading_dimension, 1}; + } + return {batch_stride, 1, leading_dimension}; + }; + auto a_size = strides_t{B, M, K}; + auto a_stride = get_strides(M * K, LDA, ALayout{}); + auto b_size = strides_t{B, N, K}; + auto b_stride = get_strides(N * K, LDB, BLayout{}); + auto c_size = strides_t{B, M, N}; + auto c_stride = get_strides(M * N, LDC, CLayout{}); + {% else %} + using strides_t = std::array; + auto get_strides = [](int32_t leading_dimension, auto layout) constexpr -> strides_t { + if constexpr (std::is_same_v) { + return {leading_dimension, 1}; + } + return {1, leading_dimension}; + }; + auto a_size = strides_t{M, K}; + auto a_stride = get_strides(LDA, ALayout{}); + auto b_size = strides_t{N, K}; + auto b_stride = get_strides(LDB, BLayout{}); + auto c_size = strides_t{M, N}; + auto c_stride = get_strides(LDC, CLayout{}); + {% endif %} + + Tensor a_m_k ( HostTensorDescriptor ( a_size, a_stride ) ); + Tensor b_k_n ( HostTensorDescriptor ( b_size, b_stride ) ); + {% if has_bias %} + Tensor d_m_n ( HostTensorDescriptor ( c_size, get_strides(LDD, BiasLayout{}) ) ); + {% endif %} + {% if has_scale %} + // NB: these are hardcoded + Tensor s_a_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Row{}) )); + Tensor s_b_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Col{}) )); + {% endif %} + + Tensor c_m_n_host ( HostTensorDescriptor ( c_size, c_stride ) ); + Tensor c_m_n_device ( HostTensorDescriptor ( c_size, c_stride ) ); + + a_m_k.GenerateTensorValue(GeneratorTensor_2()); + b_k_n.GenerateTensorValue(GeneratorTensor_2()); + {% if has_bias %} + d_m_n.GenerateTensorValue(GeneratorTensor_2()); + {% endif %} + {% if has_scale %} + s_a_m_n.GenerateTensorValue(GeneratorTensor_2()); + s_b_m_n.GenerateTensorValue(GeneratorTensor_2()); + {% endif %} + DeviceMem a_m_k_device_buf(sizeof(AElementType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BElementType) * b_k_n.mDesc.GetElementSpaceSize()); + {% if has_bias %} + DeviceMem d_m_n_device_buf(sizeof(BiasElementType) * d_m_n.mDesc.GetElementSpaceSize()); + {% endif %} + {% if has_scale %} + DeviceMem s_a_m_n_device_buf(sizeof(ScaleAElementType) * s_a_m_n.mDesc.GetElementSpaceSize()); + DeviceMem s_b_m_n_device_buf(sizeof(ScaleBElementType) * s_b_m_n.mDesc.GetElementSpaceSize()); + {% endif %} + DeviceMem c_m_n_device_buf(sizeof(CElementType) * c_m_n_device.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + {% if has_bias %} + d_m_n_device_buf.ToDevice(d_m_n.mData.data()); + {% endif %} + {% if has_scale %} + s_a_m_n_device_buf.ToDevice(s_a_m_n.mData.data()); + s_b_m_n_device_buf.ToDevice(s_b_m_n.mData.data()); + {% endif %} + + {{kernel_name}}( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + {% if has_scale %} + static_cast(s_a_m_n_device_buf.GetDeviceBuffer()), + static_cast(s_b_m_n_device_buf.GetDeviceBuffer()), + {% endif %} + {% if has_bias %} + static_cast(d_m_n_device_buf.GetDeviceBuffer()), + {% endif %} + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + {% if is_batched %} + B, + {% endif %} + M, + N, + K, + LDA, + LDB, + LDC, + LDD, + nullptr, // workspace_size + nullptr, // workspace + nullptr); // stream + + hip_check_error(hipDeviceSynchronize()); + + return 0; + } // run_main + } // extern C + + int main(int argc, char** argv) { + return run_main(argc, argv); + } + // compile with: {{compile_cmd}} + #endif // GENERATE_CK_STANDALONE_RUNNER + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ) -> None: + is_batched = len(layout.size) == 3 + name = "ck_batched_gemm_template" if is_batched else "ck_gemm_template" + super().__init__( + name=name, + input_nodes=input_nodes, + layout=layout, + input_reorder=input_reorder, + ) + self.alpha = alpha + self.beta = beta + self.is_batched = is_batched + + def header(self) -> IndentedBuffer: + res = super().header() + if self.is_batched: + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + else: + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + struct MultiplyMultiplyAdd { + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const { + e = ck::type_convert( + ck::type_convert(c) + * ck::type_convert(d0) + * ck::type_convert(d1) + + ck::type_convert(d2) + ); + } + }; + """ + ) + return res + + def inline_utils(self): + res = IndentedBuffer() + res.splice( + """ + #include "host_tensor.cpp" + #include "device_memory.cpp" + """ + ) + return res + + def _has_padding(self, dimension, gemm_specialization): + # Get the relevant padding map for the given dimension + dimension_padding = padding_lookup.get(dimension, {}) + + # Check if the specialization is in the dimension's padding map + return dimension_padding.get(gemm_specialization, False) + + def filter_op(self, op_info: InductorROCmOp): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + op, kBatch = op_info.op, op_info.kBatch + metas = [T.get_layout() for T in [*self.input_nodes, self.output_node]] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.c_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_layout(W_meta): + return None + if op.c_layout != torch_layout_to_ck_layout(Y_meta): + return None + # try to avoid launching the instance with invalid problem size + # see GridwiseGemm_xdl_cshuffle_v3::CheckValidity + + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + + if is_static_int(M): + if not self._has_padding("M", op.gemm_specialization): + if M % op.m_per_block != 0: + return None + if is_static_int(N): + if not self._has_padding("N", op.gemm_specialization): + if N % op.n_per_block != 0: + return None + if is_static_int(K): + if not self._has_padding("K", op.gemm_specialization): + if K % op.k_per_block != 0: + return None + K_t = kBatch * op.k_per_block + if K % K_t != 0: + return None + else: + # need another kBatch check here + lcm = abs(op.a_k1 * op.b_k1) // math.gcd(op.a_k1, op.b_k1) + K_t = kBatch * lcm + k_read_pad_splited = math.ceil(K / K_t) * lcm + if (k_read_pad_splited * (kBatch - 1)) >= K: + return None + + a_contig_size = ( + K if op.a_layout == "Row" else M if op.a_layout == "Col" else None + ) + if ( + is_static_int(a_contig_size) + and a_contig_size % op.a_block_transfer_src_scalar_per_vector != 0 + ): + return None + b_contig_size = ( + N if op.b_layout == "Row" else K if op.b_layout == "Col" else None + ) + if ( + is_static_int(b_contig_size) + and b_contig_size % op.b_block_transfer_src_scalar_per_vector != 0 + ): + return None + c_contig_size = ( + N if op.c_layout == "Row" else M if op.c_layout == "Col" else None + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block[0] + if isinstance( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, tuple + ) + else op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + if ( + is_static_int(c_contig_size) + and c_contig_size % c_shuffle_block_transfer_scalar_per_vector_n_per_block + != 0 + ): + return None + if not self._check_num_k_loops(op, kBatch): + return None + # TBD disable instances with invalid number of pipeline prefetch stages + # It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check + + return op + + def _check_num_k_loops(self, op, kBatch): + # Additional splitK scenario check + metas = [T.get_layout() for T in [*self.input_nodes]] + X_meta = metas[0] + W_meta = metas[1] + K = X_meta.size[-1] + if kBatch > 1: + if op.block_gemm_pipeline_version != "BlockGemmPipelineVersion::v1": + try: + prefetch_stages = self._prefetch_stages( + op, + torch.empty((), dtype=X_meta.dtype).element_size(), + torch.empty((), dtype=W_meta.dtype).element_size(), + torch.cuda.get_device_properties(X_meta.device).warp_size, + ) + except Exception as e: + log.debug( + "Failed to prefetch_stages for %s with exception %s", op.name, e + ) + # be conservative here and disable the op + return False + + K_t = op.k_per_block * kBatch + ak0 = (K + K_t - 1) // K_t * (op.k_per_block // op.a_k1) + num_k_loop = ak0 // (op.k_per_block // op.a_k1) + if num_k_loop <= prefetch_stages: + log.debug( + "Op %s is not compatible due to invalid number of pipeline prefetch stages. " + "Parameters: kBatch=%s, block_gemm_pipeline_version=%s, prefetch_stages=%s, num_k_loop=%s", + op.name(), + kBatch, + op.block_gemm_pipeline_version, + prefetch_stages, + num_k_loop, + ) + return False + + return True + + # small helper to figure out the prefetch stages on AMD + def _prefetch_stages(self, op, a_dtype_size, b_dtype_size, warp_size: int = 64): + version_str = op.block_gemm_pipeline_version.split("::")[-1] + try: + version = int(version_str[1:]) # Assuming the format is always 'vX' + except ValueError as e: + raise ValueError(f"Invalid version string: {version_str}") from e + if version not in [1, 2, 3, 4, 5]: + raise ValueError( + f"unknown prefetch stages for {op.block_gemm_pipeline_version}" + ) + # Define the mapping of versions to stages + version_to_stages = {1: 1, 3: 2, 4: 4, 5: 3} + # Get the stages for the given version + stages = version_to_stages.get(version, None) + if stages is None: + # This means we're at stage 2, and this requires computation + # See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950 + wgp_per_cu = max(4 * warp_size // op.block_size, 1) + full_mem_band_prefetch_stages = math.ceil( + 32768 + / wgp_per_cu + / ( + (op.m_per_block * a_dtype_size + op.n_per_block * b_dtype_size) + * op.k_per_block + ) + ) + stages = min(max(full_mem_band_prefetch_stages, 2), 8) + + return stages + + def emit_ck_instance(self, op: "CKGemmOperation"): + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + struct_name = ( + "DeviceBatchedGemmMultiD_Xdl_CShuffle_V3" + if self.is_batched + else "DeviceGemmMultiD_Xdl_CShuffle_V3" + ) + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::{{struct_name}}< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + template_params.append(arg) + else: + if field_value is not None: + template_params.append(f"/* {field_name} */ {field_value}") + operation_name = op.name().replace("(", "").replace(",", "").replace(")", "") + return self._template_from_string(template_definition).render( + operation_name=operation_name, + template_params=(",\n" + 12 * " ").join(template_params), + struct_name=struct_name, + ), self._template_from_string(template_type).render( + operation_name=operation_name + ) + + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGemmOperation", + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes", None) + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node", None) + if template_buffer_node is not None: + self.output_node = template_buffer_node + # input nodes: + # * X, W for matmul + # * X, W, Bias for addmm + # * X, W, inv_scale_x, inv_scale_w for scaled_mm + # * X, W, inv_scale_x, inv_scale_w, Bias for scaled_mm with bias + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = ( + self.input_nodes[2] + if 3 == len(self.input_nodes) + else self.input_nodes[4] + if 5 == len(self.input_nodes) + else None + ) + has_bias = Bias is not None + has_scale = len(self.input_nodes) in (4, 5) + op = copy.deepcopy(op) + + # This parameter is converted into tuple because of change + # from DeviceGemm_Xdl_CShuffleV3 to DeviceGemmMultiD_Xdl_CShuffle_V3. + # The first tuple element corresponds to matmul result... + if not isinstance( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, tuple + ): + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, + ) + + if has_scale: + scale_x = self.input_nodes[2] + scale_w = self.input_nodes[3] + if 1 == scale_x.get_numel() and 1 == scale_w.get_numel(): + # tensorwise scale for both X, W + if has_bias: + op.c_elementwise_op = "ScaleAdd" + else: + op.c_elementwise_op = "Scale" + else: + # rowwise scale for both X, W + if has_bias: + op.c_elementwise_op = "MultiplyMultiplyAdd" + else: + op.c_elementwise_op = "MultiplyMultiply" + op.c_shuffle_dtype = "F32" + op.ds_layouts = ( + torch_layout_to_ck_layout(scale_x.get_layout()), + torch_layout_to_ck_layout(scale_w.get_layout()), + ) + op.ds_element_dtypes = ( + self._TORCH_DTYPE_TO_CK[scale_x.get_layout().dtype], + self._TORCH_DTYPE_TO_CK[scale_w.get_layout().dtype], + ) + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += (1, 1) + else: + scale_x = None + scale_w = None + + bias_dtype = "" + if Bias is not None: + bias_layout = torch_layout_to_ck_layout(Bias.get_layout()) + bias_dtype = self._TORCH_DTYPE_TO_CK[Bias.get_layout().dtype] + op.ds_layouts += (bias_layout,) + op.ds_element_dtypes += (bias_dtype,) + if not has_scale: + op.c_elementwise_op = "Bilinear" + # c_shuffle_dtype is also used for adding bias to matmul result + # before converting down to the result dtype + op.c_shuffle_dtype = op.acc_dtype + # this parameter needs to be set accordingly to bias stride for correct accumulation + if bias_layout == "Row": + # bias has (N, ) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + elif bias_layout == "Col": + # bias has (M, 1) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = (1,) + else: + raise AssertionError( + "Bias layout is neither row-major nor column-major" + ) + # ...and the second tuple element corresponds to the bias + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += ( + bias_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + + instance_definition, instance_type = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} +*/ +""" + epilogue = None + + if op.c_elementwise_op == "Bilinear" and scale_w is None: + epilogue = f"Bilinear {{ {self.alpha}, {self.beta} }}" + + elif op.c_elementwise_op == "Scale": + epilogue = "Scale { (inv_scale_w && inv_scale_x) ? (*inv_scale_w * *inv_scale_x) : 1.0f }" + + elif op.c_elementwise_op == "ScaleAdd": + epilogue = "ScaleAdd { (inv_scale_w && inv_scale_x) ? (*inv_scale_w * *inv_scale_x) : 1.0f }" + + elif op.c_elementwise_op == "MultiplyMultiply": + epilogue = "MultiplyMultiply {}" + + elif op.c_elementwise_op == "MultiplyMultiplyAdd": + epilogue = "MultiplyMultiplyAdd {}" + + elif op.c_elementwise_op == "PassThrough": + epilogue = "PassThrough {}" + + assert epilogue is not None, "CK GEMM epilogue is not set" + + size_arg_strs = ["M", "N", "K", "LDA", "LDB", "LDC", "LDD"] + if self.is_batched: + size_arg_strs.insert(0, "B") + + res = self._template_from_string(self.gemm_template).render( + inline_utils=self.inline_utils(), + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W, scale_x, scale_w, Bias], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, inv_scale_x, inv_scale_w, Bias, Y", + input_reorder=self.input_reorder, + size_args=[f"int32_t {arg}" for arg in size_arg_strs], + ), + instance_type=instance_type, + a_element_dtype=op.a_element_dtype, + b_element_dtype=op.b_element_dtype, + c_element_dtype=op.c_element_dtype, + bias_element_dtype=bias_dtype, + alpha=self.alpha, + beta=self.beta, + a_elementwise_op="PassThrough {}", + b_elementwise_op="PassThrough {}", + epilogue=epilogue, + has_bias=has_bias, + ds_size=1 + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else 2 + if op.c_elementwise_op == "MultiplyMultiply" + else 3 + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else 0, + ds_names=", ".join( + ["Bias"] + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else ["inv_scale_x", "inv_scale_w"] + if op.c_elementwise_op == "MultiplyMultiply" + else ["inv_scale_x", "inv_scale_w", "Bias"] + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else [] + ), + ds_strides=", ".join( + ["LDD"] + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else ["0", "0"] + if op.c_elementwise_op == "MultiplyMultiply" + else ["0", "0", "LDD"] + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else [] + ), + version_comment=version_comment, + is_batched=self.is_batched, + ds_batch_strides=", ".join([]), # FIXME when supporting baddbmm + ) + + if config.rocm.generate_test_runner: + is_static_problem = all(is_static_int(arg) for arg in self.size_args()) + # NOTE: size_arg_strs is defined above + size_arg_vals = ( + self.size_args() + if is_static_problem + else ( + f"std::stoi(argv[{k}])" for k, _ in enumerate(self.size_args(), 1) + ) + ) + size_args = dict(zip(size_arg_strs, size_arg_vals, strict=True)) + runtime_args = dict( + zip( + [a.name for a in self.get_runtime_arg_info()], + self.get_runtime_arg_values(), + ) + ) + runner_code = self._template_from_string( + self.standalone_runner_template + ).render( + inline_utils=self.inline_utils().getvalue(), + kernel_name=kernel.kernel_name, + has_bias=has_bias, + has_scale=has_scale, + is_batched=self.is_batched, + a_ck_dtype=op.a_element_dtype, + b_ck_dtype=op.b_element_dtype, + c_ck_dtype=op.c_element_dtype, + bias_ck_dtype=op.ds_element_dtypes[0] if has_bias else "", + scale_a_ck_dtype=op.ds_element_dtypes[0] + if has_scale and 2 == len(op.ds_element_dtypes) + else "BF16", + scale_b_ck_dtype=op.ds_element_dtypes[1] + if has_scale and 2 == len(op.ds_element_dtypes) + else "BF16", + a_torch_dtype=DTYPE_TO_CPP[X.get_layout().dtype], + b_torch_dtype=DTYPE_TO_CPP[W.get_layout().dtype], + c_torch_dtype=DTYPE_TO_CPP[Y.get_layout().dtype], + bias_torch_dtype=DTYPE_TO_CPP[Bias.get_layout().dtype] + if Bias is not None + else "", + scale_a_torch_dtype=DTYPE_TO_CPP[scale_x.get_layout().dtype] + if scale_x is not None + else "", + scale_b_torch_dtype=DTYPE_TO_CPP[scale_w.get_layout().dtype] + if scale_w is not None + else "", + a_layout=torch_layout_to_ck_layout(X.get_layout()), + b_layout=torch_layout_to_ck_layout(W.get_layout()), + c_layout=torch_layout_to_ck_layout(Y.get_layout()), + bias_layout=torch_layout_to_ck_layout(Bias.get_layout()) + if Bias is not None + else "", + compile_cmd=rocm_compile_command( + [""], "", "exe" + ), + **size_args, + **runtime_args, + ) + res += runner_code + + return res + + def _is_rcr_f16(self): + X_meta, W_meta, Y_meta = ( + T.get_layout() for T in [*self.input_nodes, self.output_node] + ) + X_dtype, W_dtype, Y_dtype = ( + self._TORCH_DTYPE_TO_CK[m.dtype] for m in (X_meta, W_meta, Y_meta) + ) + X_layout, W_layout, Y_layout = ( + torch_layout_to_ck_layout(m) for m in (X_meta, W_meta, Y_meta) + ) + + return ( + X_dtype == "F16" + and W_dtype == "F16" + and Y_dtype == "F16" + and X_layout == "Row" + and W_layout == "Col" + and Y_layout == "Row" + ) + + # helper to calculate a potentially optimal kBatch(es) for a problem + def _get_kBatch(self, op): + # we only set a higher kBatch if K > 16 * the larger of M and N + # this is a hand-tuned heuristic to start + metas = [T.get_layout() for T in [*self.input_nodes]] + X_meta = metas[0] + W_meta = metas[1] + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + if is_dynamic(*self.input_nodes): + return [1] + if K // max(M, N) < config.rocm.split_k_threshold: + return [1] + # if the user is telling us which kBatches to sweep, just use those + if config.rocm.kBatch_sweep is not None: + return config.rocm.kBatch_sweep + # Calculate the number of blocks needed for each dimension + total_k_blocks = math.ceil(K / op.k_per_block) + # we want to calculate how many blocks we need to fit per CU + cus = torch.cuda.get_device_properties(X_meta.device).multi_processor_count + # again, manual heuristics as much larger kBatch are significantly worse in + # initial testing + kBatch = min(max(next_power_of_2(total_k_blocks // cus), 1), 128) + return [kBatch] + + def gen_ops(self) -> list[InductorROCmOp]: + """ + Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + try: + from ck4inductor.batched_universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library as gen_batched_gemm_ops_library, + ) + from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library as gen_gemm_ops_library, + gen_ops_preselected as gen_gemm_ops_preselected, + ) + except ImportError: + return [] + + generator = None + if self.is_batched: + generator = gen_batched_gemm_ops_library + else: + generator = gen_gemm_ops_library + if config.rocm.use_preselected_instances and self._is_rcr_f16(): + generator = gen_gemm_ops_preselected + + assert generator is not None + + rops = generator() + ops = [] + for o in rops: + kBatches = self._get_kBatch(o) + for kBatch in kBatches: + ops.append(InductorROCmOp(op=o, kBatch=kBatch)) + + filtered_instances = list(filter(lambda op: self.filter_op(op), ops)) + + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_max_profiling_configs), + ) + if config.rocm.ck_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_ck_gemm_choices( + choices, + layout, + input_nodes, + alpha=1, + beta=0, + input_reorder=None, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op.op, + kBatch=op.kBatch, + ) + + def size_args(self): + X = self.input_nodes[0] + W = self.input_nodes[1] + Bias = ( + self.input_nodes[2] + if len(self.input_nodes) == 3 + else self.input_nodes[4] + if len(self.input_nodes) == 5 + else None + ) + Y = self.output_node + + M = X.get_size()[-2] + K = X.get_size()[-1] + N = W.get_size()[-1] + LDA = X.get_stride()[-2 if X.get_stride()[-1] == 1 else -1] + LDB = W.get_stride()[-2 if W.get_stride()[-1] == 1 else -1] + LDC = Y.get_stride()[-2 if Y.get_stride()[-1] == 1 else -1] + LDD = ( + 0 + if (Bias is None or len(Bias.get_size()) == 1) + else Bias.get_stride()[-2 if Bias.get_stride()[-1] == 1 else -1] + ) + if self.is_batched: + B = X.get_size()[0] + return B, M, N, K, LDA, LDB, LDC, LDD + else: + return M, N, K, LDA, LDB, LDC, LDD diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cae55102b61b4cd2611055bb22269f618f1773 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/compile_command.py @@ -0,0 +1,148 @@ +# mypy: allow-untyped-defs +import logging +import os +from typing import Optional + +from torch._inductor import config +from torch._inductor.utils import is_linux + + +log = logging.getLogger(__name__) + + +def _rocm_include_paths(dst_file_ext: str) -> list[str]: + from torch.utils import cpp_extension + + rocm_include = ( + os.path.join(config.rocm.rocm_home, "include") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("include") + ) + if not config.rocm.ck_dir: + log.warning("Unspecified Composable Kernel include dir") + + if config.is_fbcode(): + from libfb.py import parutil + + ck_path = parutil.get_dir_path("composable-kernel-headers") + else: + ck_path = config.rocm.ck_dir or cpp_extension._join_rocm_home( + "composable_kernel" + ) + + ck_include = os.path.join(ck_path, "include") + ck_library_include = os.path.join(ck_path, "library", "include") + + # CK has to take priority over ROCm include paths + # Since CK is potentially more up-to-date + paths = [ + os.path.realpath(p) for p in (ck_include, ck_library_include, rocm_include) + ] + if dst_file_ext == "exe": + ck_utility_include = os.path.join(ck_path, "library", "src", "utility") + paths.append(os.path.realpath(ck_utility_include)) + return paths + + +def _rocm_lib_options(dst_file_ext: str) -> list[str]: + from torch.utils import cpp_extension + + rocm_lib_dir = ( + os.path.join(config.rocm.rocm_home, "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("lib") + ) + hip_lib_dir = ( + os.path.join(config.rocm.rocm_home, "hip", "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("hip", "lib") + ) + + opts = [ + "-include __clang_hip_runtime_wrapper.h", + f"-L{os.path.realpath(rocm_lib_dir)}", + f"-L{os.path.realpath(hip_lib_dir)}", + "-lamdhip64", + ] + if dst_file_ext == "exe": + opts += ["-lpthread", "-lstdc++"] + return opts + + +def _rocm_compiler_options() -> list[str]: + arch_list = config.rocm.arch or ["native"] + gpu_arch_flags = [f"--offload-arch={arch}" for arch in arch_list] + opts = [ + config.rocm.compile_opt_level, + "-x", + "hip", + "-std=c++17", + *gpu_arch_flags, + "-fno-gpu-rdc", + "-fPIC", + "-fvisibility=hidden", + "-mllvm", + "-amdgpu-early-inline-all=true", + "-mllvm", + "-amdgpu-function-calls=false", + "-mllvm", + "-enable-post-misched=0", + ] + if config.rocm.is_debug: + opts += ["-DDEBUG_LOG=1", "-g"] + if config.rocm.save_temps: + opts += ["--save-temps=obj"] + if config.rocm.print_kernel_resource_usage: + opts += ["-Rpass-analysis=kernel-resource-usage"] + if config.rocm.flush_denormals: + opts += ["-fgpu-flush-denormals-to-zero"] + if config.rocm.use_fast_math: + opts += ["-ffast-math"] + return opts + + +def rocm_compiler() -> Optional[str]: + if is_linux(): + if config.rocm.rocm_home: + return os.path.realpath( + os.path.join(config.rocm.rocm_home, "llvm", "bin", "clang") + ) + try: + from torch.utils import cpp_extension + + return os.path.realpath( + cpp_extension._join_rocm_home("llvm", "bin", "clang") + ) + except OSError: + # neither config.rocm.rocm_home nor env variable ROCM_HOME are set + return "clang" + return None + + +def rocm_compile_command( + src_files: list[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[list[str]] = None, +) -> str: + include_paths = _rocm_include_paths(dst_file_ext) + lib_options = _rocm_lib_options(dst_file_ext) + compiler_options = _rocm_compiler_options() + compiler = rocm_compiler() + options = ( + compiler_options + + (extra_args or []) + + [f"-I{path}" for path in include_paths] + + lib_options + ) + src_file = " ".join(src_files) + # supported extensions: .o, .so, .exe + if dst_file_ext == "o": + options.append("-c") + elif dst_file_ext == "so": + options.append("-shared") + elif dst_file_ext == "exe": + options.append("-DGENERATE_CK_STANDALONE_RUNNER") + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + return f"{compiler} {' '.join(options)} -o {dst_file} {src_file}" diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py new file mode 100644 index 0000000000000000000000000000000000000000..4a08773433c3a295e3e7090ab6e3f5fc2b1c0cf1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +from ctypes import byref, c_int, c_size_t, c_void_p +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +from torch._inductor import config +from torch._inductor.autotune_process import ( + BenchmarkRequest, + GPUDeviceBenchmarkMixin, + TensorMeta, +) +from torch._inductor.codecache import DLLWrapper, ROCmCodeCache + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +log = logging.getLogger(__name__) + + +class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = ROCmCodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate code cache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + ROCmCodeCache.compile(self.source_code, "so") + if config.rocm.generate_test_runner: + ROCmCodeCache.compile(self.source_code, "exe") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]] + size_args = [c_int(arg) for arg in self.extra_args] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=out.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *size_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len( + {meta.name for meta in self.input_tensor_meta} # noqa: set_linter + ) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + size_args = [c_int(arg) for arg in self.extra_args] + run_method( + *args, # input ptrs and output ptrs + *size_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = ROCmCodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..720509f282660be62cfeeea71edfcc1584db277c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import cast + +from ... import config +from ...codecache import code_hash, get_path +from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer +from .rocm_template_buffer import ROCmTemplateBuffer + + +log = logging.getLogger(__name__) + + +class ROCmCPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for ROCm C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and ROCm C++ specific template code generation. + """ + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_rocm_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ROCmTemplateBuffer + ) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["rocm", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.rocm(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a ROCm template, possibly with fused epilogues + """ + assert self.is_rocm_cpp_template(template_node), ( + "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) + kernel, render = ctb.make_kernel_render(ctb) + with kernel: + template_node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code, node_schedule) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..2e51a96d327921669a37360cb6c4421ee0057828 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -0,0 +1,289 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch._inductor.config as config +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.utils import do_bench_using_profiling + +from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, TensorBox +from ...virtualized import V +from ..common import Kernel, OpOverrides, WorkspaceArg, WorkspaceZeroMode +from ..cpp_utils import CppPrinter +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_template_buffer import ROCmTemplateBuffer +from .rocm_utils import DTYPE_TO_ROCM_TYPE + + +if TYPE_CHECKING: + from torch._inductor.codegen.rocm.rocm_template import ArgInfo, ROCmTemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class ROCmKernel(Kernel): + """ + Baseclass for ROCm based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + +class ROCmTemplateKernel(ROCmKernel): + """ + Template kernels defined by ROCm in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, hipStream_t stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the ROCmTemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def get_signature(self): + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + size_args: list[str], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder == [2, 0, 1]: + input_reorder = [4, 0, 1, 2, 3] + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE) + + runtime_arg_defs = [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args + runtime_arg_defs)},{self._EXTRA_CPP_ARGS})" + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "ROCmTemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The ROCmTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + arg_types: list[Any] + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # Kinda hacky because we always originally initialize name with "KERNEL_NAME" + # So, we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace("KERNEL_NAME", name) + _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE) + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + kernel_args = [] + for arg in call_args: + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + if V.graph.is_unspec_arg(arg): + arg = arg + ".item()" + else: + if not V.graph.cpp_wrapper: + arg = f"c_void_p({arg}.data_ptr())" + kernel_args.append(arg) + + # add size args + size_args = [ + f"{V.graph.sizevars.simplify(sarg)}" for sarg in node.template.size_args() + ] + + if V.graph.cpp_wrapper: + kernel_args.extend(size_args) + else: + kernel_args.extend(f"c_int({sarg})" for sarg in size_args) + + if V.graph.cpp_wrapper: + arg_types.extend(["int"] * len(node.template.size_args())) + + # the runtime args come right after the size args + kernel_args.extend(self.runtime_arg_values) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + data_ptr = f"{ws.outer_name}.data_ptr()" + kernel_args.append( + data_ptr if V.graph.cpp_wrapper else f"c_void_p({data_ptr})" + ) + else: + ws = None + kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + wrapper.generate_kernel_call( + name, + kernel_args, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + +class ROCmTemplateCaller(ChoiceCaller): + """ + ROCmTemplateCaller + + This class represents a caller for ROCm template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (ROCmBenchmarkRequest): The benchmark request for the caller. + template_buffer (ROCmTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[ + [ROCmTemplateBuffer, Optional[Sequence[IRNode]]], str + ], + bmreq: ROCmBenchmarkRequest, + template: "ROCmTemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + ) -> None: + super().__init__(name, input_nodes, layout, description="") + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def __str__(self) -> str: + return f"ROCmTemplateCaller(source_file={self.bmreq.source_file}, {self.info_dict()})" + + def call_name(self) -> str: + return f"rocm_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "ROCm", + "name": self.name, + **dict(self.info_kwargs["op"].dict_items()), # type: ignore[union-attr, index] + } + + def output_node(self) -> TensorBox: + self.bmreq.update_workspace_size() + return TensorBox.create( + ROCmTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..bfeb03eabc72d7cf9bce701f535e612644a806c3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Optional +from unittest.mock import patch + +from ...autotune_process import TensorMeta +from ...ir import Buffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_kernel import ROCmTemplateCaller, ROCmTemplateKernel +from .rocm_template_buffer import ROCmTemplateBuffer +from .rocm_utils import DTYPE_TO_ROCM_TYPE + + +log = logging.getLogger(__name__) + + +# FIXME: unify with the CUDA version +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +class ROCmTemplate(KernelTemplate): + index_counter = itertools.count() + gfx9_threads_per_warp = 64 + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + + Baseclass for ROCm C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the ROCmTemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + **kwargs, + ) -> ROCmTemplateCaller: + """ + Generates the ROCm template caller object for the given GEMM template and operation. This ROCmTemplateCaller + may be used to call and benchmark the generated ROCm kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A ROCmTemplateCaller object representing the generated ROCm template caller. + """ + kernel_name = f"rocm_{self.name}" + kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}" + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + ROCmTemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + + size_args = ( + self.size_args() if hasattr(self, "size_args") else () + ) # subclass should define def size_args() + size_args_ints = [ + V.graph.sizevars.size_hint(arg) for arg in size_args + ] # resolve to ints for benchmarking + # The runtime args come right after the size args + runtime_args = self.get_runtime_arg_values(**kwargs) + extra_args = size_args_ints + runtime_args + bmreq = ROCmBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ROCmTemplateBuffer, + epilogue_nodes: Optional[Sequence[IRNode]] = None, + ): + kernel = ROCmTemplateKernel( + kernel_name="KERNEL_NAME", + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return ROCmTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..67b929556211ddfa33bac9f9da1761098cce7cb9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py @@ -0,0 +1,27 @@ +from collections.abc import Sequence +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +from ...ir import Buffer, Layout, TemplateBuffer + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class ROCmTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout: Layout, + inputs: Sequence[Buffer], + make_kernel_render: Callable[_P, _T], + workspace_size: int, + template: "ROCmTemplate", # type: ignore[name-defined] # noqa: F821 + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self) -> int: + return self.workspace_size if self.workspace_size is not None else 0 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4fce90fa760dd80363fefde8489548890424abe3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs + + +import torch + +from ..cpp_utils import DTYPE_TO_CPP + + +DTYPE_TO_ROCM_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "uint16_t", + torch.float8_e4m3fnuz: "uint8_t", + torch.float8_e5m2fnuz: "uint8_t", + torch.bfloat16: "uint16_t", +} diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py new file mode 100644 index 0000000000000000000000000000000000000000..9425eaa97ff79d137d7446401561cbbbc2e7c0ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd.py @@ -0,0 +1,2462 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import textwrap +from collections import Counter +from typing import Any, Callable, Generic, no_type_check, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeVar + +import sympy + +import torch +import torch._logging +from torch._inductor.tiling_utils import analyze_memory_coalescing +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.immutable_collections import immutable_dict +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing +from torch.utils._sympy.symbol import ( + free_symbol_is_type, + prefix_str, + symbol_is_type, + SymT, +) + +from ..._dynamo.utils import counters +from .. import config, ir, scheduler +from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask +from ..codecache import code_hash +from ..dependencies import MemoryDep, StarDep, WeakDep + + +if TYPE_CHECKING: + from ..ir import IRNode + +from ..optimize_indexing import indexing_dtype_strength_reduction +from ..runtime.runtime_utils import green_text, yellow_text +from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..utils import ( + cache_on_self, + expr_fits_within_32bit, + get_dtype_size, + IndentedBuffer, + Placeholder, + prefix_is_reduction, + set_kernel_post_grad_provenance_tracing, + sympy_index_symbol, + sympy_product, + sympy_subs, + unique, +) +from ..virtualized import ops, OpsWrapper, V +from .block_analysis import BlockPatternMatcher +from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter +from .multi_kernel import MultiKernel +from .simd_kernel_features import ( + DisableReduction, + EnableReduction, + NodeScheduleEntry, + NodeScheduleMarker, + SIMDKernelFeatures, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + from torch._inductor.tiling_utils import CoalesceVarAnalysis + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +pexpr = PythonPrinter().doprint + +all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"]) + + +def get_max_tiles(default: int = 2) -> int: + max_tiles = torch._inductor.config.triton.max_tiles + return max_tiles if max_tiles is not None else default + + +@dataclasses.dataclass +class IterationRanges: + """ + Each range tree represents multiple sets of iteration indexing + in a single tiled dimension in the output kernel. + + If you have two loops ranges one (4, 3, 2) and another (4, 6), + then the range tree will be: + 4 (i0) + 3 (i1) 6 (i3) + 2 (i2) + Where i0 is shared between both loops, but then the split into + different indexing vars. All loop ranges must iterate over + the same number of elements. + """ + + def __init__( + self, + name: str, + var_list: list[sympy.Symbol], + var_ranges: dict[sympy.Symbol, sympy.Expr], + numel: sympy.Expr, + prefix: str, + *, + kernel: SIMDKernel, + divisor=sympy.S.One, + length=sympy.S.One, + root: IterationRangesRoot, + ) -> None: + super().__init__() + self.name = name + self.var_list = var_list + self.var_ranges = var_ranges + self.numel = numel + self.prefix = prefix + self.divisor = divisor + self.length = length + self.kernel = kernel + self.root = root + + @property + @cache_on_self + @no_type_check # https://github.com/python/mypy/issues/17184 + def is_reduction(self) -> bool: + return prefix_is_reduction(self.prefix) + + def symbol(self) -> sympy.Symbol: + return sympy_index_symbol(self.name) + + @property + @cache_on_self + @no_type_check + def symt(self) -> SymT: + prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()} + return prefix_to_symt[self.prefix] + + +class IterationRangesRoot(IterationRanges): + """ + Root of a iteration range tree that represents a single + tiled dimension in the output kernel. It contains multiple + sets of iteration represented with IterationRangesEntry. + """ + + def __init__( + self, + name: str, + numel: sympy.Expr, + prefix: str, + index: int, + kernel: SIMDKernel, + pid_cache: Optional[dict[str, str]] = None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + has_zdim: bool, + ) -> None: + if pid_cache is None: + pid_cache = {} + super().__init__( + name=name, + var_list=[], + var_ranges={}, + numel=numel, + prefix=prefix, + kernel=kernel, + root=self, + ) + self.index = index + # Store all the nodes in one flat list + self.nodes: dict[sympy.Expr, IterationRangesEntry] = {} + # This is for re-ordering program ID in triton mm template + # pid_cache["tl.program_id(0)"] = pid_m + self.pid_cache: dict[str, str] = pid_cache + + # True if the dimension is implemented as a single program looping over + # the full dimension (currently only used for non-persistent reduction) + assert not is_loop or (self.is_reduction and grid_dim is None) + self.is_loop = is_loop + # Index of corresponding dimension on triton tensors + self.tensor_dim = tensor_dim + # Index of corresponding dimension in the triton grid + self.grid_dim = grid_dim + self.has_zdim = has_zdim + + def __repr__(self) -> str: + return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" + + def cache_clear(self) -> None: + for node in self.nodes.values(): + node.cache_clear() + + def index_sym(self) -> sympy.Symbol: + return sympy_index_symbol(f"{self.prefix}index") + + def lookup(self, divisor: sympy.Expr, length: sympy.Expr) -> IterationRangesEntry: + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(self.index_sym(), divisor) + else: + expr = ModularIndexing(self.index_sym(), divisor, length) + + if expr not in self.nodes: + node = IterationRangesEntry( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + return self.nodes[expr] + + def construct_entries( + self, lengths: list[sympy.Expr] + ) -> list[IterationRangesEntry]: + divisor = sympy.S.One + itervars = [] + for length in reversed(lengths): + itervars.append(self.lookup(divisor, length)) + divisor = divisor * length + return [*reversed(itervars)] + + def construct(self, lengths: list[sympy.Expr]) -> list[sympy.Symbol]: + return [e.symbol() for e in self.construct_entries(lengths)] + + def vars_and_sizes( + self, index: sympy.Expr + ) -> tuple[list[sympy.Symbol], list[sympy.Expr]]: + """Figure out vars from this tree used in index""" + + def get_sort_key(x: IterationRangesEntry) -> tuple[int, bool]: + """ + Gets the key for sorting nodes. When two nodes have the + same divisor, the node with length as 1 should be handled + first so the current divisor is not changed after multiplied + node.length. Returns `not length_is_one_hint` for ascending + sort. + """ + divisor_hint = V.graph.sizevars.size_hint( + x.divisor, fallback=config.unbacked_symint_fallback + ) + length_is_one_hint = ( + V.graph.sizevars.size_hint( + x.length, fallback=config.unbacked_symint_fallback + ) + == 1 + ) + return (divisor_hint, not length_is_one_hint) + + nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] + nodes = [n for n in nodes if n and n.prefix == self.prefix] + nodes.sort(key=lambda x: get_sort_key(x)) + divisor = sympy.S.One + index_vars = [] + sizes = [] + + def add(node): + nonlocal divisor + index_vars.append(node.symbol()) + sizes.append(node.length) + divisor = divisor * node.length + + for node in nodes: + if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) + divisor = node.divisor + add(node) + if not V.graph.sizevars.statically_known_equals(self.numel, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(self.numel, divisor))) + + return [*reversed(index_vars)], [*reversed(sizes)] + + +class IterationRangesEntry(IterationRanges): + def __init__( + self, + name: str, + divisor: sympy.Expr, + length: sympy.Expr, + expr: sympy.Expr, + parent: IterationRanges, + ) -> None: + super().__init__( + name=name, + numel=parent.numel / length, + var_list=parent.var_list, + var_ranges=parent.var_ranges, + prefix=parent.prefix, + divisor=divisor, + length=length, + kernel=parent.kernel, + root=parent.root, + ) + self.parent = parent + self.codegen = functools.lru_cache(None)(self._codegen) + self.expr = expr + + def __repr__(self) -> str: + return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" + + def set_name(self, name: str) -> None: + self.codegen = lambda: name # type: ignore[assignment] + self.codegen.cache_clear = lambda: None # type: ignore[method-assign] + self.name = name + + def cache_clear(self) -> None: + self.codegen.cache_clear() + + def _codegen(self) -> str: + V.kernel.codegen_iteration_ranges_entry(self) + return self.name + + def precomputed_args(self) -> list[sympy.Expr]: + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: list[sympy.Expr] = [] + if isinstance(self.expr, sympy.Symbol): + return precomputed_args + assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, IterationRangesEntry) + return self.name == other.name + + +def constant_repr(value: Union[int, float]) -> str: + if value == float("inf"): + return 'float("inf")' + elif value == float("-inf"): + return 'float("-inf")' + elif math.isnan(value): + return 'float("nan")' + return repr(value) + + +CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable) + + +class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): + """ + Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests. + """ + + sexpr: Callable[[sympy.Expr], str] = pexpr + kexpr: Callable[[sympy.Expr], str] + allow_block_ptr: bool = False + kernel_name: str + + def __init__( + self, + tiling: dict[str, sympy.Expr], + features: SIMDKernelFeatures, + pid_cache: Optional[dict[str, str]] = None, + override_persistent_reduction: Optional[bool] = None, + override_cooperative_reduction: Optional[bool] = None, + tiling_scores: Optional[dict[str, sympy.Expr]] = None, + ) -> None: + if pid_cache is None: + pid_cache = {} + super().__init__() + self.features = features + self.mutations = features.get_mutations() + self.body = IndentedBuffer() + self.indexing_code = IndentedBuffer() + self.numels = { + prefix: V.graph.sizevars.simplify(val) for prefix, val in tiling.items() + } + self.range_trees: list[IterationRangesRoot] = [] + self.range_tree_nodes: dict[sympy.Symbol, IterationRangesEntry] = {} + self.iter_vars_count = itertools.count() + self.inside_reduction = features.is_reduction() + self.cooperative_reduction: bool = ( + override_cooperative_reduction + if override_cooperative_reduction is not None + else self.should_use_cooperative_reduction() + ) + self.tiling_scores: Optional[dict[str, sympy.Expr]] = tiling_scores + self.persistent_reduction: bool = ( + override_persistent_reduction + if override_persistent_reduction is not None + else self.should_use_persistent_reduction() + ) + self.no_x_dim = self.want_no_x_dim() + self.code_hash: Optional[str] = None + + # define this in a closure to make cache local to object + @functools.cache + def simplify_indexing(index: sympy.Expr): + index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) + for tree in self.range_trees: + index = self.combine_contiguous_dims(index, tree) + + return self.combine_modular_indexing_pairs(index) + + self.simplify_indexing = simplify_indexing + self.initialize_range_tree(pid_cache) + + @property + @cache_on_self + @no_type_check # https://github.com/python/mypy/issues/17184 + def num_reduction_dims(self) -> int: + return sum(prefix_is_reduction(prefix) for prefix in self.numels) + + def dtype_to_str(self, dtype: torch.dtype) -> str: + raise NotImplementedError + + def get_index_dtype_as_torch_dtype(self) -> torch.dtype: + return self.features.select_index_dtype() + + @property + def index_dtype(self) -> str: + return self.dtype_to_str(self.get_index_dtype_as_torch_dtype()) + + def want_no_x_dim(self) -> bool: + return False + + def construct_range_trees( + self, + pid_cache: Optional[dict[str, str]], + inside_reduction: bool, + is_reduction: bool, + numels: dict[str, sympy.Expr], + no_x_dim: bool, + ) -> list[IterationRangesRoot]: + active_prefixes = OrderedSet( + prefix for prefix in all_prefixes if prefix in numels + ) + no_r_dim = not inside_reduction or not is_reduction + + def filtered_index_map(seq, mask) -> dict[Any, int]: + return { + val: idx for idx, val in enumerate(val for val in seq if val in mask) + } + + grid_dims = ["x", "y", "z"] + pointwise_tensor_dims = list(reversed(grid_dims)) + reduction_dims = ["r0_", "r1_"] + if no_x_dim: + tensor_dims = reduction_dims + elif no_r_dim: + tensor_dims = pointwise_tensor_dims + else: + tensor_dims = pointwise_tensor_dims + reduction_dims + + # Filter out unused tensor dims. + # Convert to dicts for O(1) index lookup. + tensor_dim_map = filtered_index_map(tensor_dims, active_prefixes) + grid_dim_map = filtered_index_map(grid_dims, all_prefixes) + + range_trees = [] + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix_is_reduction(prefix) + tensor_dim = tensor_dim_map.get(prefix) + grid_dim = grid_dim_map.get(prefix) + index = i if grid_dim is None else grid_dim + range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numels[prefix], + prefix, + index, + self, # type: ignore[arg-type] + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim="z" in numels, + ) + ) + return range_trees + + def initialize_range_tree(self, pid_cache: dict[str, str]) -> None: + range_trees = self.construct_range_trees( + pid_cache, + self.inside_reduction, + self.features.is_reduction(), + self.numels, + self.no_x_dim, + ) + self.range_trees.extend(range_trees) + + def finalize_indexing(self, indices: Sequence[sympy.Expr]) -> None: + """ + Hook called right before codegen with every index that will be + used in the fused kernel. + """ + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + prior = self.inside_reduction + self.inside_reduction = False + try: + return self.store(name, index, value) + finally: + self.inside_reduction = prior + + def should_use_cooperative_reduction(self) -> bool: + return False # defined in subclass + + def should_use_persistent_reduction(self) -> bool: + return False # defined in subclass + + def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]: + return dict( + itertools.chain.from_iterable( + tree.var_ranges.items() for tree in self.range_trees + ) + ) + + def triton_tensor_ndim(self) -> int: + return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) + + def indexing_size_str(self, i: int) -> str: + sizes = ["None"] * self.triton_tensor_ndim() + sizes[i] = ":" + return f"[{', '.join(sizes)}]" + + def dense_size_list(self) -> list[str]: + sizes = ["1"] * self.triton_tensor_ndim() + for tree in self.range_trees: + if tree.tensor_dim is None: + continue + + if not tree.is_reduction or self.inside_reduction: + sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" + return sizes + + def dense_size_str(self) -> str: + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + if not isinstance(index, ModularIndexing): + return index + x = index.args[0] + if (tree_node := self.range_tree_nodes.get(x)) is None: + return index + new_index = sympy_subs(index, {x: tree_node.expr}) + new_index = V.graph.sizevars.combine_modular_indexing_pairs(new_index) + # the index now contains xindex/etc, which is nonstandard, fix it up + return sympy_subs( + new_index, + { + tree_node.root.index_sym(): tree_node.root.lookup( + sympy.S.One, tree_node.root.numel + ).symbol() + }, + ) + + def combine_contiguous_dims( + self, index: sympy.Expr, tree: IterationRangesRoot + ) -> sympy.Expr: + if expand_res := V.graph.sizevars.expand_floor_div(index): + new_index, denominator = expand_res # type: ignore[misc] + return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) + else: + return self._combine_contiguous_dims(index, tree) + + def _combine_contiguous_dims( + self, index: sympy.Expr, tree: IterationRangesRoot + ) -> sympy.Expr: + """ + More aggressive simplification to merge contiguous dims + """ + if isinstance(index, (sympy.Integer, sympy.Symbol)): + return index + index_vars, sizes = tree.vars_and_sizes(index) + if len(sizes) <= 1: + return index + new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) + ) + if new_sizes == sizes: + return index + new_index_vars = tree.construct(new_sizes) + new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) + return new_index + + def disable_reduction(self) -> contextlib.AbstractContextManager[None]: + should_flush = self.range_trees[-1].is_loop or self.cooperative_reduction + + @contextlib.contextmanager + def ctx(): + if not self.features.is_reduction(): + assert not self.inside_reduction + yield + return + if should_flush: + # calling codegen_body() will flush all the pending buffers + # and write out a reduction loop + self.codegen_body() + self.inside_reduction = False + try: + yield + if should_flush: + # flush out any code before opening the next loop + self.codegen_body() + finally: + self.inside_reduction = True + + return ctx() + + def set_ranges(self, *lengths: sympy.Expr) -> list[sympy.Symbol]: + assert len(lengths) == len(self.range_trees) + return [ + ranges.construct(length) + for length, ranges in zip(lengths, self.range_trees) + ] + + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] + ) -> tuple[ + list[list[sympy.Expr]], list[list[Callable[[list[sympy.Expr]], sympy.Expr]]] + ]: + # Special case: if a node's sizes are ([], []), there's nothing to split. + if all(len(length) == 0 for length in lengths): + return [[] for group in groups], [] + + sv = V.graph.sizevars + new_ranges: list[list[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + var_count = itertools.count() + + def add_range(i: int, expr: sympy.Expr) -> int: + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined( + size: sympy.Expr, idx1: int, idx2: int + ) -> Callable[[list[sympy.Expr]], sympy.Expr]: + def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: + return size * flat_vars[idx1] + flat_vars[idx2] + + return getter + + return_getters_groups = [] + current_group = 0 + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.S.Zero) + continue + + while current_group < len(remaining) and sv.statically_known_equals( + remaining[current_group], + 1, # type: ignore[arg-type] + ): + # scroll to next group with remaining elements + current_group += 1 + + if current_group + 1 < len(remaining) and sv.statically_known_gt( + size, remaining[current_group] + ): + # need to break size in two + if not sv.statically_known_multiple_of( + size, remaining[current_group] + ): + raise CantSplit + + size1 = remaining[current_group] + size2 = FloorDiv(size, remaining[current_group]) + return_getters.append( + make_combined( + size2, + add_range(current_group, size1), + add_range(current_group + 1, size2), + ) + ) + else: + if current_group < len(remaining): + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) + return_getters_groups.append(return_getters) + + assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), ( + f"failed to set ranges {remaining} {lengths}" + ) + + return new_ranges, return_getters_groups + + @classmethod + def prepare_split_iteration_lengths( + cls, + groups: Iterable[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + reduction_numel: sympy.Expr = sympy.S.One, + ) -> Sequence[Sequence[sympy.Expr]]: + "Fill in the reduction numel of lengths if missing" + sizevars = V.graph.sizevars + if len(lengths[1]) == 0 and ( + not sizevars.statically_known_equals(reduction_numel, sympy.S.One) + and sizevars.statically_known_equals( + sympy_product(groups), + sympy_product(lengths[0]) * reduction_numel, + ) + ): + return (lengths[0], [reduction_numel]) + + return lengths + + @classmethod + def is_compatible( + cls, + groups: Iterable[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + reduction_numel: sympy.Expr = sympy.S.One, + ) -> bool: + lengths = cls.prepare_split_iteration_lengths(groups, lengths, reduction_numel) + + try: + cls._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + def split_and_set_ranges( + self, lengths: Sequence[Sequence[sympy.Expr]] + ) -> list[list[sympy.Expr]]: + tiling = {rt.prefix: rt.numel for rt in self.range_trees} + if not self.inside_reduction: + for prefix in tiling: + if prefix_is_reduction(prefix): + tiling[prefix] = sympy.S.One + + groups = [*tiling.values()] + return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) + + @classmethod + def map_kernel_groups_to_node_sizes( + cls, + groups: Sequence[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + set_ranges, + ) -> list[list[sympy.Expr]]: + """ + We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). + + To do this we need to split up the iteration space of i0 into something like: + for i1 in s0: + for i2 in s1: + i0 = i1*s1 + i2 + .... + + This function matches and resplits lengths to the groups of + this kernel to enable tiled + non-tiled fusions. + """ + if len(lengths) == len(groups) and all( + V.graph.sizevars.simplify(sympy_product(x) - g) == 0 + for x, g in zip(lengths, groups) + ): + return set_ranges(*lengths) + + new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths) + itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))] + return [[fn(itervars) for fn in fns] for fns in return_getters_groups] + + def is_indirect_indexing(self, index: sympy.Expr) -> bool: + # tmpX means indirect indexing + return free_symbol_is_type(index, SymT.TMP) + + def is_broadcasted(self, index: sympy.Expr) -> bool: + # Note. This may not be correct when there is indirect indexing + if self.is_indirect_indexing(index): + return False + + index_numels = [1] * len(self.numels) + for symbol in index.free_symbols: + if symbol not in self.range_tree_nodes: + # Non-iterated variables, e.g. strides + continue + entry = self.range_tree_nodes[symbol] # type: ignore[index] + assert isinstance(entry.parent, IterationRangesRoot) + index_numels[entry.parent.index] *= entry.length + + # If the index variables only iterate over a subset of the kernel + # numels, then it must be broadcasted. + simplify = V.graph.sizevars.simplify + return any( + simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] + for idx_range, iter_range in zip(index_numels, self.numels.values()) + ) + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in output code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the generated kernel. + + Index expressions often need to be passed in as arguments to the triton kernel. + Rename_indexing and codegen_indexing keep track of the needed indices and add + new parameters to the function signature. + """ + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + return self.kexpr(self.rename_indexing(index)) # type: ignore[call-arg] + + def prepare_indexing( + self, + index: sympy.Expr, + ) -> sympy.Expr: + index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + # last resort, if no range vars are in the expr, hoist it + # TODO instead of trying to blindly find complicated exprs, we should hoist the + # inputs/outputs sizes and strides, but at the time indexing is generated + # kernel inputs and outputs are not set yet, we'd need a deeper refactor + # to do it this way + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + simp_index = self.simplify_indexing(index) + + # Now that we are done simplifying we can unwrap Identity so that downstream handling + # for its contained expression will work. previously, tl.full wrapping of sympy.Integer + # would not occur + simp_index = ( + simp_index if not isinstance(simp_index, Identity) else simp_index.args[0] + ) + + return self.codegen_indexing(simp_index) + + def active_range_trees(self) -> list[IterationRangesRoot]: + return [ + t for t in self.range_trees if not t.is_reduction or self.inside_reduction + ] + + def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, + replacements, # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + def codegen_nan_check(self) -> None: + raise NotImplementedError("NYI: codegen_nan_check") + + def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: + raise NotImplementedError("NYI: call_kernel") + + @contextlib.contextmanager + def mask_loads( + self, mask: Union[str, OpsWrapper], value: Union[int, float] + ) -> Iterator[str]: + """Context manager to add an additional mask to tl.load/store""" + prior = self._load_mask + prior_val = self._load_other + if prior: + mask = ops.logical_and(mask, prior) + + mask = OpsWrapper._unwrap(mask) + self._load_mask = mask + self._load_other = value + try: + # TODO(jansel): do we need a reshape here? + yield mask + finally: + self._load_mask = prior + self._load_other = prior_val + + def get_strides_of_load(self, index: sympy.Expr) -> dict[sympy.Symbol, sympy.Expr]: + """ + This gets the stride of the index for each of the tiling variables + (technically, it does it at index 0) + + For example, if + xindex = x0 + 512*x1 + 1024*r0 + x0 = (xindex//512) + x1 = (xindex % 512) + r0 = rindex // 1024 + + this function would return + {xindex: 512, rindex: 1024} + """ + index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} + index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] + strides = {} + for range_tree in self.range_trees: + s = sympy_index_symbol(range_tree.name) + strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( + index_in_tile_vars, {s: 0} + ) + return strides + + @staticmethod + def _map_tuple_or_scalar(fn, value): + if isinstance(value, tuple): + return tuple(map(fn, value)) + return fn(value) + + def estimate_kernel_num_bytes(self): + """ + Try the best to estimate the total size (in bytes) of the + kernel's inputs and outputs, which is used for estimating the memory + throughput of this kernel. This information is used for checking how + far we are from the peak memory bandwidth. It's important that + we want to avoid overestimating the sizes of the inputs and outputs, + because it can wrongfully give us a very large memory traffic value, + which may be even larger than the theoretical bandwidth and thus + become very misleading. This is particularly problematic for cases + where we slice some inputs. In those cases, we should only count + the size of the "slices" instead of the original inputs, because + only the slices contribute to the real memory traffic. + """ + nbytes = [] + ninplace_args = len(unique(self.args.inplace_buffers.values())) + _, call_args, _, _ = self.args.python_argdefs() + buf_accesses = self.features.buf_accesses() + + # For pointwise and reduction kernels, this is the upper-bound numels + # for the output buffer. + # FIXME: This is not exactly right for cases like below: + # def foo(tensor0, tensor1): + # x0 = narrow(tensor0) + # return cat(x0, tensor1) + # For this example, we will end up overestimate the size for the + # slice s0. Potentially, we could have precise inputs information + # if we maintained the original inputs of the Pointwise kernel created + # for the "cat". However, I think it might be a bit overwhelming that + # we add such complexity only for handling some particular cases for + # benchmarking. + out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels.values())) + for i, arg in enumerate(call_args): + # "buf" may be narrowed. In this case, the number of memory accesses + # should be estimated based on the reinterpreted layout. + # On the other hand, buf may be broadcasted. In this case, + # counting the size of the underline storage would give us + # a better estimation in terms of memory accesses. + if arg not in buf_accesses: + nbytes.append(0) + continue + arg_numel = V.graph.get_numel(arg) + buf_size = V.graph.sizevars.size_hint(arg_numel) + if buf_size > out_numel: + # This arg points to a buf that has been sliced. + # We need to count each individual slice to have + # a better estimation. + indices = OrderedSet[Any]() + no_index_dep_count = 0 + for dep in buf_accesses[arg]: + if isinstance(dep, (StarDep, WeakDep)): + indices.add(f"no_index_dep_{no_index_dep_count}") + no_index_dep_count += 1 + else: + indices.add(dep.index) + numel = len(indices) * out_numel + else: + numel = buf_size + dtype = V.graph.get_dtype(arg) + dtype_size = get_dtype_size(dtype) + nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(nbytes) + + def warn_mix_layout(self, kernel_name): + """ + Print message if the kernel have mixed layout inputs. + Only care about 4D tensor for now. + """ + if ( + len(self.args.input_buffers) == 1 + and len(self.args.output_buffers) == 1 + and len(self.args.inplace_buffers) == 0 + ): + # even if input buffer and output buffer have different layout, + # this can be a layout conversion kernel. No need to warn for + # the mix layouts. + return + + argdefs, call_args, _signature, _ = self.args.python_argdefs() + uniform_stride_order = None + for arg_name in call_args: + buf = V.graph.try_get_buffer(arg_name) + if not buf: + continue + layout = buf.get_layout() + if len(layout.size) == 4: + # ignore the tensor if only 1 dimension is non-zero + if len([x for x in layout.size if x == 1]) == 3: + continue + stride_order = ir.get_stride_order(layout.stride) + if uniform_stride_order is None: + uniform_stride_order = stride_order + elif uniform_stride_order != stride_order: + msg = yellow_text( + f"Expected stride order {uniform_stride_order}, but found stride order" + + f" {stride_order} for kernel {kernel_name}" + ) + log.warning(msg) + + stride_order_list = [ + ir.get_stride_order( + V.graph.get_buffer(name).get_layout().stride + ) + if V.graph.try_get_buffer(name) + else None + for name in call_args + ] + size_list = [ + V.graph.get_buffer(name).get_layout().size + if V.graph.try_get_buffer(name) + else None + for name in call_args + ] + source_list = [ + "GraphInput" + if name in V.graph.graph_inputs + else "IntermediateBuffer" + if name in V.graph.name_to_buffer + else None + for name in call_args + ] + + argdef_names = [x.name for x in argdefs] + msg = yellow_text( + f" param names {argdef_names}\n buf names {call_args}\n strides {stride_order_list}" + + f"\n sizes {size_list}\n sources {source_list}\n" + ) + log.warning(msg) + return + msg = green_text( + f"All the inputs for the triton kernel {kernel_name} have uniform layout" + ) + log.warning(msg) + + def welford_reduce_fallback(self, dtype, value): + sum_ = ops.reduction(dtype, dtype, "sum", value) + self.inside_reduction = False + rnumel = ops.index_expr(self.features.reduction_numel, dtype) + mean = ops.truediv(sum_, rnumel) + + self.inside_reduction = True + dx = ops.sub(value, mean) + dx2 = ops.mul(dx, dx) + m2 = ops.reduction(dtype, dtype, "sum", dx2) + return OpsWrapper._unwrap((mean, m2, rnumel)) + + def prepare_softmax_twopass_fallback(self, dtype, value): + vmax = ops.reduction(dtype, dtype, "max", value) + sub = ops.sub(value, vmax) + exp = ops.exp(sub) + vsum = ops.reduction(dtype, dtype, "sum", exp) + return OpsWrapper._unwrap((vmax, vsum)) + + def codegen_kernel(self): + raise NotImplementedError + + def codegen_body(self): + pass + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + pass + + +class SIMDScheduling(BaseScheduling): + """ + Single Instruction Multiple Data parent class used for fusion across + multiple different backends. + """ + + kernel_type: type[Any] = SIMDKernel # override in subclass + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def can_fuse(self, node1, node2): + """ + Hook called by Scheduler to determine if the Triton backend + can fuse node1 and node2. These nodes might already be + FusedSchedulerNodes. + """ + if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( + node2, scheduler.ForeachKernelSchedulerNode + ): + return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) + + _, (numel1, rnumel1) = node1.group + _, (numel2, rnumel2) = node2.group + why = WhyNoFuse(node1, node2) + + if node1.is_split_scan() and not node2.is_split_scan(): + if node2.is_reduction(): + why("Split scan cannot fuse with reductions") + elif node2.is_split_scan() and not node1.is_split_scan(): + if node1.is_reduction(): + why("Split scan cannot fuse with reductions") + + if node1.is_reduction() and node2.is_reduction(): + reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 + if not reduction_can_fuse: + why( + "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return reduction_can_fuse + + if not node1.is_reduction() and not node2.is_reduction(): + if not (numel1 == numel2 and rnumel1 == rnumel2): + if not node2.is_template(): + why( + "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return False + else: + # prologue fusion input sizes differ from output group + # fuse so long as this node matches the group of existing prologue nodes + for node in node2.get_nodes(): + # dont need to check epilogue nodes for prologue fusion, break after template + if node.is_template(): + break + # we would have already restricted prologue from fusing if it had multiple + # uses, so it must be fusing into this node + if not node.used_buffer_names() & node1.get_buffer_names(): + continue + _, (pro_numel, pro_rnumel) = node.group + if not (numel1 == pro_numel and rnumel1 == pro_rnumel): + why( + "numel/rnumel mismatch prologue mismatch (%s, %s), (%s, %s)", + numel1, + pro_numel, + rnumel1, + pro_rnumel, + ) + return False + + for n in (node1, node2): + if n.is_template(): + return True + + # check for a bad combined tiling + tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) + tiling3 = self.select_tiling( + node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 + ) + if config.triton.tiling_prevents_pointwise_fusion: + cond = True + if len(tiling1) > 2: + if len(tiling2) > 2: + cond = tiling1 == tiling2 == tiling3 + else: + cond = tiling1 == tiling3 + elif len(tiling2) > 2: + cond = tiling2 == tiling3 + if not cond: + why( + "tiling mismatch (%s, %s, %s)", + tiling1, + tiling2, + tiling3, + ) + return False + + return True + + if not node1.is_reduction() and node2.is_reduction(): + assert rnumel1 == 1 and rnumel2 != 1 + if numel1 == numel2 * rnumel2: + if not all( + SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges()) + for n in node1.get_nodes() + ): + why("nodes numel/rnumel incompatibility") + return False + if ( + config.triton.tiling_prevents_reduction_fusion + and not node1.is_template() + ): + is_reduction_tiling_valid = tuple( + self.select_tiling(node1.get_nodes(), numel1).values() + ) in ( + (numel1, 1), + (numel2, rnumel2, 1), + ) + if not is_reduction_tiling_valid: + why("invalid tiling for reduction") + return is_reduction_tiling_valid + return True + + if numel1 != numel2: + why("nodes numel incompatibility") + return numel1 == numel2 + + assert node1.is_reduction() and not node2.is_reduction() + # swap args to hit the case above + return self.can_fuse_horizontal(node2, node1) + + can_fuse_vertical = can_fuse + can_fuse_horizontal = can_fuse + + def generate_node_schedule(self, nodes, numel, rnumel): + node_schedule: list[Any] = [] + done = OrderedSet[scheduler.BaseSchedulerNode]() + # Writes with a reduced shape, meaning they are only present once the + # reduction loop has ended + not_ready_yet_nodes: OrderedSet[str] = OrderedSet() + current_loop_buffer_usage: OrderedSet[str] = OrderedSet() + maybe_split_index: Optional[int] = None + + def fits_in_main_body(n): + _, (node_numel, node_rnumel) = n.group + return (node_numel == numel and node_rnumel == rnumel) or ( + node_numel == numel * rnumel and node_rnumel == 1 + ) + + def fits_outside_reduction(n): + _, (node_numel, node_rnumel) = n.group + return node_numel == numel and node_rnumel == 1 and rnumel != 1 + + def expect_improved_memory_usage(n): + for read in n.read_writes.reads: + if read.name in current_loop_buffer_usage: + return True + return False + + def schedule_node_in_loop(n): + done.add(n) + node_schedule.append(n) + current_loop_buffer_usage.update([x.name for x in n.read_writes.reads]) + + # A scan is modelled as a reduction in the scheduler but has a + # full sized output that can be used inside the loop body + if ( + n.is_reduction() + and isinstance(n, scheduler.SchedulerNode) + and isinstance(n.node, ir.ComputedBuffer) + and not isinstance(n.node.data, ir.Scan) + ): + not_ready_yet_nodes.add(n.get_name()) + else: # this node is available within the loop + current_loop_buffer_usage.update([x.name for x in n.read_writes.writes]) + + @contextlib.contextmanager + def end_current_reduction_loop(): + nonlocal maybe_split_index + if node_schedule and node_schedule[-1] is EnableReduction: + node_schedule.pop() + else: + node_schedule.append(DisableReduction) + if maybe_split_index: + node_schedule.insert(maybe_split_index, DisableReduction) + node_schedule.insert(maybe_split_index + 1, EnableReduction) + maybe_split_index = None + yield + node_schedule.append(EnableReduction) + not_ready_yet_nodes.clear() + current_loop_buffer_usage.clear() + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not not_ready_yet_nodes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(not_ready_yet_nodes) + + for node in nodes: + if node in done: + continue + done.add(node) + + if fits_in_main_body(node): + if requires_closing_previous_reduction(node, node_schedule): + with end_current_reduction_loop(): + pass # need to start a new reduction loop + + if current_loop_buffer_usage and not expect_improved_memory_usage(node): + # If we don't improve memory usage, then it is better to split into two loops + maybe_split_index = maybe_split_index or len(node_schedule) + else: + # Memory usage got improved, cancel the loop split + maybe_split_index = None + + schedule_node_in_loop(node) + elif fits_outside_reduction(node): + with end_current_reduction_loop(): + node_schedule.append(node) + else: + raise NotImplementedError( + f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" + ) + + return node_schedule + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + + nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + + if torch._inductor.config.triton.coalesce_tiling_analysis: + coalesce_analysis = analyze_memory_coalescing(node) + else: + coalesce_analysis = None + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule( + SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis) + ) + + @staticmethod + def can_use_32bit_indexing( + numel: sympy.Expr, + buffers: Iterable[ + Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject, ir.IRNode] + ], + ) -> bool: + int_max = torch.iinfo(torch.int32).max + + if not expr_fits_within_32bit(numel): + return False + + # Any use of a MultiOutputLayout will create a buffer with a + # Layout whose sizes are accounted for + buf_sizes = [ + buf.get_layout().storage_size() + for buf in buffers + if buf.has_tensor_output() + ] + + if not all(expr_fits_within_32bit(size) for size in buf_sizes): + return False + + # Only install guards for 32-bit indexing as there is no correctness + # issue with using 64-bit for everything + V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] + for size in buf_sizes: + V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] + return True + + def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): + node_schedule = kernel_features.node_schedule + + tiling, tiling_score = self.get_tiling_and_scores( + node_schedule, + kernel_features.numel, + kernel_features.reduction_numel, + kernel_features.coalesce_analysis, + ) + kernels = self.create_kernel_choices( + kernel_features, + [tiling], + {"features": kernel_features, "tiling_scores": tiling_score}, + ) + for kernel in kernels: + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + MultiKernel.merge_workspaces_inplace(kernels) + for kernel in kernels: + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing( + node_schedule, # type: ignore[arg-type] + kernel_name, + ) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + del kernel + + final_kernel: Union[SIMDKernel, MultiKernel] + if len(kernels) > 1: + final_kernel = MultiKernel(kernels) + else: + (final_kernel,) = kernels + + with V.set_kernel_handler(final_kernel): + for node in kernel_features.scheduler_nodes(): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernels[0].kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernels[0].args.live_output_buffers() + for node in kernel_features.scheduler_nodes(): + name = node.get_name() + if name not in live_outs: + continue + assert node.node is not None + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.free_buffers_in_scheduler() + + def create_kernel_choices( + self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs + ) -> list[SIMDKernel]: + return [ + self.kernel_type( + *kernel_args, + **kernel_kwargs, + ) + ] + + def codegen_node_schedule_with_kernel(self, node_schedule, kernel): + with kernel: + stack = contextlib.ExitStack() + all_indexing = {} + + # First pass to collect indexing and decide inplace updates + for node in node_schedule: + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + node.decide_inplace_update() + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + all_indexing.update( + dict.fromkeys( + node._body.indexing_from_args(index_vars).values() + ) + ) + + kernel.finalize_indexing(all_indexing.keys()) + + # Second pass to do codegen + for node in node_schedule: + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + # TODO - use split ranges ? + indexing_dtype_strength_reduction(node._body) + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node.codegen(index_vars) + + def codegen_template( + self, template_node, epilogue_nodes, prologue_nodes, *, only_gen_src_code=False + ) -> Optional[str]: + """ + Codegen a triton template + + If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper + """ + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + kernel, render = template_node.node.make_kernel_render(template_node.node) + + buf_name_to_prologue_group = {} + template_reads = template_node.used_buffer_names() + prologue_group = [] + for prologue in prologue_nodes: + names = prologue.get_buffer_names() + prologue_group.append(prologue) + # this must be the end of a prologue group + if names & template_reads: + assert len(names) == 1 + buf_name_to_prologue_group[next(iter(names))] = prologue_group + kernel.prologue_fused_inputs.add(next(iter(names))) + prologue_group = [] + + # all prologue groups should have finalized with use in template + assert len(prologue_group) == 0 + + with kernel: + if not only_gen_src_code: + # prologue nodes can only be fused if their only use is in the template, + # so they are necessarily not allocated + for node in [template_node, *epilogue_nodes]: + node.mark_run() + + partial_code = render() + + with kernel.set_subgraph_body(""): + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + kernel.cse.invalidate(OrderedSet()) + + for input_name, buffer in kernel.named_input_nodes.items(): + subgraph_name = f"" + if prologue_group := buf_name_to_prologue_group.get( + buffer.get_name(), [] + ): + can_codegen_without_upcast = all( + p_n.can_codegen_without_upcasts() for p_n in prologue_group + ) + + # TODO - this doesn't work with libdevice calls, potentially other bugs + # upcasting to fp32 and downcasting gives large slowdown + with config.patch( + "triton.codegen_upcast_to_fp32", not can_codegen_without_upcast + ): + with kernel.set_subgraph_body(subgraph_name): + for prologue_node in prologue_group: + if ( + len(prologue_node.get_buffer_names()) == 1 + and len(prologue_group) == 1 + ): + if prologue_preserves_zero_mask(prologue_node): + kernel.prologue_fused_inputs_preserve_zero |= ( + prologue_node.get_buffer_names() + ) + + prologue_node.codegen( + kernel.split_and_set_ranges( + prologue_node.get_ranges() + ) + ) + kernel.cse.invalidate(OrderedSet()) + + if not isinstance(partial_code, str): + partial_code.finalize_hook("") + partial_code.finalize_hook("", strict=False) + # finalize must be called after adding epilogue above + + with V.set_kernel_handler(kernel): + # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. + + for input_name in kernel.named_input_nodes.keys(): + subgraph_name = f"" + partial_code.finalize_hook(subgraph_name, strict=False) + + with kernel.set_subgraph_body(""): + if isinstance(partial_code, str): + src_code = partial_code + else: + partial_code.finalize_hook("") + src_code = partial_code.code + node_schedule = [*prologue_nodes, template_node, *epilogue_nodes] + + if config.benchmark_kernel: + num_gb = kernel.estimate_kernel_num_bytes() / 1e9 + src_code = ( + f"{kernel.imports_for_benchmark_kernel()}\n" + f"{src_code}\n" + f"{kernel.codegen_kernel_benchmark(num_gb).getvalue()}" + ) + + if only_gen_src_code: + return src_code + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(node_schedule, kernel_name) + + self.codegen_comment(node_schedule) + kernel.call_kernel(kernel_name, template_node.node) + + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + self.free_buffers_in_scheduler() + return None + + def codegen_sync(self): + V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) + + def generate_combo_kernel_code( + self, + subkernel_nodes: list[BaseSchedulerNode], + custom_part_algorithm: bool, + enable_autotune: bool, + mixed_sizes: bool, + only_gen_src_code: bool = False, + ) -> list[tuple[str, Any, Any]]: + from .triton_combo_kernel import ComboKernel + + fused_node_lists = [node.get_nodes() for node in subkernel_nodes] + subkernel_map, node_schedule_map = {}, {} + for pn, nodes in zip(subkernel_nodes, fused_node_lists): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + tiling = self.select_tiling(node_schedule, numel, rnumel) + node_schedule_map[pn] = node_schedule, tiling, numel, rnumel + subkernel_map[pn] = ComboKernel.create_triton_kernel( + tiling, + features=SIMDKernelFeatures(node_schedule, numel, rnumel), + optimize_mask=not mixed_sizes, + ) + + partitions = ComboKernel.horizontal_partition( + nodes=subkernel_nodes, + triton_scheduling=self, + custom_algorithm=custom_part_algorithm, + kernel_map=subkernel_map, + node_info_map=node_schedule_map, + ) + log.debug( + "ComboKernels: %d nodes partitioned into %s groups", + len(subkernel_nodes), + [len(p) for p in partitions], + ) + kernel_code_list = [] + for node_group in partitions: + fused_node_lists = [node.get_nodes() for node in node_group] + kernel = ComboKernel( + enable_autotune=enable_autotune, + mixed_sizes=mixed_sizes, + ) + + for pn, nodes in zip(node_group, fused_node_lists): + self.codegen_node_schedule_with_kernel( + node_schedule_map[pn][0], + kernel.create_sub_kernel(subkernel_map[pn]), + ) + subkernel = subkernel_map[pn] + node_schedule = node_schedule_map[pn][0] + if not only_gen_src_code: + with V.set_kernel_handler(subkernel): # type: ignore[call-arg] + for node in NodeScheduleMarker.only_nodes(node_schedule): + node.mark_run() + V.graph.removed_buffers |= subkernel.removed_buffers + V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove + + src_code = kernel.codegen_kernel() + kernel_code_list.append((src_code, kernel, node_group)) + return kernel_code_list + + def codegen_combo_kernel(self, combo_kernel_node): + subkernel_nodes = combo_kernel_node.get_subkernel_nodes() + custom_part_algorithm = combo_kernel_node.use_custom_partition_algo + enable_autotune = combo_kernel_node.enable_autotune + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( + config.combo_kernel_allow_mixed_sizes == 1 and custom_part_algorithm + ) + + kernel_code_list = self.generate_combo_kernel_code( + subkernel_nodes, custom_part_algorithm, enable_autotune, mixed_sizes + ) + + for src_code, kernel, _ in kernel_code_list: + kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) + # dump provenance node info for ComboKernelNode/ForeachKernel type + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing( + combo_kernel_node.snodes, kernel_name + ) + self.codegen_comment([combo_kernel_node]) + log.debug("ComboKernels: generated kernel %s.", kernel_name) + kernel.call_kernel(V.graph.wrapper_code, kernel_name) + + self.free_buffers_in_scheduler() + + @classmethod + @functools.lru_cache(32) + def candidate_tilings(cls, node, numel, reduction_numel) -> list[CandidateTiling]: + is_pointwise = reduction_numel == 1 + + def tile_ranges(is_pointwise: bool, ranges, rw) -> list[CandidateTiling]: + """ + Compute tiling candidates by dividing up the iteration ranges. + """ + assert len(rw.range_vars) == len(ranges), f"{rw.range_vars=} {ranges=}" + + # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads + # that need to access the entire tensor; they don't contribute read indexing + # information (and practically, they don't have dep.index so they can't be used + # for stride_hints below + dep_sources = [rw.reads, rw.writes] + assert all( + isinstance(dep, (MemoryDep, StarDep)) + for dep in itertools.chain.from_iterable(dep_sources) + ) + deps = [ + dep + for dep in itertools.chain.from_iterable(dep_sources) + if dep.name not in V.graph.removed_buffers + and isinstance(dep, MemoryDep) + ] + write_names = OrderedSet([dep.name for dep in rw.writes]) + + def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr: + return V.graph.sizevars.simplify(sympy_product(ranges)) + + # Default to no tiling. + tilings = [ + CandidateTiling( + tiling=cls.create_partial_tiling( + [collapse_ranges(ranges)], is_pointwise + ), + name="none", + score=0, + ) + ] + + # Find non-trivial tiling candidates. + for dep in deps: + strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) + assert len(strides) == len(ranges) + try: + split = strides.index(1) + 1 + if split == len(ranges): + continue + if all(s == 0 for s in strides[split:]): + # if this is a broadcasted tensor and all dimensions after split are broadcast, + # this is not a real split + continue + + except ValueError: + continue + + tiled_groups = ( + collapse_ranges(ranges[:split]), + collapse_ranges(ranges[split:]), + ) + + # score by number of elements + score = V.graph.sizevars.size_hint( + sympy_product( + size for size, stride in zip(ranges, strides) if stride != 0 + ) + ) + if dep.name in write_names: + # ngimel said contiguous writes is more important than reads + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[0]): + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[1]): + score *= 2 + + if ( + V.graph.sizevars.size_hint( + score - sympy_product(itertools.chain(ranges, reduction_ranges)) + ) + >= 0 + ): + tilings.append( + CandidateTiling( + tiling=cls.create_partial_tiling( + [ + collapse_ranges(ranges[:split]), + collapse_ranges(ranges[split:]), + ], + reduction_numel, + ), + score=score, + name=dep.name, + ) + ) + + return tilings + + pointwise_ranges, reduction_ranges = node.get_ranges() + if ( + len(pointwise_ranges) <= 1 + and len(reduction_ranges) <= 1 + or free_unbacked_symbols(pointwise_ranges + reduction_ranges) + ): + return [] + + # Tile either pointwise or reduction dims. + pointwise_ranges, reduction_ranges = node.get_ranges() + partial_tilings = tile_ranges( + is_pointwise, + pointwise_ranges if is_pointwise else reduction_ranges, + node.pointwise_or_reduction_read_writes(is_pointwise), + ) + + # Fill in the missing ranges. + full_tilings = [ + CandidateTiling( + tiling=cls.complete_partial_tiling( + tiling.tiling, numel, reduction_numel + ), + score=tiling.score, + name=tiling.name, + ) + for tiling in partial_tilings + ] + + return full_tilings + + @classmethod + def create_tiling( + cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr] + ) -> dict[str, sympy.Expr]: + """ + Create a tiling dict from pointwise and reduction splits. + """ + pw_prefixes = ["z", "y", "x"][-len(pw_tiling) :] + reduction_prefixes = ["r0_", "r1_"][: len(reduction_tiling)] + return immutable_dict( + [*zip(pw_prefixes, pw_tiling), *zip(reduction_prefixes, reduction_tiling)] + ) + + @classmethod + def create_partial_tiling( + cls, + tiling: Sequence[sympy.Expr], + is_pointwise: bool, + ) -> dict[str, sympy.Expr]: + return cls.create_tiling( + tiling if is_pointwise else [], + tiling if not is_pointwise else [], + ) + + @classmethod + def complete_partial_tiling( + cls, + tiling: dict[str, sympy.Expr], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + ) -> dict[str, sympy.Expr]: + """ + Given a tiling for only pointwise or reduction dimensions, adds the missing one. + """ + splits = list(tiling.values()) + is_pointwise = "x" in tiling + + total_numel = numel * reduction_numel + missing_tiling = [total_numel / sympy_product(splits)] + + tiling_args = ( + (splits, missing_tiling) if is_pointwise else (missing_tiling, splits) + ) + return cls.create_tiling(*tiling_args) + + @classmethod + def get_nd_tilings( + cls, + node_schedule, + pointwise_numel, + reduction_numel, + ) -> list[dict[str, tuple[sympy.Expr]]]: + """ + Creates N-dimensional tiling candidates, attempting to simplify loads/stores + by tiling the kernel into higher dimensions. + + Returns a list of tilings ranked by dimensionality. + """ + is_pointwise = reduction_numel == 1 + tilings = OrderedSet[dict[str, sympy.Expr]]() + for node in EnableReduction.filter(node_schedule): + if not isinstance(node, scheduler.SchedulerNode): + continue + + # If this is a reduction schedule, skip nodes which are missing their + # reduction ranges. + node_ranges = node.get_ranges() + if not is_pointwise and len(node_ranges[1]) == 0: + continue + + # Use the node ranges as the default tiling candidate. + ranges_to_tile = node_ranges[0 if is_pointwise else 1] + node_tilings = [ranges_to_tile] + + # Search the indexing expressions for more candidates. + # If we see modular indexing, try to subdivide ranges into their implied + # block shape. + memory_deps = [ + dep + for dep in node.read_writes.reads_and_writes() + if isinstance(dep, MemoryDep) and len(dep.ranges) > 0 + ] + for dep in memory_deps: + # Attempt to partition variable ranges into pointwise and reduction groups. + # To achieve this, merge the leading ranges until we reach the pointwise numel. + all_var_ranges = [*dep.ranges.items()] + pointwise_vars_numel = sympy.S.One + sizevars = V.graph.sizevars + for pointwise_end_idx, (var, numel) in enumerate(all_var_ranges): + pointwise_vars_numel *= numel + if sizevars.statically_known_geq( + pointwise_vars_numel, pointwise_numel + ): + break + + # Reject the split if it does not match the total pointwise numel. + if not sizevars.statically_known_equals( + pointwise_vars_numel, pointwise_numel + ): + continue + + # Partition var ranges into pointwise and reduction splits. + reduction_start_idx = pointwise_end_idx + 1 + var_ranges = ( + all_var_ranges[:reduction_start_idx] + if is_pointwise + else all_var_ranges[reduction_start_idx:] + ) + + # Pattern match the subexpression pertaining to each index variable. + index_tiling = [] + for var, numel in var_ranges: + index = BlockPatternMatcher.get_subexpr_involving_symbol( + dep.index, var + ) + + # Heuristic to bound the maximum dimensionality of the block. + num_dims = max( + 2, + index.count(FloorDiv) + index.count(ModularIndexing), + len(ranges_to_tile), + ) + + # Attempt to pattern match the index expr. + # Failed matches default to the full range. + match_result = BlockPatternMatcher.match_mod_div_block_expr( + index, var, numel, num_dims + ) + dims = match_result[0] if match_result is not None else [numel] + index_tiling.extend(dims) + + # Prune dimensions of size 1. + index_tiling = [ + dim + for dim in index_tiling + if not V.graph.sizevars.statically_known_equals(dim, sympy.S.One) + ] + + if len(index_tiling) > 0: + node_tilings.append(index_tiling) + + # Flatten leading dimensions, assigning labels to each dim. + for node_tiling in node_tilings: + num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2)) + first_trailing_dim = num_leading_dims + 1 + collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim]) + collapsed_splits = (collapsed_leading_dim,) + tuple( + node_tiling[first_trailing_dim:] + ) + tilings.add( + cls.complete_partial_tiling( + cls.create_partial_tiling(collapsed_splits, is_pointwise), + pointwise_numel, + reduction_numel, + ) + ) + + # Rank tilings by the number of dimensions. E.g., prefer 2D to 1D. + # Since this is a stable sort, ties are broken by schedule order. + ranked_tilings = sorted( + tilings, + key=len, + reverse=True, + ) + + return ranked_tilings + + @classmethod + def compute_tiling_strategy( + cls, + node_schedule: list[NodeScheduleEntry], + pointwise_numel: sympy.Expr, + reduction_numel: sympy.Expr, + coalesce_analysis: CoalesceVarAnalysis, + ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]: + """ + Generates a tiling, and a score of each tile according to each tile's coalesced memory accesses. + """ + tiling_var: Optional[sympy.Expr] = ( + None + if not coalesce_analysis.suggested_split + else coalesce_analysis.suggested_split.var + ) + + all_iter_vars = coalesce_analysis.norm_read_writes.index_vars + all_red_vars = coalesce_analysis.norm_read_writes.reduce_vars + ranges = coalesce_analysis.norm_read_writes.var_ranges + + pw_ranges = [ranges[v] for v in all_iter_vars] + red_ranges = [ranges[v] for v in all_red_vars] + + torch._check( + sympy_product(pw_ranges) == pointwise_numel, + lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}", + ) + torch._check( + sympy_product(red_ranges) == reduction_numel, + lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}", + ) + + # score of a pointwise or reduction split + scored_sub_split: dict[Any, tuple[list[int], list[int]]] = {} + + score_split: list[ + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]] + ] = [] + + def process_node_vars( + vars_to_use: tuple[sympy.Expr, ...] = (), + use_split_var: bool = False, + is_pointwise: bool = False, + ) -> tuple[list[int], list[int]]: + """ + Generate a tiling, and a tiling score, given vars to use as splits. + """ + + ranges = pw_ranges if is_pointwise else red_ranges + target_numel = pointwise_numel if is_pointwise else reduction_numel + # Some kernels have no reduction ranges, and a reduction numel of 1 + if not ranges: + if target_numel: + return ([target_numel], []) + else: + return ([], []) + + key = (repr(vars_to_use), use_split_var, is_pointwise) + if out := scored_sub_split.get(key, None): + return out + + splitting_vars = all_iter_vars if is_pointwise else all_red_vars + + splits = [] + split_scores = [] + prod = 1 + prev_var_coalesced_score = 0 + + # iterate from non-dense to dense + for v, v_range in zip(splitting_vars, ranges): + if v not in vars_to_use: + prod *= v_range + prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get( + v, 0 + ) + continue + + if use_split_var and v == tiling_var: + var_tiling = coalesce_analysis.suggested_split + assert var_tiling is not None + + tile = var_tiling.tiling_factor + remainder = FloorDiv(v_range, var_tiling.tiling_factor) + + splits.append(prod * remainder) + split_scores.append(var_tiling.score) + + splits.append(tile) + split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0)) + + prod = 1 + prev_var_coalesced_score = 0 + + continue + + prod *= v_range + splits.append(prod) + split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0)) + prod = 1 + + if prod != 1 or (is_pointwise and len(splits) == 0): + splits.append(prod) + split_scores.append(prev_var_coalesced_score) + + # penalize splits that leave small blocks + # where we can't fully utilize full memory transaction + # TODO: incorporate exact bitwidth, and read/write + # coalesced write is 2x more important + for i in range(len(splits)): + s = V.graph.sizevars.size_hint(splits[i], fallback=32) + s = min(s, 8) + split_scores[i] = int(split_scores[i] * s / 8) + + scored_sub_split[key] = (splits, split_scores) + return (splits, split_scores) + + # add the default tiling + score_split.append( + ( + process_node_vars(is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + if tiling_var: + score_split.append( + ( + process_node_vars( + (tiling_var,), use_split_var=True, is_pointwise=True + ), + process_node_vars(is_pointwise=False), + ) + ) + + # TODO, add tests, reduction splits if config.triton.tile_reductions + # TODO: we should ignore tiny increases in score for extra splits + overlapping_iter_vars = ( + all_iter_vars & coalesce_analysis.coalesced_by_var.keys() + ) + for v in overlapping_iter_vars: + score_split.append( + ( + process_node_vars((v,), is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + if get_max_tiles(default=3) == 3 and reduction_numel == 1: + for vars_to_use in itertools.combinations(overlapping_iter_vars, 2): + score_split.append( + ( + process_node_vars(vars_to_use, is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + tilings: list[tuple[CandidateTiling, dict[str, sympy.Expr]]] = [] + for (pw_split, pw_score), (red_split, red_score) in score_split: + candidate = CandidateTiling( + cls.create_tiling(pw_split, red_split), + score=sum(pw_score) + sum(red_score), + ) + tiling_score = cls.create_tiling(pw_score, red_score) + tilings.append((candidate, tiling_score)) + + default_tiling = cls.create_tiling([pointwise_numel], [reduction_numel]) + + # add a slight penalty for longer tilings that dont increase score much, + # and are poor sizes + bad_size_additional_tiling_penalty = 1.025 + good_size_tiling_penalty = 1.005 + + def score_mod(t): + score_factor = 1.0 + for tile_size in t[0].tiling.values(): + if not CandidateTiling.is_good_size(tile_size): + score_factor = score_factor / bad_size_additional_tiling_penalty + else: + score_factor = score_factor / good_size_tiling_penalty + + return -t[0].score * score_factor + + # apply penalty for longer tilings that dont increase score much + for cand, tiling_score in sorted(tilings, key=score_mod): + if cls.tiling_is_compatible( + node_schedule, pointwise_numel, reduction_numel, cand.tiling + ): + # we always include default reduction numel == 1, dont include + tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0) + if tiling_len > get_max_tiles(default=3): + perf_hint_log.info( + "Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles " + "set to %s. Consider increasing", + tiling_len, + torch._inductor.config.triton.max_tiles, + ) + continue + + return cand.tiling, tiling_score + + # surprisingly, the default tiling is not always read as compatible by `tiling_is_compatible` + # TODO - look into, occurs with dynamic shapes often + if cand.tiling == default_tiling: + return cand.tiling, tiling_score + + return default_tiling, None + + @classmethod + def tiling_is_compatible( + cls, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + tiling: dict[str, sympy.Expr], + ): + assert isinstance(tiling, dict) + return all( + SIMDKernel.is_compatible( + tiling.values(), node.get_ranges(), reduction_numel=reduction_numel + ) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode) + ) + + @classmethod + def get_first_compatible_tiling( + cls, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + ranked_tilings: list[dict[str, sympy.Expr]], + ): + for tiling in ranked_tilings: + if cls.tiling_is_compatible(node_schedule, numel, reduction_numel, tiling): + return tiling + + return None + + @classmethod + def select_tiling( + cls, + node_schedule, + numel, + reduction_numel=sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ) -> dict[str, sympy.Expr]: + return cls.get_tiling_and_scores( + node_schedule, numel, reduction_numel, coalesce_analysis + )[0] + + @classmethod + def get_tiling_and_scores( + cls, + node_schedule, + numel, + reduction_numel=sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]: + """ + Heuristics to decide how to tile kernels. + Currently, we tile based on stride-1 dimensions. + + Returns: + `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` + + """ + # If this is a reduction, only tile reduction dims. + is_pointwise = reduction_numel == 1 + + # Tiled reductions are gated by a config flag. + default_tiling = cls.create_tiling([numel], [reduction_numel]) + + # # TODO: enable by default + if ( + torch._inductor.config.triton.coalesce_tiling_analysis + and coalesce_analysis + and not config.triton.prefer_nd_tiling + ): + return cls.compute_tiling_strategy( + node_schedule, numel, reduction_numel, coalesce_analysis + ) + + if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles( + default=2 + ) <= 1: + # Emit a perf hint in case we miss an opportunity to tile a reduction. + if perf_hint_log.level <= logging.WARNING: + for node in EnableReduction.filter(node_schedule): + if ( + not config.triton.tile_reductions + and len(cls.candidate_tilings(node, numel, reduction_numel)) > 0 + ): + perf_hint_log.info( + textwrap.dedent( + """ + Reduction over non-contiguous dims. + Consider setting config.triton.tile_reductions to True. + """ + ) + ) + break + + return default_tiling, None + + seen_names: OrderedSet[str] = OrderedSet() + candidate_tiles: Counter[CandidateTiling] = collections.Counter() + for node in EnableReduction.filter(node_schedule): + for candidate_tiling in cls.candidate_tilings(node, numel, reduction_numel): + if candidate_tiling.name in seen_names: + continue + elif candidate_tiling.name is not None: + seen_names.add(candidate_tiling.name) + candidate_tiles[candidate_tiling] += candidate_tiling.score + + ranked_tilings: list[dict[str, sympy.Expr]] = [ + candidate_tiling.tiling + for candidate_tiling, score in candidate_tiles.most_common() + ] + + if get_max_tiles(default=2) >= 3 and is_pointwise: + # Consider adding a third dimension of tiling, but only + # when a1 is a multiple of b1; otherwise, you have a lot + # of stragglers which is annoying to generate code for. + # + # NB: More than three max tiles is not enabled by default. + + def convert_tiling_to_3d( + tiling0: dict[str, sympy.Expr], tiling1: dict[str, sympy.Expr] + ) -> Optional[dict[str, sympy.Expr]]: + a0, a1 = tiling0["x"], tiling0.get("y", 1) + b0, b1 = tiling1["x"], tiling1.get("y", 1) + + if ( + free_unbacked_symbols([a1, b1]) + or V.graph.sizevars.size_hint(a1 - b1) == 0 + ): + return None + if V.graph.sizevars.size_hint(a1 - b1) < 0: + # swap so a0 is bigger + (a0, a1), (b0, b1) = (b0, b1), (a0, a1) + + assert V.graph.sizevars.size_hint(a1 - b1) > 0 + if not V.graph.sizevars.statically_known_multiple_of(a1, b1): + return None + + new_tiling = { + "z": a0, + "y": FloorDiv(a1, b1), + "x": b1, + "r0_": tiling0["r0_"], + } + + return new_tiling + + for i in range(1, len(ranked_tilings)): + new_3d_tiling = convert_tiling_to_3d( + ranked_tilings[0], ranked_tilings[i] + ) + if new_3d_tiling is not None: + ranked_tilings = [new_3d_tiling] + ranked_tilings + break # only 1 choice for now + + if len(ranked_tilings) > 1: + perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) + + # Optionally, prefer tiling into as many dimensions as possible. + if config.triton.prefer_nd_tiling: + ranked_tilings = ( + cls.get_nd_tilings(node_schedule, numel, reduction_numel) + + ranked_tilings + ) + + if tiling := cls.get_first_compatible_tiling( + node_schedule, numel, reduction_numel, ranked_tilings + ): + return tiling, None + + return default_tiling, None + + def flush(self): + pass + + def ready_to_flush(self) -> bool: + return False + + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): + if not any(n.is_template() for n in nodes): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + tiling = self.select_tiling(node_schedule, numel, rnumel) + kernel = self.kernel_type( + tiling, + features=SIMDKernelFeatures(node_schedule, numel, rnumel), + ) + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + with ( + config.patch("benchmark_kernel", benchmark_kernel), + V.set_kernel_handler(kernel), + ): + src_code = kernel.codegen_kernel() + else: + prologue, template, epilogue = nodes[0].get_prologue_template_epilogue( + nodes + ) + with config.patch("benchmark_kernel", benchmark_kernel): + src_code = self.codegen_template( + template, + epilogue, + prologue, + only_gen_src_code=True, + ) + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return src_code + + def codegen_comment(self, node_schedule): + pass + + def define_kernel(self, src_code, node_schedule, kernel): + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class CandidateTiling: + tiling: dict[str, sympy.Expr] + score: int # higher is better + name: Optional[str] = None + + @staticmethod + def is_good_size(s): + """Somewhat arbitrary heuristic used to boost scores for some sizes""" + s = V.graph.sizevars.size_hint(s) + return s >= 32 and (s % 32 == 0) + + +class CantSplit(Exception): + pass diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..77e9dba34eddade21c5f8e18059146b930eac650 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/simd_kernel_features.py @@ -0,0 +1,618 @@ +from __future__ import annotations + +import collections +import dataclasses +import functools +import itertools +import typing +from typing import Any, Optional, Union + +import sympy + +import torch + +from ...utils._ordered_set import OrderedSet +from ...utils._sympy.functions import FloorDiv, ModularIndexing +from ...utils._sympy.symbol import make_symbol, SymT +from ..dependencies import Dep, extract_loop_body_with_args, MemoryDep +from ..runtime.hints import ReductionHint +from ..scheduler import SchedulerNode +from ..utils import cache_on_self +from ..virtualized import V + + +if typing.TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from torch._inductor.tiling_utils import CoalesceVarAnalysis + + +class NodeScheduleMarker: + @staticmethod + def only_nodes(it: Iterable[NodeScheduleEntry]) -> Iterable[SchedulerNode]: + for item in it: + if not (item is DisableReduction or item is EnableReduction): + yield item # type: ignore[misc] + + @staticmethod + def is_reduction() -> bool: + return False + + +NodeScheduleEntry = Union[SchedulerNode, type[NodeScheduleMarker]] + + +class DisableReduction(NodeScheduleMarker): + """ + Marker to invoke `kernel.disable_reduction()`. This closes a + reduction loop and allows for pointwise ops to occur on the output + of a reduction. + """ + + +class EnableReduction(NodeScheduleMarker): + """ + Marker to end a DisableReduction block. + """ + + @staticmethod + def filter(node_schedule: list[NodeScheduleEntry]) -> Iterable[SchedulerNode]: + """ + Get the nodes from node_schedule skipping those in a + DisableReduction block. + """ + disabled = False + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + # Don't tile stuff outside the main reduction loop + disabled = node is DisableReduction + elif disabled: + pass + else: + yield node # type: ignore[misc] + + +class SIMDKernelFeatures: + """ + An ordered schedule of nodes that will become a single kernel. + """ + + def __init__( + self, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr = sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ): + self.node_schedule = node_schedule + # numel excludes reduction_numel + self.numel: sympy.Expr = V.graph.sizevars.simplify(numel) + self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel) + self._stats_cache: dict[tuple[sympy.Expr, ...], MemoryStats] = {} + self.coalesce_analysis = coalesce_analysis + + @cache_on_self + def is_reduction(self) -> bool: + return self.reduction_numel != 1 + + @cache_on_self + def scheduler_nodes(self) -> Iterable[SchedulerNode]: + return tuple(NodeScheduleMarker.only_nodes(self.node_schedule)) + + def reduction_nodes(self) -> list[SchedulerNode]: + return [n for n in self.scheduler_nodes() if n.is_reduction()] + + @cache_on_self + def buf_accesses(self) -> dict[str, list[Dep]]: + """only needed for config.benchmark_kernel""" + buf_accesses = collections.defaultdict(list) + for node in self.scheduler_nodes(): + for access in node.read_writes.reads | node.read_writes.writes: + buf_accesses[access.name].append(access) + return buf_accesses + + @cache_on_self + def op_counts(self) -> collections.Counter[str]: + counts: collections.Counter[str] = collections.Counter() + for node in self.scheduler_nodes(): + counts.update(node._body.op_counts) + return counts + + def contains_op(self, op_name: str) -> bool: + """True if V.ops.{op_name} is used in node_schedule""" + return bool(self.op_counts().get(op_name)) + + def get_mutations(self) -> OrderedSet[str]: + mutations: OrderedSet[str] = OrderedSet() + for node in self.scheduler_nodes(): + for buf in node.get_outputs(): + mutations.update(buf.get_mutations()) + return mutations + + @cache_on_self + def select_index_dtype(self) -> torch.dtype: + # Gather all used buffer names + buffer_names: OrderedSet[str] = OrderedSet() + for node in self.scheduler_nodes(): + buffer_names.update(node.get_buffer_names()) + buffer_names.update(node.used_buffer_names()) + buffers = [V.graph.get_buffer(name) for name in buffer_names] + + # In theory we can separately check xnumel and rnumel are <= int_max + # but some indexers do use the full linear index so we need to be + # conservative here. + total_numel = self.numel * self.reduction_numel + + from .simd import SIMDScheduling + + if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers): + return torch.int32 + return torch.int64 + + @cache_on_self + def get_reduction_hint(self) -> ReductionHint: + reductions = self.reduction_nodes() + if len(reductions) > 0: + hints = [self.reduction_hint(n) for n in reductions] + if hints.count(hints[0]) == len(hints): + reduction_hint_val = hints[0] + else: + reduction_hint_val = ReductionHint.DEFAULT + + if ( + reduction_hint_val == ReductionHint.INNER + and self.has_non_contiguous_pw_in_reduction_kernel() + ): + reduction_hint_val = ReductionHint.DEFAULT + else: + reduction_hint_val = ReductionHint.DEFAULT + return reduction_hint_val + + @cache_on_self + def buffer_read_counts(self) -> dict[str, int]: + """Counts how many times each buffer is read within the kernel""" + read_counts: dict[str, int] = collections.defaultdict(int) + + for node in self.scheduler_nodes(): + # node.read_writes.reads contains MemoryDep objects for each read + for read_dep in node.read_writes.reads: + read_counts[read_dep.name] += 1 + + return dict(read_counts) # Convert defaultdict to regular dict + + def has_non_contiguous_pw_in_reduction_kernel(self) -> bool: + pointwise_nodes = [ + n + for n in self.scheduler_nodes() + if not n.is_reduction() + and n.group[1][0] == self.numel * self.reduction_numel + ] + for node in pointwise_nodes: + # An index can be an integer when loading a random seed. + if not all( + not isinstance(dep, MemoryDep) + or dep.is_contiguous() + or isinstance(dep.index, (sympy.Integer, int)) + or dep.stride1_for_last_dim() + for dep in itertools.chain( + node.read_writes.reads, node.read_writes.writes + ) + ): + return True + return False + + @staticmethod + def reduction_hint(node: Any) -> ReductionHint: + assert node.is_reduction() + if node.node.data.reduction_hint != ReductionHint.INNER and all( + dep.is_contiguous() + for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) + ): + return ReductionHint.INNER + else: + return node.node.data.reduction_hint + + def memory_stats( + self, groups_dict: Optional[dict[str, sympy.Expr]] = None + ) -> MemoryStats: + """Analysis to generate features that can be used in heuristics""" + if groups_dict is None: + groups = (self.numel, self.reduction_numel) + elif groups_dict.keys() == OrderedSet(["x", "r0_"]): + groups = (groups_dict["x"], groups_dict["r0_"]) + else: + raise NotImplementedError(f"groups_dict={groups_dict!r}") + result = self._stats_cache.get(groups) + if result is None: + self._stats_cache[groups] = result = MemoryStats.compute( + MemoryEstimator(self, groups) + ) + return result + + +class MemoryEstimator: + """ + Estimate various properties of the kernel for use in heuristics. + We simulate the memory effects of CSE/buffer elimination in codegen. + """ + + kernel_sizes: tuple[sympy.Expr, ...] + outside_loop: MemoryEstimate + loops: list[MemoryEstimate] + persistent: MemoryEstimate + symbols: list[sympy.Symbol] + + def __init__(self, features: SIMDKernelFeatures, groups: Sequence[sympy.Expr]): + self.features = features + self.inside_reduction = features.is_reduction() + self.store_buffer_names: OrderedSet[str] = OrderedSet() + self.must_keep_buffers: OrderedSet[str] = OrderedSet() + self.num_reductions_dims = 1 + self.groups = groups + self.symbols = [make_symbol(SymT.INDEX, i) for i in range(len(groups))] + # We are doing two estimates simultaneously: + # 1) the first is a for a non-persistent (aka looped) reduction, using self.outside_loop/self.loops + # we add an item to loops each corresponding to each reduction loop in the kernel + # outside_loop is only used for broadcasting or point-wise ops that don't use the reduction dimension + # 2) the second is for a persistent kernel, using self.persistent + # persistent kernels don't have loops, so we only have one MemoryEstimate() + # for point-wise ops the two estimates will be the same, they matter for reductions only + self.outside_loop = MemoryEstimate() + self.loops = [MemoryEstimate()] + self.persistent = MemoryEstimate() + self.simulate_codegen() + self.remove_kernel_local() + + def simulate_codegen(self) -> None: + from .simd import SIMDKernel + + kernel_size_outside_loop = (*self.groups[:-1], sympy.S.One) + kernel_size_inside_loop = tuple(self.groups) + self.kernel_sizes = kernel_size_inside_loop + + for node in self.features.node_schedule: + if node is DisableReduction: + self.inside_reduction = False + self.kernel_sizes = kernel_size_outside_loop + continue + elif node is EnableReduction: + self.inside_reduction = True + self.kernel_sizes = kernel_size_inside_loop + self.loops.append(MemoryEstimate()) + continue + assert isinstance(node, SchedulerNode) + rw = extract_loop_body_with_args( + node._body, + SIMDKernel.map_kernel_groups_to_node_sizes( + self.kernel_sizes, node.get_ranges(), self.set_ranges + ), + dict(zip(self.symbols, self.kernel_sizes)), + ) + + for dep in rw._reads: + assert isinstance(dep, MemoryDep) + dep = dep.simplify_with_ranges() + if not self.persistent.writes.get(dep.name): # cache miss? + self.persistent.reads[dep.name].add(dep) + # the cache behavior of looped kernels is more complex than the persistent case above + # some operations are lifted outside the loop (if they don't use the reduction dimension) + # other operations are inside the loop, and can only be reused within the same loop + if not ( + self.outside_loop.writes.get(dep.name) + or self.loops[-1].writes.get(dep.name) + ): + self.scope(dep).reads[dep.name].add(dep) + if dep.name in self.store_buffer_names and self.loops[-1].reads.get( + dep.name + ): + self.must_keep_buffers.add(dep.name) + + for dep in rw._writes: + assert isinstance(dep, MemoryDep) + dep = dep.simplify_with_ranges() + self.store_buffer_names.add(dep.name) + self.persistent.writes[dep.name].add(dep) + self.scope(dep).writes[dep.name].add(dep) + + def remove_kernel_local(self) -> None: + # Remove any kernel-local buffers + fused_node_names = OrderedSet( + [n.get_name() for n in self.features.scheduler_nodes()] + ) + for name in self.store_buffer_names: + if not self.persistent.reads.get( + name + ) and V.graph.scheduler.can_buffer_be_removed_through_fusion( + name, fused_node_names + ): + self.persistent.remove(name) + if name not in self.must_keep_buffers: + # we can also remove this from the looped kernel + self.outside_loop.remove(name) + for loop in self.loops: + loop.remove(name) + + if not self.loops[-1]: + self.loops.pop() # for pointwise ops + + def scope(self, dep: MemoryDep) -> MemoryEstimate: + """Determine how a read/write should be categorized""" + if self.inside_reduction and ( + self.has_reduction_var(dep.index) or dep.is_indirect() + ): + return self.loops[-1] + return self.outside_loop + + def has_reduction_var(self, index: sympy.Expr) -> bool: + for sym in self.symbols[-self.num_reductions_dims :]: + if isinstance(sym, sympy.Symbol) and sym in index.free_symbols: + return True + return False + + def set_ranges(self, *lengths: list[list[sympy.Expr]]) -> list[list[sympy.Expr]]: + assert len(self.kernel_sizes) == len(lengths) + return [ + self.make_flat_range(sym, numel, length) + for sym, numel, length in zip(self.symbols, self.kernel_sizes, lengths) + ] + + @staticmethod + def make_flat_range( + sym: sympy.Symbol, numel: sympy.Expr, lengths: list[sympy.Expr] + ) -> list[sympy.Expr]: + if len(lengths) == 1 and numel == lengths[0]: + return [sym] + divisor = sympy.S.One + itervars = [] + for length in reversed(lengths): + if V.graph.sizevars.statically_known_equals(divisor * length, numel): + expr = FloorDiv(sym, divisor) + else: + expr = ModularIndexing(sym, divisor, length) + itervars.append(expr) + divisor = divisor * length + return [*reversed(itervars)] + + +@dataclasses.dataclass +class MemoryEstimate: + """Tracks the memory usage of a single loop in the generated kernel""" + + reads: dict[str, OrderedSet[MemoryDep]] = dataclasses.field( + default_factory=functools.partial(collections.defaultdict, OrderedSet) + ) + writes: dict[str, OrderedSet[MemoryDep]] = dataclasses.field( + default_factory=functools.partial(collections.defaultdict, OrderedSet) + ) + + def remove(self, name: str) -> None: + self.reads.pop(name, None) + self.writes.pop(name, None) + + def __bool__(self) -> bool: + return bool(self.reads or self.writes) + + def __repr__(self) -> str: + return f"""MemoryEstimate( + reads={[*itertools.chain.from_iterable(self.reads.values())]!r}, + writes={[*itertools.chain.from_iterable(self.writes.values())]!r} + )""" + + +@dataclasses.dataclass +class StatsForDim: + """Memory usage stats for a block dimension in the generated kernel (different from user dimensions)""" + + # the number of load/store ops + count_per_thread_contiguous: int = 0 + count_per_thread_broadcast: int = 0 + count_per_thread_non_contiguous: int = 0 # excludes broadcast + + # total bytes in each load/store op for a single element + bytes_per_thread_contiguous: int = 0 + bytes_per_thread_broadcast: int = 0 + bytes_per_thread_non_contiguous: int = 0 # excludes broadcast + + # total bytes read by entire kernel + bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero + bytes_non_contiguous: sympy.Expr = sympy.S.Zero + + def __add__(self, other: typing.Self) -> StatsForDim: + return StatsForDim( + count_per_thread_contiguous=self.count_per_thread_contiguous + + other.count_per_thread_contiguous, + count_per_thread_broadcast=self.count_per_thread_broadcast + + other.count_per_thread_broadcast, + count_per_thread_non_contiguous=self.count_per_thread_non_contiguous + + other.count_per_thread_non_contiguous, + bytes_per_thread_contiguous=self.bytes_per_thread_contiguous + + other.bytes_per_thread_contiguous, + bytes_per_thread_broadcast=self.bytes_per_thread_broadcast + + other.bytes_per_thread_broadcast, + bytes_per_thread_non_contiguous=self.bytes_per_thread_non_contiguous + + other.bytes_per_thread_non_contiguous, + bytes_contiguous_or_broadcast=self.bytes_contiguous_or_broadcast + + other.bytes_contiguous_or_broadcast, + bytes_non_contiguous=self.bytes_non_contiguous + other.bytes_non_contiguous, + ) + + @property + def count_per_thread(self) -> int: + return ( + self.count_per_thread_contiguous + + self.count_per_thread_broadcast + + self.count_per_thread_non_contiguous + ) + + @property + def bytes_per_thread(self) -> int: + return ( + self.bytes_per_thread_contiguous + + self.bytes_per_thread_broadcast + + self.bytes_per_thread_non_contiguous + ) + + @property + def bytes(self) -> sympy.Expr: + return self.bytes_contiguous_or_broadcast + self.bytes_non_contiguous + + @property + def contiguous_score(self) -> float: + return 1.0 - self.count_per_thread_non_contiguous / max( + self.count_per_thread, 1 + ) + + +@dataclasses.dataclass +class StatsForLoop: + """Memory usage stats for single loop in the generated kernel""" + + # load/store ops + count_per_thread: int = 0 + bytes_per_thread: int = 0 + + def __add__(self, other: typing.Self) -> StatsForLoop: + return StatsForLoop( + count_per_thread=self.count_per_thread + other.count_per_thread, + bytes_per_thread=self.bytes_per_thread + other.bytes_per_thread, + ) + + +@dataclasses.dataclass +class StatsForReadsOrWrites: + """Memory usage stats that are collected for reads/writes/both""" + + dim: list[StatsForDim] + loop: list[StatsForLoop] + # total bytes contiguous in any dimension + bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero + bytes_non_contiguous: sympy.Expr = sympy.S.Zero + + def __add__(self, other: typing.Self) -> StatsForReadsOrWrites: + assert len(self.dim) == len(other.dim) + assert len(self.loop) == len(other.loop) + return StatsForReadsOrWrites( + dim=[a + b for a, b in zip(self.dim, other.dim)], + loop=[a + b for a, b in zip(self.loop, other.loop)], + bytes_contiguous_or_broadcast=self.bytes_contiguous_or_broadcast + + self.bytes_contiguous_or_broadcast, + bytes_non_contiguous=self.bytes_non_contiguous + other.bytes_non_contiguous, + ) + + @property + def count_per_thread(self) -> int: + return self.dim[0].count_per_thread + + @property + def bytes_per_thread(self) -> int: + return self.dim[0].bytes_per_thread + + @property + def bytes(self) -> sympy.Expr: + return self.bytes_contiguous_or_broadcast + self.bytes_non_contiguous + + @classmethod + def compute( + cls, + loop_deps: list[dict[str, OrderedSet[MemoryDep]]], + index_symbols: list[sympy.Symbol], + ) -> typing.Self: + ndim = len(index_symbols) + result = cls(dim := [StatsForDim() for _ in range(ndim)], []) + for dep_group in loop_deps: + result.loop.append(loop_stats := StatsForLoop()) + for name, deps in dep_group.items(): + assert deps + contiguous_or_broadcast = [True] * ndim + numel = sympy.S.Zero + itemsize = V.graph.get_dtype(name).itemsize + loop_stats.count_per_thread += len(deps) + loop_stats.bytes_per_thread += itemsize * len(deps) + for dep in deps: + strides: list[sympy.Expr] = V.graph.sizevars.stride_vars( + dep.index, index_symbols + ) + for i in range(ndim): + if V.graph.sizevars.statically_known_equals(strides[i], 1): + dim[i].count_per_thread_contiguous += 1 + dim[i].bytes_per_thread_contiguous += itemsize + elif ( + V.graph.sizevars.statically_known_equals(strides[i], 0) + and not dep.is_indirect() + ): + dim[i].count_per_thread_broadcast += 1 + dim[i].bytes_per_thread_broadcast += itemsize + else: + dim[i].count_per_thread_non_contiguous += 1 + dim[i].bytes_per_thread_non_contiguous += itemsize + contiguous_or_broadcast[i] = False + numel += dep.get_numel() + if len(deps) > 1: + # can't read more elements than exist in the buffer + numel = sympy.Min(numel, V.graph.get_numel(name)) + nbytes = numel * itemsize + for i in range(ndim): + if contiguous_or_broadcast[i]: + dim[i].bytes_contiguous_or_broadcast += nbytes + else: + dim[i].bytes_non_contiguous += nbytes + if any(contiguous_or_broadcast): + result.bytes_contiguous_or_broadcast += nbytes + else: + result.bytes_non_contiguous += nbytes + if len(result.loop) > 1: + # the first loop represent the "outside of the loop" compute which could be long lived + result.loop = [result.loop[0] + x for x in result.loop[1:]] + return result + + +@dataclasses.dataclass +class StatsForKernelType: + """Memory usage stats that are collected for both persistent and looped kernels""" + + reads: StatsForReadsOrWrites + writes: StatsForReadsOrWrites + memory: StatsForReadsOrWrites + + @classmethod + def compute( + cls, loops: list[MemoryEstimate], estimator: MemoryEstimator + ) -> typing.Self: + reads = StatsForReadsOrWrites.compute( + [loop.reads for loop in loops], estimator.symbols + ) + writes = StatsForReadsOrWrites.compute( + [loop.writes for loop in loops], estimator.symbols + ) + return cls( + reads=reads, + writes=writes, + memory=reads + writes, + ) + + +@dataclasses.dataclass +class MemoryStats: + """Memory usage stats collected for each generated kernel""" + + persistent: StatsForKernelType + looped: StatsForKernelType + + def get(self, persistent: bool) -> StatsForKernelType: + return self.persistent if persistent else self.looped + + @classmethod + def compute(cls, estimator: MemoryEstimator) -> typing.Self: + persistent = StatsForKernelType.compute([estimator.persistent], estimator) + if len(estimator.loops) == 1 and not ( + estimator.outside_loop and estimator.loops[0] + ): + looped = persistent # loops/persistent is the same in this common case + else: + looped = StatsForKernelType.compute( + [estimator.outside_loop, *estimator.loops], estimator + ) + return cls( + persistent=persistent, + looped=looped, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py new file mode 100644 index 0000000000000000000000000000000000000000..b89966947e56e4282bfdc75051035cd4e96dcbc8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/subgraph.py @@ -0,0 +1,208 @@ +import itertools +import logging +from typing import Any, Callable + +import torch +import torch._inductor.config as config +from torch._inductor import ir +from torch._inductor.codegen.common import KernelTemplate +from torch._inductor.ir import ( + Buffer, + get_free_symbols, + get_symbolic_inputs, + gm_original_output_strides, + ir_node_to_tensor, + Layout, +) +from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.utils import do_bench_using_profiling +from torch._inductor.virtualized import V + + +log = logging.getLogger(__name__) + + +class SubgraphChoiceCaller(ir.ChoiceCaller): + """ + Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary + GraphModule. Compiles the Subgraph down to a module for benchmarking. + """ + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + description: str, + make_fx_graph: Callable[..., Any], + ) -> None: + super().__init__(name, input_nodes, layout, description) + + self.example_inputs = [] + with V.fake_mode: + for inp in self.input_nodes: + # Here there will be no unbacked symbols, as SubgraphBuffer does not support them + assert len(get_free_symbols(inp.get_size(), unbacked_only=True)) == 0 + assert len(get_free_symbols(inp.get_stride(), unbacked_only=True)) == 0 + + inp.data.freeze_layout() # type: ignore[attr-defined] + self.example_inputs.append(ir_node_to_tensor(inp)) + + self.gm = make_fx_graph(*self.example_inputs) + gm_original_output_strides(self.gm) + + self.sym_inputs = get_symbolic_inputs(self.input_nodes) + + def __str__(self) -> str: + return f"SubgraphCaller({self.name})" + + def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: + # Codegen Subgraph for benchmarking + # Need GraphLowering instead of SubgraphLowering to generate + # fully callable module + import torch._inductor.config as inductor_config + from torch._inductor.graph import GraphLowering + + bm_graph_lowering = GraphLowering( + gm=self.gm, + example_inputs=self.example_inputs, + shape_env=V.graph._shape_env, + cpp_wrapper=V.graph.cpp_wrapper, + aot_mode=V.graph.aot_mode, + extern_node_serializer=V.graph.extern_node_serializer, + is_inference=V.graph.is_inference, + is_backward=V.graph.is_backward, + name=f"benchmark_{self.name}", + ) + + for sym_inp in self.sym_inputs: + bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp + bm_graph_lowering.graph_input_names.append(sym_inp.name) + + sym_inputs = [ + int(V.graph.sizevars.shape_env.size_hint(sym_var)) + for sym_var in self.sym_inputs + ] + + if len(sym_inputs) == 0: + # Sanity check that args are same layout as example inputs + # Only do it if there are no symbolic inputs, otherwise + # the dynamic dim will be realized to the same size as args + for ar, example_inp in zip(args, self.example_inputs): + # Sanity check that args are same layout as example inputs + if isinstance(ar, torch.Tensor): + assert isinstance(example_inp, torch.Tensor) + assert ar.shape == example_inp.shape + assert ar.stride() == example_inp.stride() + + if len(sym_inputs) == 0: + # Sanity check that args are same layout as example inputs + # Only do it if there are no symbolic inputs, otherwise + # the dynamic dim will be realized to the same size as args + for ar, example_inp in zip(args, self.example_inputs): + # Sanity check that args are same layout as example inputs + if isinstance(ar, torch.Tensor): + assert isinstance(example_inp, torch.Tensor) + assert ar.shape == example_inp.shape + assert ar.stride() == example_inp.stride() + + with V.set_graph_handler(bm_graph_lowering): + # Don't bother autotuning on Triton here + with inductor_config.patch( + max_autotune=False, + max_autotune_gemm=False, + max_autotune_gemm_backends="ATEN", + ): + bm_graph_lowering.run(*self.example_inputs) + mod = bm_graph_lowering.compile_to_module() + bm_func = mod.call + + bm_func([*sym_inputs, *args]) + if config.profile_bandwidth_with_do_bench_using_profiling: + return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args])) + return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args])) + + def hash_key(self) -> str: + return "-".join( + [ + self.name.rsplit("_", 1)[0], + *[str(inp.get_size()) for inp in self.input_nodes], + *[str(inp.get_stride()) for inp in self.input_nodes], + str(self.gm.graph), + ] + ) + + def output_node(self) -> ir.TensorBox: + return ir.TensorBox.create( + ir.SubgraphBuffer( + layout=self.layout, + input_nodes=self.input_nodes, + gm=self.gm, + example_inputs=self.example_inputs, + subgraph_name=self.name, + ) + ) + + def info_dict(self) -> dict[str, Any]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "subgraph", + "kernel_name": self.name, + } + + def autoheuristic_id(self) -> str: + return f"subgraph_{self.name}" + + +class SubgraphTemplate(KernelTemplate): + """ + A template for subgraph evaluation to be used in autotuning. + + This class allows creating customized subgraphs that can be appended + as choices during the autotuning process, enabling the selection of + optimal implementations for complex operations. + """ + + index_counter = itertools.count() + + def __init__( + self, + name: str, + make_fx_graph: Callable[..., Any], + ): + """ + Initialize a subgraph template. + + Args: + name: The name of this template + graph: The FX graph + """ + self.name = f"{name}_{next(SubgraphTemplate.index_counter)}" + self.make_fx_graph = make_fx_graph + + def generate( # type: ignore[override] + self, + input_nodes: list[Buffer], + layout: Layout, + **kwargs: Any, + ) -> SubgraphChoiceCaller: + """ + Generate a SubgraphChoiceCaller instance for autotuning. + + Args: + input_nodes: List of input nodes to the subgraph + layout: Memory layout information for the output + example_inputs: Example tensor inputs used to trace and benchmark the subgraph + **kwargs: Additional keyword arguments + + Returns: + SubgraphChoiceCaller: A callable object that can be used for autotuning + """ + + return SubgraphChoiceCaller( + name=self.name, + input_nodes=input_nodes, + layout=layout, + description="", + make_fx_graph=self.make_fx_graph, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..a404abc136f52a7bb5207d3f3177c033115daf39 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton.py @@ -0,0 +1,4526 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import textwrap +from collections.abc import Iterable, Sequence +from functools import lru_cache +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union + +import sympy +from sympy.printing.precedence import PRECEDENCE + +import torch +import torch._logging +import torch.utils._pytree as pytree +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import identity, preserve_rng_state +from torch._prims_common import is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._triton import has_triton_package + +from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT +from ...utils._sympy.value_ranges import ValueRanges +from .. import config, ir, metrics +from ..async_compile import AsyncCompile +from ..codecache import code_hash, get_path, PyCodeCache, write_atomic +from ..ops_handler import DefaultHandler +from ..runtime import triton_heuristics +from ..runtime.benchmarking import benchmarker +from ..runtime.hints import ( + AutotuneHint, + DeviceProperties, + TRITON_MAX_BLOCK, + TRITON_MAX_RSPLIT, +) +from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2 +from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode +from ..utils import ( + cache_on_self, + DelayReplaceLine, + get_bounds_index_expr, + get_fused_kernel_name, + get_kernel_metadata, + is_welford_reduction, + Placeholder, + prefix_is_reduction, + sympy_dot, + sympy_product, + sympy_subs, + triton_type, + triton_version_uses_attrs_dict, + upcast_compute_type, +) +from ..virtualized import _ops as ops, ReductionType, StoreMode, V +from ..wrapper_benchmark import get_kernel_category_by_source_code +from .block_analysis import BlockPatternMatcher +from .common import ( + ArgName, + BackendFeature, + ConstexprArg, + CSE, + CSEVariable, + DeferredLine, + IndentedBuffer, + InplacedBuffer, + OpOverrides, + PythonPrinter, + RemovedArg, + SizeArg, + TensorArg, + WorkspaceArg, + WorkspaceZeroMode, +) +from .simd import ( + constant_repr, + IterationRanges, + IterationRangesEntry, + IterationRangesRoot, + SIMDKernel, + SIMDScheduling, +) +from .triton_utils import ( + config_of, + equal_1_arg_indices, + non_constexpr_signature, + should_unwrap_unspec_arg, + signature_to_meta, +) +from .wrapper import SymbolicCallArg + + +if TYPE_CHECKING: + from types import ModuleType + from typing import TypeVar + + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + + from ..ir import IRNode + from .simd_kernel_features import SIMDKernelFeatures + + _T = TypeVar("_T") + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +async_compile = AsyncCompile() + + +class OpDtypeSupport: + """ + Some Triton ops such as libdevice and tl.math only support float32 and float64. + This class records which dtypes are supported by specific IR ops. + """ + + supported_dtypes: dict[str, OrderedSet[torch.dtype]] = {} + convert_outputs: dict[str, bool] = {} + + @classmethod + def register_upcast(cls, func: Callable[..., str], convert_output: bool) -> None: + op_name = func.__name__ + cls.supported_dtypes[op_name] = OrderedSet([torch.float32, torch.float64]) + cls.convert_outputs[op_name] = convert_output + + +@lru_cache(None) +def gen_attr_descriptor_import() -> str: + """ + import AttrsDescriptor if the triton version is new enough to have this + class defined. + """ + if not has_triton_package(): + return "" + + import triton.compiler.compiler + + # Note: this works because triton.compiler.compiler imports AttrsDescriptor from triton.backends.compiler + # When support for the legacy AttrsDescriptor is removed then this import path should be changed. + if hasattr(triton.compiler.compiler, "AttrsDescriptor"): + return "from triton.compiler.compiler import AttrsDescriptor" + else: + return "" + + +@lru_cache(None) +def gen_common_triton_imports() -> str: + imports = IndentedBuffer() + imports.splice( + """ + import triton + import triton.language as tl + """ + ) + if attr_desc := gen_attr_descriptor_import(): + imports.writeline(attr_desc) + + imports.splice( + """ + from torch._inductor.runtime import triton_helpers, triton_heuristics + from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math + from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + """ + ) + return imports.getvalue() + + +class TritonSymbols: + """ + Stores sympy.Symbol instances and constants associated with triton codegen. + """ + + reduction_types = OrderedSet([SymT.R0_INDEX, SymT.R1_INDEX]) + block_types = OrderedSet([SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, *reduction_types]) + + block_offsets = { + symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True) + for symt in block_types + } + + block_sizes = { + symt: sympy.Symbol( + f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True + ) + for symt in block_types + } + + @classmethod + def get_block_size(cls, tree: IterationRanges) -> sympy.Symbol: + return cls.block_sizes[tree.symt] + + @classmethod + def get_block_offset(cls, tree: IterationRanges) -> sympy.Symbol: + return cls.block_offsets[tree.symt] + + +@dataclasses.dataclass +class IndexingOptions: + index_str: str + mask_vars: OrderedSet[str] + expand_str: Optional[str] + _has_rindex: bool + index: sympy.Expr + + def has_mask(self) -> bool: + return bool(self.mask_vars) + + def has_indirect(self) -> bool: + return free_symbol_is_type(self.index, SymT.TMP) + + def has_rindex(self) -> bool: + return self._has_rindex + + def has_tmpmask(self) -> bool: + return any(str(mask).startswith("tmp") for mask in self.mask_vars) + + def has_rmask(self) -> bool: + return any(str(mask).startswith("r") for mask in self.mask_vars) + + @property + def mask_str(self) -> str: + # The sorted call is added to make sure the order is still + # deterministic if self.mask_vars contains mix of string + # and TritonCSEVariable + return ( + " & ".join(sorted(map(str, self.mask_vars))) if self.mask_vars else "None" + ) + + +@dataclasses.dataclass +class BlockPtrOptions: + params: BlockParameters + constant_offset: sympy.Expr + order: list[int] + mask_vars: OrderedSet[str] + broadcast_shape: Sequence[sympy.Expr] + broadcasting_dims: list[bool] + final_shape: Sequence[sympy.Expr] + _boundary_check: Optional[list[int]] = None + + @property + def shape(self) -> list[sympy.Expr]: + return self.params.shape + + @property + def block_shape(self) -> list[sympy.Expr]: + return self.params.block_shape + + @property + def strides(self) -> list[sympy.Expr]: + return self.params.strides + + @property + def offsets(self) -> list[sympy.Expr]: + return self.params.offsets + + def codegen_broadcast_and_reshape( + self, + value: str, + initial_shape: Sequence[sympy.Expr], + final_shape: Sequence[sympy.Expr], + allow_implicit: bool, + ) -> str: + """ + Generate a broadcast and a reshape for the block pointer. + This restores stride-0 dimensions which were removed from the block pointer. + """ + + # Reshape to add singletons. + pre_broadcast_shape = [ + sympy.S.One if is_broadcasting else dim + for dim, is_broadcasting in zip( + self.broadcast_shape, self.broadcasting_dims + ) + ] + value = triton_reshape(value, initial_shape, pre_broadcast_shape) + + # Broadcast singletons. + # For loads, we can often implicitly broadcast singleton dimensions. + # We need an explicit broadcast for stores, or if the final reshape does more + # than add singletons. + sizevars = V.graph.sizevars + supports_implicit_broadcast = allow_implicit and ( + len(pre_broadcast_shape) == len(final_shape) + and all( + sizevars.statically_known_equals(pre_dim, 1) + or sizevars.statically_known_equals(pre_dim, post_dim) + for pre_dim, post_dim in zip(pre_broadcast_shape, final_shape) + ) + ) + + if any(self.broadcasting_dims) and not supports_implicit_broadcast: + value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(self.broadcast_shape)})" + + # Reshape to the final shape. + value = triton_reshape(value, self.broadcast_shape, final_shape) + + return value + + @staticmethod + def create( + *, + params: BlockParameters, + constant_offset: sympy.Expr, + range_trees: list[IterationRangesRoot], + mask_vars: OrderedSet[str], + get_max_block: Callable[[str], int], + ) -> BlockPtrOptions: + """Helper to create a BlockPtrOptions instance""" + + sizevars = V.graph.sizevars + + def lookup_size(exprs: Iterable[sympy.Expr]) -> list[sympy.Expr]: + return [sizevars.lookup_precomputed_size(expr) for expr in exprs] + + # Look up precomputed sizes + params.shape = lookup_size(params.shape) + params.strides = lookup_size(params.strides) + + # Strip out dimensions of stride 0. + # These will be restored with tl.broadcast_to. + broadcasting_dims = [ + sizevars.statically_known_equals(stride, 0) for stride in params.strides + ] + + # Strip out dimensions of size 1. + # These will be restored by tl.reshape. + singleton_dims = [ + sizevars.statically_known_equals(dim, 1) for dim in params.block_shape + ] + if all(singleton_dims): + # Handle a pure singletons, e.g. [1, 1] + singleton_dims[-1] = False + + # Record the post-broadcast shape before broadcasting dims are removed. + # The pre-broadcast shape is identical to this, except broadcasting dims are + # replaced with 1. + broadcast_shape = [ + dim + for dim, is_singleton in zip(params.block_shape, singleton_dims) + if not is_singleton + ] + + # Combine all removable dims. + removable_dims = [any(dims) for dims in zip(singleton_dims, broadcasting_dims)] + + def remove_dims(it): + """Removes any broadcasting or singleton dims from a given sequence""" + return [ + item + for item, is_removable in zip(it, removable_dims) + if not is_removable + ] + + # Drop removable dimensions from the input. + params = BlockParameters( + **{key: remove_dims(val) for key, val in dataclasses.asdict(params).items()} + ) + + # Compute the final shape, adjusting for special kernel types. + final_shape = [TritonSymbols.get_block_size(tree) for tree in range_trees] + if V.kernel.no_x_dim: + assert range_trees[0].prefix == "x" + final_shape.pop(0) + + reduction_ndim = V.kernel.num_reduction_dims + if ( + not V.kernel.inside_reduction + and len(params.strides) == len(V.kernel.numels) - reduction_ndim + and V.kernel.features.is_reduction() + ): + # Need to expand rank to match the rank used inside the reduction loop + final_shape += [sympy.S.One] * reduction_ndim + + result = BlockPtrOptions( + params=params, + constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), + order=list(reversed(range(len(params.shape)))), + mask_vars=mask_vars, + final_shape=final_shape, + broadcast_shape=broadcast_shape, + broadcasting_dims=broadcasting_dims, + ) + result.compute_boundary_check(get_max_block, range_trees) + return result + + def replace_offset( + self, expr: sympy.Expr, replacement: sympy.Expr, symt: SymT + ) -> sympy.Expr: + """ + Replaces instances of {symt}_offset with the new expression. + """ + roffset = TritonSymbols.block_offsets[symt] + return sympy_subs(expr, {roffset: replacement}) + + def format(self, name: str, roffset=True) -> str: + """ + Codegen a call to tl.make_block_ptr() + + Args: + name: variable name for pointer + roffset: should rn_offset be included in offsets=..., for use with tl.advance() + + Returns: + "tl.make_block_ptr(...)" + """ + + def remove_roffsets(expr: sympy.Expr) -> sympy.Expr: + for symt in TritonSymbols.reduction_types: + expr = self.replace_offset(expr, sympy.Integer(0), symt) + return expr + + f = V.kernel.index_to_str + offsets = [*self.offsets] + if not roffset: + offsets = [remove_roffsets(offset) for offset in offsets] + args = [ + ( + f"{name} + ({f(self.constant_offset)})" + if self.constant_offset != 0 + else name + ), + f"shape={f(self.shape)}", + f"strides={f(self.strides)}", + f"block_shape={f(self.block_shape)}", + f"order={f(self.order)}", + f"offsets={f(offsets)}", + ] + return f"tl.make_block_ptr({', '.join(args)})" + + def compute_boundary_check( + self, + get_max_block: Callable[[str], int], + range_trees: list[IterationRangesRoot], + ) -> None: + """List of indices to pass to tl.load(boundary_check=...)""" + sizevars = V.graph.sizevars + + # Substitute maximum block sizes in shape expressions. + # This works in multiple_of checks because block sizes are powers of 2. + block_to_max: dict[sympy.Expr, Any] = { + TritonSymbols.block_sizes[t.symt]: get_max_block(prefix_str[t.symt]) + for t in range_trees + } + + # Also see Note: Constant mask optimisation + # if ynumel / YBLOCK > max_ygrid, then the z dimension is used to handle + # the remaining programs that cannot fit into the y dimension. This means + # it's possible that more than the required number of programs are launched, + # possibly leading to out-of-bounds accesses. So even if ynumel divides YBLOCK, + # boundary checking is required in the dimensions that are based on YBLOCK + # e.g. for [YBLOCK // 16, YBLOCK, XBLOCK] dimensions 0 and 1 need boundary + # checks when max_ygrid is exceeded. + needs_overflow_grid = any(map(V.kernel.needs_yz_grid_overflow, range_trees)) + self._boundary_check = [ + idx + for idx in range(len(self.shape)) + if ( + not sizevars.statically_known_equals(self.strides[idx], sympy.S.Zero) + and ( + ( + needs_overflow_grid + and TritonSymbols.block_sizes[SymT.YBLOCK] + in self.block_shape[idx].free_symbols + ) + or ( + not sizevars.statically_known_multiple_of( + self.shape[idx], self.block_shape[idx] + ) + and not sizevars.statically_known_multiple_of( + self.shape[idx], + sympy_subs(self.block_shape[idx], block_to_max), + ) + ) + ) + and not ( + V.kernel.no_x_dim + and self.block_shape[idx] == TritonSymbols.block_sizes[SymT.XBLOCK] + ) + ) + ] + + def boundary_check(self) -> list[int]: + assert self._boundary_check is not None + return self._boundary_check + + def advance_roffset(self, symt: SymT) -> sympy.Expr: + """ + Codegen string to pass to tl.advance(name, ...). + + Advance is the difference between offsets in each loop iteration. + To compute it, we replace rN_offset with multiples of RN_BLOCK. + Since we expect rN_offset to vary in range(0, rN_numel, RN_BLOCK), the first + iteration has rN_offset=0, while the second has rN_offset=RN_BLOCK. + """ + rblock = TritonSymbols.block_sizes[symt] + advance = [ + ( + self.replace_offset(offset, rblock, symt) + - self.replace_offset(offset, sympy.S.Zero, symt) + ) + for offset in self.offsets + ] + return advance + + def has_indirect(self) -> bool: + return False # block_ptr can't do indirect indexing + + def has_rindex(self) -> bool: + return any( + free_symbol_is_type(expr, TritonSymbols.reduction_types) + for expr in self.block_shape + ) + + def has_rmask(self) -> bool: + return self.has_rindex() + + def has_tmpmask(self) -> bool: + return False # block_ptr can't do indirect indexing + + def has_mask(self) -> bool: + return bool(self.boundary_check()) + + +def triton_reshape( + value: str, old_shape: Sequence[sympy.Expr], new_shape: Sequence[sympy.Expr] +) -> str: + """Workaround https://github.com/triton-lang/triton/issues/2836""" + assert isinstance(old_shape, list) and isinstance(new_shape, list) + + old_shape_str = [V.kernel.index_to_str(shape) for shape in old_shape] + new_shape_str = [V.kernel.index_to_str(shape) for shape in new_shape] + + if old_shape_str == new_shape_str: + return value + if [s for s in new_shape_str if s != "1"] != old_shape_str: + return f"tl.reshape({value}, [{', '.join(new_shape_str)}])" + # rewrite to [:, None] syntax, which is less buggy + idx = 0 + expand = [] + for size in new_shape_str: + if idx < len(old_shape_str) and size == old_shape_str[idx]: + expand.append(":") + idx += 1 + else: + assert size == "1" + expand.append("None") + assert idx == len(old_shape_str) + return f"{value}[{', '.join(expand)}]" + + +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem +class TritonPrinter(PythonPrinter): + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_Float(self, expr: sympy.Expr) -> str: + if config.is_fbcode() and torch.version.hip: + ret = f"{expr}" + else: + ret = f"tl.full([], {expr}, tl.float64)" + return ret + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5) + return f"{s}.to(tl.float64)" + + def _print_PythonMod(self, expr: sympy.Expr) -> str: + quot, div = expr.args + if quot.is_nonnegative and div.is_nonnegative: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + quot_s = self._print(quot) + div_s = self._print(div) + return f"triton_helpers.remainder_integer({quot_s}, {div_s})" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + assert expr.is_integer + quot, div = expr.args + if quot.is_nonnegative and div.is_nonnegative: + return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5) + quot_s = self._print(quot) + div_s = self._print(div) + return f"triton_helpers.div_floor_integer({quot_s}, {div_s})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype + def _print_floor(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ceiling(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _helper_sqrt(self, expr: sympy.Expr) -> str: + return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))" + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + if expr.args[0].is_Integer: + return f"libdevice.pow({float(expr.args[0])}, {self._print(expr.args[1])})" + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) + + def _print_Where(self, expr: sympy.Expr) -> str: + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"tl.where({c}, {p}, {q})" + + def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str: + """ + Helper for max/min code generation. + cmp: > or < + """ + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + cls = type(expr) + a = self._print(cls(*expr.args[:mid])) + b = self._print(cls(*expr.args[mid:])) + + # Use a macro so we can propagate constexprs. + # https://github.com/triton-lang/triton/issues/3815 + a, b = tuple(f"({x})" for x in (a, b)) + assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'" + return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))" + + def _print_Min(self, expr: sympy.Expr) -> str: + return self._print_min_max_helper(expr, "<") + + def _print_Max(self, expr: sympy.Expr) -> str: + return self._print_min_max_helper(expr, ">") + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"tl_math.abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.log2(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}" + + +texpr = TritonPrinter().doprint + + +def triton_compute_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type and upcast [b]float16 to float32""" + return triton_type(upcast_compute_type(dtype)) + + +def triton_store_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type, with fix for storing tl.bool""" + if dtype == torch.bool: + dtype = torch.int8 + return triton_type(dtype) + + +def upcast_acc_dtype(dtype: torch.dtype) -> torch.dtype: + """Implicit upcasts used for Triton reduction types""" + if is_integer_dtype(dtype) and dtype.is_signed and dtype.itemsize <= 4: + return torch.int32 + return upcast_compute_type(dtype) + + +def triton_acc_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type, with reduction upcasts""" + return triton_compute_type(upcast_acc_dtype(dtype)) + + +def low_precision_fp(dtype: torch.dtype) -> bool: + return dtype.itemsize <= 2 and dtype.is_floating_point + + +def low_precision_fp_var(var: Union[CSEVariable, Any]) -> bool: + if not isinstance(var, CSEVariable): + return False + + dtype = var.dtype + return low_precision_fp(dtype) if isinstance(dtype, torch.dtype) else False + + +class TritonCSEVariable(CSEVariable): + def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None: + super().__init__(name, bounds, dtype) + # We'll use this to track which masks the variable needs when used for indirect indexing + self.mask_vars: OrderedSet[str] = OrderedSet() + assert dtype is not None, "TritonCSEVariable must have dtype" + + def update_on_args(self, name, args, kwargs): + for arg in args: + if isinstance(arg, TritonCSEVariable): + self.mask_vars.update(arg.mask_vars) + elif isinstance(arg, sympy.Symbol): + # most of the time index vars don't need masks associated with them + # however, when index vars are used to compute indices for indirect reads + # those reads should subsequently be masked, + for symt in TritonSymbols.block_types: + if symbol_is_type(arg, symt): + self.mask_vars.update([f"{prefix_str[symt]}mask"]) + break + + +def get_dtype_handler() -> DtypePropagationOpsHandler: + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + + return DtypePropagationOpsHandler() + + +def maybe_upcast_float32(convert_output: bool = True) -> Callable[[_T], _T]: + """ + Codegen helper to upcast arguments to float32, depending on the config and dtype. + This decorates tl.math/libdevice codegen functions. + """ + + def needs_upcast(var) -> bool: + return ( + not config.triton.codegen_upcast_to_fp32 + and isinstance(var, CSEVariable) + and var.dtype in (torch.float16, torch.bfloat16) + ) + + def maybe_upcast_arg(var) -> str: + upcast_string = ".to(tl.float32)" if needs_upcast(var) else "" + return f"{var}{upcast_string}" + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + # Record that this function only supports float32 and float64. + OpDtypeSupport.register_upcast(func, convert_output) + + def wrapped(*args, **kwargs) -> str: + # Optionally upcast args to float32. + upcast_args = [maybe_upcast_arg(arg) for arg in args] + upcast_kwargs = {key: maybe_upcast_arg(val) for key, val in kwargs.items()} + + # Call the decorated function, optionally downcasting the result. + result = func(*upcast_args, **upcast_kwargs) + any_needs_upcast = convert_output and any( + needs_upcast(var) for var in itertools.chain(args, kwargs.values()) + ) + result_dtype = ( + None + if not any_needs_upcast + else getattr(get_dtype_handler(), func.__name__)(*args, **kwargs) + ) + needs_downcast = result_dtype not in (torch.float32, None) + downcast_string = ( + f".to({triton_type(result_dtype)})" + if needs_downcast and result_dtype is not None + else "" + ) + return f"{result}{downcast_string}" + + return wrapped + + return decorator # type: ignore[return-value] + + +class TritonOverrides(OpOverrides): + """Map element-wise ops to Triton""" + + _LOG_2_E = math.log2(math.e) + + @staticmethod + def to_dtype( + x, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ): + def _get_min_elements_per_thread( + src_dtype: torch.dtype, dst_dtype: torch.dtype + ) -> int: + if src_dtype == dst_dtype: + # No data type conversion is needed. No requirements on min_elem_per_thread. + return 0 + + # fp8 data type conversions has min_elem_per_thread requirements. + # Refer to Triton implementations here: + # https://github.com/triton-lang/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. + fp8_dtypes = ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2. + assert not ( + src_dtype in fp8_dtypes + and dst_dtype in fp8_dtypes + and src_dtype != dst_dtype + ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!" + if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2: + return 4 + if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn: + return 2 + # No requirements on min_elem_per_thread. + return 0 + + if src_dtype is not None: + # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype). + # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions + # in the same kernel. + V.kernel.min_elem_per_thread = max( + _get_min_elements_per_thread(src_dtype, dtype), + V.kernel.min_elem_per_thread, + ) + + if dtype == torch.bool: + return f"({x} != 0)" + elif dtype == torch.uint8: + # to work around llvm uint conversion semantics + # that produces 0's for negative values + return f"{x}.to(tl.int8).to(tl.uint8)" + + if use_compute_types: + out_dtype = triton_compute_type(dtype) + else: + out_dtype = triton_store_type(dtype) + + return f"{x}.to({out_dtype})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + assert src_dtype.itemsize == dtype.itemsize + # We may promote float16 or bfloat16 to float32 and cause the + # bitwidth of dtype to be different from the input tensor (i.e. float32). + # In such as case, we will have to convert the input tensor to + # its src_type, perform bitcast, and then convert the bit-casted + # tensor back to float to ensure we use values with the right precision. + if x.dtype != src_dtype: + x = f"{x}.to({triton_type(src_dtype)})" + + out = f"{x}.to({triton_type(dtype)}, bitcast=True)" + if upcast_compute_type(dtype) != dtype: + out = f"{out}.to({triton_type(upcast_compute_type(dtype))})" + + return out + + @staticmethod + def _shaped_constant(value, dtype, shape): + type_ = torch._prims_common.dtype_to_type(dtype) + triton_val = constant_repr(type_(value)) + triton_type = triton_compute_type(dtype) + + if triton_type == "tl.float32": + # Float constants are always f32 in triton + return triton_val + + # NOTE: We use a tensor here in order to get the expected type. + # Otherwise, e.g. float64 constants would be truncated to float32. + if value < 0 and not dtype.is_signed: + triton_signed_type = f"tl.{triton_type[4:]}" + return f"tl.full({shape}, {triton_val}, {triton_signed_type}).to({triton_type})" + else: + return f"tl.full({shape}, {triton_val}, {triton_type})" + + @classmethod + def constant(cls, value, dtype): + return cls._shaped_constant(value, dtype, shape=[]) + + @staticmethod + @maybe_upcast_float32() + def abs(x): + return f"tl_math.abs({x})" + + # TODO - register these ops as having divergent dtype + # output if doing graph pass to remove consecutive casts + + @staticmethod + def truediv(x, y): + out = f"({x} / {y})" + if low_precision_fp_var(x) or low_precision_fp_var(y): + out_dtype = get_dtype_handler().truediv(x, y) + if out_dtype in (torch.float16, torch.float32): + out = f"{out}.to({triton_type(out_dtype)})" + + return out + + @staticmethod + def mod(x, y): + out = f"({x} % {y})" + if low_precision_fp_var(x) or low_precision_fp_var(y): + out_dtype = get_dtype_handler().mod(x, y) + if out_dtype in (torch.float16, torch.float32): + out = f"{out}.to({triton_type(out_dtype)})" + return out + + @staticmethod + @maybe_upcast_float32() + def exp(x): + """ + When use_fast_math, use the ftz (flushing to zero) variant + of exponent computation. + + Check https://github.com/triton-lang/triton/issues/5735 for + more details. + """ + if config.use_fast_math: + return f"libdevice.exp2({x} * {TritonOverrides._LOG_2_E})" + else: + return f"tl_math.exp({x})" + + @staticmethod + @maybe_upcast_float32() + def exp2(x): + return f"libdevice.exp2({x})" + + @staticmethod + @maybe_upcast_float32() + def expm1(x): + return f"libdevice.expm1({x})" + + @staticmethod + @maybe_upcast_float32() + def sqrt(x): + return f"libdevice.sqrt({x})" + + @staticmethod + def relu(x): + bug = config.triton.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + # NB: this only triggers runtime error as long as input + # is not all zero + return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})' + elif bug == "accuracy": + return f"{x} + 1" + elif bug is None: + return ops.maximum(ops.constant(0, torch.int32), x) + else: + raise AssertionError( + f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"triton_helpers.minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"triton_helpers.maximum({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"tl.where({a}, {b}, {c})" + + @staticmethod + def inline_asm_elementwise( + *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 + ): + triton_type = triton_compute_type(dtype) + input_refs = ", ".join([str(i) for i in inputs]) + if constraints is None: + constraints = ", ".join(["=r"] + ["r" for _ in inputs]) + return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})" # noqa: B950 + + @staticmethod + @maybe_upcast_float32() + def cos(x): + return f"tl_math.cos({x})" + + @staticmethod + @maybe_upcast_float32() + def sin(x): + return f"tl_math.sin({x})" + + @classmethod + def index_expr(cls, expr, dtype): + raise NotImplementedError("ops.index_expr not implemented outside a kernel") + + @staticmethod + def masked(mask, body, other): + raise NotImplementedError("ops.masked not implemented outside a kernel") + + @staticmethod + @maybe_upcast_float32() + def lgamma(x): + return f"libdevice.lgamma({x})" + + @staticmethod + @maybe_upcast_float32() + def erf(x): + return f"libdevice.erf({x})" + + @staticmethod + @maybe_upcast_float32() + def cosh(x): + return f"libdevice.cosh({x})" + + @staticmethod + @maybe_upcast_float32() + def sinh(x): + return f"libdevice.sinh({x})" + + @staticmethod + @maybe_upcast_float32() + def acos(x): + return f"libdevice.acos({x})" + + @staticmethod + @maybe_upcast_float32() + def acosh(x): + return f"libdevice.acosh({x})" + + @staticmethod + @maybe_upcast_float32() + def asin(x): + return f"libdevice.asin({x})" + + @staticmethod + @maybe_upcast_float32() + def asinh(x): + return f"libdevice.asinh({x})" + + @staticmethod + @maybe_upcast_float32() + def atan2(x, y): + return f"libdevice.atan2({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def atan(x): + return f"libdevice.atan({x})" + + @staticmethod + @maybe_upcast_float32() + def atanh(x): + return f"libdevice.atanh({x})" + + @staticmethod + @maybe_upcast_float32() + def copysign(x, y): + return f"libdevice.copysign({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def erfc(x): + return f"libdevice.erfc({x})" + + @staticmethod + @maybe_upcast_float32() + def erfinv(x): + return f"libdevice.erfinv({x})" + + @staticmethod + @maybe_upcast_float32() + def hypot(x, y): + return f"libdevice.hypot({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def log10(x): + return f"libdevice.log10({x})" + + @staticmethod + @maybe_upcast_float32() + def log2(x): + return f"libdevice.log2({x})" + + @staticmethod + @maybe_upcast_float32() + def nextafter(x, y): + return f"libdevice.nextafter({x}, {y})" + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + offset = f"({offset}).to(tl.uint32)" + return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + raise NotImplementedError("ops.load_seed not implemented outside a kernel") + + @staticmethod + @maybe_upcast_float32() + def rsqrt(x): + return f"libdevice.rsqrt({x})" + + @staticmethod + @maybe_upcast_float32() + def log1p(x): + return f"libdevice.log1p({x})" + + @staticmethod + @maybe_upcast_float32() + def tan(x): + return f"libdevice.tan({x})" + + @staticmethod + @maybe_upcast_float32() + def tanh(x): + return f"libdevice.tanh({x})" + + @staticmethod + @maybe_upcast_float32() + def sigmoid(x): + return f"tl.sigmoid({x})" + + @staticmethod + def signbit(x): + # XX: This is wrong for the value -0.0 in floating point + return ( + f"(libdevice.signbit({x}) != 0) if ({x}).dtype is tl.float32 else {x} < 0" + ) + + @staticmethod + @maybe_upcast_float32() + def fmod(a, b): + return f"libdevice.fmod({a}, {b})" + + @staticmethod + @maybe_upcast_float32() + def pow(a, b): + return f"libdevice.pow({a}, {b})" + + @staticmethod + @maybe_upcast_float32() + def log(x): + return f"tl_math.log({x})" + + @staticmethod + @maybe_upcast_float32(convert_output=False) + def isinf(x): + return f"libdevice.isinf({x}).to(tl.int1)" + + @staticmethod + @maybe_upcast_float32(convert_output=False) + def isnan(x): + return f"libdevice.isnan({x}).to(tl.int1)" + + @staticmethod + @maybe_upcast_float32() + def round(x): + return f"libdevice.nearbyint({x})" + + @staticmethod + @maybe_upcast_float32() + def floor(x): + return f"libdevice.floor({x})" + + @staticmethod + def floordiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Similar to div_floor_kernel_cuda in pytorch core. + # Notice that // in triton behaves as truncdiv instead of floordiv + quot = f"{a} // {b}" + rem = f"{a} % {b}" + return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})" + + @staticmethod + def sign(x): + z = ops.constant(0, torch.int32) + left = ops.to_dtype((ops.lt(z, x)), torch.int8) + right = ops.to_dtype((ops.lt(x, z)), torch.int8) + sub = ops.sub(left, right) + return f"{sub}.to({x}.dtype)" + + @staticmethod + @maybe_upcast_float32() + def trunc(x): + return f"libdevice.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Notice that // in triton behaves as truncdiv instead of floordiv + return f"{a} // {b}" + + @staticmethod + @maybe_upcast_float32() + def ceil(x): + return f"libdevice.ceil({x})" + + +TritonOverrides._initialize_pointwise_overrides("triton") + + +class TritonKernelOverrides(TritonOverrides): + """Map element-wise ops to Triton within a TritonKernel + + Unlike TritonOverrides, these assume the code is going to be inserted into + the body of the main triton kernel and so it may use indexing and mask + variables which are assumed to already be defined in the current scope. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # happens in __init__ unlike _initialize_pointwise_overrides + # because the libdevice registrations are populated during lowerings + self._setup_libdevice_routing() + + @classmethod + @functools.cache + def _setup_libdevice_routing(cls): + """Set up routing to libdevice implementations for fp64 inputs.""" + + from torch._inductor.codegen.common import OpDecompositions + + for fn_name in torch._inductor.utils.op_requires_libdevice_fp64: + assert hasattr(cls, fn_name) + original_impl = getattr(cls, fn_name) + + def decomposition_router(x, _original_impl, _fn_name): + if x.dtype != torch.float64: + return _original_impl(x) + else: + return getattr(OpDecompositions, _fn_name)(x).value + + if fn_name == "sigmoid": + assert hasattr(OpDecompositions, "sigmoid") + fn = functools.partial( + decomposition_router, _original_impl=original_impl, _fn_name=fn_name + ) + fn.__name__ = fn_name # type: ignore[attr-defined] + setattr(cls, fn_name, staticmethod(fn)) + continue + + def dtype_router(x, _original_impl, _fn_name): + if x.dtype == torch.float64: + return f"libdevice.{_fn_name}({x})" + else: + return _original_impl(x) + + fn = functools.partial( + dtype_router, _original_impl=original_impl, _fn_name=fn_name + ) + fn.__name__ = fn_name # type: ignore[attr-defined] + setattr(cls, fn_name, staticmethod(fn)) + + @classmethod + def constant(cls, value, dtype): + # NOTE: Cannot use shape=[] as it's not supported by triton-rocm + # We could use shape=[1] instead but starting with the correct + # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR. + ndim = V.kernel.triton_tensor_ndim() + shape = [1] * ndim + return cls._shaped_constant(value, dtype, shape=shape) + + @classmethod + def index_expr(cls, expr, dtype): + indexing = V.kernel.indexing(expr, block_ptr=False) + assert isinstance(indexing, IndexingOptions) + + # Our sympy expr printing casts to the current kernel index dtype. + # we only respect non int32-int64 dtypes and otherwise use current kernel indexing dtype + index_dtype = V.kernel.get_index_dtype_as_torch_dtype() + dtype = dtype if dtype not in (torch.int32, torch.int64) else index_dtype + + # after we emit this var we cast it to the correct dtype + orig = config.test_configs.runtime_triton_dtype_assert + try: + config.test_configs.runtime_triton_dtype_assert = False + var = V.kernel.cse.generate( + V.kernel.compute, + indexing.index_str, + bounds=get_bounds_index_expr(expr), + dtype=dtype, + ) + finally: + config.test_configs.runtime_triton_dtype_assert = orig + + if dtype not in (torch.int32, torch.int64): + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, dtype), + dtype=upcast_compute_type(dtype), + ) + else: + # TODO: we are not always consistent in enforcing that the output of the index expr printing + # results in the indexing dtype. So if we detect that we have an input which might type promote + # to a dtype other than indexing dtype, add a cast. + # Trying to avoid + dtype = index_dtype + for index_var in expr.free_symbols: + if symbol_is_type(index_var, SymT.TMP): + dtype = torch.promote_types( + dtype, V.kernel.cse.varname_map[index_var.name].dtype + ) + + if dtype != index_dtype: + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, index_dtype), + dtype=index_dtype, + ) + + var.mask_vars = indexing.mask_vars + return var + + @staticmethod + def masked(mask, body, other): + if mask is not None and torch.version.hip is not None: + mask = V.kernel.cse.generate( + V.kernel.compute, + f"{mask}.to(tl.int1)", + dtype=torch.bool, + ) + + nodes = body.graph.find_nodes(op="output") + assert nodes, "graph for body does not contain an output" + + need_where = False + # If we have a tl.load with a masking operator and no other value + # we can add the mask here and the other value to the tl.load + # operator to save the branching cost. + for node in nodes: + for arg in node.args: + if arg.target != "load" or should_unwrap_unspec_arg(arg.args[1]): + need_where = True + break + + value = None if need_where else other + + with V.kernel.mask_loads(mask, value=value) as new_mask: + result = body() + + if need_where: + # Remove once CSEVariables track the dtype + if result.bounds.is_bool: + other = bool(other) + # Take dtype from result to prevent accidental promotion + other = V.kernel.cse.generate( + V.kernel.compute, + f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", + bounds=ValueRanges.wrap(other), + dtype=result.dtype, + ) + ret = ops.where(new_mask, result, other) + else: + ret = result + + ret.mask_vars.discard(new_mask) + return ret + + @staticmethod + def load_seed(name, offset): + var = V.kernel.args.input(name) + return ( + f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})" + ) + + @staticmethod + def frexp(x): + cache_key = f"frexp({x})" + if cse_val := V.kernel.cse.try_get(cache_key): + return cse_val + + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + exponent = V.kernel.cse.newvar(dtype=torch.int32) + V.kernel.compute.writeline( + f"{mantissa}, {exponent} = triton_helpers.frexp({x})" + ) + V.kernel.cse.put(cache_key, (mantissa, exponent)) + return (mantissa, exponent) + + +class HelperFunctions: + """An ordered set of helper functions.""" + + _templates_seen: dict[str, str] # Template code to function name + finalized_helpers: list[str] + + def __init__(self) -> None: + self._templates_seen = {} + self.finalized_helpers = [] + + def add(self, template_code: str, *, base_name="_triton_helper_fn") -> str: + """This accepts a function definition with the function name + left as a format specifier e.g. + + @triton.jit + def {name}(arg0, arg1): + return arg0 + arg1 + + We add the templated code to the function set and return the name + assigned to that function. + + """ + existing_name = self._templates_seen.get(template_code) + if existing_name is not None: + # Don't duplicate existing helpers + return existing_name + + name = f"{base_name}{len(self.finalized_helpers)}" + self._templates_seen[template_code] = name + self.finalized_helpers.append(template_code.format(name=name)) + return name + + def __iter__(self): + return iter(self.finalized_helpers) + + def __getitem__(self, idx): + return self.finalized_helpers[idx] + + +@dataclasses.dataclass +class BlockParameters: + """ + Class representing ND block dimensions, for block pointer analysis. + """ + + shape: list[sympy.Expr] = dataclasses.field(default_factory=list) + block_shape: list[sympy.Expr] = dataclasses.field(default_factory=list) + strides: list[sympy.Expr] = dataclasses.field(default_factory=list) + offsets: list[sympy.Expr] = dataclasses.field(default_factory=list) + + def __add__(self, other: BlockParameters) -> BlockParameters: + """ + Concatenates block parameters. + """ + cls = type(self) + a, b = tuple(dataclasses.asdict(x) for x in (self, other)) + return cls(**{key: a[key] + b[key] for key in a}) + + +class CooperativeReductionWorkspaceCache: + """ + The scratch space used for cooperative reductions can be reused + after two reduction loops. This keeps track of what can be reused. + """ + + def __init__(self, args): + self.args = args + self.current_loop = [] + self.prior_loop = [] + self.ready_for_reuse = collections.defaultdict(collections.deque) + self.loop_count = 0 + self.store_count = 0 + + def allocate(self, nbytes: sympy.Expr): + cached = self.ready_for_reuse.get(nbytes) + if cached: + return cached.popleft() + ws_name, ws_offset = self.args.workspace(nbytes, False) + self.current_loop.append((nbytes, ws_name, ws_offset)) + return (ws_name, ws_offset) + + def on_loop_end(self): + # Buffers can be reused after 2 loop ends + for nbytes, ws_name, ws_offset in self.prior_loop: + self.ready_for_reuse[nbytes].append((ws_name, ws_offset)) + self.prior_loop = self.current_loop + self.current_loop = [] + self.loop_count += 1 + + def increment_store_count(self): + prior = self.store_count + self.store_count += 1 + return prior + + +@dataclasses.dataclass +class FixedTritonConfig: + config: dict[str, int] + + def __getitem__(self, item): + return self.config[item] + + def __contains__(self, item): + return item in self.config + + +class TritonCSE(CSE[TritonCSEVariable, Union[str, tuple[str, str]]]): + """ + Subclasses CSE to apply the current load mask to the cache key to avoid CSEing + variables across separate masked blocks. + """ + + def augment_key(self, cache_key: str) -> Union[str, tuple[str, str]]: + if mask := V.kernel._load_mask: + return (cache_key, mask.name) + else: + return cache_key + + +class TritonKernel(SIMDKernel[TritonCSEVariable]): + overrides = TritonKernelOverrides # type: ignore[assignment] + helper_functions: HelperFunctions + kexpr: Callable[[sympy.Expr], str] = texpr + allow_block_ptr = True + + def __init__( + self, + tiling: dict[str, sympy.Expr], + min_elem_per_thread=0, + optimize_mask=True, + fixed_config: Optional[FixedTritonConfig] = None, + **kwargs, + ) -> None: + self.optimize_mask: bool = optimize_mask + self.fixed_config = fixed_config + super().__init__(tiling, **kwargs) + self.cse = TritonCSE(self.newvar_prefix, self.suffix) + self.post_loop_combine: IndentedBuffer = IndentedBuffer() + self.post_loop_store: IndentedBuffer = IndentedBuffer() + self.outside_loop_vars = OrderedSet[Any]() + self.min_elem_per_thread = min_elem_per_thread + self.block_ptr_id = itertools.count() + self.block_ptr_to_buffer = dict[str, str]() + self.helper_functions = HelperFunctions() + self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = ( + collections.defaultdict(dict) + ) + self._load_counts: collections.Counter[str] = collections.Counter() + + # A set of autotuning hints to pass as part of triton_meta + self.autotune_hints = OrderedSet[AutotuneHint]() + self.triton_meta: Optional[dict[str, Any]] = None + + if self.inside_reduction: + self.codegen_reduction_numels(self.body) + + if self.cooperative_reduction: + self.init_cooperative_reduction() + + self.codegen_range_tree() + + if self.cooperative_reduction: + self.init_cooperative_reduction_mask() + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return triton_type(dtype) + + def should_use_cooperative_reduction(self) -> bool: + return self.inside_reduction and V.choices.should_use_cooperative_reduction( + self.features + ) + + def init_cooperative_reduction(self): + """One time setup code for cooperative reductions.""" + assert self.cooperative_reduction + + # shift all the grids over since tl.program_id(0) is for rsplit + for tree in self.range_trees: + if tree.grid_dim is not None: + tree.grid_dim += 1 + + sem_count = self.numels["x"] + if self.fixed_config: + sem_count = CeilDiv(sem_count, self.fixed_config["XBLOCK"]) + self.semaphores_name = self.args.semaphores(sem_count) + self.cooperative_reduction_workspace_cache = CooperativeReductionWorkspaceCache( + self.args + ) + self.body.splice( + """\ + RSPLIT_NEXT_POWER_OF_2: tl.constexpr = triton_helpers.constexpr_next_power_of_2(RSPLIT) + RSPLIT_IS_POWER_OF_2: tl.constexpr = RSPLIT == RSPLIT_NEXT_POWER_OF_2 + HAS_RSPLIT: tl.constexpr = RSPLIT > 1 + rsplit_id = tl.program_id(0) + num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK + rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK + rsplit_start = rsplit_chunk * rsplit_id + rsplit_end = rsplit_chunk * (rsplit_id + 1) + """, + ) + if any( + not self._has_constant_mask(tree) + for tree in self.range_trees + if tree.is_reduction + ): + self.body.writeline( + "rsplit_end = tl.where(rsplit_end < rnumel, rsplit_end, rnumel)" + ) + + def init_cooperative_reduction_mask(self): + rsplit_arange = "tl.arange(0, RSPLIT_NEXT_POWER_OF_2)" + if not self.no_x_dim: + rsplit_arange = f"{rsplit_arange}[None, :]" + self.body.writeline(f"rsplit_arange = {rsplit_arange}") + + if self._has_constant_xmask(): + self.body.splice( + """\ + if RSPLIT_IS_POWER_OF_2: + rsplit_mask: tl.constexpr = None + else: + rsplit_mask = rsplit_arange < RSPLIT + """ + ) + else: + assert not self.no_x_dim + self.body.writeline( + "rsplit_mask = xmask if RSPLIT_IS_POWER_OF_2 else ((rsplit_arange < RSPLIT) & xmask)" + ) + + def codegen_range_tree(self): + for tree in self.range_trees: + # reduction indexing goes inside a loop + if not tree.is_loop: + self.iteration_ranges_codegen_header(tree, self.body) + elif self.inside_reduction: + # workaround for this issue: + # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 + self.body.writeline( + f"{tree.prefix}base = {self.iteration_ranges_ranges_code(tree)}" + ) + + if self.inside_reduction: + if any(tree.is_loop for tree in self.range_trees): + # If the kernel contains loops, compute rbase. + rn_bases = self._get_reduction_symbols( + "base", integer=True, nonnegative=True + ) + rbase = self._flatten_reduction_indices(rn_bases) + self.body.splice(f"rbase = {self.index_to_str(rbase)}") + else: + # For looped reductions, indexing is deferred to the innermost loop. + self.codegen_reduction_indices(self.body) + + def need_numel_args(self): + """ + Indicate whether we need provide numel as arguments for the generated + kernel calls in the benchmark. + + Should be true for pointwise/reduction kernels but false for triton + matmul kernels. + """ + return True + + def should_use_persistent_reduction(self) -> bool: + return self.inside_reduction and V.choices.should_use_persistent_reduction( + self.features, self.cooperative_reduction + ) + + def want_no_x_dim(self): + if ( + self.persistent_reduction + and len(self.numels) == self.num_reduction_dims + 1 + ): + if self.fixed_config: + return self.fixed_config["XBLOCK"] == 1 + return V.choices.want_no_x_dim(self.features) + return False + + @property + def assert_function(self) -> str: + return "tl.device_assert" + + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + ): + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = self.prepare_indexing(index) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: OrderedSet[str] = OrderedSet() + for var in sorted(index_vars, key=operator.attrgetter("name")): + assert isinstance(var, sympy.Symbol) + has_rindex = has_rindex or symbol_is_type( + var, TritonSymbols.reduction_types + ) + if override_mask: + pass + elif symbol_is_type(var, SymT.TMP): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif symbol_is_type( + var, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + SymT.INDEX, + SymT.FLOAT, + SymT.UNBACKED_FLOAT, + ), + ): + pass + else: + # var is one of xN, yN, r0_N or r1_N + prefix_matches = [ + prefix_str[symt] + for symt in TritonSymbols.block_types + if symbol_is_type(var, symt) + ] + assert len(prefix_matches) == 1, f"Ambiguous type: {var.name}" + mask_vars.add(f"{prefix_matches[0]}mask") + + need_dense = ( + config.triton.dense_indexing + or dense_indexing + or self._load_mask is not None + ) and index != 0 + + have_dense = True + have_loop_vars = False + dense_mask_vars: OrderedSet[str] = OrderedSet() + + for tree in self.active_range_trees(): + if index_vars.intersection(tree.var_list): + have_loop_vars = True + else: + have_dense = False + dense_mask_vars.add(f"{tree.prefix}mask") + + if ( + block_ptr + and self.allow_block_ptr + and config.triton.use_block_ptr + and not override_mask + and not self._load_mask + and len(mask_vars - dense_mask_vars) == 0 + and not self.is_indirect_indexing(index) + and have_loop_vars + # workaround https://github.com/triton-lang/triton/issues/2821 + and self.index_dtype == "tl.int32" + ): + + def match_affine_block( + index: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Matches expressions of the form: + idx = s * xindex + + This implies stride (s,), and shape (XBLOCK,). + """ + stride = BlockPatternMatcher.match_affine_block_expr( + index, range_tree.symbol() + ) + if stride is None: + return None + + return BlockParameters( + shape=[range_tree.numel], + block_shape=[TritonSymbols.get_block_size(range_tree)], + strides=[stride], + offsets=[TritonSymbols.get_block_offset(range_tree)], + ) + + def match_mod_div_block( + index: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing. + + Example expression to match: + sN * ((rindex//(d1 * ... * d(N-1)))) + + s1 * ModularIndexing(rindex, 1, d1) + + ... + + s(N-1) * ModularIndexing(rindex, d1 * ... * d(N-2), d(N-1)) + + This iterates over a block of shape (dN, ..., d1) and stride + (sN, ..., s1). (d1,...,d(N-1)) and (s1,...,sN) are + wildcards that we match. + + Note that dN does not appear in the expression, but we solve for it + using range tree numels and the other dims. + """ + + index_var = range_tree.symbol() + + # Bound the possible number of dims. We use the following heuristics: + # - At least one dim for each range tree node. + # - At least one dim for every FloorDiv or ModularIndexing op. + # - At least 2 dims to pattern match. + denom, modulo = sympy.symbols( + "denom modulo", + cls=functools.partial(sympy.Wild, exclude=[index_var]), + ) + num_dims = max( + 2, + len(self.range_tree_nodes), + ( + index.count(FloorDiv(index_var, denom)) + + index.count(ModularIndexing(index_var, denom, modulo)) + ), + ) + + match_result = BlockPatternMatcher.match_mod_div_block_expr( + index, index_var, range_tree.numel, num_dims + ) + if match_result is None: + return None + + ( + dims, + strides, + block_index_exprs, + ) = match_result + slice_numels = BlockPatternMatcher.get_slice_numels(dims) + + # Check for applicable iteration range sizes. + # When mapping a 1D block into an ND one, we need to know that + # the number of elements is not changed. This means the slice numels of + # the ND iteration range must evenly divide the length of the 1D block. + # There are two cases where we can guarantee this: + # 1. Numels are powers of 2. If numel == 2 ** n, and we know XBLOCK == 2 ** m, + # with n and m integers, then either numel is a multiple of XBLOCK, or numel + # is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.) + # 2. Numels are multiples of the maximum possible block size. + sizevars = V.graph.sizevars + max_block = self.max_block(range_tree.prefix) + if any( + not sizevars.statically_known_multiple_of(numel, max_block) + and not sizevars.statically_known_power_of_2(numel) + for numel in slice_numels + ): + return None + + # Compute the ND block shape from the linear block size. + # Use CielDiv to round leading dimensions up to 1. + # Non-leading dimensions are clamped to the size of the iteration range, + # while the leading dimension can exceed this to accommodate a larger + # block size. + linear_block_size = TritonSymbols.get_block_size(range_tree) + block_shape: list[sympy.Expr] = [ + CeilDiv(linear_block_size, slice_numels[0]) + ] + [ + sympy.Min(CeilDiv(linear_block_size, numel), dim) + for numel, dim in zip(slice_numels[1:], dims[1:]) + ] + + # Compute block offsets from {xyzr}offset and the matched expressions. + block_offsets: list[sympy.Expr] = [ + sympy_subs( + expr, {index_var: TritonSymbols.get_block_offset(range_tree)} + ) + for expr in block_index_exprs + ] + + return BlockParameters( + shape=dims, + block_shape=block_shape, + strides=strides, + offsets=block_offsets, + ) + + def match_block_pointer_subexpr( + expr: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Match a block indexing subexpression involving a single range tree. + """ + for match_func in ( + match_affine_block, + match_mod_div_block, + ): + match = match_func(expr, range_tree) + if match is not None: + return match + + return None + + def match_block_pointer() -> Optional[BlockPtrOptions]: + index_relative_to_xyr_index = sympy_subs( + index, {v: t.expr for v, t in self.range_tree_nodes.items()} + ) + range_trees = self.active_range_trees() + + # Partition the index into subexpressions pertaining to each range tree. + # For example xindex * 5 + r0_index * 3 is partitioned to + # (xindex * 5, r0_index * 3). + index_subexprs = [ + BlockPatternMatcher.get_subexpr_involving_symbol( + index_relative_to_xyr_index, tree.symbol() + ) + for tree in range_trees + ] + + # Match each range tree's subexpression separately. + range_symbols = OrderedSet(tree.symbol() for tree in range_trees) + block_params = BlockParameters() + for tree, subexpr in zip(range_trees, index_subexprs): + # Reject mixed terms, e.g. xindex * r0_index. + # NB: the zero expression is allowed, for broadcasting. + if len(range_symbols.intersection(subexpr.free_symbols)) > 1: + return None + + # Match the subexpression for this range tree. + params = match_block_pointer_subexpr(subexpr, tree) + if params is None: + return None + block_params += params + + # Collect leftover terms as a constant offset. + offset = index_relative_to_xyr_index - sum(index_subexprs) + + # Form the block pointer. + self.filter_masks(mask_vars) + return BlockPtrOptions.create( + params=block_params, + constant_offset=offset, + range_trees=range_trees, + mask_vars=mask_vars, + get_max_block=self.max_block, + ) + + # Return a block pointer, if indexing matches the pattern. + options = match_block_pointer() + if options is not None: + return options + + expand_str = None + index_str = self.index_to_str(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + if self.fixed_config and not self._has_constant_xmask(): + mask_vars = OrderedSet(["xmask"]) + else: + mask_vars = OrderedSet() + if self._load_mask: + mask_vars.add(self._load_mask) + return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) + + if need_dense and not have_dense: + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.broadcast_to({index_str}, {expand_str})" + mask_vars = dense_mask_vars + elif not have_loop_vars and copy_shape: + index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" + mask_vars = dense_mask_vars + + if override_mask: + mask_vars = OrderedSet([override_mask]) + + if self._load_mask: + mask_vars.add(self._load_mask) + + self.filter_masks(mask_vars) + + return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) + + def codegen_block_ptr( + self, name: str, var: str, indexing: BlockPtrOptions, other="" + ) -> tuple[str, str]: + check = indexing.boundary_check() + if not check: + # workaround https://github.com/triton-lang/triton/issues/2813 + other = "" + elif other: + assert other == ", other=0.0" + other = f", boundary_check={check!r}, padding_option='zero'" + else: + other = f", boundary_check={check!r}" + if ( + self.inside_reduction + and self.range_trees[-1].is_loop + and indexing.has_rindex() + ): + block_ptr = f"block_ptr{next(self.block_ptr_id)}" + self.body.writeline( + DeferredLine( + name, f"{block_ptr} = {indexing.format(var, roffset=False)}" + ) + ) + # Store for later use. If the buffer is removed the below advancements + # are no longer necessary + self.block_ptr_to_buffer[block_ptr] = name + + # Generate block pointer advancements, for later use. + for symt in TritonSymbols.reduction_types: + advance_offsets = indexing.advance_roffset(symt) + + # Ignore identity advancements. + if all( + V.graph.sizevars.statically_known_equals(offset, sympy.Integer(0)) + for offset in advance_offsets + ): + continue + + advancements = self.pointer_advancements[symt] + assert block_ptr not in advancements, ( + "duplicate advancement for pointer '{block_ptr}' at type '{symt}'" + ) + advancements[block_ptr] = advance_offsets + else: + block_ptr = indexing.format(var) + return block_ptr, other + + def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""): + # Stores require an explicit broadcast. We do this in two phases: + # 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK, + # YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads. + # 2. In case the block pointer has different dimensionality, broadcast/reshape the + # result to the shape of the pointer. + value = f"tl.broadcast_to({value}, {indexing.final_shape})" + + # These dims no longer need broadcasting. + for idx, (dim, broadcast_dim) in enumerate( + zip(indexing.final_shape, indexing.broadcast_shape) + ): + if V.graph.sizevars.statically_known_equals(dim, broadcast_dim): + indexing.broadcasting_dims[idx] = False + + value = indexing.codegen_broadcast_and_reshape( + value, indexing.final_shape, indexing.block_shape, False + ) + + # workaround https://github.com/triton-lang/triton/issues/2814 + value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" + return f"tl.store({block_ptr}, {value}{other})" + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + assert isinstance(expr, sympy.Expr) + indexing = self.indexing(expr, block_ptr=False) + assert isinstance(indexing, IndexingOptions) + + index_str = indexing.index_str + mask_str = indexing.mask_str if indexing.has_mask() else None + size_str = texpr(self.rename_indexing(size)) if upper else None + + # expr is already wrapped + line = self.indirect_assert( + index_str, "0" if lower else None, size_str, mask_str + ) + + buffer = self.get_load_buffer(indexing) + self.cse.generate(buffer, line, assignment=False, dtype=torch.int32) + + def get_load_buffer(self, indexing): + if indexing.has_indirect() or indexing.has_tmpmask(): + # Masked loads must come after the mask is computed + return self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indexing.has_rindex() + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + return self.body + else: + return self.loads + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + load_counts = self._load_counts + load_counts[name] += 1 + make_line: Callable[[str], Union[str, DelayReplaceLine]] = identity + indirect_indexing = self.is_indirect_indexing(index) + original_index = index + indexing = self.indexing(index, block_ptr=True) + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + + # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold + # 1) We are doing broadcasting + # 2) It is a non-coalesced load. The intuition is that if it's + # non-coalesced, we will likely load each element multiple times in + # practice. + # 3) It will be used later and it won't be CSE'd. Equiv., if all the following hold + # 3.1) We are in a reduction loop + # 3.2) Its not its last use + # 3.3) This load will not be lifted to the body + # + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + if self.is_broadcasted(original_index): + ep = ", eviction_policy='evict_last'" + elif not is_coalesced: + ep = ", eviction_policy='evict_last'" + elif self.inside_reduction and self.range_trees[-1].is_loop: + + def decide_later(): + if load_counts[name] > expected_count and ( + has_rindex or indirect_indexing + ): + return "evict_last" + return "evict_first" + + expected_count = load_counts[name] + ep = ", eviction_policy=''" + make_line = functools.partial(DelayReplaceLine, "", decide_later) + else: + ep = "" + + if (has_tmpmask or has_rindex) and indexing.has_mask(): + if self._load_other: + other = f", other={constant_repr(self._load_other)}" + else: + other = ", other=0.0" + else: + other = "" + + """Check if the buffer we're about to load, has + more than one read dependency + NOTE: enabled with env variable TORCHINDUCTOR_SKIP_L1 + """ + has_read_deps = True + if config.triton.skip_l1_cache: + buffer_read_counts = self.features.buffer_read_counts() + has_read_deps = buffer_read_counts[name] > 1 + """Skip L1 cache if we're (pretty?) sure the data is used only once + """ + skip_l1_cache = ( + not self.is_broadcasted(original_index) + and not self.inside_reduction + and not has_read_deps + and is_coalesced # for indirect loads is_coalesced is False? + ) + cachemod = "" + if skip_l1_cache: + cachemod = ", cache_modifier='.cg'" + + append_broadcast = None + dtype = V.graph.get_dtype(name) + + if should_unwrap_unspec_arg(name): + line = var + # unwrapped bf16/fp16 0d tensors are passed in as float32 scalars + # see triton_utils.py:signature_of + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, other = self.codegen_block_ptr(name, var, indexing, other) + line = f"tl.load({block_ptr}{other}{ep}{cachemod})" + line = indexing.codegen_broadcast_and_reshape( + line, indexing.block_shape, indexing.final_shape, True + ) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + append_broadcast = indexing.expand_str + else: + line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})" + + if ( + dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + line += ".to(tl.float32)" + dtype = torch.float32 + if dtype == torch.bool and torch.version.hip is None: + # Workaround for https://github.com/triton-lang/triton/issues/2151 + # tl.load returns int8 when loading from pointer to int1 + # NOTE: Currently causes hangs on bool UTs for ROCm + line += ".to(tl.int1)" + dtype = torch.bool + + load_buffer = self.get_load_buffer(indexing) + result_var = self.cse.generate(load_buffer, make_line(line), dtype=dtype) + if result_var.use_count > 1: + load_counts[name] -= 1 # don't double count cache hit + assert isinstance(result_var, TritonCSEVariable) + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast: + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + if indexing.mask_vars: + if dtype.is_floating_point: + zero = "0.0" + elif dtype == torch.bool: + zero = "True" + else: + zero = "0" + other_val = ( + constant_repr(self._load_other) if self._load_other else zero + ) + line = f"tl.where({indexing.mask_str}, {result_var}, {other_val})" + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + var = self.args.output(name) + original_index = index + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None) + + # Guard against write-after-read corruption in triton. + # See # https://github.com/triton-lang/triton/issues/1615 + # This triton bug means that a load which is broadcasted over multiple + # warps may see the result of a store that happens later in the triton + # program. The workaround is to add a barrier before storing, which + # enforces that all warps have already read the data. + is_inplace = name in self.args.inplace_buffers + is_broadcasted = self.is_broadcasted(original_index) + if is_inplace and is_broadcasted: + self.stores.writeline(DeferredLine(name, "tl.debug_barrier()")) + + if isinstance(indexing, BlockPtrOptions): + block_ptr, other = self.codegen_block_ptr(name, var, indexing) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other + ) + elif mode is None: + line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')" + else: + raise NotImplementedError(f"store mode={mode}") + + exit_stack = contextlib.ExitStack() + if not self.inside_reduction and self.cooperative_reduction: + exit_stack.enter_context(self.guard_cooperative_store(name, self.stores)) + + self.stores.writeline(DeferredLine(name, line)) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + exit_stack.close() + + def guard_cooperative_store(self, name, buffer): + """ + For cooperative reductions only one thread block should write out the result. + We rotate which thread block does each write for better parallelism + """ + idx = self.cooperative_reduction_workspace_cache.increment_store_count() + buffer.writeline(DeferredLine(name, f"if rsplit_id == ({idx} % RSPLIT):")) + return buffer.indent() + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + + # Triton performance for bucketize_binary_search is much better when the number + # of threads equals the number of elements. + # If we're trying to use a bucketize kernel, we should make sure that an + # autotuning config with num_elements_per_warp=(warp_size) exists. + self.autotune_hints.add(AutotuneHint.ONE_ELEMENT_PER_THREAD) + + boundaries_ptr = self.args.input(boundaries[0]) + boundary_size = self.index_to_str(boundaries[1]) + boundaries_underlying_numel = self.index_to_str(boundaries[2]) + boundary_stride = self.index_to_str(boundaries[3]) + sorter_ptr = self.args.input(sorter[0]) if sorter else "None" + sorter_stride = self.index_to_str(sorter[1]) if sorter else "None" + + if indexing_dtype == torch.int32: + triton_dtype = "tl.int32" + elif indexing_dtype == torch.int64: + triton_dtype = "tl.int64" + else: + raise NotImplementedError( + "Bucketize only supports indexing with int32 and int64" + ) + + result = self.cse.generate( + self.compute, + f"triton_helpers.bucketize_binary_search({values}, " + f"{boundaries_ptr}, {boundary_size}, {boundaries_underlying_numel}, {boundary_stride}, " + f"{boundary_indices}, " + f"{triton_dtype}, " + f"{right}, " + f"{sorter_ptr}, {sorter_stride}, " + f"{sorter_indices}, " + ")", + dtype=indexing_dtype, # type: ignore[attr-defined] + ) + + return result + + def reduction_resize(self, value) -> str: + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + + nreduce = self.num_reduction_dims + sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce + return f"{value}[{', '.join(sizes)}]" + + def reduction_collapse_dims(self, buffer, value: str, dtype: torch.dtype) -> str: + """ + Reshape to RBLOCK, collapsing all reduction dims. + """ + # This is not needed for 1D reductions. + if self.num_reduction_dims == 1: + return value + + target_ndim = self.triton_tensor_ndim() - self.num_reduction_dims + initial_shape = self.dense_size_list() + target_shape = initial_shape[:target_ndim] + ["RBLOCK"] + return str( + self.cse.generate( + buffer, triton_reshape(value, initial_shape, target_shape), dtype=dtype + ) + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + def maybe_upcast(value: CSEVariable) -> CSEVariable: + # Math reductions in FP16/BF16 are less accurate because the Triton compiler does not + # automatically promote to FP32 for accumulation. Additionally, max/min reductions + # do not support FP16/BF16. We manually promote to FP32 here. + return ( + ops.to_dtype(value, torch.float32) + if value.dtype + in [ + torch.float16, + torch.bfloat16, + ] + else value + ) + + original_dtypes = [val.dtype for val in pytree.tree_leaves(value)] + value = pytree.tree_map(maybe_upcast, value) + if any(x in [torch.float16, torch.bfloat16] for x in original_dtypes): + # Only promote FB16/BF16; do not promote other integer/boolean dtypes + src_dtype = torch.promote_types(src_dtype, torch.float32) + dtype = torch.promote_types(dtype, torch.float32) + + assert self.inside_reduction + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix[0] + + # Say we have + # tmp0 = ops.constant(1, torch.int64) + # tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0) + # tmp0 in the triton code is either a scalar, or single-element tensor + # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1 + # To avoid this, we broadcast to the expected shape first. + dense_size_str = self.dense_size_str() + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, + f"tl.broadcast_to({v}, {dense_size_str})", + dtype=v.dtype, + ), + value, + ) + + dim = self.triton_tensor_ndim() - self.num_reduction_dims + root_op: str + + def final_reduction( + buffer, + value: str, + result_type: Optional[str], + ) -> str: + """ + Helper to generate a reduction call, e.g. tl.sum. + """ + use_helper = reduction_type in ("any", "max", "min", "prod") + module = "triton_helpers" if use_helper else "tl" + + value = self.reduction_collapse_dims(buffer, value, dtype) + if reduction_type in ("max", "min"): + value = self.reduction_resize( + f"{module}.{reduction_type}2({value}, {dim})" + ) + else: + value = self.reduction_resize( + f"{module}.{reduction_type}({value}, {dim})" + ) + + if result_type is not None: + value = f"{value}.to({result_type})" + + return value + + def final_reduction_define( + buffer, + result_var: str, + value: str, + result_type: Optional[str], + ) -> None: + """ + Generate a reduction and assign it to an existing variable. + """ + value = final_reduction(buffer, value, result_type) + buffer.splice(f"{result_var} = {value}") + + def final_argreduce(buffer, result_var, value, index): + value = self.reduction_collapse_dims(buffer, value, dtype) + index = self.reduction_collapse_dims(buffer, index, dtype) + buffer.splice( + f"""\ + {result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f"{result_var}_idx")} + """ + ) + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + acc_type = triton_acc_type(src_dtype) + torch_acc_type = upcast_acc_dtype(src_dtype) + result_var: Any = self.cse.newvar(dtype=torch_acc_type) + result_var.mask_vars = OrderedSet( + var for var in masks if not prefix_is_reduction(var[0]) + ) + cond = " & ".join(masks) + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + + def _mask_value(value, default) -> CSEVariable: + return self.cse.generate( + self.compute, where_cond(value, default), dtype=value.dtype + ) + + masked_value: Union[CSEVariable, Sequence[CSEVariable]] + if reduction_type == "online_softmax_reduce": + # Don't generate mask value for online_softmax since we + # will fallback below + pass + elif isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + + if reduction_type in ("argmax", "argmin"): + accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype() + accumulator_index = str( + self.cse.generate( + self.compute, + f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", + dtype=accumulator_dtype, + ) + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + result_var.dtype = accumulator_dtype + elif reduction_type == "welford_reduce": + if self.cooperative_reduction: + # cooperative reductions require full welford for correctness + result_var = self.welford_reduce( + result_var, reduction_type, value, where_cond, acc_type, dtype + ) + else: + # For persistent reductions, don't bother with + # welford's algorithm since it uses more registers, and + # taking two reductions doesn't increase memory usage. + result_var = self.welford_reduce_fallback(dtype, value) + elif reduction_type == "welford_combine": + assert isinstance(masked_value, Sequence) + (mean, m2, weight) = masked_value + result_var = tuple( + self.cse.generate(self.compute, value, dtype=dtype) + for value in self._welford( + self.compute, mean, m2, weight, dim, dtype + ) + ) + elif reduction_type == "online_softmax_reduce": + # All data is loaded to register anyway, no need to do + # online softmax + result_var = self.prepare_softmax_twopass_fallback(dtype, value) + else: + assert isinstance(masked_value, CSEVariable) + result_var = self.cse.generate( + self.compute, + final_reduction(self.compute, str(masked_value), None), + dtype=masked_value.dtype, + ) + else: + accumulator = self.cse.namedvar(f"_{result_var}", dtype=torch_acc_type) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + if not isinstance(default, tuple): + self.body.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in ("argmax", "argmin"): + accumulator_index = f"_{result_var}_index" + index_dtype = self.features.select_index_dtype() + self.body.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, " + f"{torch.iinfo(index_dtype).max}, {self.dtype_to_str(index_dtype)})" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = {where_cond(f"{accumulator}_next", accumulator)} + {accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)} + """ + ) + final_argreduce( + self.post_loop_combine, result_var, accumulator, accumulator_index + ) + elif is_welford_reduction(reduction_type): + result_var = self.welford_reduce( + result_var, reduction_type, value, where_cond, acc_type, dtype + ) + elif reduction_type == "online_softmax_reduce": + accumulator_max = f"_{result_var}_max" + accumulator_sum = f"_{result_var}_sum" + + # setup accumulator + self.body.writeline( + f"{accumulator_max} = tl.full({self.dense_size_str()}, float('-inf'), {acc_type})" + ) + self.body.writeline( + f"{accumulator_sum} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + + # combine + # Note, we pass config.use_fast_math to the JITFunction + # since a triton kernel can not access a config. + self.compute.splice( + f""" + {accumulator_max}_next, {accumulator_sum}_next = triton_helpers.online_softmax_combine( + {accumulator_max}, {accumulator_sum}, {value}, {config.use_fast_math} + ) + """ + ) + + # mask + self.compute.splice( + f""" + {accumulator_max} = {where_cond(f"{accumulator_max}_next", accumulator_max)} + {accumulator_sum} = {where_cond(f"{accumulator_sum}_next", accumulator_sum)} + """ + ) + + # reduce. Similar to the final reduction for coopereative + # reduction + result_max = result_var + result_sum = self.cse.newvar(dtype=dtype) + + result_var = self.online_softmax_reduce_final_reduction( + self.post_loop_combine, + result_max, + result_sum, + accumulator_max, + accumulator_sum, + dim, + dtype, + ) + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + # This is only really used for aten.any. It changes the + # final reduction of a non-persistent reduction from + # tmp5 = triton_helpers.max(_tmp5, 1)[:, None] + # to + # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) + # which is needed because tl.reduce doesn't support tl.int1 + accumulator_casted_str = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + final_reduction_define( + self.post_loop_combine, + str(result_var), + accumulator_casted_str, + result_type, + ) + else: + final_reduction_define( + self.post_loop_combine, str(result_var), str(accumulator), None + ) + + if self.cooperative_reduction: + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + exit_stack = contextlib.ExitStack() + for buf in (self.post_loop_combine, self.post_loop_store): + # only do cooperative reduction combines if we have more than one thread block + buf.writeline("if HAS_RSPLIT:") + exit_stack.enter_context(buf.indent()) + + if reduction_type in ("argmax", "argmin"): + self.post_loop_combine.writeline( + f"{result_var}_bval = {self.reduction_resize(f'{result_var}_val')}" + ) + peer_val = self.codegen_cooperative_reduction_peer_combine( + f"{result_var}_bval", src_dtype, default + ) + index_dtype = self.features.select_index_dtype() + peer_idx = self.codegen_cooperative_reduction_peer_combine( + result_var, index_dtype, torch.iinfo(index_dtype).max + ) + final_argreduce(self.post_loop_store, result_var, peer_val, peer_idx) + elif is_welford_reduction(reduction_type): + assert reduction_type == "welford_reduce" + result_mean, result_m2, result_weight = result_var + peer_mean = self.codegen_cooperative_reduction_peer_combine( + result_mean, + upcast_acc_dtype(src_dtype), + default[0], # type: ignore[index] + ) + peer_m2 = self.codegen_cooperative_reduction_peer_combine( + result_m2, + upcast_acc_dtype(src_dtype), + default[1], # type: ignore[index] + ) + peer_weight = self.codegen_cooperative_reduction_peer_combine( + result_weight, + upcast_acc_dtype(src_dtype), + default[2], # type: ignore[index] + ) + self.welford_reduce_final_reduction( + self.post_loop_store, + result_mean, + result_m2, + result_weight, + peer_mean, + peer_m2, + peer_weight, + dim, + dtype, + ) + elif reduction_type == "online_softmax_reduce": + result_max, result_sum = result_var + peer_max = self.codegen_cooperative_reduction_peer_combine( + result_max, upcast_acc_dtype(src_dtype), default[0] + ) + peer_sum = self.codegen_cooperative_reduction_peer_combine( + result_sum, upcast_acc_dtype(src_dtype), default[1] + ) + self.online_softmax_reduce_final_reduction( + self.post_loop_store, + result_max, + result_sum, + peer_max, + peer_sum, + dim, + dtype, + ) + else: + peers = self.codegen_cooperative_reduction_peer_combine( + result_var, upcast_acc_dtype(src_dtype), default + ) + final_reduction_define( + self.post_loop_store, str(result_var), peers, None + ) + exit_stack.close() + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + assert all(isinstance(x, TritonCSEVariable) for x in result_var) + self.outside_loop_vars.update(result_var) + + # Match output dtype with input dtype + if reduction_type in ("welford_reduce", "online_softmax_reduce"): + assert len(original_dtypes) == 1 + original_dtypes = len(result_var) * original_dtypes + + assert len(result_var) == len(original_dtypes) + for var, orig_dtype in zip(result_var, original_dtypes): + assert orig_dtype is not None + if var.dtype != orig_dtype: + self.post_loop_combine.writeline( + f"{var} = {var}.to({triton_compute_type(orig_dtype)})" + ) + else: + assert isinstance(result_var, TritonCSEVariable) + self.outside_loop_vars.add(result_var) + + # Match output dtype with input dtype + if result_var.dtype != original_dtypes[0]: + assert original_dtypes[0] is not None + self.post_loop_combine.writeline( + f"{result_var} = {result_var}.to({triton_compute_type(original_dtypes[0])})" + ) + + return result_var + + def _online_softmax_reduce( + self, buffer, accumulator_max, accumulator_sum, dim, dtype: torch.dtype + ): + accumulator_max = self.reduction_collapse_dims(buffer, accumulator_max, dtype) + accumulator_sum = self.reduction_collapse_dims(buffer, accumulator_sum, dtype) + result_max, result_sum = [str(self.cse.newvar(dtype=dtype)) for _ in range(2)] + buffer.splice( + f""" + {result_max}, {result_sum} = triton_helpers.online_softmax_reduce( + {accumulator_max}, {accumulator_sum}, {dim}, {config.use_fast_math}) + {result_max} = {self.reduction_resize(f"{result_max}")} + {result_sum} = {self.reduction_resize(f"{result_sum}")} + """ + ) + + return result_max, result_sum + + def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype): + """ + Helper to codegen triton_helpers.welford. + """ + mean, m2, weight = ( + self.reduction_collapse_dims(buffer, value, dtype) + for value in (mean, m2, weight) + ) + welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" + welford_results = [str(self.cse.newvar(dtype=dtype)) for _ in range(3)] + buffer.writeline(f"{', '.join(welford_results)} = {welford}") + + result_values = tuple(self.reduction_resize(value) for value in welford_results) + return result_values + + def welford_reduce( + self, result_var, reduction_type, value, where_cond, acc_type, dtype + ): + """Helper to codegen a welford reduction""" + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + accumulator = f"{result_var}_mean" + accumulator_m2 = f"{result_var}_m2" + accumulator_weight = f"{result_var}_weight" + self.body.writeline( + f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + if reduction_type == "welford_combine": + mean, m2, weight = value + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine( + {accumulator}, {accumulator_m2}, {accumulator_weight}, + {mean}, {m2}, {weight} + ) + """ + ) + else: + assert reduction_type == "welford_reduce" + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce( + {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0 + ) + """ + ) + self.compute.splice( + f"""\ + {accumulator} = {where_cond(f"{accumulator}_next", accumulator)} + {accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)} + {accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)} + """ + ) + result_mean = result_var + result_m2 = self.cse.newvar(dtype=dtype) + result_weight = self.cse.newvar(dtype=dtype) + return self.welford_reduce_final_reduction( + self.post_loop_combine, + result_mean, + result_m2, + result_weight, + accumulator, + accumulator_m2, + accumulator_weight, + dim, + dtype, + ) + + def welford_reduce_final_reduction( + self, + buffer, + result_mean, + result_m2, + result_weight, + mean, + m2, + weight, + dim, + dtype, + ): + """Helper to codegen call to triton_helpers.welford""" + values = self._welford(buffer, mean, m2, weight, dim, dtype) + result_exprs = [result_mean, result_m2, result_weight] + for result_expr, value in zip(result_exprs, values): + buffer.splice(f"{result_expr} = {value}") + + return result_mean, result_m2, result_weight + + def online_softmax_reduce_final_reduction( + self, buffer, result_max, result_sum, peer_max, peer_sum, dim, dtype + ): + values = self._online_softmax_reduce(buffer, peer_max, peer_sum, dim, dtype) + result_exprs = [result_max, result_sum] + for result_expr, value in zip(result_exprs, values): + buffer.splice(f"{result_expr} = {value}") + + return result_max, result_sum + + def max_rsplit(self): + if self.fixed_config: + return self.fixed_config["RSPLIT"] + return TRITON_MAX_RSPLIT + + def codegen_cooperative_reduction_peer_combine( + self, result_var, dtype, default_val + ): + """ + Generate code to save a [XBLOCK, RSPLIT] temporary workspace, where each thread block writes a different + column. After the barrier, every thread block loads the completed value so that it can compute the final + value independently. + """ + xnumel = self.numels["x"] + mask = "xindex < xnumel" if not self._has_constant_xmask() else None + + nbytes = xnumel * dtype.itemsize * self.max_rsplit() + ws_name, ws_offset = self.cooperative_reduction_workspace_cache.allocate(nbytes) + + self.post_loop_combine.splice( + f""" + {result_var}_ws = ({ws_name} + {self.index_to_str(ws_offset)}).to(tl.pointer_type({triton_type(dtype)})) + tl.store({result_var}_ws + (xindex * RSPLIT + rsplit_id), {result_var}, {mask}) + """, + strip=True, + ) + self.post_loop_store.writeline( + f"{result_var}_peers = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), " + f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.if_mask(rsplit_mask, {constant_repr(default_val)}))" + ) + return f"{result_var}_peers" + + def store_reduction( + self, + name: str, + index: sympy.Expr, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ): + assert self.inside_reduction + self.inside_reduction = False + indexing = self.indexing(index, block_ptr=True) + self.inside_reduction = True + var = self.args.output(name) + + exit_stack = contextlib.ExitStack() + if self.cooperative_reduction: + exit_stack.enter_context( + self.guard_cooperative_store(name, self.post_loop_store) + ) + + if isinstance(indexing, BlockPtrOptions): + self.post_loop_store.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + assert isinstance(indexing, IndexingOptions) + self.post_loop_store.writeline( + DeferredLine( + name, + f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})", + ) + ) + + exit_stack.close() + + def _lift_helper(self, fn, num_args, dtypes: tuple[torch.dtype, ...]) -> str: + # Lift IR function for scan operations into a triton function + # in the global namespace + helper = IndentedBuffer() + helper.writeline("@triton.jit") + cse = CSE() + + args = [ + tuple(cse.namedvar(f"arg{i}_{n}", dtype=dtypes[n]) for n in range(num_args)) + for i in range(2) + ] + signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args)) + helper.writeline(f"def {{name}}({signature}):") + + overrides = TritonOverrides() + + # Build a name that changes depending on fn to workaround a triton bug + # where the combine_fn to reduce and scan is not hashed, and so different + # scan ops may collide in the triton cache. + # This is fixed with the latest triton pin, but not the triton-rocm pin. + helper_name = "_triton_helper_fn" + + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + + dtype_handler = DtypePropagationOpsHandler() + + class CSEProxy(DefaultHandler): + def _default( + self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Any: + nonlocal helper_name + helper_name += f"_{name}" + + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) + + return cse.generate( + helper, + getattr(overrides, name)(*args, **kwargs), + dtype=output_dtype, + ) + + with helper.indent(), V.set_ops_handler(CSEProxy()): + outputs = fn(*args) + outputs = ", ".join(str(output) for output in outputs) + helper.writeline(f"return {outputs}") + + return self.helper_functions.add(helper.getvalue(), base_name=helper_name) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + assert self.inside_reduction + assert not self.cooperative_reduction, "TODO" + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.scan not supported inside ops.masked" + + broadcasted_values = [] + accumulators = [] + + dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes) + cse_compute = functools.partial(self.cse.generate, self.compute) + combine_helper_fn = self._lift_helper(combine_fn, len(values), dtypes) + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + for value, dtype in zip(values, dtypes): + value_dtype = self.cse.generate( + self.compute, + f"{value}.to({triton_compute_type(dtype)})", + dtype=dtype, + ) + value = self.cse.generate( + self.compute, + f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", + dtype=dtype, + ) + broadcasted_values.append(value) + + acc_type = triton_acc_type(dtype) + + if not self.persistent_reduction: + accumulator = self.cse.newvar(dtype=dtype) + reduced_size = self.dense_size_list() + reduced_size[-1] = "1" + reduced_size = f"[{', '.join(reduced_size)}]" + + default = "float('nan')" if dtype.is_floating_point else "-1" + self.body.writeline( + f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})" + ) + + accumulators.append(accumulator) + + def csv(values): + return " ".join(f"{value}," for value in values) + + def cse_multiple(line, values, masks, dtypes): + n = len(values) + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(self.cse.contains(cache_key) for cache_key in cache_keys): + return [self.cse.get(cache_key) for cache_key in cache_keys] + result_vars = [self.cse.newvar(dtype=_dtype) for _dtype in dtypes] + self.compute.writeline( + f"{csv(result_vars)} = {line}", + ) + for result_var, cache_key in zip(result_vars, cache_keys): + if masks: + result_var.mask_vars = masks # type: ignore[attr-defined] + self.cse.put(cache_key, result_var) + return tuple(result_vars) + + partial_scan_vars = cse_multiple( + f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", + values, + masks, + dtypes, + ) + + if not self.persistent_reduction: + # tl.reduce doesn't work for non-commutative operators, so instead + # of repeating the scan op as a reduction, we use sum to select the + # last scan value + partial_reduce_vars = [ + cse_compute( + f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)", + dtype=upcast_compute_type(partial_scan_var.dtype), + ) + for partial_scan_var in partial_scan_vars + ] + accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars)) + full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars) + result_vars = [ + cse_compute( + f"tl.where(roffset > 0, {full_scan}, {partial_scan})", + dtype=partial_scan.dtype, + ) + for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) + ] + for acc_next, accumulator, partial_reduce in zip( + accs_next, accumulators, partial_reduce_vars + ): + self.compute.writeline( + f"{accumulator} = tl.where(roffset > 0, {acc_next}, {partial_reduce})" + ) + else: + result_vars = partial_scan_vars + + for result_var in result_vars: + assert isinstance(result_var, TritonCSEVariable) + result_var.mask_vars = OrderedSet(masks) + + return tuple(result_vars) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + assert self.inside_reduction + assert not self.cooperative_reduction, "TODO" + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.sort not supported inside ops.masked" + assert self.persistent_reduction, ( + "ops.sort is only supported in persistent reductions" + ) + + cse_compute = functools.partial(self.cse.generate, self.compute) + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes) + assert len(dtypes) == len(values) + broadcasted_values = [ + cse_compute( + f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtypes[i] + ) + for i, value in enumerate(values) + ] + + def csv(values): + return " ".join(f"{value}," for value in values) + + def cse_multiple(line, n, masks, dtypes): + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(self.cse.contains(cache_key) for cache_key in cache_keys): + return [self.cse.get(cache_key) for cache_key in cache_keys] + result_vars = [self.cse.newvar(dtype=dtypes[i]) for i in range(n)] # type: ignore[attr-defined] + self.compute.writeline( + f"{csv(result_vars)} = {line}", + ) + for result_var, cache_key in zip(result_vars, cache_keys): + if masks: + result_var.mask_vars = masks # type: ignore[attr-defined] + self.cse.put(cache_key, result_var) + return tuple(result_vars) + + assert self.range_trees[-1].is_reduction + rnumel = "None" if self._has_constant_mask(self.range_trees[-1]) else "rnumel" + + if len(values) == 2: + line = ( + f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]}," + f" {rnumel}, {dim}, stable={stable}, descending={descending})" + ) + result_vars = cse_multiple(line, len(values), masks, dtypes) + else: + raise AssertionError("Unhandled sort") + + for result_var, input_var in zip(result_vars, values): + result_var.mask_vars = masks # type: ignore[attr-defined] + result_var.bounds = input_var.bounds + + return tuple(result_vars) + + def codegen_body(self): + """ + Concat output code from index_code, loads, compute, stores, + suffix into self.body. + + For pointwise kernels, this is called just once at the end. + + For reduction kernels, this generates a loop over the reduction + axis. + """ + if not ( + self.indexing_code + or self.loads + or self.stores + or self.compute + or self.post_loop_combine + or self.post_loop_store + ): + return + + loop_trees = [tree for tree in self.range_trees if tree.is_loop] + if self.inside_reduction and len(loop_trees) > 0: + # Write the loop headers. + for level, tree in enumerate(loop_trees): + with self.body.indent(offset=level): + prefix = tree.prefix + loop_start = "rsplit_start" if self.cooperative_reduction else "0" + loop_end = ( + "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" + ) + self.body.writeline( + f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + ) + with self.body.indent(offset=level + 1): + self.iteration_ranges_codegen_header(tree, self.body) + + # The innermost loop performs the reduction. + with self.body.indent(offset=len(loop_trees)): + self.codegen_reduction_indices(self.body) + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + # Write loop suffixes. + for level, tree in reversed([*enumerate(loop_trees)]): + with self.body.indent(offset=level + 1): + # Advance pointers at the end of each loop. + for block_ptr, advancement in self.pointer_advancements[ + tree.symt + ].items(): + # Subtract any advancements made in the previous loop level. + if level < len(loop_trees) - 1: + prev_tree = loop_trees[level + 1] + prev_advancement = self.pointer_advancements[ + prev_tree.symt + ][block_ptr] + prev_block = TritonSymbols.get_block_size(prev_tree) + prev_num_iter = CeilDiv(prev_tree.numel, prev_block) + advancement = [ + cur - prev * prev_num_iter + for cur, prev in zip(advancement, prev_advancement) + ] + + self.body.writeline( + DeferredLine( + self.block_ptr_to_buffer[block_ptr], + f"{block_ptr} = tl.advance({block_ptr}, {V.kernel.index_to_str(advancement)})", + ) + ) + + # Invalidate any cache entries that came from inside the loop. + self.cse.invalidate(self.outside_loop_vars) + tree.cache_clear() + else: + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.body.splice(self.post_loop_combine) + if self.cooperative_reduction and ( + self.post_loop_combine or self.post_loop_store + ): + sem_ptr = f"{self.semaphores_name} + tl.program_id(1)" + self.body.splice( + f""" + if HAS_RSPLIT: + triton_helpers.x_grid_barrier({sem_ptr}) + """, + strip=True, + ) + self.cooperative_reduction_workspace_cache.on_loop_end() + self.body.splice(self.post_loop_store) + self.indexing_code.clear() + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_combine.clear() + self.post_loop_store.clear() + + def kernel_benchmark_extra_args(self) -> list[str]: + args = [] + if self.need_numel_args(): + numel_args: list[sympy.Expr] = [] + self.add_numel_to_call_args("", numel_args, []) + for arg in numel_args: + if isinstance(arg, int): + args.append(str(arg)) + elif isinstance(arg, SymbolicCallArg): + args.append(str(V.graph.sizevars.size_hint(arg.inner_expr))) + elif isinstance(arg, sympy.Expr): + args.append(str(V.graph.sizevars.size_hint(arg))) + else: + raise ValueError(f"Unsupported numel argument type: {type(arg)}") + return args + + def codegen_kernel_benchmark(self, num_gb): + result = IndentedBuffer() + _argdefs, call_args, signature, _ = self.args.python_argdefs() + + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.try_get_buffer(arg_name) + if buf: + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.get_current_device_or_throw() + count = V.graph.sizevars.size_hint(arg_sig.count) + result.writeline( + f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})" + ) + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + var_names.extend(self.kernel_benchmark_extra_args()) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + current_device = V.graph.get_current_device_or_throw() + index = current_device.index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline( + "from torch._inductor.runtime.benchmarking import benchmarker" + ) + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self): + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")) + ) + + def _get_heuristic(self): + if self.fixed_config: + return "fixed_config" + elif self.cooperative_reduction: + return "cooperative_reduction" + elif self.persistent_reduction: + assert self.inside_reduction + return "persistent_reduction" + elif self.inside_reduction: + return "reduction" + return "pointwise" + + @staticmethod + def inductor_meta_common(): + inductor_meta = { + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), + "assert_indirect_indexing": config.assert_indirect_indexing, + "autotune_local_cache": config.autotune_local_cache, + "autotune_pointwise": config.triton.autotune_pointwise, + "autotune_remote_cache": config.autotune_remote_cache, + "force_disable_caches": config.force_disable_caches, + "dynamic_scale_rblock": config.dynamic_scale_rblock, + "max_autotune": config.max_autotune, + "max_autotune_pointwise": config.max_autotune_pointwise, + "min_split_scan_rblock": config.triton.min_split_scan_rblock, + "spill_threshold": config.triton.spill_threshold, + "store_cubin": config.triton.store_cubin, + } + if torch.version.hip is not None: + inductor_meta["is_hip"] = True + if config.is_fbcode(): + inductor_meta["is_fbcode"] = True + if config.profile_bandwidth: + inductor_meta["profile_bandwidth"] = config.profile_bandwidth + inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex + inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output + inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = ( + config.profile_bandwidth_with_do_bench_using_profiling + ) + if config.coordinate_descent_tuning: + inductor_meta["coordinate_descent_tuning"] = ( + config.coordinate_descent_tuning + ) + inductor_meta["coordinate_descent_search_radius"] = ( + config.coordinate_descent_search_radius + ) + inductor_meta["coordinate_descent_check_all_directions"] = ( + config.coordinate_descent_check_all_directions + ) + return inductor_meta + + def codegen_kernel(self, name=None): + code = IndentedBuffer() + + size_hints = {} + for prefix, numel in self.numels.items(): + if prefix_is_reduction(prefix) and not self.inside_reduction: + continue + + numel_hint = V.graph.sizevars.symbolic_hint(numel) + if not isinstance(numel_hint, (int, sympy.Integer)): + # This default heuristic hint was picked carefully: it is + # large, to ensure that we don't shrink the block size (since + # if you don't have many elements, it'd be wasteful to pick a + # large block size). Since we don't know how many elements we + # might have, we should be OK with some inefficiency to make + # sure we handle the large case well. 8192 is the largest + # block size we support, so we pick that. + # + # If we have a better hint for unbacked SymInts (e.g., because + # a user told us, or we are tracking upper bounds) we could + # use that here. + size_hint = 8192 + else: + size_hint = next_power_of_2(int(numel_hint)) + size_hints[prefix] = size_hint + + if name is None: + code.splice(gen_common_triton_imports()) + device_type = V.graph.get_current_device_or_throw().type + if device_type == "cpu": + code.splice("triton_helpers.set_driver_to_cpu()") + else: + code.splice("triton_helpers.set_driver_to_gpu()") + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + # maps actual expression to SizeArg if it is in sizevars replacements + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + # mypy is unhappy about the sympy.Expr + # type for the key of the dict below + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + mutated_args: OrderedSet[str] = OrderedSet() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add( + cast(InplacedBuffer, self.args.inplace_buffers[mutation]).inner_name + ) + if mutation in self.args.output_buffers: + mutation_arg = self.args.output_buffers[mutation] + assert not isinstance(mutation_arg, RemovedArg) + mutated_args.add(mutation_arg) + + # Note: [Workspace Mutation] + # workspace arguments are mutated, but are not marked as mutations in self.mutations + # because their buffers are added during codegen, and aren't tracked during + # lowering/scheduling. So we add them as mutated_args explicitly below. + # + # In the logic below, we only mark the workspaces a mutated if they are marked with + # zero_fill: that's because, if we don't expect the buffer to be pre-filled with + # zeros, then, although we still mutate the data, we don't care about those + # mutations because we don't make any assumptions about the contents of the + # workspace buffer. Similarly, ZERO_PER_GRAPH requires the kernel to return + # the buffer back to its original state. + for argname, arg in zip(argdefs, signature): + if ( + isinstance(arg, WorkspaceArg) + and arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL + ): + mutated_args.add(argname.name) + + mutated_args = sorted(mutated_args) + + for tree in self.active_range_trees(): + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) + signature.append(sizearg) + argdefs.append(ArgName(sizearg.name)) + # constexpr version causes issues, see + # https://github.com/pytorch/torchdynamo/pull/1362 + # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( + # tree.numel + # ) + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + + def add_constexpr_arg(arg_name): + # new versions (but not old versions) of Triton need constexprs included in the signature + if triton_version_uses_attrs_dict(): + signature.append(ConstexprArg(arg_name)) + argdefs.append(ArgName(arg_name, is_constexpr=True)) + + for tree in self.range_trees: + if tree.is_reduction and self.persistent_reduction: + # Rn_BLOCK for persistent_reduction is defined in codegen_static_numels + continue + if tree.tensor_dim is None: + continue + + add_constexpr_arg(f"{tree.prefix.upper()}BLOCK") + + if self.cooperative_reduction: + add_constexpr_arg("RSPLIT") + + triton_meta_signature = signature_to_meta( + signature, size_dtype=self.index_dtype, argdefs=argdefs + ) + triton_meta: dict[str, Any] = { + "signature": triton_meta_signature, + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + "constants": {}, + } + + # Skip memory optimization for forward of the training loop where we expect + # every new node will increase the peak memory and our greedy approach would + # introduce a lot of unnecessary cpu copies. + optimize_mem = V.graph.is_inference or V.graph.is_backward + + inductor_meta = { + # Triton will not accept an OrderedSet for autotune_hints + "grid_type": self._get_grid_type().__name__, + "autotune_hints": set(self.autotune_hints), # noqa: set_linter + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + "optimize_mem": optimize_mem, + "no_x_dim": self.no_x_dim, + "num_load": self.num_load, + "num_reduction": self.num_reduction, + **self.inductor_meta_common(), + } + if self.tiling_scores: + inductor_meta["tiling_scores"] = self.tiling_scores + + if self.cooperative_reduction: + inductor_meta["persistent_reduction"] = self.persistent_reduction + + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + triton_meta["configs"] = [config_of(signature)] + + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + for arg_num in equal_1_arg_indices(signature): # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + + self.triton_meta = triton_meta + + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + if self.fixed_config: + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + config={self.fixed_config.config!r}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + elif self.inside_reduction: + reduction_hint = self.features.get_reduction_hint() + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if ( + len(non_constexpr_signature(signature)) == 4 + ): # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + @staticmethod + def _get_persistent_RBLOCK(rnumel): + rnumel = V.graph.sizevars.simplify(rnumel) + if isinstance(rnumel, (sympy.Integer, int)): + val = int(rnumel) + val = next_power_of_2(val) + else: + val = 128 + while not V.graph.sizevars.statically_known_leq(rnumel, val): + if val > 16 * 1024: + raise ValueError(f"Failed to find static RBLOCK for {rnumel}") + val *= 2 + return val + + @staticmethod + def has_persistent_RBLOCK(rnumel): + try: + TritonKernel._get_persistent_RBLOCK(rnumel) + return True + except ValueError: + return False + + def codegen_static_numels(self, code): + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + + We would add + xnumel = 4096 + r0_numel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + + def is_static_integer(expr: sympy.Expr) -> bool: + return isinstance(expr, (sympy.Integer, int)) + + for tree in self.range_trees: + if not tree.is_reduction or self.inside_reduction: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if is_static_integer(simplified_tree_numel): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + + if tree.is_reduction and self.persistent_reduction: + if self.cooperative_reduction: + numel = self.kexpr(self.rename_indexing(tree.numel)) + val = f"triton_helpers.constexpr_next_power_of_2(({numel} + RSPLIT - 1) // RSPLIT)" + else: + val = self._get_persistent_RBLOCK(tree.numel) + code.writeline(f"{tree.prefix.upper()}BLOCK: tl.constexpr = {val}") + + if tree.prefix == "x" and self.no_x_dim: + code.writeline("XBLOCK: tl.constexpr = 1") + + def _get_grid_type(self) -> type[triton_heuristics.GridExpr]: + n = sum([int(not tree.is_reduction) for tree in self.range_trees]) + if self.cooperative_reduction: + assert n == 1 + return triton_heuristics.CooperativeReductionGrid + elif n == 1: + return triton_heuristics.Grid1D + elif n == 2: + if any(map(self.needs_yz_grid_overflow, self.range_trees)): + return triton_heuristics.Grid2DWithYZOverflow + return triton_heuristics.Grid2D + elif n == 3: + return triton_heuristics.Grid3D + raise ValueError(f"Unsupported number of dimensions: {n}") + + def add_numel_to_call_args(self, name, call_args, arg_types): + # TODO(jansel): if there are constants, we shouldn't bother passing them as args + for tree in self.range_trees: + if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr(name, tree) + + if not tree.is_reduction or self.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def call_kernel(self, name: str, node: Optional[IRNode] = None): + wrapper = V.graph.wrapper_code + wrapper.write_triton_header_once() + _, call_args, _, arg_types = self.args.python_argdefs() + self.add_numel_to_call_args(name, call_args, arg_types) + + for ws in self.args.workspace_args: + wrapper.generate_workspace_allocation(ws) + + wrapper.generate_kernel_call( + name, + call_args, + triton=True, + arg_types=arg_types, + triton_meta=self.triton_meta, + ) + + for ws in reversed(self.args.workspace_args): + wrapper.generate_workspace_deallocation(ws) + + def codegen_nan_check(self) -> None: + wrapper = V.graph.wrapper_code + _, call_args, arg_signatures, _ = self.args.python_argdefs() + for arg, arg_signature in zip(call_args, arg_signatures): + if isinstance(arg_signature, TensorArg): + if V.graph.cpp_wrapper: + wrapper.writeline( + f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' + ) + else: + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + def create_cse_var(self, *args, **kwargs) -> TritonCSEVariable: + return TritonCSEVariable(*args, **kwargs) + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}" + if entry.root.is_loop: + self.indexing_code.writeline(line) + else: + # lift non-reduction stores outside loop + self.body.writeline(line) + + def iteration_ranges_ranges_code(self, entry: IterationRangesRoot) -> str: + assert entry.tensor_dim is not None + size = self.indexing_size_str(entry.tensor_dim) + index_dtype = self.index_dtype + suffix = f".to({index_dtype})" if index_dtype != "tl.int32" else "" + if ( + self.cooperative_reduction + and self.persistent_reduction + and entry.is_reduction + ): + suffix = f"{suffix} + rsplit_start" + return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{suffix}" + + def iteration_ranges_scalar_code( + self, entry: IterationRangesRoot, value: Any + ) -> str: + index_dtype = self.index_dtype + ndim = self.triton_tensor_ndim() + size = [1] * ndim + return f"tl.full({size}, {value}, {index_dtype})" + + def iteration_ranges_get_pid(self, entry: IterationRangesRoot) -> str: + assert entry.grid_dim is not None + key = f"tl.program_id({entry.grid_dim})" + # y_grid has a limit, so express it in terms of y and z in case of overflow. + # z grid is only exercised when max_tiles == 3 (off by default). + if self.needs_yz_grid_overflow(entry): + # For ynumel larger than max_ygrid, we need to use zdim. + # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z). + # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset. + key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))" + pid = entry.pid_cache.get(key, key) + if self.index_dtype != "tl.int32": + return f"{pid}.to({self.index_dtype})" + return pid + + def needs_yz_grid_overflow(self, entry: IterationRangesRoot) -> bool: + return ( + entry.grid_dim == 1 + and not entry.has_zdim + and not self.cooperative_reduction + and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid()) + ) + + def max_block(self, prefix: str) -> int: + if self.fixed_config: + return self.fixed_config[f"{prefix.upper()}BLOCK"] + return TRITON_MAX_BLOCK[prefix.upper()] + + def _has_constant_mask(self, tree: IterationRangesRoot) -> bool: + if not self.optimize_mask: + return False + + if self.fixed_config and f"{tree.prefix.upper()}BLOCK" in self.fixed_config: + if self.fixed_config[f"{tree.prefix.upper()}BLOCK"] == 1: + return True + else: + if V.graph.sizevars.statically_known_equals(tree.numel, 1): + return True + + # Masks are superfluous if numel is a multiple of BLOCK + # (We use the fact that BLOCK is required by triton to be a power of 2) + if tree.is_reduction and self.persistent_reduction: + max_block = self._get_persistent_RBLOCK(tree.numel) + elif tree.prefix == "x" and self.no_x_dim: + max_block = 1 + else: + max_block = self.max_block(tree.prefix) + + if tree.is_reduction and self.cooperative_reduction: + max_block = max_block * self.max_rsplit() + + # [Note: Constant mask optimisation] + # Optional optimization: if block divides numel exactly, we will + # never need to do a masked load to handle stragglers at the end. + # If this tree is for the y dimension, we should only use a constant + # mask if it can be guaranteed that: + # 1. (ynumel / YBLOCK) < max_ygrid or + # 2. (ynumel / YBLOCK) % max_ygrid == 0 + # Because YBLOCK is not constant, use a conservative heuristic: + # only use a constant mask if ynumel < max_ygrid. + # It's faster to avoid masking at all. But it is sound to always + # mask. + if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): + return ( + tree.grid_dim != 1 + or tree.has_zdim + or V.graph.sizevars.statically_known_leq(tree.numel, get_max_y_grid()) + ) + + return False + + def _has_constant_xmask(self) -> bool: + xtree = self.range_trees[0] + assert xtree.prefix == "x" + return self._has_constant_mask(xtree) + + def filter_masks(self, mask_vars: OrderedSet[str]) -> None: + for tree in self.range_trees: + if self._has_constant_mask(tree): + mask_vars.discard(f"{tree.prefix}mask") + + # can be added as an override_mask + mask_vars.discard("None") + + @cache_on_self + def get_reduction_prefixes(self) -> list[str]: + return [ + prefix_str[symt] + for symt in list(TritonSymbols.reduction_types)[: self.num_reduction_dims] + ] + + def codegen_reduction_numels(self, buffer: IndentedBuffer) -> None: + """ + Generates code that flattens ND reduction numels, block sizes, etc. into 1D. + """ + # rnumel = r0_numel * ... * r(n-1)_numel + reduction_trees = [tree for tree in self.range_trees if tree.is_reduction] + rnumel = " * ".join(sorted(f"{tree.prefix}numel" for tree in reduction_trees)) + buffer.splice(f"rnumel = {self.kexpr(rnumel)}") + + # RBLOCK = R0_BLOCK * ... * R(N-1)_BLOCK + rn_blocks = [ + TritonSymbols.block_sizes[tree.symt] + for tree in self.range_trees + if tree.is_reduction + ] + rblock = sympy_product(rn_blocks) + buffer.splice(f"RBLOCK: tl.constexpr = {self.kexpr(rblock)}") + + def _get_reduction_symbols(self, suffix: str, **kwargs) -> list[sympy.Symbol]: + """ + Helper to initialize symbols like rn_numel, rn_base, etc. + """ + rn_prefixes = self.get_reduction_prefixes() + return [sympy.Symbol(f"{prefix}{suffix}", **kwargs) for prefix in rn_prefixes] + + @cache_on_self + def _get_reduction_index_coeffs(self) -> list[sympy.Expr]: + """ + Compute coefficients to convert ND reduction indices to linear indices. + For example: + rindex = r0_index * r1_numel * ... * rn_numel + ... + rn_index. + """ + rn_prefixes = self.get_reduction_prefixes() + rn_numels = self._get_reduction_symbols("numel", integer=True, positive=True) + return [ + sympy_product(rn_numels[idx + 1 :]) for idx in range(len(rn_prefixes) - 1) + ] + [sympy.Integer(1)] + + def _flatten_reduction_indices(self, multi_inds: list[sympy.Expr]) -> sympy.Expr: + """ + Compute linear reduction indices from N dimensional ones. + """ + coeffs = self._get_reduction_index_coeffs() + return sympy_dot(coeffs, multi_inds) + + def codegen_reduction_indices(self, buffer: IndentedBuffer) -> None: + """ + Generates code that converts ND reduction indices into linear indices. + """ + # Gather relevant numels, indices, etc. + rn_offsets = self._get_reduction_symbols( + "offset", integer=True, nonnegative=True + ) + rn_inds = self._get_reduction_symbols("index", integer=True, nonnegative=True) + + # Compute roffset and rindex. + roffset = self._flatten_reduction_indices(rn_offsets) + buffer.splice(f"roffset = {self.index_to_str(roffset)}") + rindex = self._flatten_reduction_indices(rn_inds) + buffer.splice(f"rindex = {self.index_to_str(rindex)}") + + def iteration_ranges_codegen_header( + self, entry: IterationRangesRoot, code: IndentedBuffer + ) -> None: + x = entry.prefix + if entry.is_loop: + code.writeline(f"{entry.name} = {x}offset + {x}base") + elif entry.grid_dim is None: + # no need to "{x}offset = " + code.writeline(f"{entry.name} = {self.iteration_ranges_ranges_code(entry)}") + code.writeline(f"{x}offset = 0") + else: + if entry.tensor_dim is not None: + line = f"{x}offset + {self.iteration_ranges_ranges_code(entry)}" + else: + line = self.iteration_ranges_scalar_code(entry, f"{x}offset") + code.writelines( + [ + f"{x}offset = {self.iteration_ranges_get_pid(entry)} * {x.upper()}BLOCK", + f"{entry.name} = {line}", + ] + ) + + if self._has_constant_mask(entry): + sizes = self.dense_size_str() + code.writeline(f"{x}mask = tl.full({sizes}, True, tl.int1)") + else: + code.writeline(f"{x}mask = {entry.name} < {x}numel") + + +class TritonScheduling(SIMDScheduling): + kernel_type: type[Any] = TritonKernel + backend_features = OrderedSet( + [ + BackendFeature.FOREACH, + BackendFeature.BUCKETIZE, + BackendFeature.INPLACE_BUFFERS, + BackendFeature.MASKED_SCATTER_WITH_INDEX, + BackendFeature.SCAN, + BackendFeature.SORT, + BackendFeature.TRITON_TEMPLATES, + BackendFeature.TUPLE_REDUCTION, + ] + ) + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + if scheduler is None or not hasattr(scheduler, "nodes"): + return + for node in scheduler.nodes: + if isinstance(node, (SchedulerNode, FusedSchedulerNode)): + node.debug_device_str = debug_triton_code + + @classmethod + def get_backend_features(cls, device: torch.device): + if ( + config.triton.cooperative_reductions + or config.triton.force_cooperative_reductions + ): + return OrderedSet( + [*cls.backend_features, BackendFeature.REDUCE_TO_SINGLE_ELEMENT] + ) + return cls.backend_features + + def codegen_comment(self, node_schedule): + wrapper = V.graph.wrapper_code + origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper) + if origins: + wrapper.make_comment(origins) + + if config.debug_fusion: + from torch._inductor.scheduler import ( + BaseSchedulerNode, + ForeachKernelSchedulerNode, + ) + + if not any( + isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule + ): + # We probably should look what are the nodes inside a foreach + # schedule node + node_names = [ + n.get_name() + for n in node_schedule + if isinstance(n, BaseSchedulerNode) + ] + wrapper.make_comment( + f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" + ) + + def define_kernel(self, src_code, node_schedule, kernel): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] + ) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + + # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name + # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set + # to "triton_" to maximize caching opportunities (when unique_kernel_names = False). + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "#") + + _basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") + compile_wrapper = IndentedBuffer() + + if async_compile.use_process_pool(): + # The process pool is warm, we can shell out to workers right away. This + # allows us to save the result in async_compile.CompiledTritonKernels, + # so that the second time we call async_compile.triton, we do no work. + async_compile.triton(subs_name, src_code) + + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + + compile_wrapper.splice(src_code, strip=True) + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + + # log kernel metadata for offline analysis. + # E.g. one can find all unaligned inner reduction and check if + # padding helps with the perf kernel by kernel. + if metrics.is_metric_table_enabled("kernel_metadata"): + metrics.log_kernel_metadata(kernel_name, kernel_path, src_code) + + return kernel_name + + def benchmark_fused_nodes(self, nodes, n_spills_threshold=8) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True) + mod = PyCodeCache.load(src_code) + return self.benchmark_codegened_module( + mod, n_spills_threshold, node_names=OrderedSet(n.get_name() for n in nodes) + ) + + def benchmark_codegened_module( + self, mod, n_spills_threshold=8, node_names: Optional[OrderedSet[str]] = None + ) -> tuple[float, str]: + """Benchmark an already compiled module""" + device_interface = get_interface_for_device(V.graph.device_type) + with ( + preserve_rng_state(), + device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined] + ): + ms = None + + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def store_cache(): + path = cache_file_path() + write_atomic(path, str(ms)) + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return float(fd.read()) + return None + + node_names = ( + node_names if node_names is not None else OrderedSet(["unknown"]) + ) + log.debug( + "kernel src code for %s written to: %s", + node_names, + mod.__file__, + ) + ms = load_cache() + if ms is not None: + return ms, mod.__file__ + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + # call once to trigger the compilation + try: + call(wrapped_jit_function.clone_args(*args)[0]) + except Exception as e: + if config.triton.disallow_failing_autotune_kernels_TESTING_ONLY: + raise + log.debug( + "Exception (%s) in compiling fused nodes %s", + e, + node_names, + ) + ms = float("inf") + store_cache() + return ms, mod.__file__ + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + # n_spills does not necessarily mean it's not profitable to fuse, + # and sometimes it can be inaccurate + if launchers[0].n_spills > n_spills_threshold: + # skip benchmarking the kernel if there are register spills + ms = float("inf") + else: + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = benchmarker.benchmark_gpu( + lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ) + # overhead of cloning args gives bias for fusing the kernel + # in the case of mutating/in-placeable second fusion + # TODO - would be better as a hook in triton do_bench that reset + # the input values between benchmarking + if len(wrapped_jit_function.mutated_arg_names) > 0: + ms = ms - benchmarker.benchmark_gpu( + lambda: wrapped_jit_function.clone_args(*args) + ) + + log.debug( + "The fused kernel for %s took %.3f ms to run", + node_names, + ms, + ) + store_cache() + return ms, mod.__file__ + + def create_kernel_choices( # type: ignore[override] + self, + kernel_features: SIMDKernelFeatures, + kernel_args: list[Any], + kernel_kwargs: dict[str, Any], + ) -> list[TritonKernel]: + is_scan = kernel_features.contains_op("scan") + is_split_scan = is_scan and any( + node.is_split_scan() for node in kernel_features.scheduler_nodes() + ) + kernel_type: type[TritonKernel] = self.kernel_type + if is_split_scan: + from .triton_split_scan import TritonSplitScanKernel + + kernel_type = TritonSplitScanKernel + + if is_scan: + # TODO(jansel): scan does not yet work with cooperative reductions + kernel_kwargs["override_cooperative_reduction"] = False + + # ops.sort only works with persistent reduction, and is not bandwidth bound anyway + # so taking the hit of non-coalesced loads is okay + if kernel_features.contains_op("sort"): + kernel_kwargs["override_persistent_reduction"] = True + kernel_kwargs["override_cooperative_reduction"] = False + + if not TritonKernel.has_persistent_RBLOCK(kernel_features.reduction_numel): + # Cannot use persistent reduction with unknown dynamic rnumel + assert not kernel_kwargs.get("override_persistent_reduction") + kernel_kwargs["override_persistent_reduction"] = False + + kernel_kwargs = V.choices.triton_kernel_kwargs( + kernel_type, kernel_features, kernel_args, kernel_kwargs + ) + kernel = kernel_type(*kernel_args, **kernel_kwargs) + return self.add_multi_kernel_choices(kernel, kernel_args, kernel_kwargs) + + def add_multi_kernel_choices( + self, + kernel: TritonKernel, + kernel_args: list[Any], + kernel_kwargs: dict[str, Any], + ) -> list[TritonKernel]: + kernels: list[TritonKernel] = [kernel] + if not config.triton.multi_kernel: + return kernels + + optional_persistent = kernel.persistent_reduction and not kernel_kwargs.get( + "override_persistent_reduction" + ) + optional_cooperative = kernel.cooperative_reduction and not kernel_kwargs.get( + "override_cooperative_reduction" + ) + if optional_persistent: + kernels.append( + self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_persistent_reduction=False, + ) + ) + if optional_cooperative: + rnumel = kernel.features.reduction_numel + # for larger sizes non-cooperative gets very slow + if V.graph.sizevars.statically_known_leq(rnumel, 65536): + kernels.append( + other := self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_cooperative_reduction=False, + ) + ) + if optional_persistent and other.persistent_reduction: + kernels.append( + self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_cooperative_reduction=False, + override_persistent_reduction=False, + ) + ) + + if len(kernels) > 1: + for kernel2 in kernels[1:]: + # Keep buffers needed by the non-persistent reduction so both kernels have the same arguments + kernel2.must_keep_buffers = kernel.must_keep_buffers + # persistent kernels must be generated last so must_keep_buffers works right + kernels.sort(key=lambda k: k.persistent_reduction) + return kernels + + def benchmark_combo_kernel(self, node_list): + mod: ModuleType + ms: float + ms_clone: float + + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return tuple(float(e) for e in fd.read().split()) + return (None, None) + + def store_cache(): + path = cache_file_path() + write_atomic(path, str(ms) + " " + str(ms_clone)) + + total_ms, file_list = 0, [] + total_clone_ms: float = 0.0 + removed_buffers_orig = V.graph.removed_buffers + V.graph.removed_buffers = OrderedSet(removed_buffers_orig) + inplaced_to_remove_orig = V.graph.inplaced_to_remove + V.graph.inplaced_to_remove = OrderedSet(inplaced_to_remove_orig) + enable_autotune = config.combo_kernels_autotune > 0 + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 0 + kernel_code_list = self.generate_combo_kernel_code( + subkernel_nodes=node_list, + custom_part_algorithm=True, + enable_autotune=enable_autotune, + mixed_sizes=mixed_sizes, + only_gen_src_code=True, + ) + + for src_code, _, node_group in kernel_code_list: + fused_node_lists = [node.get_nodes() for node in node_group] + names = [n.get_name() for nodes in fused_node_lists for n in nodes] + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + mod = PyCodeCache.load(src_code) + + log.debug( + "kernel src code for %s written to: %s", + names, + mod.__file__, + ) + ms, ms_clone = load_cache() + if ms is not None: + total_ms += ms # type: ignore[assignment] + total_clone_ms += ms_clone + file_list.append(mod.__file__) + continue + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + + # call once to trigger the compilation + call(wrapped_jit_function.clone_args(*args)[0]) + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + if launchers[0].n_spills > 0: + # skip benchmarking the kernel if there are register spills + ms = ms_clone = float("inf") + else: + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = benchmarker.benchmark_gpu( + lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ) + ms_clone = benchmarker.benchmark_gpu( + lambda: wrapped_jit_function.clone_args(*args)[0] + ) + + log.debug( + "The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs", + OrderedSet(n.get_name() for n in node_group), + ms, + ms_clone, + ) + store_cache() + total_ms += ms + total_clone_ms += ms_clone + file_list.append(mod.__file__) + V.graph.removed_buffers = removed_buffers_orig + V.graph.inplaced_to_remove = inplaced_to_remove_orig + return total_ms, total_clone_ms, file_list + + +def debug_triton_code(node: BaseSchedulerNode) -> list[str]: + lines = [] + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: + lines.append(f"{node.get_name()} Unfinalized multi template buffer") + else: + from torch._inductor.codegen.cuda_combined_scheduling import ( + CUDACombinedScheduling, + ) + + device = node.get_device() + assert device is not None + backend = node.scheduler.get_backend(device) + assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), ( + f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}" + ) + + with V.graph.set_current_device(device): + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes( + node.get_nodes() + ).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{node.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return lines diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2392119cc5118e8d50508d69119087c3aa27e5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_combo_kernel.py @@ -0,0 +1,978 @@ +import itertools +import logging +import textwrap +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, cast, Optional, Union + +import sympy +from sympy import Integer, Symbol + +from torch.utils._ordered_set import OrderedSet + +from .. import config, metrics +from ..runtime.hints import DeviceProperties +from ..runtime.runtime_utils import next_power_of_2 +from ..runtime.triton_heuristics import ( + RoundRobinComboKernelGrid, + SequentialComboKernelGrid, +) +from ..scheduler import BaseSchedulerNode +from ..utils import Placeholder, triton_version_uses_attrs_dict +from ..virtualized import V +from .common import ( + ArgName, + ConstexprArg, + DeferredLine, + IndentedBuffer, + InplacedBuffer, + Kernel, + PythonPrinter, + RemovedArg, + SizeArg, + WorkspaceArg, +) +from .simd import prefix_is_reduction, SIMDScheduling +from .simd_kernel_features import SIMDKernelFeatures +from .triton import gen_common_triton_imports, TritonKernel +from .triton_utils import config_of, signature_to_meta + + +log = logging.getLogger(__name__) +pexpr = PythonPrinter().doprint +LARGE_NUMELS = 512e5 +BLOCK_UTILIZATION = 0.8 + + +def _default_custom_combo_kernel_horizontal_partition( + nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + kernel_map: dict[BaseSchedulerNode, TritonKernel], + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], +) -> list[list[BaseSchedulerNode]]: + """Horizontally partition the given list of nodes into a list of list of nodes where each sublist + represents a partition. Nodes in different partitions are implemented in different combo kernels. + Nodes in the same partition are likely to be implemented + in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args. + + Input arguments: + nodes: a list of fused scheduler nodes to partition. + triton_scheduling: TritonScheduling instance. + kernel_map: a map from node to its kernel. + node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel). + Output: + a list of list of nodes with each sublist representing a partition. + + The default algorithm is to partition nodes based on the following rules: + 1) nodes with the same number of block dimensions are grouped together. + 2) large pointwise nodes (numels greater than LARGE_NUMELS) are separated from other nodes. + 3) large reduce nodes are separated from other nodes. + """ + + assert len(nodes) >= 1 + + # first partition nodes based on number of block dimensions + tilings = [node_info_map[n][1] for n in nodes] + + max_dims = max(len(t) for t in tilings) + nodes_per_ndim: list[list[BaseSchedulerNode]] = [] + for i in range(2, max_dims + 1): + group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i] + reduction = [ + n + for n in group_per_dim + if kernel_map[n].inside_reduction + and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim) + ] + not_reduction = [n for n in group_per_dim if n not in reduction] + # rnumel > 2048 usually has long execution time + # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes + long_reduction = [ + n + for n in reduction + if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type] + ] + short_reduction = [n for n in reduction if n not in long_reduction] + if long_reduction: + log.warning( + "ComboKernels: %d long reduction nodes are separated", + len(long_reduction), + ) + large_pointwise = [ + n + for n in not_reduction + if not kernel_map[n].inside_reduction + and len(kernel_map[n].numels) == 2 + and V.graph.sizevars.size_hint(kernel_map[n].numels["x"]) > LARGE_NUMELS + ] + if large_pointwise: + # TODO benchmark the performance when large pointwise nodes combining with others + log.warning( + "ComboKernels: %d large pointwise nodes are separated", + len(large_pointwise), + ) + not_reduction = [n for n in not_reduction if n not in large_pointwise] + nodes_per_ndim.extend([node] for node in large_pointwise) + + nodes_per_ndim.extend( + g for g in (not_reduction, short_reduction, long_reduction) if g + ) + + assert sum(len(p) for p in nodes_per_ndim) == len(nodes) + return nodes_per_ndim + + +_custom_combo_kernel_horizontal_partition_algorithm: Callable[ + [ + list[BaseSchedulerNode], + SIMDScheduling, + dict[BaseSchedulerNode, TritonKernel], + dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + ], + list[list[BaseSchedulerNode]], +] = _default_custom_combo_kernel_horizontal_partition + + +def set_custom_combo_kernel_horizontal_partition( + algorithm: Callable[ + [ + list[BaseSchedulerNode], + SIMDScheduling, + dict[BaseSchedulerNode, TritonKernel], + dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + ], + list[list[BaseSchedulerNode]], + ], +) -> None: + """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions + are implemented in different combo kernels. Nodes in the same partition are likely to be implemented + in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args. + + The algorithm should take a list of nodes and return a list of list of nodes. + + The default algorithm is to partition nodes based on number of block dimensions. + """ + global _custom_combo_kernel_horizontal_partition_algorithm + _custom_combo_kernel_horizontal_partition_algorithm = algorithm + + +@dataclass +class PartitionState: + partitions: list[list[BaseSchedulerNode]] + cur_partition: list[BaseSchedulerNode] + cur_count: int + + def finalize(self) -> None: + if self.cur_partition: + self.partitions.append(self.cur_partition) + + +class ComboKernel(Kernel): + MAX_NUM_ARGS = 250 # number where I would no longer get triton errors + + @staticmethod + def _update_partition( + partition_state: PartitionState, + node_rw_count: int, + node_info: BaseSchedulerNode, + ) -> None: + if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS: + partition_state.partitions.append(partition_state.cur_partition) + partition_state.cur_partition = [node_info] + partition_state.cur_count = node_rw_count + else: + partition_state.cur_count += node_rw_count + partition_state.cur_partition.append(node_info) + + @staticmethod + def _base_horizontal_partition( + subkernel_nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + custom_algorithm: bool, + ) -> list[list[BaseSchedulerNode]]: + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + # TODO support combination of kernels with different block dimensions + assert len(subkernel_nodes) >= 1 + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( + config.combo_kernel_allow_mixed_sizes == 1 and custom_algorithm + ) + + ndim_to_partition_state: dict[int, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + yelem_to_partition_state: dict[int, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + + for node in subkernel_nodes: + _node_schedule, tiled_groups, _numel, _rnumel = node_info_map[node] + node_info = node + + read_writes = node.read_writes + read_write_count = len(read_writes.reads) + len(read_writes.writes) + + ndim = len(tiled_groups) + assert ndim >= 2, f"Combokernel not support tile {tiled_groups}" + if not mixed_sizes and ndim == 3: + y_elem = tiled_groups["y"] + partition_state = yelem_to_partition_state[y_elem] + ComboKernel._update_partition( + partition_state, read_write_count, node_info + ) + else: + assert mixed_sizes or ndim <= 3, f"No mixed sizes: tile {tiled_groups}" + partition_state = ndim_to_partition_state[ndim] + ComboKernel._update_partition( + partition_state, read_write_count, node_info + ) + + all_partitions = [] + for partition_state in ndim_to_partition_state.values(): + partition_state.finalize() + all_partitions.extend(partition_state.partitions) + for partition_state in yelem_to_partition_state.values(): + partition_state.finalize() + all_partitions.extend(partition_state.partitions) + + return all_partitions + + @staticmethod + def horizontal_partition( + nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + kernel_map: dict[BaseSchedulerNode, TritonKernel], + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + custom_algorithm: bool = False, + ) -> list[list[BaseSchedulerNode]]: + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum) + for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into + sublists in the following way: + 1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True + 2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is + guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same + 2D or 1D blocking strategy. + """ + if custom_algorithm: + raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm( + nodes, triton_scheduling, kernel_map, node_info_map + ) + else: + raw_partitions = [nodes] + + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + all_partitions = [] + for raw_partition in raw_partitions: + all_partitions.extend( + ComboKernel._base_horizontal_partition( + raw_partition, triton_scheduling, node_info_map, custom_algorithm + ) + ) + return all_partitions + + class SequentialDispatch: + """ + The dispatcher which dispatches the subkernels in a sequential manner: + the blocks are first dispatched to the 1st subkernel (until it is filled), + then to the 2nd subkernel, and so on. + The class defines the methods specific to the dispatch algorithm. + Methods: + codegen_pid_range(...): codegen the pid range for each subkernel. + grid(...): codegen the grid size for launching the combo kernel. + """ + + grid_expr = SequentialComboKernelGrid + + @classmethod + def codegen_pid_range( + cls, kernel: "ComboKernel", num: int, code: IndentedBuffer + ) -> None: + if num == 0: + cls._calculate_xblocks(kernel, code) + code.splice(f"if pid < num_xblocks_{num}:") + with code.indent(): + code.splice("pid_offset = pid") + else: + code.splice(f"elif pid < num_xblocks_{num}:") + with code.indent(): + code.splice(f"pid_offset = pid - num_xblocks_{num - 1}") + + @classmethod + def _calculate_xblocks( + cls, kernel: "ComboKernel", code: IndentedBuffer + ) -> None: + x_numels_list = kernel.x_numels_list + for i in range(len(x_numels_list)): + xnumels, no_x_dim = ( + (x_numels_list[i], False) + if isinstance(x_numels_list[i], str) + and cast(str, x_numels_list[i])[0] != "-" + or ( + isinstance(x_numels_list[i], int) + and cast(int, x_numels_list[i]) > 0 + ) + else (kernel.min_x_blocks_list[i], True) + ) + xblock_str = ( + f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}" + ) + if i == 0: + code.splice(f"num_xblocks_{i} = {xblock_str}") + else: + code.splice(f"num_xblocks_{i} = num_xblocks_{i - 1} + {xblock_str}") + + class RoundRobinDispatch: + """ + The dispatcher which dispatches the subkernels in a round robin manner: + the blocks are interleavedly dispatched to each subkernel to execute them + in parallel. + The class defines the methods specific to the dispatch algorithm. + Methods: + codegen_pid_range(...): codegen the pid range for each subkernel. + grid(...): codegen the grid size for launching the combo kernel. + """ + + grid_expr = RoundRobinComboKernelGrid + + @classmethod + def codegen_pid_range( + cls, kernel: "ComboKernel", num: int, code: IndentedBuffer + ) -> None: + num_kernels = len(kernel.sub_kernels) + if num == 0: + cond = "if" + else: + cond = "elif" + code.splice(f"{cond} pid % {num_kernels} == {num}:") + with code.indent(): + code.splice(f"pid_offset = pid // {num_kernels}") + + def __init__( + self, enable_autotune: bool = False, mixed_sizes: bool = False + ) -> None: + super().__init__() + self.sub_kernels: list[TritonKernel] = [] + self.iter_vars_count = itertools.count() + self.grids: list[list[int]] = [] + self.min_x_blocks_list: list[Union[int, str]] = [] + self.x_numels_list: list[Union[int, str]] = [] + self.enable_autotune = enable_autotune + self.mixed_sizes = mixed_sizes + self.dispatch_class: Optional[ + type[Union[ComboKernel.SequentialDispatch, ComboKernel.RoundRobinDispatch]] + ] = None + self.block_args: list[str] = [] + # there following are used when autotuning is disabled + self.block_size_1d = 1024 # Try tuning this value + self.block_size_2d = 32 + self.num_warps = 8 + self.block_size_reduce = 256 + self.dynamic_shape_args: list[str] = [] + + def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: + sub_kernel = triton_kernel + metrics.generated_kernel_count -= 1 + sub_kernel.args = self.args + sub_kernel.iter_vars_count = self.iter_vars_count + sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids + self.sub_kernels.append(sub_kernel) + return sub_kernel + + @staticmethod + def create_triton_kernel( + tiling: dict[str, sympy.Expr], + features: SIMDKernelFeatures, + optimize_mask: bool, + ) -> TritonKernel: + """ + Only allow optimize_mask=True when 1) sequential dispatch is used, + 2) numels except x dimension are the same for each sub kernel. + """ + return TritonKernel( + tiling, + features=features, + pid_cache={"tl.program_id(0)": "pid_offset"}, + optimize_mask=optimize_mask, + # foreach kernels don't work with cooperative reductions + override_cooperative_reduction=False, + ) + + def codegen_static_numels_sub_kernel( + self, code: IndentedBuffer, sub_kernel: TritonKernel, num: int + ) -> list[str]: + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + + We would add + xnumel = 4096 + rnumel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + grid = [] + uniquify_block_sizes = [] + for tree in sub_kernel.range_trees: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if isinstance(simplified_tree_numel, (Integer, int)): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + else: + assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args + uniquify_block_sizes.append(f"{tree.prefix}numel") + + if not tree.is_reduction: + if isinstance(simplified_tree_numel, (Integer, int)): + grid.append(int(simplified_tree_numel)) + else: + grid.append(f"{tree.prefix}numel_{num}") + + if tree.is_reduction and sub_kernel.persistent_reduction: + if isinstance(simplified_tree_numel, (Integer, int)): + val = int(simplified_tree_numel) + else: + raise RuntimeError( + "Dynamic shape on reduction dimension is not supported" + ) + val = next_power_of_2(val) + code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}") + code.writeline(f"R0_BLOCK_{num}: tl.constexpr = {val}") + uniquify_block_sizes.append("R0_BLOCK") + + if tree.prefix == "x" and sub_kernel.no_x_dim: + code.writeline(f"XBLOCK_{num}: tl.constexpr = 1") + uniquify_block_sizes.append("XBLOCK") + self.grids.append(grid) + return uniquify_block_sizes + + def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None: + """ + Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks. + Grid calculation needs to make sure that they are assigned with enough number of blocks. + """ + min_x_blocks: Union[int, str] = 0 + x_numels: Union[int, str] = 0 + for tree in sub_kernel.range_trees: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if tree.prefix == "x": + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" + if sub_kernel.no_x_dim: + min_x_blocks = x_numels + x_numels = ( + -min_x_blocks + if isinstance(x_numels, int) + else "-" + cast(str, x_numels) + ) + else: + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" + self.min_x_blocks_list.append(min_x_blocks) + self.x_numels_list.append(x_numels) + + def select_heuristics(self, sub_kernel: TritonKernel) -> tuple[str, dict[str, int]]: + size_hints = { + prefix: next_power_of_2(V.graph.sizevars.size_hint(numel)) + for prefix, numel in sub_kernel.numels.items() + if not prefix_is_reduction(prefix) or sub_kernel.inside_reduction + } + if sub_kernel.persistent_reduction: + assert sub_kernel.inside_reduction + heuristics = "persistent_reduction" + elif sub_kernel.inside_reduction: + heuristics = "reduction" + else: + heuristics = "pointwise" + return heuristics, size_hints + + def select_combo_heuristics( + self, heuristics_list: list[str], size_hints_list: list[dict[str, int]] + ) -> tuple[str, dict[str, int], TritonKernel]: + if not self.enable_autotune: + return "foreach", size_hints_list[0], self.sub_kernels[0] + if "reduction" in heuristics_list: + i, _ = max( + enumerate(size_hints_list), + key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "reduction" else 0, + ) + return heuristics_list[i], size_hints_list[i], self.sub_kernels[i] + elif "pointwise" in heuristics_list: + i, _ = max( + enumerate(size_hints_list), + key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "pointwise" else 0, + ) + # modify size_hint to avoid oom check fail (may be a false alarm) + num_pointwise = len([e for e in heuristics_list if e == "pointwise"]) + num_reduction = len([e for e in heuristics_list if e == "reduction"]) + num_persistent_reduction = len( + [e for e in heuristics_list if e == "persistent_reduction"] + ) + assert num_reduction == 0, ( + "combining pointwise and reduction are not supported yet." + ) + heuristics = ( + "pointwise_with_reduction" + if num_persistent_reduction > 0 + else "pointwise" + ) + if len(heuristics_list) - num_pointwise >= 4: + size_hints = size_hints_list[i] + size_hints["x"] = min(128, size_hints["x"]) + return heuristics, size_hints_list[i], self.sub_kernels[i] + else: + return heuristics_list[0], size_hints_list[0], self.sub_kernels[0] + + def get_mutated_args_sub_kernels(self) -> list[str]: + mutated_args: OrderedSet[str] = OrderedSet() + for sub_kernel in self.sub_kernels: + for mutation in sub_kernel.mutations: + if mutation in sub_kernel.args.input_buffers: + mutated_args.add(sub_kernel.args.input_buffers[mutation]) + if ( + mutation in sub_kernel.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in sub_kernel.removed_buffers + ): + mutated_args.add( + cast( + InplacedBuffer, sub_kernel.args.inplace_buffers[mutation] + ).inner_name + ) + if mutation in sub_kernel.args.output_buffers: + arg = sub_kernel.args.output_buffers[mutation] + assert not isinstance(arg, RemovedArg) + mutated_args.add(arg) + return sorted(mutated_args) + + def select_dispatch_strategy(self) -> None: + if self.dispatch_class is not None: + return + # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch + # Not mixed sizes on y dim technically is ok to use round robin as wells. + if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list): + # str in x_numels_list means a dynamic shape + self.dispatch_class = ComboKernel.SequentialDispatch + return + # A negative x_blocks_list element means the kernel is not tunable, + # i.e., no_x_dim = True + x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list] + total = max(x_numels_list) * len(x_numels_list) + needed = sum(x_numels_list) + if needed / total > BLOCK_UTILIZATION: + # Introduced overhead (masked blocks) is less than 20% + self.dispatch_class = ComboKernel.RoundRobinDispatch + else: + self.dispatch_class = ComboKernel.SequentialDispatch + + def jit_line( + self, + heuristics: str, + size_hints: dict[str, int], + selected_kernel: TritonKernel, + signature: list[Any], + argdefs: list[ArgName], + pointwise_with_reduce: bool = False, + ) -> str: + can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) + size_dtype = "tl.int32" if can_use_32bit else "tl.int64" + for i, sub in enumerate(self.sub_kernels): + self.min_x_blocks_sub_kernel(sub, i) + self.select_dispatch_strategy() + triton_meta = { + "signature": signature_to_meta( + signature, size_dtype=size_dtype, argdefs=argdefs + ), + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + mutated_args = self.get_mutated_args_sub_kernels() + dispatch = self.dispatch_class + assert dispatch is not None + inductor_meta = { + "grid_type": dispatch.grid_expr.__name__, + "combo_grid_meta": self.combo_grid_meta(), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + **TritonKernel.inductor_meta_common(), + } + + sub_kernel = selected_kernel + if heuristics == "foreach": + heuristics_line = f""" + @triton_heuristics.foreach( + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + elif sub_kernel.inside_reduction: + reduction_hint = sub_kernel.features.get_reduction_hint() + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + + return heuristics_line + + def codegen_blocks(self, code: IndentedBuffer) -> None: + for block in self.block_args: + assert block in ( + "XBLOCK", + "YBLOCK", + "R0_BLOCK", + ), f"{block} is not supported without autotuning" + if "YBLOCK" in self.block_args: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}") + code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}") + else: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}") + if "R0_BLOCK" in self.block_args: + code.splice(f"R0_BLOCK: tl.constexpr = {self.block_size_reduce}") + code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}") + + def get_block_args(self) -> list[ConstexprArg]: + """ + Calculate blocks from sub_kernels and range_trees. + **Update self.block_args** + Return the block args + """ + block_names = {} + for sub_kernel in self.sub_kernels: + # TODO: we assume all sub_kernels have the same block size + for tree in sub_kernel.range_trees: + if tree.is_reduction and ( + not sub_kernel.inside_reduction or sub_kernel.persistent_reduction + ): + continue + if tree.prefix == "x" and sub_kernel.no_x_dim: + continue + block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix + self.block_args = list(block_names.keys()) + + return [ConstexprArg(x) for x in block_names.keys()] + + def add_numel_to_args( + self, argdefs: list[ArgName], signature: list[Any] + ) -> list[ArgName]: + for num, sub_kernel in enumerate(self.sub_kernels): + for tree in sub_kernel.active_range_trees(): + if not isinstance(tree.numel, (Integer, int)): + # only if it is a dynamic shape + sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel) + signature.append(sizearg) + argdefs.append(ArgName(f"{tree.prefix}numel_{num}")) + self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}") + return argdefs + + def add_numel_to_call_args( + self, name: str, call_args: list[Any], arg_types: list[Any] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + if isinstance(tree.numel, (Integer, Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr( + name, tree, suffix=str(num) + ) + if not tree.is_reduction or sub_kernel.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def kernel_benchmark_extra_args(self) -> list[str]: + extra_args = [] + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + if not tree.is_reduction or sub_kernel.inside_reduction: + extra_args.append(str(V.graph.sizevars.size_hint(tree.numel))) + return extra_args + + def codegen_kernel(self, name: Optional[str] = None) -> str: + # TODO: is it correct to use the first sub kernel's heuristics? + heuristics_list, size_hints_list = [], [] + for subkernel in self.sub_kernels: + h, s = self.select_heuristics(subkernel) + heuristics_list.append(h) + size_hints_list.append(s) + heuristics, size_hints, selected_kernel = self.select_combo_heuristics( + heuristics_list, size_hints_list + ) + pointwise_with_reduction, heuristics = ( + (True, "pointwise") + if heuristics == "pointwise_with_reduction" + else (False, heuristics) + ) + code = IndentedBuffer() + + code.splice(gen_common_triton_imports()) + if config.benchmark_combo_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + argdefs = self.add_numel_to_args(argdefs, signature) + block_args = self.get_block_args() + if self.enable_autotune: + argdefs.extend([ArgName(x.name, is_constexpr=True) for x in block_args]) + if triton_version_uses_attrs_dict(): + signature.extend(block_args) + + code.splice( + self.jit_line( + heuristics, + size_hints, + selected_kernel, + pointwise_with_reduce=pointwise_with_reduction, + signature=signature, + argdefs=argdefs, + ) + ) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" + ) + + with code.indent(): + code.splice("pid = tl.program_id(0)") + if not self.enable_autotune: + self.codegen_blocks(code) + + for num, sub_kernel in enumerate(self.sub_kernels): + assert self.dispatch_class is not None + self.dispatch_class.codegen_pid_range(self, num, code) + with code.indent(): + uniquify = self.codegen_static_numels_sub_kernel( + code, sub_kernel, num + ) + sub_kernel.codegen_body() + uniquified_body = self.uniquify_block_sizes( + sub_kernel.body, num, uniquify + ) + code.splice(uniquified_body) + + code.splice("else:") + with code.indent(): + code.splice("pass") + + if config.benchmark_combo_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb=0)) + + return code.getvalue() + + def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer: + result = IndentedBuffer() + _argdefs, call_args, signature, _ = self.args.python_argdefs() + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.try_get_buffer(arg_name) + if buf: + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.get_current_device_or_throw() + count = V.graph.sizevars.size_hint(arg_sig.count) + # for benchmark harness, we ignore arg_sig.zero_mode and always zero it + result.writeline( + f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})" + ) + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + if self.dynamic_shape_args: + var_names.extend(self.kernel_benchmark_extra_args()) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + index = V.graph.get_current_device_or_throw().index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline( + "from torch._inductor.runtime.benchmarking import benchmarker" + ) + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self) -> str: + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")) + ) + + def uniquify_block_sizes( + self, code: IndentedBuffer, num_kernel: int, uniquify: list[str] + ) -> IndentedBuffer: + if not uniquify: + return code + modified = IndentedBuffer(initial_indent=code._indent) + for line in code._lines: + if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]): + modified_line = line + for block in blocks: + modified_line = modified_line.replace( + block, f"{block}_{num_kernel}" + ) + modified.writeline(modified_line) + elif isinstance(line, DeferredLine) and ( + blocks := [e for e in uniquify if e in line.line] + ): + modified_line = line.line + for block in blocks: + modified_line = modified_line.replace( + block, f"{block}_{num_kernel}" + ) + new_line = DeferredLine(line.name, modified_line) + modified.writeline(new_line) + else: + modified.writeline(line) + return modified + + def call_kernel(self, code: IndentedBuffer, name: str) -> None: + _, call_args, _, arg_types = self.args.python_argdefs() + + wrapper = V.graph.wrapper_code + assert self.dispatch_class is not None + if self.dynamic_shape_args: + self.add_numel_to_call_args(name, call_args, arg_types) + + wrapper.generate_kernel_call( + name, + call_args, + triton=True, + arg_types=arg_types, + ) + + def combo_grid_meta(self) -> dict[str, Any]: + dynamic_shape = bool(self.dynamic_shape_args) + num_kernels = len(self.sub_kernels) + min_blocks = ( + max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None + ) + + if not self.enable_autotune: + if "YBLOCK" in self.block_args: + default_config = { + "XBLOCK": self.block_size_2d, + "YBLOCK": self.block_size_2d, + } + else: + default_config = {"XBLOCK": self.block_size_1d} + else: + default_config = None + + meta = { + "num_kernels": num_kernels, + "min_blocks": min_blocks, + "default_config": default_config, + } + + for num, sub_kernel in enumerate(self.sub_kernels): + meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim + for i, tree in enumerate(sub_kernel.range_trees): + if not tree.is_reduction: + numel_name = f"{tree.prefix}numel_{num}" + if numel_name in self.dynamic_shape_args: + meta[numel_name] = None + else: + meta[numel_name] = int(V.graph.sizevars.simplify(tree.numel)) + + return meta diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..23ee1e38d18b23f0834ce3c7cd89663fd3e531ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_split_scan.py @@ -0,0 +1,207 @@ +# mypy: allow-untyped-defs +import functools +from typing import Union + +import sympy + +from torch._inductor import config +from torch._inductor.codegen.simd import IterationRangesRoot, prefix_is_reduction +from torch._inductor.codegen.triton import ( + triton_compute_type, + TritonCSEVariable, + TritonKernel, +) +from torch._inductor.runtime.triton_heuristics import SplitScanGrid +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv + +from ..utils import sympy_product + + +class TritonSplitScanKernel(TritonKernel): + """Generates a triton kernel that supports ops.scan calls while also splitting + the reduction dimension over multiple triton programs. + + For this kernel, loop numels will always take the form ``(xdim, rdim)`` + and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication + between blocks occurs within a global memory workspace buffer, which + must be zero-filled before launching the kernel. + + Note that generation for ``ops.reduction`` is not supported. + + For details of the communication strategy, see + https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + """ + + def __init__( + self, + tiling: dict[str, sympy.Expr], + pid_cache=None, + fixed_config=None, + **kwargs, + ) -> None: + assert pid_cache is None, "not supported" + assert fixed_config is None, "not supported" + super().__init__( + tiling, + **kwargs, + ) + self.no_x_dim = True + + def should_use_persistent_reduction(self) -> bool: + return False + + def should_use_cooperative_reduction(self) -> bool: + return False + + def initialize_range_tree(self, pid_cache): + prefixes = ["y", "x", "r0_"] + assert len(self.numels) <= len(prefixes), ( + "z dimension not supported for split scan" + ) + active_prefixes = prefixes[len(prefixes) - len(self.numels) :] + + grid_dims = {"r0_": 0, "x": 1, "y": 2} + for prefix in active_prefixes: + numel = self.numels[prefix] + tensor_dim = 0 if prefix_is_reduction(prefix) else None + grid_dim = grid_dims[prefix] + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numel, + prefix, + grid_dim, + self, # type: ignore[arg-type] + pid_cache=pid_cache, + is_loop=False, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim=False, + ) + ) + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise NotImplementedError("NYI TritonSplitDimKernel reductions") + + def scan(self, dtypes, combine_fn, values): + import triton.language as tl + + (dtype,) = dtypes + (value,) = values + + compute_type = triton_compute_type(dtype) + compute_type_triton = getattr(tl, compute_type[3:]) + + element_nbits = compute_type_triton.primitive_bitwidth + + scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64" + scratch_type_triton = getattr(tl, scratch_type[3:]) + scratch_elems_per_block = 3 if element_nbits == 64 else 1 + scratch_nbytes_per_block = scratch_elems_per_block * ( + scratch_type_triton.primitive_bitwidth // 8 + ) + + cse_load = functools.partial(self.cse.generate, self.loads, dtype=dtype) + cse_compute = functools.partial(self.cse.generate, self.compute) + + assert len(self.numels) == 2, "Unexpected tiling" + min_rblock = config.triton.min_split_scan_rblock + reduction_numel = sympy_product( + numel + for prefix, numel in self.numels.items() + if prefix_is_reduction(prefix) + ) + pointwise_numel = sympy_product( + numel + for prefix, numel in self.numels.items() + if not prefix_is_reduction(prefix) + ) + max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock) + nbytes = scratch_nbytes_per_block * max_blocks + scratch_base: Union[str, TritonCSEVariable] + scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True) + if offset != 0: + scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}") + runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})") + scratch_base = cse_load( + f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " + f"{scratch_elems_per_block} * {runtime_rblocks}" + ) + + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + assert not self._load_mask, "ops.scan not supported inside ops.masked" + + value = cse_compute( + f"{value}.to({compute_type})", + dtype=dtype, + ) + value = cse_compute( + f"tl.broadcast_to({value}, {self.dense_size_str()})", + dtype=dtype, + ) + + combine_helper_fn = self._lift_helper(combine_fn, 1, (dtype,)) + dim = self.triton_tensor_ndim() - 1 + assert dim == 0, "" + + block_sum = cse_compute( + f"tl.reduce({value}, {dim}, {combine_helper_fn})", + dtype=dtype, + ) + exclusive_prefix = self.cse.newvar( + dtype=dtype, + ) + if element_nbits == 64: + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64( + {scratch_base}, + {block_sum}, + {self.iteration_ranges_get_pid(self.range_trees[-1])}, + {combine_helper_fn}, + ) + """, + strip=True, + ) + + else: + assert element_nbits <= 32 + value_as_uint_dtype = f"tl.uint{element_nbits}" + + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback( + {scratch_base}, + {block_sum}, + {self.iteration_ranges_get_pid(self.range_trees[-1])}, + {combine_helper_fn}, + DTYPE_VALUE_AS_UINT={value_as_uint_dtype}, + DTYPE_PACK={scratch_type}, + ) + """, + strip=True, + ) + # Compute final cumsum + block_scan = cse_compute( + f"tl.associative_scan({value}, {dim}, {combine_helper_fn})", + dtype=dtype, + ) + combined_result = cse_compute( + f"{combine_helper_fn}({exclusive_prefix}, {block_scan})", + dtype=dtype, + ) + return ( + cse_compute( + f"tl.where(roffset == 0, {block_scan}, {combined_result})", + dtype=dtype, + ), + ) + + def _get_heuristic(self): + return "split_scan" + + def _get_grid_type(self) -> type[SplitScanGrid]: + return SplitScanGrid diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9019fa6b62ffcf55951022260bd3d73cd22acb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/triton_utils.py @@ -0,0 +1,257 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import sympy + +import torch + +from .. import config +from ..runtime.hints import AttrsDescriptorWrapper +from ..utils import _type_of, expr_fits_within_32bit, triton_version_uses_attrs_dict +from ..virtualized import V +from .common import ( + ArgName, + ConstexprArg, + KernelArgType, + SizeArg, + TensorArg, + TMADescriptorArg, + WorkspaceArg, +) + + +def should_unwrap_unspec_arg(name: str): + if V.graph.is_unspec_arg(name): + # Unwrap on all devices except CPU + if V.graph.get_current_device_or_throw().type != "cpu": + return True + # Only unwrap on CPU if the input is not used as an output + if name not in V.graph.mutated_buffers: + return True + return False + + +def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: + if isinstance(arg, TensorArg): + # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. + # Related PR: https://github.com/triton-lang/triton/pull/2279/ + if arg.dtype == torch.float8_e4m3fn: + typ = "*fp8e4nv" + elif arg.dtype == torch.float8_e5m2: + typ = "*fp8e5" + elif arg.dtype == torch.float8_e4m3fnuz: + typ = "*fp8e4b8" + elif arg.dtype == torch.float8_e5m2fnuz: + typ = "*fp8e5b16" + else: + typ = _type_of(arg.dtype) + if should_unwrap_unspec_arg(arg.buffer): + # had unwrapped 0d tensor as scalar + new_typ = typ.lstrip("*") + if new_typ in ["fp16", "bf16"]: + return "fp32" + else: + return new_typ + else: + return typ + if isinstance(arg, SizeArg): + if arg.expr is None: + if triton_version_uses_attrs_dict(): + # In newer versions of Triton, the signature includes "None" args + # and their type is marked as "constexpr" + return "constexpr" + else: + # In older versions of Triton... + # From triton/runtime/jit.py + # `None` is nullptr. Implicitly convert to *i8. + return "*i8" + elif _arg_equals_1(arg) and triton_version_uses_attrs_dict(): + # In new versions of Triton, if we have an equal-to-1 arg that's marked as a constant, + # it should be marked as "constexpr" in the signature. + return "constexpr" + elif isinstance(arg.expr, (float, sympy.Float)): + return "fp32" + + # if this is a integer + if size_dtype == "tl.int32": + return "i32" + elif size_dtype == "tl.int64": + return "i64" + elif size_dtype is None: + # no hint: we'll see if we know that this is a 32-bit int, and guard if possible. + int_max = torch.iinfo(torch.int32).max + if expr_fits_within_32bit(arg.expr): + V.graph.sizevars.guard_leq(arg.expr, int_max) + return "i32" + else: + return "i64" + else: + raise NotImplementedError(f"unhandled size_dtype {size_dtype}") + if isinstance(arg, WorkspaceArg): + return _type_of(arg.dtype) + if isinstance(arg, TMADescriptorArg): + if arg.api_type == "experimental": + return "nvTmaDesc" + else: + # https://github.com/triton-lang/triton/blob/9695baed9b46cf957e08b157bb4133f4a4b331c5/python/triton/runtime/jit.py#L360-L363 + assert arg.api_type == "stable" + assert arg.block_shape is not None + assert arg.dtype is not None + inner = _type_of(arg.dtype)[1:] # strip the `*`: *fp32 -> fp32 + return f"tensordesc<{inner}{list(arg.block_shape)}>" + if isinstance(arg, ConstexprArg): + return "constexpr" + raise NotImplementedError(f"unhandled {type(arg)}: {arg}") + + +def non_constexpr_signature(signature): + new_signature = [] + for arg in signature: + if not isinstance(arg, ConstexprArg): + new_signature.append(arg) + + return new_signature + + +def signature_to_meta( + signature: list[KernelArgType], + *, + size_dtype: Optional[str], + argdefs: list[ArgName], + indices: Optional[list[int]] = None, + is_template: bool = False, +) -> dict[str, str]: + if indices is None: + indices = list(range(len(signature))) + + def _decide_tl_dtype(arg): + # Even if the ks0 symbol itself is within tl.int32 range, it's + # risky to use tl.int32 dtype since we may have ks0*ks1 later + # for kernels like torch.mean when dynamic shape is enabled. + # + # Check config.triton.use_block_ptr, since Triton block pointer + # does not support 64bit indexing: + # https://gist.github.com/shunting314/6a41c776171720ce4561f202dcde0ad6 + # + # If the triton metadata is for a template, don't use tl.int64 index. + # Templates like flex attention/decoding uses block pointers which + # does not support 64 bit indexing. + if ( + not config.triton.use_block_ptr + and not is_template + and isinstance(arg, SizeArg) + and arg.name.startswith("ks") + ): + return "tl.int64" + return size_dtype + + return { + argdefs[i].name: signature_of(arg, size_dtype=_decide_tl_dtype(arg)) + for i, arg in zip(indices, signature) + } + + +def is_unaligned_buffer(arg: TensorArg): + buf_name = arg.buffer + if buf_name in V.graph.unaligned_buffers: + return True + + if buf_name in V.graph.graph_inputs: + # See Note: [Input Alignment handling in Inductor] + # For graph inputs that is not recorded in V.graph.unaligned_buffers, + # we know for sure the tensor is aligned. + return False + + if buf_name in V.graph.constants: + # all constants are assumed to be aligned + return False + + if V.graph.scheduler: + layout = V.graph.scheduler.get_buffer_layout(buf_name) + else: + buffer = V.graph.try_get_buffer(buf_name) + # output arg + if not buffer: + assert buf_name == V.kernel.output_node.name + layout = V.kernel.output_node.layout + else: + layout = buffer.get_layout() + + if isinstance(layout, torch._inductor.ir.NonOwningLayout): + return not layout.maybe_guard_aligned() + else: + return False + + +def _arg_equals_1(arg: KernelArgType) -> bool: + return ( + isinstance(arg, SizeArg) + and isinstance(arg.expr, (int, sympy.Integer)) + and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] + ) + + +def equal_1_arg_indices( + args: list[KernelArgType], + *, + indices: Optional[list[int]] = None, +) -> tuple[int, ...]: + if indices is None: + indices = list(range(len(args))) + + equal_to_1 = tuple(i for i, arg in zip(indices, args) if _arg_equals_1(arg)) + + return equal_to_1 + + +def config_of( + args: list[KernelArgType], + *, + indices: Optional[list[int]] = None, +) -> Any: + if indices is None: + indices = list(range(len(args))) + + def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: + """ + Roughly follow triton code here: + https://github.com/triton-lang/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222 + """ + if isinstance(x, TensorArg): + if include_tensor: + offset_aligned = V.graph.sizevars.statically_known_multiple_of( + x.offset * x.dtype.itemsize, + alignment, # type: ignore[arg-type] + ) + return offset_aligned and not is_unaligned_buffer(x) + else: + return False + if isinstance(x, SizeArg): + # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with + # _maybe_evaluate_static... + if x.name.startswith("load_seed_offset"): + return False + if x.expr is None: + return False + if isinstance(x.expr, float): + return False + return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type] + if isinstance(x, WorkspaceArg): + # We allocate the workspace ourselves, so it is always aligned + return True + if isinstance(x, (TMADescriptorArg, ConstexprArg)): + return False + raise NotImplementedError(f"unhandled {type(x)}: {x}") + + if config.triton.divisible_by_16: + divisible_by_16 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=16, include_tensor=True) + ) + else: + divisible_by_16 = () + + equal_to_1 = equal_1_arg_indices(args, indices=indices) + + return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8ace6845f46fef558e88f6cdba2b93b1b86a0faa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper.py @@ -0,0 +1,3413 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import dis +import functools +import inspect +import logging +import operator +import random +import re +import tempfile +from itertools import chain, count +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import sympy +from sympy import Expr + +import torch +import torch._ops +import torch.utils._pytree as pytree +from torch import dtype as torch_dtype +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.codegen.debug_utils import DebugPrinterManager +from torch._inductor.codegen.multi_kernel import MultiKernelState +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.fx.experimental.symbolic_shapes import ( + CallMethodKey, + ConvertIntKey, + DivideByKey, + resolve_unbacked_bindings, + SymTypes, +) +from torch.fx.node import _get_qualified_name +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .. import async_compile, config, ir +from ..codecache import output_code_log +from ..ir import IRNode, ReinterpretView +from ..runtime import triton_heuristics +from ..runtime.hints import DeviceProperties +from ..utils import ( + cache_on_self, + DelayReplaceLine, + get_benchmark_name, + IndentedBuffer, + is_codegen_graph_partition_subgraph, + LineContext, + set_kernel_post_grad_provenance_tracing, + sympy_product, + sympy_str, + sympy_subs, + triton_version_uses_attrs_dict, +) +from ..virtualized import V +from .common import ( + ArgName, + CodeGen, + DeferredLine, + PythonPrinter, + WorkspaceArg, + WorkspaceZeroMode, +) +from .cpp_utils import cexpr +from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta + + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + import triton + + from ..graph import GraphLowering + from .wrapper_fxir import FxConverter + + +log = logging.getLogger(__name__) + +pexpr = PythonPrinter().doprint + + +ReuseKey = tuple[torch.device, torch.dtype, str, bool] +BufferLike = Union[ir.Buffer, WorkspaceArg] +FxConversionFunc = Callable[["WrapperLine"], None] + + +def buffer_reuse_key(node: BufferLike) -> ReuseKey: + storage_size = V.graph.get_allocation_storage_size(node) + alignment = node.get_name() not in V.graph.unaligned_buffers + return ( + node.get_device_or_error(), + node.get_dtype(), + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(V.graph.sizevars.simplify(storage_size)), + alignment, + ) + + +def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike): + # Return True if input_buf can be re-inplaced for output_buf. + # This differs from `buffer_reuse_key` for general buffer reuse. + if input_buf.get_device_or_error() != output_buf.get_device_or_error(): + return False + + if input_buf.get_dtype() != output_buf.get_dtype(): + return False + + input_size = V.graph.sizevars.simplify( + V.graph.get_allocation_storage_size(input_buf) + ) + output_size = V.graph.sizevars.simplify( + V.graph.get_allocation_storage_size(output_buf) + ) + + if ( + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(input_size) == sympy_str(output_size) + ) or ( + # statically known that 0.95 * input_size <= output_size <= input_size + V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size) + and V.graph.sizevars.statically_known_leq(output_size, input_size) + ): + return True + + return False + + +# TODO: Move to a well known place +TritonMetaParams = dict[str, int] +TritonGrid = Union[ + tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], tuple[int, ...]] +] + + +def user_defined_kernel_grid_fn_code( + name: str, + configs: list[triton.Config], # type: ignore[name-defined] + grids: list[TritonGrid], + wrapper: Optional[PythonWrapperCodegen] = None, + original_fxnode_name: Optional[str] = None, +) -> tuple[str, str]: + output = IndentedBuffer() + + def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr: + return item if isinstance(item, sympy.Expr) else sympy.Integer(item) + + def determine_grid( + grid: TritonGrid, + example_grid: Optional[TritonGrid] = None, + ): + """ + This function return a tuple of two values: the first one is for the real grid + which is used in the generated code; the second one is an example grid with + concreate values which is used in the autotune block to run the generated + kernels at compile time. + """ + if wrapper is None or callable(grid): + # return as-is when used in eager mode or when grid is callable + return grid, grid + # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen + sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid) + if not example_grid: + example_grid = sympy_grid + return ( + wrapper.codegen_python_shape_tuple(sympy_grid), + ( + wrapper.codegen_python_shape_tuple( + tuple( + wrapper.generate_example_arg_value(g, type(g)) + for g in example_grid # type: ignore[union-attr] + ) + ) + if config.triton.autotune_at_compile_time + else None + ), + ) + + def writeline(line: str, example_grid: Optional[str] = None): + output.writeline(line) + if ( + wrapper + and config.triton.autotune_at_compile_time + and name not in wrapper.kernel_autotune_names + ): + wrapper.kernel_autotune_calls.writeline(example_grid or line) + + fn_name = f"grid_wrapper_for_{name}" + writeline(f"def {fn_name}(meta):") + kernel_autotune_calls_indent = ( + wrapper.kernel_autotune_calls.indent() + if wrapper and config.triton.autotune_at_compile_time + else contextlib.nullcontext() + ) + with output.indent(), kernel_autotune_calls_indent: + if ( + config.triton.autotune_at_compile_time + and original_fxnode_name + and V.graph.autotuning_grids + and original_fxnode_name in V.graph.autotuning_grids + ): + example_grids = V.graph.autotuning_grids[original_fxnode_name] + else: + example_grids = [None] * len(grids) + if len(grids) == 1: + grid, example_grid = determine_grid(grids[0], example_grids[0]) + writeline(f"return {grid}", f"return {example_grid}") + else: + assert len(grids) > 1 + assert len(grids) == len(configs) + seen: OrderedSet[str] = OrderedSet() + # sort the configs from the largest # of kwargs to the smallest to + # emit the grids in the order of (approximately) decreasing specificity + # TODO(aakhundov): the sorting below is generally not sufficient, so + # maybe we'll need to restrict the supported cases to identical kwarg + # names in all autotuning configs. + for grid, c, example_grid in sorted( + zip(grids, configs, example_grids), + key=lambda x: len(x[1].kwargs), + reverse=True, + ): + if c.kwargs: + guards = [ + f"meta['{name}'] == {val}" for name, val in c.kwargs.items() + ] + guards = " and ".join(guards) + else: + guards = "True" # for configs with empty kwargs + grid, example_grid = determine_grid(grid, example_grid) + statement = f"if {guards}: return {grid}" + if statement in seen: + continue + seen.add(statement) + writeline(statement, f"if {guards}: return {example_grid}") + + return fn_name, output.getvalue() + + +def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str: + """ + Given a triton kernel function pointer collect the transitive closure of + its dependencies + """ + compile_wrapper = IndentedBuffer() + compile_wrapper.splice(kernel.src, strip=True) + + # Also include any possible kernel being called indirectly + from triton import JITFunction # type: ignore[name-defined, attr-defined] + from triton.language import constexpr # type: ignore[name-defined] + + # global constexpr vars handled above + symbols_included = OrderedSet([kernel.__name__]) + + def traverse(cur_kernel): + # here we extract the unqualified names (i.e., not attributes and + # without prepended module name) loaded in the kernel code, which + # are matched with the co_names and __globals__ below to codegen + # the respective imports necessary for the kernel compilation + unqualified_loads = OrderedSet( + inst.argval + for inst in dis.Bytecode(cur_kernel.fn) + if inst.opname == "LOAD_GLOBAL" + ) + global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {}) + for symbol_name in cur_kernel.fn.__code__.co_names: + if symbol_name in symbols_included: + continue + if symbol_name in cur_kernel.fn.__globals__: + symbol = cur_kernel.fn.__globals__[symbol_name] + if isinstance(symbol, JITFunction): + compile_wrapper.newline() + compile_wrapper.writeline("@triton.jit") + compile_wrapper.splice(symbol.src, strip=True) + symbols_included.add(symbol_name) + traverse(symbol) + elif isinstance(symbol, (int, str, bool, constexpr)): + compile_wrapper.newline() + if isinstance(symbol, constexpr): + symbol_str = f"tl.constexpr({symbol.value!r})" + else: + symbol_str = f"{symbol!r}" + if annotation := global_annotations.get(symbol_name): + if isinstance(annotation, type): + annotation_code = ( + f": {annotation.__module__}.{annotation.__name__}" + ) + else: + annotation_code = f": {annotation!r}" + compile_wrapper.writeline( + f"{symbol_name}{annotation_code} = {symbol_str}" + ) + else: + compile_wrapper.writeline(f"{symbol_name} = {symbol_str}") + symbols_included.add(symbol_name) + elif ( + symbol_name in unqualified_loads + and symbol_name != "tl" # already imported + and hasattr(symbol, "__module__") + # only codegen imports from triton; JITFunctions + # imported from other modules will be codegened + # in the separate branch above + and symbol.__module__.startswith("triton") + ): + # a global symbol imported from triton is referenced + # without module qualification (i.e., `store` instead + # of `tl.store`): need to codegen an import + compile_wrapper.writeline( + f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}" + ) + symbols_included.add(symbol_name) + + traverse(kernel) + return compile_wrapper.getvalue() + + +@dataclasses.dataclass +class SymbolicCallArg: + inner: str + # the original symbolic expression represented by inner + inner_expr: sympy.Expr + + def __str__(self): + return str(self.inner) + + +class MemoryPlanningState: + def __init__(self): + super().__init__() + self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = ( + collections.defaultdict(list) + ) + self.total_allocated_buffer_size: int = 0 + + def __contains__(self, key: ReuseKey) -> bool: + return bool(self.reuse_pool.get(key, None)) + + def pop(self, key: ReuseKey) -> FreeIfNotReusedLine: + item = self.reuse_pool[key].pop() + assert not item.is_reused + return item + + def push(self, key: ReuseKey, item: FreeIfNotReusedLine) -> None: + assert not item.is_reused + self.reuse_pool[key].append(item) + + +class WrapperLine: + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + raise NotImplementedError("FX codegen not yet supported for type {type(self)}") + + +@dataclasses.dataclass +class EnterSubgraphLine(WrapperLine): + wrapper: PythonWrapperCodegen + graph: GraphLowering + + def __post_init__(self) -> None: + self.wrapper.push_computed_sizes(self.wrapper.computed_sizes) + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.push_codegened_graph(self.graph) + code.do_indent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_enter_subgraph + + +@dataclasses.dataclass +class CommentLine(WrapperLine): + line: LineContext + + def codegen(self, code: IndentedBuffer) -> None: + code.writeline(self.line) + + @staticmethod + def codegen_fx(converter: FxConverter) -> FxConversionFunc: + return converter._generate_comment + + +@dataclasses.dataclass +class ExitSubgraphLine(WrapperLine): + wrapper: PythonWrapperCodegen + + def __post_init__(self) -> None: + self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes() + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.pop_codegened_graph() + code.do_unindent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_exit_subgraph + + +@dataclasses.dataclass +class EnterDeviceContextManagerLine(WrapperLine): + device_idx: int + last_seen_device_guard_index: Optional[int] + + def codegen(self, code: IndentedBuffer) -> None: + if V.graph.cpp_wrapper: + code.writeline("\n") + if V.graph.aot_mode: + # In AOT mode, we have a stream provided as a param. A stream is + # associated with a device, so we never expect the device to change. + # CUDAStreamGuard sets the stream and the device. + if self.last_seen_device_guard_index is None: + code.writeline( + f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" + ) + else: + assert self.last_seen_device_guard_index == self.device_idx, ( + "AOTInductor only supports running on one CUDA device" + ) + else: + if self.last_seen_device_guard_index is None: + code.writeline( + f"{V.graph.device_ops.cpp_aoti_device_guard()} device_guard({self.device_idx});" + ) + else: + code.writeline(f"device_guard.set_index({self.device_idx});") + else: + # Note _DeviceGuard has less overhead than device, but only accepts + # integers + code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:") + code.do_indent() + code.writeline(V.graph.device_ops.set_device(self.device_idx)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_enter_device_context_manager + + +class ExitDeviceContextManagerLine(WrapperLine): + def codegen(self, code: IndentedBuffer) -> None: + if not V.graph.cpp_wrapper: + code.do_unindent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_exit_device_context_manager + + +@dataclasses.dataclass +class ExternKernelAllocLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.ExternKernelAlloc + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + args = [*node.codegen_args(), *node.codegen_kwargs()] + self.wrapper._generate_extern_kernel_alloc_helper(self.node, args) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_extern_kernel_alloc + + +@dataclasses.dataclass +class ExternKernelOutLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.ExternKernelOut + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + args = [*node.codegen_args(), *node.codegen_kwargs(skip_out=True)] + kernel_name = node.get_kernel_name() + if ( + V.graph.cpp_wrapper + and node.cpp_kernel_name == "torch::inductor::_mm_plus_mm" + ): + # For https://github.com/pytorch/pytorch/issues/128474 + kernel_name = "aoti_torch__mm_plus_mm_out" + else: + kernel_name = node.get_kernel_name() + device = d.type if (d := node.get_device()) else V.graph.device_type + # set provenance tracing kernel mapping for ExternKernel types + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True) + self.wrapper._generate_extern_kernel_out_helper( + kernel_name, + node.codegen_reference(), + node.output_view.codegen_reference() if node.output_view else None, + args, + device, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_extern_kernel_out + + +@dataclasses.dataclass +class FreeLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: Union[BufferLike, ir.TorchBindObject] + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + code.writeline(self.wrapper.make_buffer_free(self.node)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_free + + +@dataclasses.dataclass +class KernelCallLine(WrapperLine): + wrapper: PythonWrapperCodegen + kernel_name: str + call_args: tuple[Any, ...] + raw_keys: tuple[Any, ...] + raw_args: tuple[Any, ...] + arg_types: list[str] + triton: bool + triton_meta: dict[str, Any] + device: torch.device + graph_name: str + original_fxnode_name: str + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._generate_kernel_call_helper( + self.kernel_name, + self.call_args, + triton=self.triton, + arg_types=self.arg_types, + raw_keys=self.raw_keys, + raw_args=self.raw_args, + triton_meta=self.triton_meta, + device=self.device, + graph_name=self.graph_name, + original_fxnode_name=self.original_fxnode_name, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_kernel_call + + +@dataclasses.dataclass +class KernelDefinitionLine(WrapperLine): + wrapper: PythonWrapperCodegen + kernel_name: str + kernel_body: str + metadata: Optional[str] = None + gpu: bool = True + cpp_definition: Optional[str] = None + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._define_kernel_helper( + self.kernel_name, + self.kernel_body, + metadata=self.metadata, + gpu=self.gpu, + cpp_definition=self.cpp_definition, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_kernel_definition + + +@dataclasses.dataclass +class MemoryPlanningLine(WrapperLine): + wrapper: PythonWrapperCodegen + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + """First pass to find reuse""" + return self + + def codegen(self, code: IndentedBuffer) -> None: + """Second pass to output code""" + + def __str__(self) -> str: + """ + Emits a string representation that fits on one line. + """ + args: list[str] = [] + for field in dataclasses.fields(self): + if field.name == "wrapper": + continue + val = getattr(self, field.name) + args.append( + f"{field.name}={val.get_name() if field.type is ir.Buffer else val}" + ) + return f"{type(self).__name__}({', '.join(args)})" + + +@dataclasses.dataclass +class AllocateLine(MemoryPlanningLine): + node: BufferLike + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + + # try to reuse a recently freed buffer + key = buffer_reuse_key(self.node) + if config.allow_buffer_reuse and key in state: + free_line = state.pop(key) + free_line.is_reused = True + return ReuseLine(self.wrapper, free_line.node, self.node) + + if self.node.get_device_or_error().type == "cpu": + static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) + if static_shape is not None: + state.total_allocated_buffer_size += int( + functools.reduce(operator.mul, static_shape, 1) + ) + + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + line = self.wrapper.make_buffer_allocation(self.node) + code.writeline(line) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_allocate + + +@dataclasses.dataclass +class FreeIfNotReusedLine(MemoryPlanningLine): + node: BufferLike + is_reused: bool = False + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if len(self.node.get_inputs_that_alias_output()) > 0: + return self + if isinstance(self.node.layout, ir.MultiOutputLayout): + return self + assert not self.is_reused + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + if config.allow_buffer_reuse: + state.push(buffer_reuse_key(self.node), self) + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + if not self.is_reused: + code.writeline(self.wrapper.make_buffer_free(self.node)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_free_if_not_reused + + +@dataclasses.dataclass +class ReinterpretLine(MemoryPlanningLine): + node: BufferLike + reused_as: BufferLike + layout: ir.Layout + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert isinstance(self.layout, ir.NonOwningLayout) + assert isinstance(self.layout.view, ir.ReinterpretView) + self.wrapper.codegen_deferred_allocation( + self.reused_as.get_name(), self.layout.view + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_reinterpret + + +@dataclasses.dataclass +class ReuseLine(MemoryPlanningLine): + node: BufferLike + reused_as: BufferLike + delete_old: bool = True + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + assert self.reused_as.get_name() in V.graph.removed_buffers + return NullLine(self.wrapper) + assert self.reused_as.get_name() not in V.graph.removed_buffers + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + assert self.reused_as.get_name() not in V.graph.removed_buffers + code.writeline( + self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old) + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_reuse + + +class NullLine(MemoryPlanningLine): + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_null + + +@dataclasses.dataclass +class CommBufferLine(WrapperLine): + wrapper: PythonWrapperCodegen # type: ignore[name-defined] # noqa: F821 + node: ir.Buffer + + @property + def size(self) -> int: + from torch._inductor.utils import is_symbolic + + numel = self.node.get_numel() + dtype = self.node.get_dtype() + if is_symbolic(numel): + raise AssertionError( + f"The size of a comm buffer can't be symbolic: {self.node}" + ) + return int(numel) * dtype.itemsize + + @property + def comm_buffer_type(self) -> ir.CommBufferType: + layout = self.node.get_output_spec() + assert isinstance(layout, ir.CommBufferLayout) + return layout.comm_buffer_type + + @property + def group_name(self) -> str: + layout = self.node.get_output_spec() + assert isinstance(layout, ir.CommBufferLayout) + return layout.group_name + + +@dataclasses.dataclass +class CommBufferAllocateLine(CommBufferLine): + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + name = self.node.get_name() + device = self.node.get_device() + dtype = self.node.get_dtype() + shape = tuple(self.node.get_size()) + stride = tuple(self.node.get_stride()) + code.writeline( + self.make_allocation_line( + self.comm_buffer_type, + self.group_name, + self.wrapper, + name, + device, + dtype, + shape, + stride, + ) + ) + + @staticmethod + def make_allocation_line( + comm_buffer_type, group_name, wrapper, name, device, dtype, shape, stride + ): + if comm_buffer_type == ir.CommBufferType.SYMM_MEM: + return ( + f"{name} = empty_strided_p2p(" + f"{wrapper.codegen_shape_tuple(shape)}, " + f"{wrapper.codegen_shape_tuple(stride)}, " + f"{dtype}, " + f'torch.device("cuda:{device.index}"), ' + f'group_name="{group_name}", ' + f"alloc_id={random.randint(0, 2**64 - 1)})" + ) + else: + raise NotImplementedError( + f"Unsupported comm buffer type: {comm_buffer_type}" + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_comm_buffer_allocate + + +@dataclasses.dataclass +class CommBufferFreeLine(CommBufferLine): + def codegen(self, code: IndentedBuffer) -> None: + line = self.wrapper.make_buffer_free(self.node) + code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free") + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_comm_buffer_free + + +@dataclasses.dataclass +class MultiOutputLine(WrapperLine): + """ + Given a MultiOutputLayout buffer, indexes actual buffer(s) from the result. + """ + + wrapper: PythonWrapperCodegen + result_name: str + arg_name: str + indices: Sequence[Any] + + def codegen(self, code: IndentedBuffer) -> None: + def codegen_list_tuple_access(basename, indices): # type: ignore[no-untyped-def] + if len(indices) > 0: + itype, i = indices[0] + if issubclass(itype, list): + return codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) + elif issubclass(itype, tuple): + # cpp wrapper code needs to use std::get<> to access a tuple + tuple_access = self.wrapper.codegen_tuple_access( + basename, self.result_name, str(i) + ) + return codegen_list_tuple_access(tuple_access, indices[1:]) + elif issubclass(itype, dict): + return codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) + else: + raise AssertionError("non supported index type: ", itype) + else: + return basename + + value = codegen_list_tuple_access(self.arg_name, self.indices) + code.writeline( + f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}" + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_multi_output + + +@dataclasses.dataclass +class SymbolicCallArgLine(WrapperLine): + wrapper: PythonWrapperCodegen + arg: SymbolicCallArg + graph: GraphLowering + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_symbolic_call_arg + + +BufferName = str +Line = Union[MemoryPlanningLine, LineContext] + + +class PythonWrapperCodegen(CodeGen): + """ + Generate outer wrapper in Python that calls the kernels. + """ + + supports_caching = True # Whether the output code is cacheable. + + def __init__(self): + super().__init__() + self._names_iter: Iterator[int] = count() + self.args_to_buffers: dict[ + str, Union[None, ir.TensorBox, ir.Buffer, ir.TorchBindObject] + ] = {} + self.imports = IndentedBuffer() + self.header = IndentedBuffer() + self.prefix = IndentedBuffer() + self.suffix = IndentedBuffer() + self.kernel_declarations = IndentedBuffer() + self.wrapper_call = IndentedBuffer() + self.kernel_autotune_defs = IndentedBuffer() + self.kernel_autotune_calls = IndentedBuffer() + self.subgraph_definitions = IndentedBuffer() + self.kernel_autotune_names: OrderedSet[str] = OrderedSet() + # Map key is the kernel argument name; value is a tuple of the resulting example + # tensor name with the kernel where that tensor was most recently used. + self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {} + self.kernel_autotune_tmp_arg_idx: int = 0 + # If the generated source code is exactly the same, reuse the + # pre-existing kernel for it + self.src_to_kernel: dict[str, str] = {} + self.kernel_numel_expr: OrderedSet[tuple[str, GraphLowering]] = OrderedSet() + self.lines: list[Line] = [] + self.declare = "" + self.declare_maybe_reference = "" + self.ending = "" + self.comment = "#" + self.none_str = "None" + self.move_begin = "std::move(" if V.graph.cpp_wrapper else "" + self.move_end = ")" if V.graph.cpp_wrapper else "" + self.last_seen_device_guard_index: Optional[int] = None + self.supports_intermediate_hooks = True + self.user_defined_kernel_cache: dict[tuple[Any, ...], tuple[str, Any]] = {} + self.unbacked_symbol_decls: OrderedSet[str] = ( + OrderedSet() + ) # str of sympy.Symbol + self.computed_sizes: OrderedSet[sympy.Symbol] = OrderedSet() + self.launcher_fn_name = None + # This function can be overridden to change the launcher name + self.set_launcher_fn_name() + + # this is used for tracking which GraphLowering instance---parent graph + # or (nested) subgraph---is currently codegened; the primary use case is + # including the graph instance into a cache key to avoid cross-graph + # caching during lowering of nested subgraphs + self.codegened_graph_stack = [] + self.computed_sizes_stack = [] + + self.write_header() + + if not is_codegen_graph_partition_subgraph(self): + # See [Note: Removed Graph Partition Arguments] + self.write_prefix() + + self.write_kernel_autotune_defs_header() + + if not V.graph.aot_mode: + for name, hashed in V.graph.constant_reprs.items(): + # include a hash so our code cache puts different constants into different files + self.write_constant(name, hashed) + + self.allocated = OrderedSet[BufferName]() + self.freed = OrderedSet[BufferName]() + + # maps from reusing buffer to reused buffer + self.reuses: dict[BufferName, BufferName] = {} + + self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment] + self.write_get_raw_stream + ) + + @functools.cache + def add_import_once(line: str) -> None: + self.imports.writeline(line) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(line) + + self.add_import_once = add_import_once + self._metas: dict[str, str] = {} + self._meta_vars: OrderedSet[str] = OrderedSet() + self.multi_kernel_state = MultiKernelState() + self.already_codegened_subgraphs: OrderedSet[str] = OrderedSet() + self.allocated_workspaces: dict[str, Any] = {} + + # intermediate tensor value printing utility + self.debug_printer = DebugPrinterManager( + debug_printer_level=config.aot_inductor.debug_intermediate_value_printer, + use_array_ref=config.aot_inductor.allow_stack_allocation, + ) + + # Additional files that are dependent to the wrapper (ex. cubin files) + self.additional_files = [] + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + if is_subgraph: + assert subgraph_name is not None + assert parent_wrapper is not None + return SubgraphPythonWrapperCodegen( + subgraph_name, parent_wrapper, partition_signatures + ) + return PythonWrapperCodegen() + + def set_launcher_fn_name(self) -> None: + self.launcher_fn_name = "call" + + def write_constant(self, name: str, hashed: str) -> None: + self.header.writeline(f"{name} = None # {hashed}") + + def write_header(self) -> None: + context = torch._guards.TracingContext.try_get() + aot_config_comment = "" + if context is not None and context.aot_graph_name is not None: + aot_config_comment = f"# AOT ID: {context.aot_graph_name}" + aot_inductor_debug_utils = "" + if int(config.aot_inductor.debug_intermediate_value_printer) > 0: + aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + self.imports.splice( + f""" + {aot_config_comment} + from ctypes import c_void_p, c_long, c_int + import torch + import math + import random + import os + import tempfile + from math import inf, nan + from cmath import nanj + from torch._inductor.hooks import run_intermediate_hooks + from torch._inductor.utils import maybe_profile + from torch._inductor.codegen.memory_planning import _align as align + from torch import device, empty_strided + from {async_compile.__name__} import AsyncCompile + from torch._inductor.select_algorithm import extern_kernels + {aot_inductor_debug_utils} + """, + strip=True, + ) + self.header.splice( + """ + aten = torch.ops.aten + inductor_ops = torch.ops.inductor + _quantized = torch.ops._quantized + assert_size_stride = torch._C._dynamo.guards.assert_size_stride + assert_alignment = torch._C._dynamo.guards.assert_alignment + empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu + reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor + alloc_from_pool = torch.ops.inductor._alloc_from_pool + async_compile = AsyncCompile() + """, + strip=True, + ) + try: + # Only add empty_strided_p2p() if distributed and SymmetricMemory + # is available + from torch._C._distributed_c10d import _SymmetricMemory # noqa: F401 + + self.header.splice( + """ + empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + """, + strip=True, + ) + except (AttributeError, ImportError): + pass + if config.annotate_training: + self.header.writeline("from torch.cuda import nvtx") + + def include_extra_header(self, header: str): + pass + + def write_kernel_autotune_defs_header(self) -> None: + self.kernel_autotune_defs.splice( + f""" + import torch + from torch._dynamo.testing import rand_strided + from torch._dynamo.utils import preserve_rng_state + from torch._inductor.select_algorithm import AlgorithmSelectorCache + from {async_compile.__name__} import AsyncCompile + + async_compile = AsyncCompile() + generate_example_value = AlgorithmSelectorCache.generate_example_value + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu + """ + ) + + @cache_on_self + def write_triton_header_once(self) -> None: + import_str = f""" + import triton + import triton.language as tl + from {triton_heuristics.__name__} import start_graph, end_graph + """ + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.splice(import_str) + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + def write_get_raw_stream_header(self) -> None: + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + self.write_get_raw_stream_header() + + def add_meta_once(self, meta: TritonMetaParams) -> str: + meta = repr(meta) + if meta not in self._metas: + var = f"meta{len(self._metas)}" + self._metas[meta] = var + self.header.writeline(f"{var} = {meta}") + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(f"{var} = {meta}") + self._meta_vars.add(var) + return self._metas[meta] + + @cache_on_self + def get_output_refs(self) -> list[str]: + return [ + x.codegen_reference(self.wrapper_call) for x in self.get_graph_outputs() + ] + + def mark_output_type(self) -> None: + return + + def get_graph_inputs( + self, + ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]: + return V.graph.graph_inputs + + def get_graph_outputs(self) -> list[IRNode]: + return V.graph.graph_outputs + + def codegen_input_size_asserts(self) -> None: + for name, buf in self.get_graph_inputs().items(): + if isinstance(buf, (sympy.Expr, ir.TorchBindObject)): + continue + + # a graph partition may take an IRNode output from a previous partition + if name not in V.graph.graph_input_names or isinstance( + buf, ir.GeneratorState + ): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + size = self.codegen_python_shape_tuple(buf.get_size()) + stride = self.codegen_python_shape_tuple(buf.get_stride()) + self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})") + + def codegen_input_nan_asserts(self) -> None: + self.prefix.writeline("# make sure graph inputs are not nan/inf") + for name, buf in self.get_graph_inputs().items(): + if isinstance(buf, (sympy.Expr, ir.TorchBindObject)): + continue + + line = f"assert not {name}.isnan().any().item()" + self.prefix.writeline(line) + line = f"assert not {name}.isinf().any().item()" + self.prefix.writeline(line) + + def write_async_compile_wait(self) -> None: + self.prefix.splice( + """ + + async_compile.wait(globals()) + del async_compile + """ + ) + + def write_args(self, input_names: list[str]): + lhs = ", ".join(input_names) + if len(input_names) == 1: + lhs += "," + self.prefix.writeline(f"{lhs} = args") + self.prefix.writeline("args.clear()") + + def write_launcher_fn_call_get_indent(self) -> int: + if config.graph_partition: + self.prefix.splice( + """ + class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + """ + ) + prefix_indent = 2 + else: + self.prefix.splice( + f""" + def {self.launcher_fn_name}(args): + """ + ) + prefix_indent = 1 + + return prefix_indent + + def get_graph_input_names(self) -> list[str]: + return V.graph.graph_input_names + + def write_prefix(self) -> None: + assert self.launcher_fn_name is not None + self.write_async_compile_wait() + prefix_indent = self.write_launcher_fn_call_get_indent() + + with self.prefix.indent(prefix_indent): + if config.triton.debug_sync_graph: + self.prefix.writeline(V.graph.device_ops.synchronize()) + phase = V.graph.get_training_phase() + if config.annotate_training: + self.prefix.writeline( + f"training_annotation = nvtx._device_range_start('{phase}')" + ) + + if graph_input_names := self.get_graph_input_names(): + self.write_args(graph_input_names) + + self.codegen_inputs() + self.codegen_input_size_and_nan_asserts() + + def codegen_input_size_and_nan_asserts(self) -> None: + if config.size_asserts: + self.codegen_input_size_asserts() + if config.nan_asserts: + self.codegen_input_nan_asserts() + + # this function (and below) takes the graph name as input so + # that stream caching happens per graph instance. this + # is important for nested subgraph codegening. + def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str: + self.write_get_raw_stream_header_once() + name = f"stream{device_idx}" + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + f"{name} = get_raw_stream({device_idx})" + ) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return name + self.writeline(f"{name} = get_raw_stream({device_idx})") + return name + + def get_codegened_graph(self): + return self.codegened_graph_stack[-1] + + def push_codegened_graph(self, graph): + self.codegened_graph_stack.append(graph) + + def pop_codegened_graph(self): + return self.codegened_graph_stack.pop() + + def push_computed_sizes(self, computed_sizes): + from copy import deepcopy + + return self.computed_sizes_stack.append(deepcopy(computed_sizes)) + + def pop_computed_sizes(self): + return self.computed_sizes_stack.pop() + + def next_kernel_suffix(self) -> str: + return f"{next(self._names_iter)}" + + def codegen_device_guard_enter(self, device_idx: int) -> None: + self.writeline( + EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index) + ) + if config.triton.autotune_at_compile_time: + # mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block + self.write_triton_header_once() + self.kernel_autotune_calls.writeline( + f"with {V.graph.device_ops.device_guard(device_idx)}:" + ) + self.kernel_autotune_calls.do_indent() + self.kernel_autotune_calls.writeline( + V.graph.device_ops.set_device(device_idx) + ) + if is_codegen_graph_partition_subgraph(self): + # Need get_raw_stream for subgraph + self.write_get_raw_stream_header() + self.kernel_autotune_calls.writeline( + f"stream{device_idx} = get_raw_stream({device_idx})" + ) + self.last_seen_device_guard_index = device_idx + + def codegen_device_guard_exit(self) -> None: + self.writeline(ExitDeviceContextManagerLine()) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.do_unindent() + + def generate_return(self, output_refs: list[str]) -> None: + if output_refs: + if config.nan_asserts: + self.wrapper_call.writeline( + "return_vars = (" + ", ".join(output_refs) + ", )" + ) + self.wrapper_call.writeline("for var in return_vars:") + self.wrapper_call.do_indent() + self.wrapper_call.writeline("if isinstance(var, torch.Tensor):") + self.wrapper_call.do_indent() + self.wrapper_call.writeline("assert not var.isnan().any().item()") + self.wrapper_call.writeline("assert not var.isinf().any().item()") + self.wrapper_call.do_unindent(2) + + self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") + else: + self.wrapper_call.writeline("return ()") + + def generate_before_suffix(self, result: IndentedBuffer) -> None: + return + + def generate_after_suffix(self, result: IndentedBuffer) -> None: + if config.graph_partition: + all_partition_name_list = ", ".join(self.all_partition_names) + ( + "," if len(self.all_partition_names) == 1 else "" + ) + + result.splice( + f""" + runner = Runner(partitions=[{all_partition_name_list}]) + call = runner.call + recursively_apply_fns = runner.recursively_apply_fns + """ + ) + + def generate_end(self, result: IndentedBuffer) -> None: + return + + def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: + self.writeline(ExternKernelAllocLine(self, node)) + + def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc): + node.codegen_comment(self) + self.writeline(ExternKernelAllocLine(self, node)) + if isinstance(node.layout, ir.Layout): + node.codegen_size_asserts(self) + + def _generate_extern_kernel_alloc_helper(self, extern_kernel, args): + # If it's a NoneLayout then the extern_kernel should essentially be + # treated as if it doesn't return anything + no_return = isinstance(extern_kernel.layout, ir.NoneLayout) + output_name = extern_kernel.get_name() + origin_node = extern_kernel.get_origin_node() + kernel_name = extern_kernel.get_kernel_name() + ending = self.ending + if config.memory_planning and "view_as_complex" in kernel_name: + # view operation fallbacks cause issues since inductor + # doesn't know the memory is still needed and might reuse it. + ending = f".clone(){ending}" + + if no_return: + self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}") + else: + self.writeline( + f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}" + ) + if ( + self.supports_intermediate_hooks + and config.generate_intermediate_hooks + and origin_node is not None + ): + counters["inductor"]["intermediate_hooks"] += 1 + self.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {output_name})" + ) + + def generate_extern_kernel_out( + self, + node: ir.ExternKernelOut, + ) -> None: + node.codegen_comment(self) + self.writeline(ExternKernelOutLine(self, node)) + + def _generate_extern_kernel_out_helper( + self, + kernel: str, + out: str, + out_view: Optional[str], + args: list[str], + device: str, + ) -> None: + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(args, kernel, None, None, "extern") + args.append(f"out={out_view if out_view else out}") + with debug_printer_manager: + self.writeline(f"{kernel}({', '.join(args)})") + + def _generate_tma_descriptor_call_experimental(self, desc, apply_size_hints=False): + dims = desc.dims + block_dims = desc.block_dims + if apply_size_hints: + dims = tuple(V.graph.sizevars.atomically_apply_size_hint(d) for d in dims) + block_dims = tuple( + V.graph.sizevars.atomically_apply_size_hint(d) for d in block_dims + ) + + ptr = f"{desc.tensor.codegen_reference()}.data_ptr()" + # Explicitly call the Python version of val_to_arg_str + dims = ", ".join(PythonWrapperCodegen.val_to_arg_str(self, dim) for dim in dims) + block_dims = ", ".join( + PythonWrapperCodegen.val_to_arg_str(self, dim) for dim in block_dims + ) + element_size = PythonWrapperCodegen.val_to_arg_str(self, desc.element_size) + prefix = "triton.tools.experimental_descriptor" + fn = f"{prefix}.create_{desc.rank}d_tma_descriptor" + args = f"{ptr}, {dims}, {block_dims}, {element_size}" + call = f"{fn}({args})" + return call + + def _generate_tma_descriptor_call_stable(self, desc, apply_size_hints=False): + block_shape = desc.block_shape + if apply_size_hints: + block_shape = tuple( + V.graph.sizevars.atomically_apply_size_hint(d) for d in block_shape + ) + + prefix = "triton.tools.tensor_descriptor.TensorDescriptor" + fn = f"{prefix}.from_tensor" + args = f"{desc.tensor.codegen_reference()}, {block_shape}" + call = f"{fn}({args})" + return call + + def _generate_tma_descriptor_call(self, desc, apply_size_hints=False): + if isinstance(desc, ir.TMADescriptorExperimental): + return self._generate_tma_descriptor_call_experimental( + desc, apply_size_hints + ) + else: + assert isinstance(desc, ir.TMADescriptorStable) + return self._generate_tma_descriptor_call_stable(desc, apply_size_hints) + + def generate_tma_descriptor(self, desc): + call = self._generate_tma_descriptor_call(desc) + line = f"{desc.name} = {call}{self.ending}" + self.writeline(line) + + def generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + ): + line = f"{python_kernel_name}({','.join(map(str, inputs))}" + if python_kernel_name.startswith("aten.scatter_reduce"): + line += ", ".join([""] + kwargs) + else: + if reduce: + line += f", reduce={repr(reduce)}" + line += ")" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + indices_str = f"[{', '.join(indices)}]" + args = [x, indices_str, values, accumulate] + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})") + + def generate(self, is_inference): + with dynamo_timed("PythonWrapperCodegen.generate"): + return self._generate(is_inference) + + def get_wrapper_call_indent(self) -> int: + if config.graph_partition: + return 2 + else: + return 1 + + @contextlib.contextmanager + def set_writeline(self, new: Callable[..., None]) -> Iterator[Callable[..., None]]: + old = self.writeline + try: + self.writeline = new # type: ignore[method-assign] + yield new + finally: + self.writeline = old # type: ignore[method-assign] + + def _write_multi_kernel_defs(self) -> None: + kernel_defs = self.multi_kernel_state.kernel_defs + if config.triton.autotune_at_compile_time: + self.kernel_autotune_defs.splice(kernel_defs) + else: + self.header.splice(kernel_defs) + + def _generate(self, is_inference): + if config.profile_bandwidth: + self.write_triton_header_once() + + with contextlib.ExitStack() as stack: + stack.enter_context(self.wrapper_call.indent()) + if config.profiler_mark_wrapper_call: + self.generate_profiler_mark_wrapper_call(stack) + if config.profile_bandwidth: + self.generate_start_graph() + + self.run_wrapper_ir_passes(is_inference) + + if config.triton.store_cubin and not config.triton.autotune_at_compile_time: + self.generate_reset_kernel_saved_flags() + + # At this point, we shouldn't generate any new memory planning lines. + # Override writeline to point at the wrapper call, in case it gets called. + with self.set_writeline(self.wrapper_call.writeline): + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + + self._write_multi_kernel_defs() + + output_refs = self.get_output_refs() + self.mark_output_type() + if config.triton.debug_sync_graph: + self.wrapper_call.writeline(V.graph.device_ops.synchronize()) + + if config.profile_bandwidth: + self.generate_end_graph() + + if config.triton.store_cubin and not config.triton.autotune_at_compile_time: + self.generate_save_uncompiled_kernels() + + if config.triton.autotune_at_compile_time: + self.generate_and_run_autotune_block() + + # cpp_wrapper currently doesn't support nvtx + if config.annotate_training and not config.cpp_wrapper: + self.wrapper_call.writeline( + "nvtx._device_range_end(training_annotation)" + ) + self.generate_return(output_refs) + + # Assemble the final code from sections. + result = IndentedBuffer() + result.splice(self.imports) + result.writeline("") + result.splice(self.header) + # We do not want the cpp header for intermediate const graph. Headers would be + # rendered by the main module instead. + if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph: + result = IndentedBuffer() + + # Add subgraph definitions to the result + result.splice(self.subgraph_definitions) + self.finalize_prefix() + result.splice(self.prefix) + + wrapper_call_indent = self.get_wrapper_call_indent() + + with result.indent(wrapper_call_indent): + result.splice(self.wrapper_call) + + self.generate_before_suffix(result) + result.splice(self.suffix) + self.generate_after_suffix(result) + + self.generate_end(result) + + self.add_benchmark_harness(result) + + return ( + result.getvaluewithlinemap(), + self.kernel_declarations.getvaluewithlinemap(), + ) + + def generate_and_run_autotune_block(self): + """ + Compose self.kernel_autotune_defs and self.kernel_autotune_calls into a single block of + code and execute it to trigger Triton kernel compilation and auto-tuning + """ + self.kernel_autotune_defs.splice( + """ + async_compile.wait(globals()) + del async_compile + """ + ) + scope = {} # type: ignore[var-annotated] + if config.triton.autotune_at_compile_time and V.graph.autotuning_inputs: + scope = { + self.get_autotuning_input_name(idx): v # type: ignore[attr-defined] + for idx, v in enumerate(V.graph.autotuning_inputs) + } + tuning_code = ( + self.kernel_autotune_defs.getvalue() + + "\n" + + self.kernel_autotune_calls.getvalue() + ) + if output_code_log.level == logging.DEBUG: + # Save the autotuning code block into a file + # Create a temporary file + with tempfile.NamedTemporaryFile( + dir=cache_dir(), suffix=".py", delete=False + ) as f: + f.write(tuning_code.encode("utf-8")) + file_path = f.name + output_code_log.debug( + "Auto-tuning code written to %s", + file_path, + ) + # Execute the code to autotune kernels + try: + exec(tuning_code, scope) + except Exception as e: + raise RuntimeError(f"Failed to run autotuning code block: {e}") from e + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + + def memory_plan_reuse(self): + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + # FIXME(rec): not used + _total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + def run_wrapper_ir_passes(self, is_inference: bool): + # We disable planning during training because it presently increases peak memory consumption. + if is_inference and config.memory_planning: + self.memory_plan() + else: + self.memory_plan_reuse() + + def codegen_input_symbol_assignment( + self, + name: str, + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + code = self.prefix + + @functools.cache + def sizeof(name): + code.writeline(f"{name}_size = {name}.size()") + return f"{name}_size" + + @functools.cache + def strideof(name): + code.writeline(f"{name}_stride = {name}.stride()") + return f"{name}_stride" + + if isinstance(value, sympy.Expr): + if not isinstance(value, sympy.Symbol) or value in bound_vars: + return + code.writeline(f"{value} = {name}") + bound_vars.add(value) + elif isinstance(value, ir.TensorBox): + for dim, size in enumerate(value.get_size()): + if isinstance(size, sympy.Symbol) and size not in bound_vars: + code.writeline(f"{size} = {sizeof(name)}[{dim}]") + bound_vars.add(size) + for dim, stride in enumerate(value.get_stride()): + if isinstance(stride, sympy.Symbol) and stride not in bound_vars: + code.writeline(f"{stride} = {strideof(name)}[{dim}]") + bound_vars.add(stride) + elif isinstance(value, ir.TorchBindObject): + return + elif isinstance(value, ir.GeneratorState): + return + else: + if torch._inductor.config.graph_partition: + pass + else: + raise AssertionError(f"Unknown value type: {type(value)}") + + def codegen_inputs(self): + """Assign all symbolic shapes to locals""" + bound_vars = OrderedSet[sympy.Symbol]() + # There is a subtle case in the cpp wrapper codegen which requires generating + # symbol inputs first followed by non-symbol ones. + # + # When a dynamic size constraint specified at the Export time is an expression, + # we need to solve that expression to proper define a symbol in cpp. Thus we + # are enforcing this iterating order here to make sure all plain size symbols + # are defined first. + graph_inputs = self.get_graph_inputs() + inputs = [ + (k, v) for k, v in graph_inputs.items() if isinstance(v, sympy.Symbol) + ] + [(k, v) for k, v in graph_inputs.items() if not isinstance(v, sympy.Symbol)] + for name, value in inputs: + self.codegen_input_symbol_assignment(name, value, bound_vars) + + def _verify_input_symbol_assignment( + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + for expr in chain.from_iterable([value.get_size(), value.get_stride()]): + if not isinstance(expr, Expr) or isinstance(expr, sympy.Symbol): + continue + + undefined_symbols = [ + sym for sym in expr.free_symbols if sym not in bound_vars + ] + if len(undefined_symbols) > 0: + raise AssertionError( + f"For {expr}, expected {undefined_symbols} to have been codegen-ed." + ) + + # For inputs with size/strides which contain sympy expressions, we can + # encounter symbols that weren't defined yet. Now, let's check each + # symbol is defined. + for _, value in inputs: + if not isinstance(value, ir.TensorBox): + continue + _verify_input_symbol_assignment(value, bound_vars) + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline(f"{sym} = {pexpr(expr)}") + + def finalize_prefix(self): + pass + + def codegen_cpp_sizevar(self, x: Expr, *, simplify: bool = True) -> str: + raise RuntimeError("codegen_cpp_sizevar is only implemented for cpp_wrapper!") + + def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str: + return pexpr(x, simplify=simplify) + + def codegen_sizevar(self, x: Expr) -> str: + return self.codegen_python_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + return f"{basename}[{index}]" + + def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str: + parts = [*map(self.codegen_python_sizevar, shape)] + if len(parts) == 0: + return "()" + if len(parts) == 1: + return f"({parts[0]}, )" + return f"({', '.join(parts)})" + + def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str: + return self.codegen_python_shape_tuple(shape) + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + pexpr(offset), # bytes not numel + str(dtype), + self.codegen_python_shape_tuple(shape), + self.codegen_python_shape_tuple(stride), + ] + ) + ) + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + if ( + size == data.layout.size + and stride == data.layout.stride + and offset == data.layout.offset + ): + if dtype is not None and dtype != data.dtype: + return f"aten.view.dtype({data.get_name()}, {dtype})" + else: + return f"{data.get_name()}" + else: + size = self.codegen_python_shape_tuple(size) + stride = self.codegen_python_shape_tuple(stride) + offset = self.codegen_sizevar(offset) + if dtype is not None and dtype != data.dtype: + return f"aten.view.dtype(reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset}), {dtype})" + else: + return ( + f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" + ) + + def codegen_device_copy(self, src, dst, non_blocking: bool): + self.writeline(f"{dst}.copy_({src}, {non_blocking})") + + def codegen_multi_output(self, node: ir.MultiOutput): + result_name = node.get_name() + arg_name = node.inputs[0].get_name() + self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) + + def codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + if len(node.keypath) == 0: + self.writeline(f"{node.sym} = {data}.item()") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): + self.writeline(f"{node.sym} = 1 if {data}.item() else 0") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): + self.writeline(f"{node.sym}_undivided = {data}.item()") + self.writeline( + f"assert {node.sym}_undivided % {node.keypath[0].divisor} == 0, " + f"f'{{{node.sym}_undivided}} not divisible by {node.keypath[0].divisor}'" + ) + self.writeline( + f"{node.sym} = {node.sym}_undivided // {node.keypath[0].divisor}" + ) + else: + raise AssertionError(f"unrecognized keypath {node.keypath}") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + self.writeline(f"{node.get_name()} = None") + + def benchmark_compiled_module(self, output): + def add_fake_input(name, shape, stride, device, dtype): + output.writeline( + f"{name} = rand_strided(" + f"{self.codegen_python_shape_tuple(shape)}, " + f"{self.codegen_python_shape_tuple(stride)}, " + f"device='{device}', dtype={dtype})" + ) + + def add_expr_input(name, val): + output.writeline(f"{name} = {val}") + + def add_torchbind_input(name, value): + import pickle + + assert isinstance(value, torch.ScriptObject) + + output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})") + + output.writelines( + ["", "", "def benchmark_compiled_module(times=10, repeat=10):"] + ) + with output.indent(): + output.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + """, + strip=True, + ) + + for name, value in V.graph.constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_fake_input( + name, value.size(), value.stride(), value.device, value.dtype + ) + + if len(V.graph.torchbind_constants) > 0: + output.writeline("import pickle") + for name, torchbind_obj in V.graph.torchbind_constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_torchbind_input(name, torchbind_obj) + + for name, value in V.graph.graph_inputs.items(): + if isinstance(value, sympy.Symbol) and isinstance( + V.graph.sizevars.var_to_val.get(value, None), SingletonInt + ): + # Inductor should only work with dense -> dense graph, and + # SingletonInts belong to metadata that should only live on + # the subclass. + continue + if isinstance(value, ir.TorchBindObject): + if len(V.graph.torchbind_constants) == 0: + # otherwise we have already imported the pickle package + output.writeline("import pickle") + output.writeline(f"global {name}") + add_torchbind_input(name, value.get_real_obj()) + elif isinstance(value, sympy.Expr): # Don't need to add symbolic + # TODO: this fallback and those below actually will generate possibly + # invalid benchmark code, because it's not guaranteed 42 + # is actually a valid value for the kernel in question. + # See https://github.com/pytorch/pytorch/issues/124686 + add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42)) + elif isinstance(value, ir.GeneratorState): + add_expr_input( + name, + f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()", + ) + else: + shape = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_size() + ] + stride = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_stride() + ] + add_fake_input( + name, + shape, + stride, + value.get_device(), + value.get_dtype(), + ) + + call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])" + output.writeline(f"fn = lambda: {call_str}") + output.writeline("return print_performance(fn, times=times, repeat=repeat)") + + def add_benchmark_harness(self, output): + """ + Append a benchmark harness to generated code for debugging + """ + if not config.benchmark_harness: + return + + self.benchmark_compiled_module(output) + + output.writelines(["", "", 'if __name__ == "__main__":']) + with output.indent(): + output.writelines( + [ + "from torch._inductor.wrapper_benchmark import compiled_module_main", + f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)", + ] + ) + + def define_kernel( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + self.writeline( + KernelDefinitionLine( + self, + kernel_name, + kernel_body, + metadata=metadata, + gpu=gpu, + cpp_definition=cpp_definition, + ) + ) + + @staticmethod + def _format_kernel_definition( + kernel_name: str, kernel_body: str, metadata: Optional[str] = None + ): + metadata_comment = f"{metadata}\n" if metadata else "" + body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}" + return body + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + if config.triton.autotune_at_compile_time: + # Skip inserting comments for the autotune block as they may contain cpp style comments + body = self._format_kernel_definition( + kernel_name, kernel_body, metadata=None + ) + self.kernel_autotune_defs.splice(body) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return + + body = self._format_kernel_definition( + kernel_name, kernel_body, metadata=metadata + ) + self.header.splice(body) + + def define_subgraph_launcher_fn(self, fn_code: str): + self.subgraph_definitions.splice(fn_code) + + def define_user_defined_triton_kernel( + self, + kernel, + configs, + kwargs, + restore_value_args, + reset_to_zero_args, + grids: list[list[Union[int, sympy.Expr]]], + ): + from ..runtime.triton_heuristics import ( + config_to_dict, + FixedGrid, + PrecomputedGrid, + ) + from .common import ( + ConstexprArg, + KernelArgType, + SizeArg, + TensorArg, + TMADescriptorArg, + ) + from .triton import gen_common_triton_imports, TritonKernel + + original_name = kernel.__name__ + signature: list[KernelArgType] = [] + constants: dict[str, Any] = {} + arg_indices: list[int] = [] + equal_to_1_args: list[str] = [] + + def add_to_signature(idx, arg): + signature.append(arg) + arg_indices.append(idx) + + def add_arg(idx, arg, is_constexpr=False, equals_1=False, equals_none=False): + if is_constexpr: + if triton_version_uses_attrs_dict(): + # tl.constexpr args appear in the signature in new versions of triton, + # but not in old versions of triton. + add_to_signature(idx, arg) + + if arg.name in kwargs: + # the arg may not appear in kwargs if it is an autotuned arg. + # in this case, it will be added in triton_heuristics after autotuning. + constants[arg.name] = kwargs[arg.name] + + else: + # the only case where arg name isn't in kwargs, should be + # when the arg is a constexpr. + assert arg.name in kwargs + + if equals_1: + if triton_version_uses_attrs_dict(): + # new versions of triton: add the equal-to-1 arg in the signature (labeled as "constexpr"), + # and add the arg as a constant. + # new versions of triton: add the equal-to-1 arg in the signature (labeled as, e.g., "i32"), + # and add the arg as a constant. + add_to_signature(idx, ConstexprArg(name=arg.name)) + else: + add_to_signature(idx, arg) + constants[arg.name] = 1 + elif equals_none: + if triton_version_uses_attrs_dict(): + # new versions of triton: add the none arg in the signature (as a constexpr arg) and as a constant + # old versions of triton: include the none arg as a constant (but not in the signature) + add_to_signature(idx, ConstexprArg(name=arg.name)) + constants[arg.name] = None + else: + add_to_signature(idx, arg) + + for idx, key in enumerate(kernel.arg_names): + if idx in kernel.constexprs: + add_arg(idx, ConstexprArg(name=key), is_constexpr=True) + continue + + if key not in kwargs: + continue + + arg = kwargs[key] + + if kwargs[key] is None: + add_arg(idx, ConstexprArg(name=key), equals_none=True) + else: + if isinstance(arg, ir.TMADescriptor): + api_type, block_shape, dtype = ( + ("stable", arg.block_shape, arg.tensor.get_dtype()) + if isinstance(arg, ir.TMADescriptorStable) + else ("experimental", None, None) + ) + add_arg( + idx, + TMADescriptorArg( + name=key, + api_type=api_type, + block_shape=block_shape, + dtype=dtype, + ), + ) + elif isinstance(arg, ir.Buffer): + add_arg( + idx, + TensorArg( + name=key, + buffer=arg.get_name(), + dtype=arg.get_dtype(), + ), + ) + elif isinstance(arg, ir.ReinterpretView): + # for ReinterpretView we use the underlying + # buffer name and note the (possibly non-zero) + # offset relative to the underlying buffer + add_arg( + idx, + TensorArg( + name=key, + buffer=arg.data.get_name(), + dtype=arg.get_dtype(), + offset=arg.layout.offset, + ), + ) + else: + equals_1 = isinstance( + arg, (int, sympy.Integer) + ) and V.graph.sizevars.statically_known_equals( + arg, + 1, # type: ignore[arg-type] + ) + add_arg(idx, SizeArg(key, arg), equals_1=equals_1) + + triton_signature = signature_to_meta( + signature, + size_dtype=None, # try to infer based on symints + indices=arg_indices, + argdefs=[ArgName(x) for x in kernel.arg_names], + ) + triton_meta: dict[str, Any] = { + "signature": triton_signature, + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # TODO(aakhundov): add None args to constants, too. currently, this + # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + "constants": { + **constants, + **dict.fromkeys(equal_to_1_args, 1), + }, + "configs": [ + config_of( + signature, + indices=arg_indices, + ) + ], + } + + if restore_value_args: + triton_meta["restore_value"] = tuple(restore_value_args) + + if reset_to_zero_args: + triton_meta["reset_to_zero"] = tuple(reset_to_zero_args) + + if len(grids) == 1: + # compute the grid in the wrapper and pass it in as an arg + inductor_meta: dict[str, Any] = FixedGrid.setup_grid_as_args() + extra_launcher_call_args = [*map(sympy.sympify, grids[0])] + else: + + def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr: + if isinstance(expr, sympy.Expr): + symbols = [*expr.free_symbols] + if not symbols: + return expr + symbols.sort(key=str) + for sym in symbols: + if sym in extra_launcher_args: + continue + extra_launcher_args[sym] = sympy.Symbol( + f"_launcher_s{len(extra_launcher_args)}" + ) + return sympy_subs(expr, extra_launcher_args) + assert isinstance(expr, int) + return sympy.Integer(expr) + + extra_launcher_args: dict[sympy.Symbol, sympy.Symbol] = {} + grids = [[*map(rename_sizes_for_launcher, grid)] for grid in grids] + + assert grids and len(grids) == len(configs) + precomputed_grids = [] + for grid, cfg in sorted( + zip(grids, configs), key=lambda x: len(x[1].kwargs), reverse=True + ): + precomputed_grids.append( + { + "config": config_to_dict(cfg), + "python": [*map(pexpr, grid)], + "cpp": [*map(cexpr, grid)], + } + ) + inductor_meta = { + "grid_type": PrecomputedGrid.__name__, + "precomputed_grids": precomputed_grids, + "extra_launcher_args": [*map(str, extra_launcher_args.values())], + } + extra_launcher_call_args = [*extra_launcher_args.keys()] + + # Distinguish between different functions using function id + cache_key: Any = [id(kernel.fn)] + if len(configs) > 0: + for arg in kwargs.values(): + # We need to key on non tensor arg only in autotune mode + if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): + cache_key.append(arg) + cache_key.append(str(triton_meta)) + cache_key.extend(str(inductor_meta)) + cache_key = tuple(cache_key) + if cache_key in self.user_defined_kernel_cache: + return ( + *self.user_defined_kernel_cache[cache_key], + extra_launcher_call_args, + ) + + name = f"{original_name}_{len(self.user_defined_kernel_cache)}" + + compile_wrapper = IndentedBuffer() + if config.triton.unique_user_kernel_names: + compile_wrapper.writeline(f"async_compile.triton({name!r}, '''") + else: + compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") + + inductor_meta["kernel_name"] = name + inductor_meta.update(TritonKernel.inductor_meta_common()) + + compile_wrapper.splice(gen_common_triton_imports()) + compile_wrapper.splice( + f""" + @triton_heuristics.user_autotune( + configs={[*map(config_to_dict, configs)]!r}, + inductor_meta={inductor_meta!r}, + triton_meta={triton_meta!r}, + filename=__file__, + custom_kernel=True, + ) + @triton.jit + """ + ) + kernel_src = user_defined_triton_kernel_transitive_closure_source_code(kernel) + if config.triton.unique_user_kernel_names: + # We replace the original_name with the unique name. + kernel_src = kernel_src.replace(f"def {original_name}(", f"def {name}(") + kernel_src = kernel_src.replace("'''", "\\'\\'\\'") + compile_wrapper.splice(kernel_src) + + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + _, lineno = inspect.getsourcelines(kernel.fn) + srcfile = inspect.getsourcefile(kernel.fn) + metadata = f"# Original path: {srcfile}:{lineno}" + self.define_kernel( + name, + compile_wrapper.getvalue(), + metadata, + ) + # Add to the cache for the next use + self.user_defined_kernel_cache[cache_key] = (name, triton_meta) + return name, triton_meta, extra_launcher_call_args + + def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): + expr = f"{kernel_name}_{tree.prefix}numel" + if suffix is not None: + expr += f"_{suffix}" + + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + arg = SymbolicCallArg(expr, tree.numel) + self.writeline(SymbolicCallArgLine(self, arg, V.graph)) + + return arg + + def _generate_symbolic_call_arg_helper( + self, arg: SymbolicCallArg, graph: GraphLowering + ) -> None: + self.writeline(f"{arg.inner} = {pexpr(arg.inner_expr)}") + + def generate_workspace_allocation(self, ws: WorkspaceArg): + name = ws.get_name() + line = AllocateLine(self, ws) + if ws.zero_mode == WorkspaceZeroMode.UNINITIALIZED: + self.writeline(line) + elif ws.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL: + self.writeline(line) + self.writeline(self.make_zero_buffer(name)) + elif ws.zero_mode == WorkspaceZeroMode.ZERO_PER_GRAPH: + prior = self.allocated_workspaces.get(name) + if prior: + assert isinstance(prior, AllocateLine) and isinstance( + prior.node, WorkspaceArg + ) + # expand existing allocation + prior.node = WorkspaceArg.maximum(prior.node, ws) + else: + self.writeline(line) + self.writeline(self.make_zero_buffer(name)) + self.allocated_workspaces[name] = line + else: + raise AssertionError(ws.zero_mode) + + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + PythonWrapperCodegen.make_allocation( + self, + name, + ws.device, + ws.dtype, + shape=(V.graph.sizevars.size_hint(ws.count),), + stride=(1,), + ) + ) + if ws.zero_mode != WorkspaceZeroMode.UNINITIALIZED: + self.kernel_autotune_calls.writeline( + PythonWrapperCodegen.make_zero_buffer(self, name) + ) + + def generate_workspace_deallocation(self, ws: WorkspaceArg): + if ws.zero_mode != WorkspaceZeroMode.ZERO_PER_GRAPH: + self.writeline(FreeIfNotReusedLine(self, ws)) + + def make_zero_buffer(self, name): + return f"{name}.zero_(){self.ending}" + + def wrap_kernel_call(self, name, call_args): + return f"{name}({', '.join(call_args)}){self.ending}" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline("from torch.profiler import record_function") + self.wrapper_call.writeline( + f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):" + ) + stack.enter_context(self.wrapper_call.indent()) + + def generate_start_graph(self): + self.wrapper_call.writeline("start_graph()") + + def generate_end_graph(self): + self.wrapper_call.writeline(f"end_graph({config.profile_bandwidth_output!r})") + + def generate_reset_kernel_saved_flags(self): + self.wrapper_call.splice( + f""" + for kernel in globals().values(): + if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): + kernel.cuda_kernel_saved = False + """ + ) + + def generate_save_uncompiled_kernels(self): + """ + Precompile and save the CUBINs of the Triton kernels that haven't + been precompiled and saved as a side effect of running the generated + JIT model (Python wrapper). This can happen when the model contains + control flow: only one pass through the control flow operators covers + the kernels that are saved, the remaining kernels are not launched, + hence not saved. The main purpose of this codegen is to compile and + save the Triton kernels outside the active control flow path for + subsequent AOTInductor code generation and compilation. + """ + self.wrapper_call.splice( + f""" + for kernel in globals().values(): + if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): + if not kernel.cuda_kernel_saved: + if len(kernel.launchers) == 0: + kernel.precompile() + kernel.save_gpu_kernel( + grid=(0, 0, 0), # use dummy grid + stream="stream", # use dummy stream + launcher=kernel.launchers[0], + ) + """ + ) + + def prepare_triton_kernel_call(self, call_args): + def wrap_arg(arg): + if isinstance(arg, str): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg + elif isinstance(arg, (int, float, bool, SymbolicCallArg)): + return str(arg) + else: + return pexpr(V.graph.sizevars.simplify(arg)) + + return [wrap_arg(arg) for arg in call_args] + + def generate_example_arg_value(self, arg, arg_type, raw_arg=None): + if isinstance(arg_type, torch_dtype): + if isinstance(raw_arg, ir.TMADescriptor): + # first we generate the underlying buffer + buf_name = raw_arg.get_tensor().get_name() + buf = self.args_to_buffers[arg] + elif self.args_to_buffers.get(arg): + buf_name = arg + buf = self.args_to_buffers[arg] + else: + assert raw_arg is not None, ( + "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" + ) + buf_name = f"tmp_arg_{self.kernel_autotune_tmp_arg_idx}" + buf = raw_arg + self.kernel_autotune_tmp_arg_idx += 1 + + assert buf is not None, f"Failed to find a buffer for arg {arg}" + size = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in buf.get_size() + ) + allocation_size = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in V.graph.get_allocation_size(buf) + ) + stride = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in buf.get_stride() + ) + device = buf.get_device() + dtype = buf.get_dtype() + offset = V.graph.sizevars.size_hint( + buf.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ) + value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset}, {allocation_size})" + self.kernel_autotune_calls.writeline(f"{buf_name} = {value}") + + if isinstance(raw_arg, ir.TMADescriptor): + # generate another line initializing a host-side TMA + # descriptor from the underlying buffer created above + value = self._generate_tma_descriptor_call( + desc=raw_arg, + apply_size_hints=True, + ) + buf_name = arg + self.kernel_autotune_calls.writeline(f"{buf_name} = {value}") + + return buf_name + elif issubclass(arg_type, sympy.Basic) or isinstance(arg, SymbolicCallArg): + # arg is a symbol or symbolic expression + if isinstance(arg, str): + if arg in self._meta_vars: + return arg + if raw_arg is None: + return "None" + arg = raw_arg + if isinstance(arg, SymbolicCallArg): + arg = arg.inner_expr + if arg in V.graph.sizevars.inv_precomputed_replacements: + arg = V.graph.sizevars.inv_precomputed_replacements[arg] + + return str( + V.graph.sizevars.atomically_apply_size_hint( + arg, fallback=config.unbacked_symint_fallback + ) + ) + + elif isinstance(arg, (str, int, float, bool)): + return str(arg) + elif isinstance(arg, list): + return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]" + else: + raise NotImplementedError(f"Unsupported type {type(arg)}") + + def _grid_dim_str(self, grid_per_dim): + if isinstance(grid_per_dim, list): + return ( + "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]" + ) + else: + return pexpr(grid_per_dim) + + def generate_kernel_call( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True, + and C++ when gpu=False. + """ + + # Store buffers corresponding to each call arg. + # This is used to generate example args for autotuning later on. + self.args_to_buffers.update( + { + arg: V.graph.try_get_buffer(arg) + for arg in call_args + if isinstance(arg, str) + } + ) + + device = device or V.graph.get_current_device_or_throw() + self.writeline( + KernelCallLine( + self, + kernel_name=kernel_name, + call_args=call_args, + raw_keys=raw_keys, + raw_args=raw_args, + arg_types=arg_types, + triton=triton, + triton_meta=triton_meta, + device=device, + graph_name=V.graph.name, + original_fxnode_name=original_fxnode_name, + ) + ) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + device = device or V.graph.get_current_device_or_throw() + if not (triton or device.type != "cpu"): + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + return + + call_args_str = self.prepare_triton_kernel_call(call_args) + call_args_str = ", ".join(call_args_str) + stream_name = PythonWrapperCodegen.write_get_raw_stream( + self, device.index, graph_name + ) + if not triton: + stream_ptr = f"c_void_p({stream_name})" + self.writeline( + f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" + ) + return + + self.write_triton_header_once() + + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Create example args for autotune in a separate epilogue + assert arg_types is not None and len(call_args) == len(arg_types), ( + "call_args and arg_types do not match" + ) + + autotune_args = None + if original_fxnode_name and V.graph.autotuning_mapping: + autotune_args = V.graph.autotuning_mapping.get( + original_fxnode_name, None + ) + + def get_autotune_deletion_call() -> str: + """After all the autotune kernel calls have been written (i.e. + self.kernel_autotune_example_args is complete), returns a deletion call + for all autotune example tensors that are unnecessary after kernel_name + is called.""" + tensors_to_delete = [ + tensor + for tensor, kn in self.kernel_autotune_example_args.values() + if kn == kernel_name + ] + if tensors_to_delete: + return f"del {', '.join(tensors_to_delete)}\n" + return "" + + def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args): + """We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args. + This is particularly useful for jagged cases, where the dimension is often + being passed in as an input.""" + + target_arg = raw_args[idx] + if target_arg in reused_args: + return True + + for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)): + if i == idx or not isinstance(raw_arg, IRNode): + continue + + triton_input = "" + if autotune_args and raw_key in autotune_args: + triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined] + autotune_args[raw_key] + ) + if triton_input == "": + continue + + try: + layout = raw_arg.get_layout() + for dim, s in enumerate(layout.size): + if s == target_arg: + reused_args[target_arg] = f"{triton_input}.shape[{dim}]" + return True + except NotImplementedError: + # If layout for this IRNode is not implemented, we could just skip. + # Only raise for other Error cases. + continue + return False + + all_args = [] + if raw_args is None: + # create a dummy raw_args for uniform behavior in the following loop + assert raw_keys is None, "keys are not None but args are" + raw_keys = [None] * len(call_args) + raw_args = [None] * len(call_args) + else: + assert len(raw_args) == len(call_args), ( + "call_args and raw_args do not match" + ) + + reused_args = {} + for i, (arg, arg_type, raw_key, raw_arg) in enumerate( + zip(call_args, arg_types, raw_keys, raw_args) + ): + key = None + if isinstance(arg, str) and "=" in str(arg): + # arg may be passed in a kwarg style, and then we need to extract its value + key, arg = arg.split("=") + + triton_input: Optional[str] = None + if autotune_args and raw_key in autotune_args: + triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined] + autotune_args[raw_key] + ) + + if triton_input: + arg_str = triton_input + if not isinstance(arg_type, torch_dtype) and ( + issubclass(arg_type, sympy.Basic) + or isinstance(arg, SymbolicCallArg) + ): + reused_args[raw_arg] = arg_str + elif raw_key == "" and infer_arg_by_inputs( + raw_keys, raw_args, i, reused_args + ): + # Empty raw_key means this is a arg that's not native to the triton kernel, + # and is being added by inductor. + arg_str = reused_args[raw_arg] + elif isinstance(arg_type, torch_dtype): + # workspace allocation is already generated by `generate_workspace_allocation()` + # in `TritonKernel.call_kernel()`. + if re.match(r"^(workspace|semaphore)", arg): + arg_str = arg + elif arg not in self.kernel_autotune_example_args: + arg_str = self.generate_example_arg_value( + arg, arg_type, raw_arg + ) + else: + arg_str = self.kernel_autotune_example_args[arg][0] + self.kernel_autotune_example_args[arg] = (arg_str, kernel_name) + else: + arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg) + all_args.append(arg_str if key is None else f"{key}={arg_str}") + + self.kernel_autotune_calls.writeline( + f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})" + ) + self.kernel_autotune_calls.writeline( + DelayReplaceLine("", get_autotune_deletion_call, "") + ) + self.kernel_autotune_names.add(kernel_name) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return + + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) + with debug_printer_manager: + self.writeline(f"{kernel_name}.run({call_args_str}, stream={stream_name})") + self.write_triton_header_once() + + def writeline(self, line): + self.lines.append(line) + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def enter_context(self, ctx): + self.lines.append(LineContext(ctx)) + + def val_to_arg_str(self, s, type_=None): + from torch.utils._triton import has_triton_package + + if has_triton_package(): + import triton + + if isinstance(s, SymTypes): + return pexpr(s.node.expr) + elif isinstance(s, sympy.Expr): + return pexpr(s) + elif isinstance(s, (tuple, list)): + + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self): + return self.ref + + # Explicitly call the Python version of val_to_arg_str + return repr( + type(s)(Shim(PythonWrapperCodegen.val_to_arg_str(self, a)) for a in s) + ) + elif isinstance(s, torch._ops.OpOverload): + return _get_qualified_name(s) + elif isinstance(s, (ir.Buffer, ir.MutableBox, ReinterpretView)): + return s.codegen_reference() + elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined] + return repr(s) + elif isinstance(s, ir.GeneratorState): + return s.codegen_reference() + else: + return repr(s) + + # The following methods are for memory management + def make_buffer_allocation(self, buffer: BufferLike): + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = tuple(buffer.get_size()) + allocation_shape = tuple(V.graph.get_allocation_size(buffer)) + stride = tuple(buffer.get_stride()) + return self.make_allocation( + buffer.get_name(), device, dtype, shape, stride, allocation_shape + ) + + def make_allocation( + self, name, device, dtype, shape, stride, allocation_shape=None + ): + if allocation_shape is None: + allocation_shape = shape + + codegen_shape_tuple = self.codegen_python_shape_tuple(shape) + codegen_allocation_shape_tuple = self.codegen_python_shape_tuple( + allocation_shape + ) + codegen_stride_tuple = self.codegen_python_shape_tuple(stride) + if device.type in ("cpu", "cuda", "xpu"): + # optimized path for faster allocations, saving ~2us versus the stuff below + out = ( + f"{name} = empty_strided_{device.type}(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"{dtype})" + ) + # all other devices: + else: + out = ( + f"{name} = empty_strided(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"device='{device.type}', dtype={dtype})" + ) + if codegen_shape_tuple != codegen_allocation_shape_tuple: + # need an extra as_strided call + out = out + f".as_strided({codegen_shape_tuple}, {codegen_stride_tuple})" + return out + + def make_comment(self, line): + self.writeline(CommentLine(line)) + + def make_tensor_alias(self, new_name, old_name, comment=""): + return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}" + + def make_buffer_free(self, buffer: Union[BufferLike, ir.TorchBindObject]): + return f"del {buffer.get_name()}" + + def make_free_by_names(self, names_to_del: list[str]): + return f"del {', '.join(name for name in names_to_del)}" + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" + + def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline + ) + return f"{self.declare}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse" + + def codegen_deferred_allocation(self, name: str, view: ir.ReinterpretView) -> None: + self.writeline( + DeferredLine( + name, + f"{self.declare}{name} = {view.codegen_reference()}{self.ending} {self.comment} alias", + ) + ) + + def codegen_allocation(self, buffer: ir.Buffer): + name = buffer.get_name() + + if ( + name in V.graph.removed_buffers + or name in self.allocated + or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer)) + ): + return + self.allocated.add(name) + if ( + isinstance( + buffer.get_defining_op(), + (ir.ExternKernelAlloc, ir.MultiOutput), + ) + and not buffer.should_allocate() + ): + return + + layout = buffer.get_output_spec() + if isinstance(layout, ir.MutationLayoutSHOULDREMOVE): + return + if isinstance(layout, ir.NoneLayout): + return + if isinstance(layout, ir.NonOwningLayout): + assert isinstance(layout.view, ir.ReinterpretView), ( + f"unexpected {type(layout.view)}: {layout.view}" + ) + box = layout.view.data + assert isinstance(box, ir.StorageBox), type(box) + input_buffer = box.data + assert isinstance(input_buffer, ir.Buffer), type(box) + self.codegen_allocation(input_buffer) + self.writeline(ReinterpretLine(self, input_buffer, buffer, layout)) + return + + if isinstance(layout, ir.CommBufferLayout): + self.writeline(CommBufferAllocateLine(self, buffer)) + return + + self.writeline(AllocateLine(self, buffer)) + + def codegen_free(self, buffer): + name = buffer.get_name() + + # can be freed but not reused + if isinstance(buffer, (ir.InputBuffer, ir.TorchBindObject)): + self.writeline(FreeLine(self, buffer)) + return + + if isinstance(buffer.get_output_spec(), ir.CommBufferLayout): + # Comm buffers are not eligible for in-place reuse. Their reuse is + # achieved exclusively via buffer planning. + self.writeline(CommBufferFreeLine(self, buffer)) + return + + if not self.can_reuse(buffer): + return + self.freed.add(name) + + self.writeline(FreeIfNotReusedLine(self, buffer)) + + def can_reuse(self, input_buffer, output_buffer=None): + name = input_buffer.get_name() + return not ( + name in V.graph.removed_buffers + or ( + name in V.graph.graph_inputs + and not isinstance( + V.graph.graph_inputs_original[name], ir.DonatedBuffer + ) + ) + or name in V.graph.constants + or name in V.graph.torchbind_constants + or name in V.graph.never_reuse_buffers + or name in self.freed + ) + + def did_reuse(self, buffer, reused_buffer): + # Check whether a given buffer was reused by a possible reuser in the wrapper codegen + # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed + return ( + buffer.get_name() in self.reuses + and self.reuses[buffer.get_name()] == reused_buffer.get_name() + ) + + def codegen_inplace_reuse(self, input_buffer: ir.Buffer, output_buffer: ir.Buffer): + assert can_match_buffer_size(input_buffer, output_buffer) + self.codegen_allocation(input_buffer) + self.freed.add(input_buffer.get_name()) + self.allocated.add(output_buffer.get_name()) + self.reuses[output_buffer.get_name()] = input_buffer.get_name() + self.writeline(ReuseLine(self, input_buffer, output_buffer)) + + def codegen_unbacked_symbol_decl(self, symbol): + name = str(symbol) + if name in self.unbacked_symbol_decls: + return name + else: + # When in CppWrapperCpu, we should only generate the declaration once + self.unbacked_symbol_decls.add(name) + return self.declare + name + + def codegen_unbacked_symbol_defs_for_outputs( + self, + output_name: str, + outputs: Any, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], + ) -> None: + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ) + + if not unbacked_bindings: + return + + # This code is designed to generate code expressions from symbolic paths (keypaths) + # associated with certain symbols (unbacked bindings). These keypaths describe how + # to access the unbacked symbol in a structured way. + # For example, we might want to generate "u0 = outs[0].stride(1)"", where s = u0, and the keypath + # describes the structure of "outs[0].stride(1)", like [SequenceKey(0), CallMethodKey("stride"), SequenceKey[1]]. + for s, keypath in unbacked_bindings.items(): + # `go` recursively constructs a code expression by processing each element of + # the keypath and construct the expression incrementally. + # For example, given output name outs and keypath [SequenceKey(0), CallMethodKey("stride", 1)], + # it generates "outs[0]" based on SequenceKey(0), then recursively go("outs[0]", [CallMethodKey("stride"), ...]) + def go(expr: str, keypath: pytree.KeyPath): + if keypath == (): + return expr + + if ( + len(keypath) >= 2 + and isinstance(keypath[0], CallMethodKey) + and isinstance(keypath[1], pytree.SequenceKey) + ): + return go( + f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:] + ) + elif isinstance(keypath[0], CallMethodKey): + return go(f"{expr}.{keypath[0].name}()", keypath[1:]) + elif isinstance(keypath[0], pytree.SequenceKey): + return ( + go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:]) + if V.graph.cpp_wrapper + else go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + ) + elif isinstance(keypath[0], DivideByKey): + # TODO: need to assert divisibility + # TODO: this is invalid C++ codegen + return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:]) + else: + raise AssertionError(f"unrecognized keypath {keypath}") + + # `go_outer` manages the top-level logic for generating the final expression. + # It handles special cases for C++ code generation and adjusts + # the keypath based on the context (e.g., single vs. multiple outputs). + def go_outer(): # type: ignore[no-untyped-def] + if V.graph.cpp_wrapper: + # Special handling for the top level buffer access, + # because self.get_name() is actually never bound; the + # individual output arguments are bound by + # generate_c_shim_fallback_kernel + if len(outputs) == 1: + out = outputs[0] + # When fallback kernel returns a list consisting of a single tensor, + # the output is represented as a MultiOutput with non empty indices. + # In this case, we strip the first key path away. + return go( + outputs[0].get_name(), + keypath[1:] + if isinstance(out, ir.MultiOutput) and len(out.indices) != 0 + else keypath, + ) + else: + assert isinstance(keypath[0], pytree.SequenceKey) + return go(outputs[keypath[0].idx].get_name(), keypath[1:]) + else: + return go(output_name, keypath) + + self.writeline( + f"{self.codegen_unbacked_symbol_decl(s)} = {go_outer()}{self.ending}" + ) + + def codegen_subgraph_by_inlining(self, subgraph, outer_inputs, outer_outputs): + # TODO (desertfire) - This function is the old way of supporting + # subgraph codegen by inlining subgraphs in the output code. For python + # wrapper, we have moved to lifting subgraphs as functions, supported by + # `codegen_subgraph` function. + # + # However this does not work with cpp wrapper. With cpp wrapper, we make + # two passes and the kernels are shared from the first pass to the next. + # Therefore, both the Python and CppWrapper need to share the some + # codegen infra. For now, CppWrapperCpu has not been updated to lift the + # subgraph as functions. Therefore for cpp_wrapper first pass with + # PythonWrapper, we still fallback to the old way of inlining subgraphs + # in the output code. Once we update CppWrapperCpu, we can remove this + # function. + def _codegen_subgraph_prefix(): + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + for inner_input, outer_input in zip( + subgraph.graph.graph_inputs, outer_inputs + ): + self.writeline( + f"{self.declare}{inner_input} = {outer_input}{self.ending}" + ) + + def _codegen_subgraph_suffix(): + assert len(subgraph.graph.graph_outputs) == len(outer_outputs) + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + self.writeline( + f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" + ) + + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + _codegen_subgraph_prefix() + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + _codegen_subgraph_suffix() + finally: + self.pop_codegened_graph() + + def codegen_partition_call( + self, + partition_id: int, + partition_signatures: ir.GraphPartitionSignature, + ): + """Generate code to call a graph partition""" + input_deallocation = partition_signatures.input_deallocation + output_nodes = partition_signatures.output_nodes + + input_names = list(input_deallocation.keys()) + [ + symbol_input.name for symbol_input in partition_signatures.symbol_inputs + ] + + inputs = ", ".join(input_names) + ("," if len(input_names) == 1 else "") + + output_names = [node.get_name() for node in output_nodes] + outputs = ", ".join(output_names) + ("," if len(output_nodes) == 1 else "") + + # Create a list of inputs for the subgraph call + self.writeline(f"partition{partition_id}_args = [{inputs}]") + + names_to_del = [ + name for name, deallocate in input_deallocation.items() if deallocate + ] + if names_to_del: + self.writeline(f"del {', '.join(names_to_del)}") + + # Call the subgraph launcher function + self.writeline( + f"({outputs}) = self.partitions[{partition_id}](partition{partition_id}_args)" + ) + self.writeline(f"del partition{partition_id}_args") + + def set_all_partition_names(self, num_partitions: int): + self.all_partition_names = [f"partition_{idx}" for idx in range(num_partitions)] + + def codegen_subgraph_call_with_flattened_outputs( + self, subgraph, outer_inputs, outer_flattened_outputs + ): + # Get the input and output names of the subgraph + outer_output_names = ", ".join(outer_flattened_outputs) + ( + "," if len(outer_flattened_outputs) == 1 else "" + ) + outer_input_names = ", ".join(outer_inputs) + ( + "," if len(outer_inputs) == 1 else "" + ) + + self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]") + + # Call the subgraph launcher function + self.writeline( + f"({outer_output_names}) = {subgraph.graph.name}({subgraph.graph.name}_args)" + ) + + def codegen_subgraph_call(self, subgraph, outer_inputs, outer_buffer_name): + # Get the input and output names of the subgraph + outer_input_names = ", ".join(outer_inputs) + ( + "," if len(outer_inputs) == 1 else "" + ) + + self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]") + + # Since the buffers are already put into the args list, we can free the + # buffers here. + V.graph.scheduler.free_buffers() + + # Call the subgraph launcher function + self.writeline( + f"{outer_buffer_name} = {subgraph.graph.name}({subgraph.graph.name}_args)" + ) + + def codegen_subgraph_common(self, subgraph): + self.push_codegened_graph(subgraph.graph) + self.writeline("") + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + + parent_graph = V.graph + subgraph.graph.cpp_wrapper = parent_graph.cpp_wrapper + + if subgraph.graph.name not in self.already_codegened_subgraphs: + # If it is already codegened, the parent wrapper already has + # subgraph fn by name subgraph.graph.name + with V.set_graph_handler(subgraph.graph): + # do not graph partition for subgraph + with config.patch("graph_partition", False): + # Call the codegen of subgraph recursively + subgraph_code, _ = subgraph.graph.codegen() + self.already_codegened_subgraphs.add(subgraph.graph.name) + self.define_subgraph_launcher_fn(subgraph_code.value) + + def codegen_subgraph_with_flattened_outputs( + self, subgraph, outer_inputs, outer_flattened_outputs + ): + self.codegen_subgraph_common(subgraph) + self.codegen_subgraph_call_with_flattened_outputs( + subgraph, outer_inputs, outer_flattened_outputs + ) + + def codegen_subgraph(self, subgraph, outer_inputs, outer_buffer_name): + # Codegen subgraph by recursively calling the codegen for the subgraph. + # This lifts the subgraph as a function in the output code. + self.codegen_subgraph_common(subgraph) + self.codegen_subgraph_call(subgraph, outer_inputs, outer_buffer_name) + + def codegen_invoke_subgraph(self, invoke_subgraph): + name = invoke_subgraph.get_name() + + self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}") + outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs] + + if V.graph.aot_mode: + outer_outputs = [ + f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs)) + ] + self.codegen_subgraph_by_inlining( + invoke_subgraph.subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name) + + def codegen_conditional(self, conditional): + name = conditional.get_name() + + outer_inputs = [buf.codegen_reference() for buf in conditional.operands] + + predicate = conditional.predicate.codegen_reference() + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # move the Tensor predicate to host + predicate = f"{predicate}.item()" + + self.writeline(f"{name} = [None] * {len(conditional.outputs)}") + self.writeline(f"if {predicate}:") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + if V.graph.aot_mode: + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.codegen_subgraph_by_inlining( + conditional.true_subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, name) + + self.writeline(ExitSubgraphLine(self)) + self.writeline("else:") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + if V.graph.aot_mode: + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.codegen_subgraph_by_inlining( + conditional.false_subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name) + self.writeline(ExitSubgraphLine(self)) + + def codegen_while_loop(self, while_loop): + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + + self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}") + for i, inp in enumerate(outer_carried_inputs): + # set the initial state before the loop + self.writeline(f"{name}[{i}] = {inp}") + + cond_outer_inputs = [ + *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))], + *outer_additional_inputs, + ] + cond_outer_outputs = [f"{name}_cond_result"] + body_outer_inputs = list( + cond_outer_inputs + ) # same inputs for cond_fn and body_fn + # Carry over the state from body_fn. Note: We only carry over + # the carried_inputs part of the inputs, the additional ones + # are passed in as they're before. + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + + self.writeline("while True:") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + + if V.graph.aot_mode: + self.codegen_subgraph_by_inlining( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + else: + self.codegen_subgraph_with_flattened_outputs( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + self.writeline( + f"if not {cond_outer_outputs[0]}: break" + ) # condition doesn't hold + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + if V.graph.aot_mode: + self.codegen_subgraph_by_inlining( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + else: + self.codegen_subgraph_with_flattened_outputs( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + + @staticmethod + def statically_known_int_or_none(x): + try: + if getattr(x, "free_symbols", None): + # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but + # the actual codegen will still generate the full expression here. + return None + if isinstance(x, int): + return x + val = V.graph._shape_env._maybe_evaluate_static(x) + if val is None: + return val + return int(val) # type: ignore[call-overload] + except Exception: + return None + + @staticmethod + def statically_known_list_of_ints_or_none(lst): + result = [] + for x in lst: + num = PythonWrapperCodegen.statically_known_int_or_none(x) + if num is None: + return None + result.append(num) + return result + + @staticmethod + def is_statically_known_list_of_ints(lst): + return ( + PythonWrapperCodegen.statically_known_list_of_ints_or_none(lst) is not None + ) + + @staticmethod + def static_shape_for_buffer_or_none(buffer): + return PythonWrapperCodegen.statically_known_list_of_ints_or_none( + buffer.get_size() + ) + + @staticmethod + def can_prove_buffer_has_static_shape(buffer): + return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None + + +class SubgraphPythonWrapperCodegen(PythonWrapperCodegen): + """ + A wrapper codegen that generates code for a subgraph. For most of the + methods, we rely on the implementation in the PythonWrapperCodegen. But we + override a few functions to produce cleaner code (like avoiding writing + imports twice in the output code) + """ + + def __init__( + self, + subgraph_name: str, + parent_wrapper: PythonWrapperCodegen, + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # It is necessary to set the subgraph_name before calling super __init__ + # because __init__ calls set_launcher_fn_name + self.subgraph_name = subgraph_name + self.parent_wrapper = parent_wrapper + self.partition_signatures = partition_signatures + + super().__init__() + + def set_launcher_fn_name(self) -> None: + # This sets up the name of the function containing the launcher code of + # the subgraph. + self.launcher_fn_name = self.subgraph_name + + def write_header(self) -> None: + pass + + def add_benchmark_harness(self, output): + pass + + def benchmark_compiled_module(self, output): + pass + + def write_async_compile_wait(self): + pass + + def next_kernel_suffix(self) -> str: + # Ensures that subgraphs kernels do not clash with each other + return self.parent_wrapper.next_kernel_suffix() + + def generate_after_suffix(self, result: IndentedBuffer) -> None: + return + + def write_launcher_fn_call_get_indent(self) -> int: + self.prefix.splice( + f""" + def {self.launcher_fn_name}(args): + """ + ) + prefix_indent = 1 + return prefix_indent + + def get_wrapper_call_indent(self) -> int: + return 1 + + def get_graph_inputs( + self, + ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]: + if signature := self.partition_signatures: + inputs = signature.input_nodes | { + str(s): s for s in signature.symbol_inputs + } + else: + inputs = V.graph.graph_inputs + return inputs + + def get_graph_input_names(self) -> list[str]: + if signature := self.partition_signatures: + names = list(signature.input_nodes.keys()) + [ + symbol_input.name for symbol_input in signature.symbol_inputs + ] + else: + names = V.graph.graph_input_names + return names + + def get_graph_outputs(self) -> list[IRNode]: + if signature := self.partition_signatures: + outputs = signature.output_nodes + else: + outputs = V.graph.graph_outputs + return outputs + + def codegen_allocation(self, buffer: ir.Buffer): + name = buffer.get_name() + if (signature := self.partition_signatures) and name in signature.input_nodes: + # skip allocation if buffer is a subgraph input. + # This allows reusing an input buffer in graph partition, + # although this is not allowed in general. + return + + super().codegen_allocation(buffer) + + @cache_on_self + def write_triton_header_once(self) -> None: + # TODO: Uncomment in future. This will be needed to support subgraph + # codegen for cpp wrapper. + # if config.triton.autotune_at_compile_time: + # import_str = self.triton_header_str() + # self.kernel_autotune_calls.splice(import_str) + self.parent_wrapper.write_triton_header_once() + + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + # TODO: Uncomment in future. This will be needed to support subgraph + # codegen for cpp wrapper. + # if config.triton.autotune_at_compile_time: + # self.kernel_autotune_calls.writeline( + # V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + # ) + self.parent_wrapper.write_get_raw_stream_header_once() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py new file mode 100644 index 0000000000000000000000000000000000000000..12cf3859c2f4a7fea68943e8e6481970d1234cb8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/wrapper_fxir.py @@ -0,0 +1,693 @@ +import dataclasses +import functools +import logging +import operator +import textwrap +from collections import Counter +from typing import Any, Callable, Optional, Union + +import sympy + +import torch +from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + tracing_triton_hopifier_singleton, + triton_kernel_wrapper_mutation, +) +from torch._inductor.codecache import PyCodeCache +from torch._inductor.runtime.triton_heuristics import CachingAutotuner +from torch._inductor.select_algorithm import extern_kernels # noqa: F401 +from torch._inductor.utils import sympy_product +from torch._inductor.virtualized import V +from torch._library.triton import wrap_triton +from torch.fx import GraphModule +from torch.utils import _pytree as pytree +from torch.utils._sympy.functions import FloorDiv + +from .. import config, ir +from ..utils import convert_shape_to_symint, convert_to_symint, LineContext +from .common import ( + CodegenSymbol, + FileBackedGraphModule, + WorkspaceArg, + WorkspaceZeroMode, +) +from .wrapper import ( + AllocateLine, + BufferLike, + CommBufferAllocateLine, + CommBufferFreeLine, + CommentLine, + EnterDeviceContextManagerLine, + EnterSubgraphLine, + ExitDeviceContextManagerLine, + ExitSubgraphLine, + ExternKernelAllocLine, + ExternKernelOutLine, + FreeIfNotReusedLine, + FreeLine, + KernelCallLine, + KernelDefinitionLine, + Line, + MultiOutputLine, + NullLine, + PythonWrapperCodegen, + ReinterpretLine, + ReuseLine, + SymbolicCallArg, + SymbolicCallArgLine, + WrapperLine, +) + + +aten = torch.ops.aten +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class SymbolBuffer(CodegenSymbol): + """ + Represents a sympy.Symbol graph input. + """ + + symbol: sympy.Symbol + + def get_name(self) -> str: + return str(self.symbol) + + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + return self.symbol + + +CodegenBuffer = Union[BufferLike, SymbolBuffer] + + +@dataclasses.dataclass +class TritonKernel: + """ + Stores metadata about Triton kernels for use in FX. + """ + + tuner: CachingAutotuner + wrapped: TraceableTritonKernelWrapper + + +class WrapperFxCodegen(PythonWrapperCodegen): + """ + Backend to generate wrapper code as an FX IR graph. + """ + + supports_caching = False + + def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]: + self.run_wrapper_ir_passes(is_inference) + + prologue = "\n".join( + [ + self.imports.getvalue(), + self.header.getvalue(), + ] + ) + gm = FxConverter(lines=self.lines, prologue=prologue).generate() + compiled_fn = self.compile_graph(gm) + + return FileBackedGraphModule(gm, compiled_fn), None + + def compile_graph(self, gm: GraphModule) -> Callable[..., Any]: + """ + Converts the graph module into a runnable function. The default implementation + is simply an interpreter calling kernels in eager mode. Derived backends can + override this to do further compilation. + """ + return gm.forward + + @classmethod + def create( + cls, + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ) -> "WrapperFxCodegen": + if is_subgraph: + raise NotImplementedError( + "Subgraphs are not yet supported by FX conversion" + ) + + # For derived backends, this could be a subclass. + return cls() + + +@dataclasses.dataclass +class FxConverter: + """ + Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the + input and output code are stored as attributes. + """ + + lines: list[Line] + prologue: str = "" + + def __post_init__(self) -> None: + graph = torch.fx.Graph() + self.gm = GraphModule({}, graph) # Wrapper FX IR. + self.buffer_to_node: dict[ + Optional[str], torch.fx.Node + ] = {} # Symbol table for codegen. + self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels. + self._unique_symbol_ids: Counter[str] = Counter() + + def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner: + """ + Imports a kernel from source, possibly autotuning block parameters. + """ + module_code = "\n".join([self.prologue, code]) + mod = PyCodeCache.load(module_code) + kernel = getattr(mod, kernel_name) + + if not isinstance(kernel, CachingAutotuner): + raise NotImplementedError( + textwrap.dedent(f""" + Unsupported type for kernel {kernel_name}: {type(kernel)}. + FX conversion only supports Triton kernels. + """) + ) + + return kernel + + def _fake_tensor( + self, + size: tuple[Any, ...], + stride: tuple[Any, ...], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + with V.fake_mode: + return torch.empty_strided( + convert_shape_to_symint(size), + convert_shape_to_symint(stride), + dtype=dtype, + device=device, + ) + + def _create_meta_from_buffer( + self, node: torch.fx.Node, buffer: CodegenBuffer + ) -> None: + name = buffer.get_name() + assert name + node.name = name + node.meta["val"] = buffer.get_example() + + def _create_as_strided( + self, + input_node: torch.fx.Node, + size: tuple[Any, ...], + stride: tuple[Any, ...], + offset: Union[int, sympy.Expr], + ) -> torch.fx.Node: + return self.gm.graph.call_function( + torch.as_strided, + args=( + input_node, + convert_shape_to_symint(size), + convert_shape_to_symint(stride), + convert_to_symint(offset), + ), + ) + + def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None: + """ + Updates the symbol table to record that an Inductor buffer maps to the result of + an FX node. + """ + assert node not in self.buffer_to_node + self.buffer_to_node[buffer.get_name()] = node + + def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None: + """ + Removes the buffer from the symbol table. + """ + name = buffer.get_name() + del self.buffer_to_node[name] + + def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]: + """ + Maps call args back to FX nodes. + """ + return tuple( + self.buffer_to_node[arg] + if isinstance(arg, str) + else arg.inner_expr + if isinstance(arg, SymbolicCallArg) + else arg + for arg in args + ) + + def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer: + """ + Extract buffer data from an IR node. + """ + if isinstance(node, (ir.Buffer, WorkspaceArg)): + return node + elif isinstance(node, (ir.BaseView, ir.MutableBox)): + return self._get_buffer(node.data) + elif isinstance(node, sympy.Symbol): + return SymbolBuffer(node) + else: + raise NotImplementedError(f"Unable to extract buffer from node: {node}") + + def _generate_graph_inputs(self) -> None: + """ + Converts graph inputs to FX placeholders. + """ + for name, ir_node in V.graph.graph_inputs.items(): + # Introduce a new symbol for constant inputs. + buffer = ( + SymbolBuffer(sympy.Symbol(name, is_integer=True)) + if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float)) + else self._get_buffer(ir_node) + ) + node = self.gm.graph.placeholder(buffer.get_name()) + self._create_meta_from_buffer(node, buffer) + self._record_allocation(buffer, node) + + def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]: + """ + Generates FX IR for transformations on a buffer, such as ReinterpretView. + Does nothing if no such transformations are present. + """ + + def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]: + if isinstance(node, (ir.Buffer, WorkspaceArg)): + return node + elif isinstance(node, ir.NoneAsConstantBuffer): + return None + elif isinstance(node, ir.StorageBox): + return generate_to_buffer(node.data) + elif isinstance(node, ir.ReinterpretView): + # We need to introduce a new symbol if the output is a ReinterpretView. + # Use a WorkspaceArg for this. + buffer = self._get_buffer(node.data) + assert isinstance(buffer, (ir.Buffer, WorkspaceArg)) + unique_name = self.gm.graph._graph_namespace.create_name( + f"{buffer.get_name()}_view", None + ) + device = buffer.get_device() + assert device + reused_as = WorkspaceArg( + count=buffer.get_size(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + device=device, + outer_name=unique_name, + dtype=buffer.get_dtype(), + ) + + # Generate FX IR for the view. + self._generate_reinterpret_helper(buffer, reused_as, node.layout) + + return reused_as + else: + raise NotImplementedError(f"Unrecognized buffer/view node: {node}") + + buffer = generate_to_buffer(node) + return self.buffer_to_node[buffer.get_name()] if buffer is not None else None + + def _generate_output(self) -> None: + """ + Generate FX IR for graph outputs. + """ + output_nodes = [ + self._generate_buffer(node) + for idx, node in enumerate(V.graph.graph_outputs) + ] + + # Single return elements don't use a tuple. + output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes + + self.gm.graph.output(output_value) + + def generate(self) -> torch.fx.GraphModule: + """ + Main entrypoint for FX codegen. + """ + self._generate_graph_inputs() + + # Generate FX IR from Wrapper IR lines. + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen_fx(self)(line) + elif isinstance(line, LineContext): + # Ignore line context in FX IR. + pass + else: + raise NotImplementedError( + textwrap.dedent( + f""" + Found line of unrecognized type '{type(line)}': + '{line}' + + FX conversion only supports Wrapper IR lines. + """ + ) + ) + + self._generate_output() + self.gm.recompile() + return self.gm + + def _generate_allocate(self, line: WrapperLine) -> None: + assert isinstance(line, AllocateLine) + buffer = line.node + name = buffer.get_name() + assert name not in V.graph.removed_buffers + + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = convert_shape_to_symint(buffer.get_size()) + stride = convert_shape_to_symint(buffer.get_stride()) + + node = self.gm.graph.call_function( + torch.empty_strided, + args=(shape, stride), + kwargs={"dtype": dtype, "device": device}, + ) + assert name + node.name = name + self._create_meta_from_buffer(node, buffer) + self._record_allocation(buffer, node) + + def _generate_comment(self, line: WrapperLine) -> None: + assert isinstance(line, CommentLine) + # We ignore comments in FX IR. + + def _generate_enter_device_context_manager(self, line: WrapperLine) -> None: + assert isinstance(line, EnterDeviceContextManagerLine) + # We ignore the device context in FX IR. + + def _generate_exit_device_context_manager(self, line: WrapperLine) -> None: + assert isinstance(line, ExitDeviceContextManagerLine) + # We ignore the device context in FX IR. + + def _generate_enter_subgraph(self, line: WrapperLine) -> None: + assert isinstance(line, EnterSubgraphLine) + raise NotImplementedError("Subgraphs are not yet supported by FX conversion") + + def _generate_exit_subgraph(self, line: WrapperLine) -> None: + assert isinstance(line, ExitSubgraphLine) + raise NotImplementedError("Subgraphs are not yet supported by FX conversion") + + def _generate_free(self, line: WrapperLine) -> None: + assert isinstance(line, FreeLine) + + buf = line.node + + # No need to free placeholders. + if self.buffer_to_node[buf.get_name()].op == "placeholder": + return + + self._free(buf) + + def _generate_free_if_not_reused(self, line: WrapperLine) -> None: + assert isinstance(line, FreeIfNotReusedLine) + buf = line.node + assert buf.get_name() not in V.graph.removed_buffers + if not line.is_reused: + self._free(buf) + + def _generate_line_context(self, line: WrapperLine) -> None: + assert isinstance(line, LineContext) + # We ignore line context in FX IR. + + def _generate_reinterpret(self, line: WrapperLine) -> None: + assert isinstance(line, ReinterpretLine) + self._generate_reinterpret_helper(line.node, line.reused_as, line.layout) + + def _generate_reinterpret_helper( + self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout + ) -> None: + input_node = self.buffer_to_node[input_buffer.get_name()] + + # Look up output metadata. + name = result_buffer.get_name() + assert name + size = tuple(layout.size) + stride = tuple(layout.stride) + if isinstance(layout, ir.NonOwningLayout): + # Look up the view's layout. + view = layout.view + assert isinstance(view, ir.ReinterpretView), ( + f"unexpected type: {type(view)}" + ) + layout = view.layout + offset = input_buffer.get_offset() + layout.offset + + # Map ReinterpretView to as_strided. + result_node = self._create_as_strided(input_node, size, stride, offset) + result_node.name = name + result_node.meta["val"] = layout.get_example() + self._record_allocation(result_buffer, result_node) + + def _generate_reuse(self, line: WrapperLine) -> None: + assert isinstance(line, ReuseLine) + old = line.node + new = line.reused_as + assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new)) + assert old.get_dtype() == new.get_dtype() + + old_node = self.buffer_to_node[old.get_name()] + result_node = old_node + + # Change shape and stride. + size = tuple(new.get_size()) + stride = tuple(new.get_stride()) + offset = new.get_offset() + if ( + tuple(old.get_size()) != size + or tuple(old.get_stride()) != stride + or old.get_offset() != offset + ): + result_node = self._create_as_strided(old_node, size, stride, offset) + self._create_meta_from_buffer(result_node, new) + + self._record_allocation(new, result_node) + + # Free the old buffer, if we allocated a new tensor. + if ( + old.get_name() not in V.graph.get_output_names() + and line.delete_old + and result_node is not old_node + ): + self._free(old) + + def _generate_multi_output(self, line: WrapperLine) -> None: + assert isinstance(line, MultiOutputLine) + + # Extract the index for tuple access. + inds = line.indices[0][1:] + assert len(inds) == 1, f"Cannot convert {inds} to an index." + idx = inds[0] + + arg_node = self.buffer_to_node[line.arg_name] + node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx)) + node.meta["val"] = arg_node.meta["val"][idx] + node.name = line.result_name + self.buffer_to_node[line.result_name] = node + + def _generate_null(self, line: WrapperLine) -> None: + assert isinstance(line, NullLine) + # Does nothing. + + def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None: + assert isinstance(line, CommBufferAllocateLine) + raise NotImplementedError("Comm buffer allocation is not yet supported") + + def _generate_comm_buffer_free(self, line: WrapperLine) -> None: + assert isinstance(line, CommBufferFreeLine) + self._free(line.node) + + def _generate_triton_call(self, line: WrapperLine) -> None: + assert isinstance(line, KernelCallLine) + + # Collect all kwargs, including autotuned block sizes. + call_args = self._lookup_args(line.call_args) + kernel = self.kernels[line.kernel_name] + tuner = kernel.tuner + + # Optionally autotune the kernels. + # The FX backend currently only supports compile-time tuning. + kernel_name = tuner.fn.__name__ + if config.triton.autotune_at_compile_time: + from triton.runtime import driver + + log.info("Autotuning Triton kernel %s at compile time.", kernel_name) + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + + def node_to_tuning_arg(arg: Any) -> Any: + """ + Create real tensors for autotuning arguments, substituting size hints + for dynamic shapes. + """ + to_size_hint = functools.partial( + pytree.tree_map, V.graph.sizevars.size_hint + ) + if not isinstance(arg, torch.fx.Node): + return to_size_hint(arg) + + fake = arg.meta["val"] + return torch.empty_strided( + to_size_hint(fake.shape), + to_size_hint(fake.stride()), + device=device, + ).zero_() + + arg_values = [node_to_tuning_arg(arg) for arg in call_args] + tuner.run(*arg_values, stream=stream) + else: + log.info( + "Skipping autotuning for kernel %s. Set config.triton.autotune_at_compile_time = True to enable.", + kernel_name, + ) + + kernel_config = tuner.compile_results[0].config + call_args, grid = tuner._interpret_args_grid(call_args, kernel_config) + call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args)) + call_kwargs.update(kernel_config.kwargs) + + def replace_floor_div(expr: sympy.Expr) -> sympy.Expr: + """ + Converts floor(x / c) to x // c. + """ + if isinstance(expr, sympy.core.mul.Mul) and isinstance( + expr.args[0], sympy.Rational + ): + # Only the first argument of a Mul can be a Rational. + frac = expr.args[0] + numerator = sympy_product(expr.args[1:]) * frac.numerator + denominator = frac.denominator + + # Sanity check the results. + new_expr = numerator / denominator + assert V.graph.sizevars.statically_known_equals(new_expr, expr), ( + f"Unsound replacement: '{new_expr}' != '{expr}'" + ) + + return FloorDiv(numerator, denominator) + else: + return sympy.floor(expr) + + def expr_to_symint(expr: Union[int, sympy.Expr]) -> Union[int, sympy.Expr]: + return ( + convert_to_symint(expr.replace(sympy.floor, replace_floor_div)) + if isinstance(expr, sympy.Expr) + else expr + ) + + # Convert sympy expressions to symints. + # Use FloorDiv over sympy.floor, so we can get nicer Python code from FX. + wrapper_grid = [tuple(expr_to_symint(dim) for dim in grid)] + call_kwargs = {name: expr_to_symint(val) for name, val in call_kwargs.items()} + + # Store non-graphable kwargs in the side table. + ( + call_kwargs, + constant_args_idx, + ) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs) + + self.gm.graph.call_function( + triton_kernel_wrapper_mutation, + kwargs={ + "kernel_idx": kernel.wrapped.kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": wrapper_grid, + "tma_descriptor_metadata": {}, + "kwargs": call_kwargs, + }, + ) + + def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None: + assert isinstance(line, ExternKernelAllocLine) + node = line.node + self._generate_extern_kernel_common(node, node) + + def _generate_extern_kernel_out( + self, + line: WrapperLine, + ) -> None: + assert isinstance(line, ExternKernelOutLine) + node = line.node + out_node = node.output_view if node.output_view else node + self._generate_extern_kernel_common(node, out_node) + + def _generate_extern_kernel_common( + self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode + ) -> None: + """ + Generates FX IR from either ExternKernelAlloc or ExternKernelOut. + """ + + # Get FX nodes corresponding to the call args. + tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs) + args = tensor_nodes + tuple(kernel.constant_args) + + # Get the result buffer. + # Some kernels write to a pre-existing output tensor via the "out" kwarg. + kwargs = kernel.kwargs.copy() + result_buffer: Optional[str] = None + if isinstance(kernel, ir.ExternKernelOut): + kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()] + elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)): + result_buffer = kernel.get_name() + elif isinstance(kernel.layout, ir.NoneLayout): + pass + else: + raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}") + + # Look up the kernel function from its name. + kernel_name = kernel.get_kernel_name() + module_name, kernel_name = kernel_name.split(".", 1) + op = globals()[module_name] # E.g. extern_kernels, aten, etc. + for subname in kernel_name.split("."): + op = getattr(op, subname) # E.g. extern_kernels.addmm + + fx_node = self.gm.graph.call_function(op, args=args, kwargs=kwargs) + + # Assign the result to the given name. + if result_buffer: + assert "out" not in kwargs, ( + f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one." + ) + fx_node.name = result_buffer + self.buffer_to_node[result_buffer] = fx_node + + arg_tensors = [ + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in args + ] + + # Run the operation to propagate metadata. + fx_node.meta["val"] = op(*arg_tensors, **kwargs) + + def _generate_kernel_call(self, line: WrapperLine) -> None: + assert isinstance(line, KernelCallLine) + if not line.triton: + raise NotImplementedError("FX conversion only supports Triton kernels.") + + self._generate_triton_call(line) + + def _generate_kernel_definition(self, line: WrapperLine) -> None: + assert isinstance(line, KernelDefinitionLine) + + # Generate code for the kernel. + kernel_code = PythonWrapperCodegen._format_kernel_definition( + line.kernel_name, line.kernel_body, metadata=line.metadata + ) + + # Import the module and store the JIT kernel. + tuner = self._import_kernel(kernel_code, line.kernel_name) + wrapped = wrap_triton(tuner.fn) + self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped) + + def _generate_symbolic_call_arg(self, line: WrapperLine) -> None: + assert isinstance(line, SymbolicCallArgLine) + # No need for an FX node, as we will pass the arg to kernels via a SymInt. diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..632cfd29f174fada935cd14a521ff3fd3471a233 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Optional + +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) + + +class XPUDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _xpu_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.xpu.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.xpu.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.xpu._DeviceGuard({device_idx})" + + def cpp_device_guard(self) -> str: + return "at::DeviceGuard" + + def cpp_aoti_device_guard(self) -> str: + return "AOTIXpuGuard" + + def cpp_stream_guard(self) -> str: + return "at::xpu::XPUStreamGuard" + + def cpp_aoti_stream_guard(self) -> str: + return "AOTIXpuStreamGuard" + + def cpp_getStreamFromExternal(self) -> str: + return "at::xpu::getStreamFromExternal" + + def kernel_header(self) -> str: + source_codes = """ + #include + """ + return source_codes + + def kernel_driver(self) -> str: + return "" + + def cpp_stream_type(self) -> str: + return "sycl::queue*" + + def aoti_get_stream(self) -> str: + return "aoti_torch_get_current_xpu_stream" + + def cpp_kernel_type(self) -> str: + return "std::unique_ptr" + + def cpp_device_ptr(self) -> str: + return "void *" + + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: + return None + + +register_device_op_overrides("xpu", XPUDeviceOpOverrides()) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca0f1e5a4fb2a6aeb1224285d76e78a05a0f499 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/__main__.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +import argparse +import base64 +import functools +import importlib +import logging +import os +import sys +from typing import TypeVar + +from torch._inductor.async_compile import pre_fork_setup +from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.subproc_pool import ( + SubprocKind, + SubprocMain, + SubprocPickler, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path + + +_T = TypeVar("_T") + + +log = logging.getLogger(__name__) + +_set_triton_ptxas_path() + +try: + import triton + + assert triton is not None # preload in parent +except ImportError: + pass + + +def _lookup_and_create_type(base: type[_T], qname: str) -> _T: + """ + Given a base type and qualified name: import & lookup that name, check + that it's of the given type and then instantiate it. + """ + pkg, name = qname.rsplit(".", 1) + mod = importlib.import_module(pkg) + ty = getattr(mod, name) + if not issubclass(ty, base): + raise TypeError(f"Type {ty} is not a subtype of {base}") + return ty() + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument( + "--pickler", type=functools.partial(_lookup_and_create_type, SubprocPickler) + ) + parser.add_argument("--kind", type=SubprocKind) + parser.add_argument("--workers", type=int) + parser.add_argument("--parent", type=int) + parser.add_argument("--read-fd", type=int) + parser.add_argument("--write-fd", type=int) + parser.add_argument("--torch-key", type=str) + args = parser.parse_args() + if os.getppid() != args.parent: + sys.exit(0) + read_fd = os.fdopen(args.read_fd, "rb") + write_fd = os.fdopen(args.write_fd, "wb") + + pre_fork_setup() + + torch_key.set(base64.b64decode(args.torch_key.encode("utf-8"))) # type: ignore[attr-defined] + + _async_compile_initializer(args.parent) + + SubprocMain(args.pickler, args.kind, args.workers, read_fd, write_fd).main() + except Exception: + log.exception("Uncaught exception in compile_worker subprocess") + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..bf213b37e8540fd78e5c21a802af31c9f8456bd0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/subproc_pool.py @@ -0,0 +1,379 @@ +import base64 +import functools +import itertools +import logging +import multiprocessing +import os +import pickle +import struct +import subprocess +import sys +import threading +import traceback +import typing +from concurrent.futures import Future, ProcessPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from enum import Enum +from typing import Any, Callable, IO, Optional, TypeVar +from typing_extensions import Never, ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 +from torch._inductor import config +from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.tracked_process_pool import ( + TrackedProcessPoolExecutor, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.utils import get_ld_library_path + + +log = logging.getLogger(__name__) + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def _pack_msg(job_id: int, length: int) -> bytes: + return struct.pack("nn", job_id, length) + + +def _unpack_msg(data: bytes) -> tuple[int, int]: + if not data: + return -1, -1 + return struct.unpack("nn", data) + + +msg_bytes = len(_pack_msg(0, 0)) + + +def _send_msg(write_pipe: IO[bytes], job_id: int, job_data: bytes = b"") -> None: + length = len(job_data) + write_pipe.write(_pack_msg(job_id, length)) + if length > 0: + write_pipe.write(job_data) + write_pipe.flush() + + +def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]: + job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) + data = read_pipe.read(length) if length > 0 else b"" + return job_id, data + + +class _SubprocExceptionInfo: + """ + Carries exception info from subprocesses across the wire. traceback + objects are not pickleable, so we store the trace as a string and + use it for the message in the exception thrown in the main process. + """ + + def __init__(self, details: str) -> None: + self.details = details + + +class SubprocException(Exception): + """ + Thrown when a job in a subprocess raises an Exception. + """ + + def __init__(self, details: str) -> None: + super().__init__(f"An exception occurred in a subprocess:\n\n{details}") + + +class SubprocPickler: + """ + Allows a caller to provide a custom pickler for passing data with the + subprocess. + """ + + def dumps(self, obj: object) -> bytes: + return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) + + def loads(self, data: bytes) -> object: + return pickle.loads(data) + + +class SubprocKind(Enum): + FORK = "fork" + SPAWN = "spawn" + + +class SubprocPool: + """ + Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in + a subprocess.Popen() to try to avoid issues with forking/spawning + """ + + def __init__( + self, + nprocs: int, + pickler: Optional[SubprocPickler] = None, + kind: SubprocKind = SubprocKind.FORK, + ) -> None: + entry = os.path.join(os.path.dirname(__file__), "__main__.py") + self.pickler = pickler or SubprocPickler() + self.kind = kind + + subproc_read_fd, write_fd = os.pipe() + read_fd, subproc_write_fd = os.pipe() + self.write_pipe = os.fdopen(write_fd, "wb") + self.read_pipe = os.fdopen(read_fd, "rb") + torch_key_str = base64.b64encode(torch_key()).decode("utf-8") + + cmd = [ + sys.executable, + entry, + f"--pickler={self.pickler.__class__.__module__}.{self.pickler.__class__.__name__}", + f"--kind={self.kind.value}", + f"--workers={nprocs}", + f"--parent={os.getpid()}", + f"--read-fd={str(subproc_read_fd)}", + f"--write-fd={str(subproc_write_fd)}", + f"--torch-key={torch_key_str}", + ] + local = False + if config.worker_suppress_logging: + log.info("Suppressing compile worker output due to config") + local = True + + self.process = subprocess.Popen( + cmd, + env={ + **os.environ, + # We need to set the PYTHONPATH so the subprocess can find torch. + "PYTHONPATH": os.environ.get( + "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) + ), + # We don't want to re-warm the pool when the subprocess imports + # torch._inductor.codecache since the warming process is what + # creates the SubprocPool in the first place. + "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": get_ld_library_path(), + }, + pass_fds=(subproc_read_fd, subproc_write_fd), + stdout=subprocess.DEVNULL if local else None, + stderr=subprocess.DEVNULL if local else None, + ) + self.write_lock = threading.Lock() + self.read_thread = threading.Thread(target=self._read_thread, daemon=True) + + self.futures_lock = threading.Lock() + self.pending_futures: dict[int, Future[Any]] = {} + self.job_id_count = itertools.count() + + self.running = True + + # Start thread last to ensure all member variables are initialized + # before any access. + self.read_thread.start() + + def submit( + self, job_fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_T]: + if args or kwargs: + job_fn = functools.partial(job_fn, *args, **kwargs) + job_data = self.pickler.dumps(job_fn) + future: Future[_T] + with self.futures_lock: + job_id = next(self.job_id_count) + self.pending_futures[job_id] = future = Future() + future.set_running_or_notify_cancel() + with self.write_lock: + if not self.running: + raise RuntimeError("submit() on closed pool") + _send_msg(self.write_pipe, job_id, job_data) + return future + + def _read_thread(self) -> None: + while True: + data = b"" + try: + job_id, data = _recv_msg(self.read_pipe) + except Exception: + # Something went wrong during the read. There's no way we have a + # valid job_id. + log.exception("failure in subproc_pool._recv_msg") + job_id = -1 + + if job_id < 0: + # read_pipe returned None or got exception + if self.running: + log.warning("SubprocPool unclean exit") + self.running = False + self.read_pipe.close() + # Cancel all the pending futures. + self.shutdown() + return + + try: + result = self.pickler.loads(data) + except Exception as e: + # Something went wrong unpickling. We have a job_id so just + # notify that particular future and continue on. + log.exception("unpickle failure in SubprocPool._read_thread") + result = e + + with self.futures_lock: + if not self.running: + return + if isinstance(result, _SubprocExceptionInfo): + # An exception occurred in the submitted job + self.pending_futures[job_id].set_exception( + SubprocException(result.details) + ) + elif isinstance(result, Exception): + # An exception occurred in some of our subprocess machinery. + self.pending_futures[job_id].set_exception(result) + else: + self.pending_futures[job_id].set_result(result) + del self.pending_futures[job_id] + + def shutdown(self) -> None: + try: + with self.write_lock: + if not self.running: + return + self.running = False + _send_msg(self.write_pipe, -1) + self.write_pipe.close() + self.process.wait(300) + except OSError as e: + log.warning("Ignored OSError in pool shutdown: %s", e) + finally: + with self.futures_lock: + for future in self.pending_futures.values(): + if not future.cancel(): + future.set_exception(RuntimeError("SubprocPool closed")) + self.pending_futures.clear() + + +class SubprocMain: + """Communicates with a SubprocPool in the parent process, called by __main__.py""" + + def __init__( + self, + pickler: SubprocPickler, + kind: SubprocKind, + nprocs: int, + read_pipe: IO[bytes], + write_pipe: IO[bytes], + ) -> None: + self.pickler = pickler + self.kind = kind + self.read_pipe = read_pipe + self.write_pipe = write_pipe + self.write_lock = threading.Lock() + self.nprocs = nprocs + self.pool = self._new_pool(nprocs, True) + self.running = True + + def _new_pool(self, nprocs: int, warm: bool) -> ProcessPoolExecutor: + pool = TrackedProcessPoolExecutor( + nprocs, + mp_context=multiprocessing.get_context(self.kind.value), + initializer=functools.partial(_async_compile_initializer, os.getpid()), + ) + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + if warm: + _warm_process_pool(pool, nprocs) + return pool + + def main(self) -> None: + while True: + job_id, data = _recv_msg(self.read_pipe) + if job_id < 0: + return self._shutdown() + self.submit(job_id, data) + + def _shutdown(self) -> None: + with self.write_lock: + self.running = False + try: + _send_msg(self.write_pipe, -1) + self.write_pipe.close() + except BrokenPipeError: + pass # parent process already shutdown + self.read_pipe.close() + self.pool.shutdown() + + def submit(self, job_id: int, data: bytes) -> None: + while self.running: + try: + self._submit_inner(job_id, data) + return + except BrokenProcessPool: + # If any subprocess in the pool crashes, we get a BrokenProcessPool + # exception and the whole pool becomes unusable. Handle crashes by + # recreating the pool and resubmitting. + self.pool = self._new_pool(self.nprocs, False) + + def _submit_inner(self, job_id: int, data: bytes) -> None: + def callback(fut: Future[Any]) -> None: + if not self.running: + return + try: + result = fut.result() + except Exception as e: + log.exception("Error in subprocess") + result = self.pickler.dumps(e) + assert isinstance(result, bytes) + with self.write_lock: + if self.running: + _send_msg(self.write_pipe, job_id, result) + return + + future = self.pool.submit( + functools.partial(SubprocMain.do_job, self.pickler, data) + ) + future.add_done_callback(callback) + + @staticmethod + def do_job(pickler: SubprocPickler, data: bytes) -> bytes: + # do the pickle/unpickle in the sub-subproc + job = typing.cast(Callable[[], object], pickler.loads(data)) + + try: + result = job() + except Exception: + result = _SubprocExceptionInfo(traceback.format_exc()) + return pickler.dumps(result) + + +AnyPool = typing.Union[ProcessPoolExecutor, SubprocPool] + + +def _warm_process_pool(pool: ProcessPoolExecutor, n: int) -> None: + # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the + # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread. + + # Examples: + # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup + # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup + + # So we want to start the workers early when it is still cheap, and also to allow the workers to get + # ready before we have work for them. + + # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle. + # But if we waited until then fork time will be long and we will be waiting for the processes to initialize. + + # We force them to start here with some YOLOing of the internal methods. + + if hasattr(pool, "_start_queue_management_thread"): + pool._start_queue_management_thread() + else: + for _ in range(n): + pool._adjust_process_count() + if hasattr(pool, "_start_executor_manager_thread"): + pool._start_executor_manager_thread() + + +class TestException(RuntimeError): + pass + + +def raise_testexc() -> Never: + raise TestException diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..36df56b963d69f955b831b508cccc2dcb08417f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py @@ -0,0 +1,111 @@ +import atexit +import concurrent +import dataclasses +import logging +import threading +from concurrent.futures import Future, ProcessPoolExecutor +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from time import time +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +log = logging.getLogger(__name__) + + +@dataclass +class _QueueStats: + # Mapping from id(future) -> start time + pending: dict[int, float] = dataclasses.field(default_factory=dict) + timing: list[float] = dataclasses.field(default_factory=list) + enqueue_count: int = 0 + dequeue_count: int = 0 + max_queue_depth: int = 0 + pool_count: int = 0 + + +# The queue statistics tracked by TrackedProcessPoolExecutor. Always grab +# _queue_stats_lock before touching. +_queue_stats = _QueueStats() +_queue_stats_lock = threading.Lock() + + +class TrackedProcessPoolExecutor(ProcessPoolExecutor): + def __init__( + self, + max_workers: Optional[int] = None, + mp_context: Optional[BaseContext] = None, + initializer: Optional[Callable[[], object]] = None, + ) -> None: + with _queue_stats_lock: + _queue_stats.pool_count += 1 + super().__init__(max_workers, mp_context, initializer) + + def _record_dequeue(self, f: Future[Any]) -> None: + now = time() + with _queue_stats_lock: + stats = _queue_stats + if (start_time := stats.pending.pop(id(f), None)) is None: + return + stats.dequeue_count += 1 + duration = now - start_time + stats.timing.append(duration) + + def _record_enqueue(self, f: Future[Any]) -> None: + # Monkeypatch the set_running_or_notify_cancel so we can track when the Future moves out of PENDING. + saved_running_or_notify_cancel = f.set_running_or_notify_cancel + + def set_running_or_notify_cancel() -> Any: + self._record_dequeue(f) + return saved_running_or_notify_cancel() + + now = time() + with _queue_stats_lock: + stats = _queue_stats + stats.pending[id(f)] = now + stats.enqueue_count += 1 + stats.max_queue_depth = max(stats.max_queue_depth, len(stats.pending)) + f.set_running_or_notify_cancel = set_running_or_notify_cancel # type: ignore[method-assign] + + if f._state != concurrent.futures._base.PENDING: + self._record_dequeue(f) + + def submit( + self, fn: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_R]: + f = super().submit(fn, *args, **kwargs) + self._record_enqueue(f) + return f + + +@atexit.register +def _queue_stats_report() -> None: + stats = _queue_stats + if stats.pool_count == 0: + return + + timing = stats.timing + timing.sort() + + log.info("AsyncCompile Metrics:") + log.info(" Pools %s", stats.pool_count) + log.info( + " Items %d enqueued / %d dequeued", stats.enqueue_count, stats.dequeue_count + ) + log.info(" Max Queue Depth: %d", stats.max_queue_depth) + n = len(timing) + if n > 0: + log.info(" Longest queue time: %0.2fs", timing[-1]) + log.info(" P50: %0.2fs", timing[n // 2]) + if n >= 20: + log.info(" P95: %0.2fs", timing[n * 95 // 100]) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..864dcf9c9682d9b7b83fc77c26c28716522c7eab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/compile_worker/utils.py @@ -0,0 +1,49 @@ +import os +import signal +from threading import Thread +from time import sleep +from typing import Optional + + +_IN_TOPLEVEL_PROCESS = True + + +def in_toplevel_process() -> bool: + global _IN_TOPLEVEL_PROCESS + return _IN_TOPLEVEL_PROCESS + + +# If this process dies abnormally (e.g. segfault) +# it will not shut down the workers. Instead, +# the workers will have their parent reassigned to the +# init process. This launches a separate thread to +# watch for the worker getting reassigned, +# and cleans it up in this case. +# +# This function cannot be an inner function since otherwise mp_context="spawn" would +# not work for ProcessPoolExecutor since inner functions cannot be pickled. +def _async_compile_initializer(orig_ppid: int) -> None: + def run() -> None: + while True: + sleep(1) + if orig_ppid != os.getppid(): + os.kill(os.getpid(), signal.SIGKILL) + + global _watchdog_thread, _original_parent + _original_parent = orig_ppid + _watchdog_thread = Thread(target=run, daemon=True) + _watchdog_thread.start() + # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam. + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Set a bit to distinguish async_compile subprocesses from the toplevel process. + global _IN_TOPLEVEL_PROCESS + _IN_TOPLEVEL_PROCESS = False + + +_watchdog_thread: Optional[Thread] = None +_original_parent: Optional[int] = None + + +def has_parent_changed() -> bool: + return _original_parent != os.getppid() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..483099b6aca4c20ab87a51db84435fafab717f6f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/b2b_gemm.py @@ -0,0 +1,760 @@ +# mypy: allow-untyped-defs +import functools +from collections import deque + +import torch +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map + +from ..._dynamo.utils import counters +from ..ir import ( + ComputedBuffer, + FixedLayout, + FlexibleLayout, + InputBuffer, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import lowerings +from ..pattern_matcher import ( + Arg, + CallFunction, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, + TritonTemplateCaller, +) +from ..utils import ceildiv + + +B2B_GEMM_PASS = PatternMatcherPass( + pass_name="b2b_gemm_pass", +) + + +@SymbolicGridFn +def b2b_gemm_grid(M, P, meta, *, cdiv): + return (cdiv(M, meta["BLOCK_SIZE_M"]) * cdiv(P, meta["BLOCK_SIZE_P"]), 1, 1) + + +b2b_gemm_left_template = TritonTemplate( + name="b2b_gemm_left", + grid=b2b_gemm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C")}} + + + # B2B_GEMM_LEFT_TRITON_ENTRANCE + + # dynamic shapes + M = {{size("A", 0)}} + N = {{size("A", 1)}} + O = {{size("C", 0)}} + P = {{size("C", 1)}} + + # dynamic strides + stride_am = {{stride("A", 0)}} + stride_an = {{stride("A", 1)}} + stride_bn = {{stride("B", 0)}} + stride_bo = {{stride("B", 1)}} + stride_co = {{stride("C", 0)}} + stride_cp = {{stride("C", 1)}} + + # output block counts + num_m_block = tl.cdiv(M, BLOCK_SIZE_M) + num_p_block = tl.cdiv(P, BLOCK_SIZE_P) + + # internal block counts + num_n_block = tl.cdiv(N, BLOCK_SIZE_N) + num_o_block = tl.cdiv(O, BLOCK_SIZE_O) + + # output block ids + pid = tl.program_id(axis=0) + m_block_id = pid // num_p_block + p_block_id = pid % num_p_block + + # accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32) + + # main loop + offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)) + # (subgraph(A @ B) @ C) + offs_o = tl.arange(0, BLOCK_SIZE_O) + for _ in range(num_o_block): + c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P) + c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp) + c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P + acc_ab = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_O), dtype=tl.float32) + offs_n = tl.arange(0, BLOCK_SIZE_N) + for __ in range(num_n_block): + a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N + b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O) + b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo) + b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O + acc_ab += tl.dot(a, b, out_dtype=tl.float32) + offs_n += BLOCK_SIZE_N + # apply the subgraph + {{ modification( + subgraph_number=0, + output_name="post_subgraph_acc_ab", + inner_mm="acc_ab" + ) | indent_except_first(2) }} + acc += tl.dot(post_subgraph_acc_ab, c, out_dtype=tl.float32) + offs_o += BLOCK_SIZE_O + + # type conversion + acc = acc.to(tl.float16) + + # store preparation + idx_m = offs_m[:, None] + idx_p = offs_p[None, :] + out_mask = (idx_m < M) & (idx_p < P) + + {{store_output(("idx_m", "idx_p"), "acc", "out_mask")}} +""", +) + + +b2b_gemm_right_template = TritonTemplate( + name="b2b_gemm_right", + grid=b2b_gemm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C")}} + + + # B2B_GEMM_RIGHT_TRITON_ENTRANCE + + # dynamic shapes + M = {{size("A", 0)}} + N = {{size("A", 1)}} + O = {{size("C", 0)}} + P = {{size("C", 1)}} + + # dynamic strides + stride_am = {{stride("A", 0)}} + stride_an = {{stride("A", 1)}} + stride_bn = {{stride("B", 0)}} + stride_bo = {{stride("B", 1)}} + stride_co = {{stride("C", 0)}} + stride_cp = {{stride("C", 1)}} + + # output block counts + num_m_block = tl.cdiv(M, BLOCK_SIZE_M) + num_p_block = tl.cdiv(P, BLOCK_SIZE_P) + + # internal block counts + num_n_block = tl.cdiv(N, BLOCK_SIZE_N) + num_o_block = tl.cdiv(O, BLOCK_SIZE_O) + + # output block ids + pid = tl.program_id(axis=0) + m_block_id = pid // num_p_block + p_block_id = pid % num_p_block + + # accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32) + + # main loop (two cases) + offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)) + # (A @ subgraph(B @ C)) + offs_n = tl.arange(0, BLOCK_SIZE_N) + for _ in range(num_n_block): + a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N + acc_bc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_P), dtype=tl.float32) + offs_o = tl.arange(0, BLOCK_SIZE_O) + for __ in range(num_o_block): + b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O) + b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo) + b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O + c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P) + c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp) + c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P + acc_bc += tl.dot(b, c, out_dtype=tl.float32) + offs_o += BLOCK_SIZE_O + # apply the subgraph + {{ modification( + subgraph_number=0, + output_name="post_subgraph_acc_bc", + inner_mm="acc_bc" + ) | indent_except_first(2) }} + acc += tl.dot(a, post_subgraph_acc_bc, out_dtype=tl.float32) + offs_n += BLOCK_SIZE_N + + # type conversion + acc = acc.to(tl.float16) + + # store preparation + idx_m = offs_m[:, None] + idx_p = offs_p[None, :] + out_mask = (idx_m < M) & (idx_p < P) + + {{store_output(("idx_m", "idx_p"), "acc", "out_mask")}} +""", +) + + +# Note: load_ratio_left and load_ratio_right are only calculating numbers +# in the trivial subgraph case; i.e. (A @ (B @ C)) or ((A @ B) @ C) + + +def load_ratio_left( + M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int +) -> float: + """ + compute the ratio of estimated numbers of loads in baseline and b2bgemm + M, N, O, P are matrix sizes + m, n, o, p are block sizes + | | baseline (lower bound) | b2bgemm + | load | M * N + N * O + M * O + O * P | M / m * P / p * O / o * (o * p + N / n * (m * n + n * o)) + | store | M * O + M * P | M * P + b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function + """ + base = M * N + N * O + M * O + O * P + gemm = ( + ceildiv(M, m) + * ceildiv(P, p) + * ceildiv(O, o) + * (o * p + ceildiv(N, n) * (m * n + n * o)) + ) + return base / gemm + + +def load_ratio_right( + M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int +) -> float: + """ + compute the ratio of estimated numbers of loads in baseline and b2bgemm + M, N, O, P are matrix sizes + m, n, o, p are block sizes + | | baseline (lower bound) | b2bgemm + | load | N * O + O * P + M * N + N * P | M / m * P / p * N / n * (m * n + O / o * (n * o + o * p)) + | store | N * P + M * P | M * P + b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function + """ + base = N * O + O * P + M * N + N * P + gemm = ( + ceildiv(M, m) + * ceildiv(P, p) + * ceildiv(N, n) + * (m * n + ceildiv(O, o) * (n * o + o * p)) + ) + return base / gemm + + +# the block sizes are limited by hardware (the shared memory) +# intuitively, the optimization works when the intermediate matrix is large +# and we assign large block sizes to large dimensions +b2b_gemm_configs = [ + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 16, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 32, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 64, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 16, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 32, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 64, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 128, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 128, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, +] + + +def is_b2b_gemm_good_on( + is_left_assoc: bool, + A_node: torch.fx.Node, + B_node: torch.fx.Node, + C_node: torch.fx.Node, +) -> bool: + """ + checks whether the sizes are good for b2b_gemm + """ + # basic checks + if not all(["val" in A_node.meta, "val" in B_node.meta, "val" in C_node.meta]): + return False + fake_tensors = ( + A_node.meta["val"], + B_node.meta["val"], + C_node.meta["val"], + ) # torch._subclasses.fake_tensor.FakeTensor + + A, B, C = fake_tensors + + def check_all_attr_true(objects, attr): + return all(hasattr(obj, attr) and getattr(obj, attr) for obj in objects) + + if not check_all_attr_true(fake_tensors, "is_cuda") and not check_all_attr_true( + fake_tensors, "is_xpu" + ): + return False + if not all([len(A.shape) == 2, len(B.shape) == 2, len(C.shape) == 2]): + return False + if not ((A.shape[1] == B.shape[0]) and (B.shape[1] == C.shape[0])): + return False + # size checks: we only dispatch to B2B-GEMM when the average load ratio is > 1 + M, N = A.shape + O, P = C.shape + ratios = [] + if is_left_assoc: + for config in b2b_gemm_configs: + ratio = load_ratio_left( + M, + N, + O, + P, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_O"], + config["BLOCK_SIZE_P"], + ) + ratios.append(ratio) + else: + for config in b2b_gemm_configs: + ratio = load_ratio_right( + M, + N, + O, + P, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_O"], + config["BLOCK_SIZE_P"], + ) + ratios.append(ratio) + ratios.sort(reverse=True) + average_ratio = 1.0 + for r in ratios[:3]: # top 3 choices + average_ratio *= r + average_ratio = average_ratio ** (1 / 3) + return ( + average_ratio > 1 + ) # even if average_ratio is close to 1, the number of stores is always better + + +def unoptimized_b2b_gemm( + is_left_assoc: bool, + subgraph: Subgraph, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + *, + out: torch.Tensor, +) -> torch.Tensor: + """ + The unoptimized version is used as a fallback when the b2b_gemm kernel is not beneficial. + """ + if is_left_assoc: + torch.mm(subgraph.graph_module(torch.mm(A, B)), C, out=out) + else: + torch.mm(A, subgraph.graph_module(torch.mm(B, C)), out=out) + return out + + +unoptimized_choice = ExternKernelChoice(unoptimized_b2b_gemm) + + +def build_subgraph_buffer( + args: list[TensorBox], + subgraph: Subgraph, +): + """ + This function is adapted from ../kernel/flex_attention.py. + The goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph + subgraph: The Subgraph ir for which to produce the output node + """ + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + if node.op == "placeholder": + env[node] = args[cnt] + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + args, kwargs = tree_map( + lambda x: env[x] if x in env else x, (node.args, node.kwargs) + ) + env[node] = lowerings[node.target](*args, **kwargs) + elif node.op == "output": + + def convert_output_node_to_buffer(output): + if output is None: + return None + output_node = output + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for B2B-GEMM's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + # node.args[0] should be a single element representing the output of the subgraph + return tree_map(convert_output_node_to_buffer, node.args[0]) + + raise ValueError("B2B-GEMM was passed a subgraph with no output node!") + + +def create_placeholder( + name: str, dtype: torch.dtype, device: torch.device +) -> TensorBox: + """ + Creates a placeholder input buffers for producing subgraph_output + """ + input_buffer = InputBuffer(name=name, layout=FixedLayout(device, dtype, [], [])) + return TensorBox.create(input_buffer) + + +def tuned_b2b_gemm( + is_left_assoc: bool, + subgraph: Subgraph, + A: torch._inductor.ir.TensorBox, + B: torch._inductor.ir.TensorBox, + C: torch._inductor.ir.TensorBox, + *, + layout=None, +) -> torch._inductor.ir.TensorBox: + # call .realize() to get rid of Pointwise + A.realize() + B.realize() + C.realize() + layout = FixedLayout( + A.get_device_or_error(), + A.get_dtype(), + [A.shape[0], C.shape[1]], # type: ignore[index] + ) + subgraph_buffer = build_subgraph_buffer( + [create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())], + subgraph, + ) + choices: list[TritonTemplateCaller] = [] + for config in b2b_gemm_configs: + if is_left_assoc: + b2b_gemm_left_template.maybe_append_choice( + choices, + input_nodes=(A, B, C), + layout=layout, + subgraphs=[subgraph_buffer], + **config, + ) + else: + b2b_gemm_right_template.maybe_append_choice( + choices, + input_nodes=(A, B, C), + layout=layout, + subgraphs=[subgraph_buffer], + **config, + ) + # add the unoptimized choice to mitigate performance degradation + choices.append( + unoptimized_choice.bind( + (A, B, C), layout, is_left_assoc=is_left_assoc, subgraph=subgraph + ) + ) + # autotune + return autotune_select_algorithm("b2b_gemm", choices, [A, B, C], layout) + + +# match the inner mm of a potential b2b_gemm +@register_graph_pattern( + CallFunction(torch.ops.aten.mm, Arg(), Arg()), + pass_dict=B2B_GEMM_PASS, +) +def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> None: + # match.args: list[torch.fx.Node] + + def is_pointwise_node(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and (torch.Tag.pointwise in node.target.tags) + ) + + def is_mm(node: torch.fx.Node) -> bool: + return node.target == torch.ops.aten.mm.default + + # the inner MM + inner_mm = match.nodes[-1] + + # find the (candidate) outer MM, which will be re-checked below to ensure every path reaches it + # In a real (A @ f(B @ C)), every path starting from (B @ C) must reach (A @ _). + outer_mm = None + node = inner_mm + while len(node.users) > 0: + node = next(iter(node.users)) + if is_mm(node): + outer_mm = node + break + elif is_pointwise_node(node): + continue + else: + break + if not outer_mm: + return + + # find the unique input node for outer_mm representing f(B @ C) in (A @ f(B @ C)) + # we call it the "f_node" + # when the pattern is simply (A @ (B @ C)), f_node is just inner_mm + f_node = inner_mm + while next(iter(f_node.users)) is not outer_mm: + f_node = next(iter(f_node.users)) + + def all_reach_via_pointwise_with_no_other_inputs( + src: torch.fx.Node, + dst: torch.fx.Node, + ) -> tuple[bool, OrderedSet[torch.fx.Node]]: + """ + check whether every user path from src reaches dst via pointwise nodes, + with no other input nodes for the intermediates and dst; + return + (1) the Boolean value + (2) the subgraph node set including src and dst (which only makes sense when the Boolean value is True) + """ + visited = OrderedSet[torch.fx.Node]() + input_counter: dict[torch.fx.Node, int] = {} + + all_reachable = True + queue = deque([src]) + while queue: + node = queue.popleft() + if node not in visited: + if node is dst: + visited.add(node) + elif (node is src) or is_pointwise_node(node): + for user in node.users.keys(): + # for nodes other than dst, bookkeep their users' input counts + if user not in input_counter: + input_counter[user] = len(user.all_input_nodes) + input_counter[user] -= 1 + # continue BFS + queue.append(user) + visited.add(node) + else: + all_reachable = False + break + + return ( + all_reachable and all(count == 0 for count in input_counter.values()), + visited, + ) + + # check inner_mm reaches f_node on every user path via pointwise nodes with no outside input_nodes + ok, subgraph_node_set = all_reach_via_pointwise_with_no_other_inputs( + inner_mm, f_node + ) + if not ok: + return + + # check inner_mm's inputs and f_node's outputs + if not (len(inner_mm.all_input_nodes) == 2 and len(f_node.users) == 1): + return + + # at this point, the nodes between inner_mm and f_node (both included) + # are all used internally inside (A @ subgraph(B @ C)) + # i.e. they neither have other users nor have other inputs + + # original graph and module + graph, module = inner_mm.graph, inner_mm.graph.owning_module + + # construct the new (sub)graph + subgraph_node_list: list[ + torch.fx.Node + ] = [] # ordered list of nodes used for node removal later + new_graph: torch.fx.Graph = torch.fx.Graph() + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + new_input_anchor: torch.fx.Node # inner_mm, to be changed to an input node + new_output_anchor: torch.fx.Node # f_node, to be used to construct an output node + new_input_node: torch.fx.Node + new_output_node: torch.fx.Node + for node in graph.nodes: # preserve the order of nodes + if node in subgraph_node_set: + subgraph_node_list.append(node) + new_node = new_graph.node_copy( + node, lambda x: node_remapping[x] if x in node_remapping else x + ) + node_remapping[node] = new_node + if node is inner_mm: + new_input_anchor = new_node + if node is f_node: + new_output_anchor = new_node + if new_input_anchor is not new_output_anchor: # subgraph is non-trivial + # update the input node + with new_graph.inserting_before(new_input_anchor): + new_input_node = new_graph.placeholder(name="subgraph_input") + new_input_node.meta.update(new_input_anchor.meta) + new_input_anchor.replace_all_uses_with(new_input_node) + new_graph.erase_node(new_input_anchor) + # add the output node + new_output_node = new_graph.output(new_output_anchor) + new_output_node.meta.update(new_output_anchor.meta) + else: # subgraph is trivial, e.g. (A @ (B @ C)) + # update the input node + with new_graph.inserting_before(new_input_anchor): + new_input_node = new_graph.placeholder(name="subgraph_input") + new_input_node.meta.update(new_input_anchor.meta) + new_input_anchor.replace_all_uses_with(new_input_node) + new_graph.erase_node(new_input_anchor) + # update the output node (don't use new_output_anchor since it has been erased) + new_output_node = new_graph.output(new_input_node) + new_output_node.meta.update(new_input_node.meta) + new_graph.lint() + + # construct the subgraph + subgraph = Subgraph( + name="subgraph", graph_module=torch.fx.GraphModule(module, new_graph) + ) + + # two cases + # (1) (subgraph(A @ B) @ C), called "left_assoc" + # (2) (A @ subgraph(B @ C)), called "right_assoc" + is_left_assoc = outer_mm.args[0] is f_node + + # find the nodes A, B, C and check the sizes + A: torch.fx.Node + B: torch.fx.Node + C: torch.fx.Node + if is_left_assoc: + A = inner_mm.args[0] # type: ignore[assignment] + B = inner_mm.args[1] # type: ignore[assignment] + C = outer_mm.args[1] # type: ignore[assignment] + else: + A = outer_mm.args[0] # type: ignore[assignment] + B = inner_mm.args[0] # type: ignore[assignment] + C = inner_mm.args[1] # type: ignore[assignment] + if not is_b2b_gemm_good_on(is_left_assoc, A, B, C): + return + + # finally update the original graph + counters["inductor"]["b2b_gemm"] += 1 + graph = match.graph + with graph.inserting_before(outer_mm): + function = functools.partial(tuned_b2b_gemm, is_left_assoc, subgraph) + function.__name__ = tuned_b2b_gemm.__name__ # type: ignore[attr-defined] + function._inductor_lowering_function = True # type: ignore[attr-defined] + replacement: torch.fx.Node = graph.call_function( + function, + (A, B, C), + match.kwargs, + ) + replacement.meta.update(outer_mm.meta) + outer_mm.replace_all_uses_with(replacement) + # erase unnecessary nodes + graph.erase_node(outer_mm) + for node in reversed(subgraph_node_list): + graph.erase_node(node) + graph.lint() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ad3e1c8f91903a65b6dc5c17fa42750badb101 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/binary_folding.py @@ -0,0 +1,503 @@ +# mypy: allow-untyped-defs +import functools +import itertools + +import torch + +from ..._dynamo.utils import counters +from .. import config +from ..pattern_matcher import Arg, CallFunction, KeywordArg +from .freezing_patterns import register_binary_folding_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims + + +def mark_mixed_dtype(computation_node): + computation_node_dtype = computation_node.meta["val"].dtype + if computation_node_dtype not in (torch.float16, torch.bfloat16): + return + + if not len(computation_node.users) == 1: + return + + computation_node_user = next(iter(computation_node.users.keys())) + if not isinstance(computation_node_user.meta["val"], torch.Tensor): + return + + if not computation_node_user.meta["val"].dtype == torch.float32: + return + + while computation_node_user.target in _binary_ops: + if not len(computation_node_user.users) == 1: + return + + computation_node_user = next(iter(computation_node_user.users.keys())) + + if computation_node_user.target != prims.convert_element_type.default: + return + + computation_node.meta["_allow_mixed_dtype_folding"] = computation_node_dtype + + +def mark_mixed_dtype_allowed_computation_ops(gm): + """ + Mark convolutions/linear which we will binary fold even with mixed precision constants. We constant fold in the higher precision + for better accuracy and then recover the original precision after. + """ + for target in [aten.convolution.default, aten.addmm.default, aten.mm.default]: + for node in gm.graph.find_nodes(op="call_function", target=target): + mark_mixed_dtype(node) + + +def recover_original_precision_folded_computation_ops(gm): + """ + After binary folding conv/linear weights and biases to a higher dtype, recover the original precision they were in. + """ + graph = gm.graph + for target, idx in ( + (aten.convolution.default, (1, 2)), + (aten.addmm.default, (0, 2)), + (aten.mm.default, (1,)), + ): + for node in graph.find_nodes(op="call_function", target=target): + orig_dtype = node.meta.get("_allow_mixed_dtype_folding", None) + if orig_dtype is None: + continue + + with graph.inserting_before(node): + for i in idx: + old_input = node.args[i] + if old_input is None: + continue + + new_input = graph.create_node( + "call_function", + prims.convert_element_type.default, + (old_input, orig_dtype), + ) + node.replace_input_with(old_input, new_input) + + +_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] + + +@functools.cache +def binary_folding_init(): + _conv_args = [Arg() for _ in range(9)] + _addmm_args = [Arg() for _ in range(3)] + _mm_args = [Arg() for _ in range(2)] + _computation_ops = [aten.convolution.default, aten.addmm.default, aten.mm.default] + _computation_calls = [ + CallFunction(aten.convolution.default, *_conv_args, _users=1), + CallFunction(aten.addmm.default, *_addmm_args, _users=1), + CallFunction( + aten.reshape.default, + CallFunction(aten.addmm.default, *_addmm_args, _users=1), + Arg(), + _users=1, + ), + CallFunction(aten.mm.default, *_mm_args, _users=1), + CallFunction( + aten.reshape.default, + CallFunction(aten.mm.default, *_mm_args, _users=1), + Arg(), + _users=1, + ), + ] + + """ + In order to fuse add/sub/mul/div with conv/linear, the dimensions of its + constant tensor must satisfy the following: + - with resizing, broadcast to w/ weight/bias tensor shape + - broadcast to the conv/linear output shape + It needs to have a shape that can resize to weight/bias + tensor shape because we need to run the op with the conv/linear + weights/bias without changing their sizes. + It needs to broadcast to the conv/linear output shape so that we do + accidentally change the shape of op output by pre-fusing it + compared to eager. + The only dimension value shared by weight, bias, and conv/linear output + is they all contain a dim with value = channels-out. In the + conv/linear output tensor, this is in the second dimension, + so the pointwise op tensor may have a second dimension of + value == channels-out, but all the other dimensions have to be 1 + """ + + def _op_not_broadcasting_with_conv(weight_tensor, other_tensor): + # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + if len(weight_shape) < len(other_shape): + return False + if len(weight_shape) == len(other_shape) + 1: + # weight shape is [o, i, *], other_shape is [o, 1...]. + for i in reversed(range(len(other_shape))): + if i == 0 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + else: + # weight shape is [o, i, *], other_shape is [1, i, *] + for i in reversed(range(len(other_shape))): + if i == 1 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + return True + + def _op_not_broadcasting_with_linear(weight_tensor, other_tensor, has_reshape): + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + other_shapes = [ + torch.Size( + [ + weight_shape[1], + ] + ), + torch.Size([1, weight_shape[1]]), + torch.Size( + [ + 1, + ] + ), + torch.Size([1, 1]), + ] + if has_reshape: + other_shapes.extend( + [ + torch.Size([1, 1, weight_shape[1]]), + torch.Size([1, 1, 1]), + ] + ) + return other_shape in other_shapes + + def _check_conv_and_broadcast_op(conv_node, other): + # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp. + # conv.weight + if conv_node.args[1].op != "get_attr": + return False + # conv.bias + if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if not len(conv_node.args[1].users) == 1: + return False + + weight_meta_value = conv_node.args[1].meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): # type: ignore[union-attr] + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr] + != weight_meta_value.dtype + ): + if not conv_node.meta.get("_allow_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float # type: ignore[union-attr] + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value): + return False + elif not isinstance(other, float): + return False + + return True + + def _check_linear_and_broadcast_op(linear_node, other, has_reshape): + weight_node = ( + linear_node.args[2] + if linear_node.target is aten.addmm.default + else linear_node.args[1] + ) + bias_node = ( + linear_node.args[0] if linear_node.target is aten.addmm.default else None + ) + if weight_node.op != "get_attr": + return False + if bias_node is not None and bias_node.op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if not len(weight_node.users) == 1: + return False + + weight_meta_value = weight_node.meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): # type: ignore[union-attr] + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr] + != weight_meta_value.dtype + ): + if not linear_node.meta.get("_allow_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float # type: ignore[union-attr] + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_linear( + weight_meta_value, other_meta_value, has_reshape + ): + return False + elif not isinstance(other, float): + return False + + return True + + def _is_foldable_pattern(match): + binary_node = match.output_node() + has_reshape = False + if binary_node.args[0].target in _computation_ops: + computation_node = binary_node.args[0] + other = binary_node.args[1] + elif binary_node.args[0].target == aten.reshape.default: + computation_node = binary_node.args[0].args[0] + other = binary_node.args[1] + has_reshape = True + elif binary_node.args[1].target in _computation_ops: + computation_node = binary_node.args[1] + other = binary_node.args[0] + else: + computation_node = binary_node.args[1].args[0] + other = binary_node.args[0] + has_reshape = False + if computation_node.target == aten.convolution.default: + return _check_conv_and_broadcast_op(computation_node, other) + elif computation_node.target in [aten.addmm.default, aten.mm.default]: + return ( + config.enable_linear_binary_folding + and _check_linear_and_broadcast_op(computation_node, other, has_reshape) + ) + + return False + + def resize_scalar_or_tensor_to_shape(graph, other, shape, weight): + if isinstance(other, float): + with torch.utils._python_dispatch._disable_current_modes(): + other_tensor = torch.tensor( + other, dtype=weight.dtype, device=weight.device + ) + graph.owning_module.register_buffer("other_tensor", other_tensor) + res = graph.create_node("get_attr", "other_tensor") + res = graph.create_node( + "call_function", + aten.reshape.default, + (res, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + elif other.meta.get("val").numel() == 1: + # expand errors if the shape input has less # dims than the tensor input + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + else: + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, shape), + ) + return res + + def _create_new_conv_node(graph, conv_node, binary_node, other): + assert conv_node.target == aten.convolution.default + conv_args = list(conv_node.args) + weight_meta_value = conv_node.args[1].meta.get("val") + bias = conv_args[2] + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(0),), + weight_meta_value, + ) + new_bias = graph.create_node( + "call_function", + binary_node.target, + (0 if bias is None else bias, other_reshape), + ) + conv_args[2] = new_bias + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))] + weight_broadcast_shape[0] = weight_meta_value.size(0) + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, + other, + tuple(weight_broadcast_shape), + weight_meta_value, + ) + new_weight = graph.create_node( + "call_function", binary_node.target, (conv_args[1], other_reshape1) + ) + new_weight.meta.update(conv_args[1].meta) + conv_args[1] = new_weight + if bias is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(0),), + weight_meta_value, + ) + new_bias = graph.create_node( + "call_function", binary_node.target, (bias, other_reshape) + ) + new_bias.meta.update(bias.meta) + conv_args[2] = new_bias + return graph.create_node("call_function", conv_node.target, tuple(conv_args)) + + def _create_new_linear_node(graph, linear_node, binary_node, other): + assert linear_node.target in [aten.addmm.default, aten.mm.default] + input_node = ( + linear_node.args[1] + if linear_node.target is aten.addmm.default + else linear_node.args[0] + ) + weight_node = ( + linear_node.args[2] + if linear_node.target is aten.addmm.default + else linear_node.args[1] + ) + bias_node = ( + linear_node.args[0] if linear_node.target is aten.addmm.default else None + ) + weight_meta_value = weight_node.meta.get("val") + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(1),), + weight_meta_value, + ) + new_bias_node = graph.create_node( + "call_function", + binary_node.target, + (0 if bias_node is None else bias_node, other_reshape), + ) + return graph.create_node( + "call_function", + aten.addmm.default, + (new_bias_node, input_node, weight_node), + ) + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1, weight_meta_value.size(1)] + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, + other, + tuple(weight_broadcast_shape), + weight_meta_value, + ) + new_weight_node = graph.create_node( + "call_function", binary_node.target, (weight_node, other_reshape1) + ) + new_weight_node.meta.update(weight_node.meta) + if bias_node is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(1),), + weight_meta_value, + ) + new_bias_node = graph.create_node( + "call_function", binary_node.target, (bias_node, other_reshape) + ) + new_bias_node.meta.update(bias_node.meta) + return graph.create_node( + "call_function", + linear_node.target, + (new_bias_node, input_node, new_weight_node), + ) + else: + return graph.create_node( + "call_function", linear_node.target, (input_node, new_weight_node) + ) + + for _computation_call, binary_op in itertools.product( + _computation_calls, _binary_ops + ): + + @register_binary_folding_pattern( + CallFunction(binary_op, _computation_call, KeywordArg("other")), + extra_check=_is_foldable_pattern, + ) + def folded_op(match, *args, **kwargs): + counters["inductor"]["binary_folding"] += 1 + other = kwargs.get("other") + binary_node = match.output_node() + reshape_node = None + if binary_node.args[0].target in _computation_ops: + computation_node = binary_node.args[0] + elif binary_node.args[0].target == aten.reshape.default: + computation_node = binary_node.args[0].args[0] + reshape_node = binary_node.args[0] + elif binary_node.args[1].target in _computation_ops: + computation_node = binary_node.args[1] + else: + computation_node = binary_node.args[1].args[0] + reshape_node = binary_node.args[1] + graph = match.graph + with graph.inserting_before(reshape_node if reshape_node else binary_node): + assert computation_node.target in _computation_ops + if computation_node.target == aten.convolution.default: + counters["inductor"]["binary_folding_conv"] += 1 + new_computation_node = _create_new_conv_node( + graph, computation_node, binary_node, other + ) + else: + new_computation_node = _create_new_linear_node( + graph, computation_node, binary_node, other + ) + new_computation_node.meta.update(computation_node.meta) + if reshape_node: + assert reshape_node.target == aten.reshape.default + computation_node.replace_all_uses_with(new_computation_node) + binary_node.replace_all_uses_with(reshape_node) + else: + binary_node.replace_all_uses_with(new_computation_node) + graph.erase_node(binary_node) + graph.erase_node(computation_node) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ccea7d7e70af503893c0b3ce289f34e1ffae2778 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/ddp_fusion.py @@ -0,0 +1,586 @@ +# Owner(s): ["oncall: distributed"] +import collections +import inspect +import logging +import math +import operator +from collections.abc import Generator +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, cast, Optional, Union + +import torch +import torch.fx as fx +from torch._dynamo.utils import counters +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from ..fx_utils import get_fake_args_kwargs +from ..virtualized import V + + +aten = torch.ops.aten +logger: logging.Logger = logging.getLogger("comm_fusion") + + +def move_block_after(block: list[fx.Node], target_node: fx.Node) -> None: + for node in block: + target_node.append(node) + target_node = node + + +def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None: + for node in block: + target_node.prepend(node) + target_node = node + + +def call_function( + graph: fx.Graph, + target: Union[str, Callable[..., Any]], + args: Optional[tuple[fx.node.Argument, ...]] = None, + kwargs: Optional[dict[str, fx.node.Argument]] = None, +) -> fx.Node: + # We accept target as a str to avoid typing error as the type of + # a node.target is Union[str, Callable[..., Any]]. + # This also allows us to avoid writing check for every call. + if isinstance(target, str): + raise RuntimeError(f"Call function should not get a str target {target=}") + node = graph.call_function(target, args, kwargs) + _, args, kwargs = get_fake_args_kwargs(node) + with V.fake_mode: + node.meta["val"] = target(*args, **kwargs) + # node.meta["val"] may be a container. So we use tree_map here + # to recursively extract the tensor metadata. + node.meta["tensor_meta"] = tree_map( + _extract_tensor_metadata, (node.meta["val"],) + )[0] + return node + + +@dataclass(unsafe_hash=True) +class CommBlock: + shape: Union[torch.Size, list[torch.Size]] + node_list: list[fx.Node] + inputs: list[fx.Node] + wait_nodes: list[fx.Node] + comm_node: fx.Node + outputs: OrderedSet[fx.Node] + + +def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]: + """ + Given a collective node (e.g., allreduce), find out all the nodes belong to + this communication. + + Args: + comm_node(fx.Node): The target communication/collective node. + Returns: + The CommBlock that encapsulates the related nodes (e.g., wait_node) of + the given comm_node. + """ + node_list = [] + wait_nodes = [] + inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs)) + input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)] + # If the users of the wait node are following items, we consinder them + # to be a part of the output. + intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias") + + first_user = next(iter(comm_node.users)) + if ( + len(comm_node.users) == 1 + and first_user.target == torch.ops._c10d_functional.wait_tensor.default + ): + # Collective with only one output + node_list = [comm_node, first_user] + wait_nodes.append(first_user) + elif len(comm_node.users) > 1 and first_user.target == operator.getitem: + # Collective with only more than one output + node_list.append(comm_node) + for user in comm_node.users: + if user.target != operator.getitem: + return None + if len(user.users) != 1: + return None + wait_node = next(iter(user.users)) + if wait_node.target != torch.ops._c10d_functional.wait_tensor.default: + return None + wait_nodes.append(wait_node) + node_list.append(user) + node_list.extend(wait_nodes) + else: + return None + + # Identify all the outputs of this collective block. + outputs = OrderedSet[fx.Node]() + nodes = collections.deque(wait_nodes) + while nodes: + node = nodes.popleft() + for user in node.users: + if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs): + nodes.append(user) + node_list.append(user) + else: + outputs.add(node) + break + + tensor_meta = input_nodes[0].meta["tensor_meta"] + shape: Union[torch.Size, list[torch.Size]] + if isinstance(tensor_meta, TensorMetadata): + shape = tensor_meta.shape + elif isinstance(tensor_meta, (list, tuple)): + shape = [tm.shape for tm in tensor_meta] + else: + logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta)) + return None + + return CommBlock( + shape=shape, + node_list=node_list, + wait_nodes=wait_nodes, + comm_node=comm_node, + inputs=input_nodes, + outputs=outputs, + ) + + +def get_all_comm_blocks( + graph: fx.Graph, + comm_ops: tuple[torch._ops.OpOverload, ...], + comm_filter: Optional[Callable[..., bool]] = None, +) -> list[CommBlock]: + if comm_filter is None: + + def always_true(comm_block: CommBlock) -> bool: + return True + + comm_filter = always_true + + blocks = [] + for node in graph.nodes: + if node.target not in comm_ops: + continue + comm_block = get_comm_block(node) + if comm_block is not None and comm_filter(comm_block): + blocks.append(comm_block) + return blocks + + +def _fuse_allreduce_by_concat( + graph: fx.Graph, + last_input_node: fx.Node, + all_input_nodes: list[fx.Node], + last_comm_block: CommBlock, +) -> CommBlock: + """Given a list of inputs in order, create a fused allreduce using concat.""" + # Flatten all the inputs to the all_reduce nodes. + with graph.inserting_after(last_input_node): + cat_inputs = [] + for input_node in all_input_nodes: + assert isinstance(input_node.args[0], fx.Node) + input_node = input_node.args[0] + cat_inputs.append( + call_function(graph, aten.flatten.using_ints, (input_node,)) + ) + + # Concat all the flattened nodes. + with graph.inserting_after(cat_inputs[0]): + cat_node = call_function(graph, aten.cat, (cat_inputs,)) + + # Insert the fused div node and remove the input div nodes. + # This is an optimization and is not mandatory for fusion. + divisors = [div.args[1] for div in all_input_nodes] + assert all(divisor == divisors[0] for divisor in divisors) + with graph.inserting_after(cat_node): + div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0])) + + # Create a new Comm/all_reduce node. + last_comm_node = last_comm_block.comm_node + last_wait_node = last_comm_block.wait_nodes[0] + with graph.inserting_after(div_node): + flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) + flatten_args[0] = div_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs) + + # Create a new Wait node. + with graph.inserting_after(fused_comm_node): + flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) + flatten_args[0] = fused_comm_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs) + + # Move the fused all_reduce and its args to right after the input node + nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node] + move_block_after(nodes_to_move, last_input_node) + + return CommBlock( + shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape, + node_list=[fused_comm_node, fused_wait_node], + wait_nodes=[fused_wait_node], + comm_node=fused_comm_node, + inputs=[div_node], + outputs=OrderedSet([fused_wait_node]), + ) + + +def _fuse_with_coalesced_op( + graph: fx.Graph, + last_input_node: fx.Node, + all_input_nodes: list[fx.Node], + last_comm_block: CommBlock, +) -> CommBlock: + """Given a list of inputs in order, create a fused allreduce by coalesced.""" + last_comm_node = last_comm_block.comm_node + last_wait_node = last_comm_block.wait_nodes[0] + + # Insert the fused div node and remove the input div nodes. + # This is an optimization and is not mandatory for fusion. + dividends = [div.args[0] for div in all_input_nodes] + divisors = [div.args[1] for div in all_input_nodes] + assert all(divisor == divisors[0] for divisor in divisors) + with graph.inserting_before(last_input_node): + last_input_node = call_function( + graph, aten._foreach_div.Scalar, (dividends, divisors[0]) + ) + input_node = last_input_node + + # Create a new Comm/all_reduce_coalesced node. + with graph.inserting_after(last_comm_node): + flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) + flatten_args[0] = input_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_comm_node = call_function( + graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs + ) + + # Create a new wait node. + getitem_nodes = [] + wait_nodes = [] + flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) + for idx in range(len(all_input_nodes)): + with graph.inserting_after(fused_comm_node): + gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx)) + getitem_nodes.append(gi_node) + flatten_args[0] = gi_node + args, kwargs = tree_unflatten(flatten_args, spec) + with graph.inserting_after(gi_node): + wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs)) + + # Move the new all_reduce_coalesced and its args to right after the input node + nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes + move_block_after(nodes_to_move, last_input_node) + + return CommBlock( + shape=[ + tm.shape + for tm in cast( + list[TensorMetadata], fused_comm_node.meta.get("tensor_meta") + ) + ], + node_list=[fused_comm_node] + getitem_nodes + wait_nodes, + wait_nodes=wait_nodes, + comm_node=fused_comm_node, + inputs=[input_node], + outputs=OrderedSet(wait_nodes), + ) + + +def _scatter_fused_allreduce_waits( + graph: fx.Graph, + fused_comm_block: CommBlock, + orig_comm_blocks: list[CommBlock], + node_indices: dict[fx.Node, int], + split_and_reshape: bool = True, +) -> None: + """ + Scatters the result of the fused communication node to the original users. + If the fused method is concat splitting the output and reshape will be inserted, + before inserting getitem. Otherwise getitem will be used as the users of the + wait node. + """ + + # Before we mass up the order, we need to get the index of the last wait node + # in orig_comm_blocks. This index will be later used to determine what users + # nodes need to be move to maintain a correct topological sort order. + last_wait_node_idx = 0 + for node in graph.nodes: + last_wait_node_idx = max( + node_indices.get(node, last_wait_node_idx), last_wait_node_idx + ) + if node == orig_comm_blocks[-1].wait_nodes[0]: + break + + if split_and_reshape: + fused_wait_node = fused_comm_block.wait_nodes[0] + with graph.inserting_after(fused_wait_node): + split_node = call_function( + graph, + aten.split, + ( + fused_wait_node, + [math.prod(cast(list[int], cb.shape)) for cb in orig_comm_blocks], + ), + ) + with graph.inserting_after(split_node): + fused_outputs = [] + for idx, comm_block in enumerate(orig_comm_blocks): + split_idx_node = call_function( + graph, operator.getitem, (split_node, idx) + ) + with graph.inserting_after(split_idx_node): + fused_outputs.append( + call_function( + graph, aten.reshape, (split_idx_node, comm_block.shape) + ) + ) + else: + fused_outputs = fused_comm_block.wait_nodes + + # Scatter the fused outputs. + incorrect_order_nodes = [] + for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs): + # Some descendant users of the orig_comm_blocks may be scheduled before + # the fused all_reduce. For example, the user nodes of the very first + # all_reduce may be scheduled before the second all_reduce. Since the + # fused all_reduce is inserted right after the last all_reudce, the + # order can be wrong. + # `incorrect_order_nodes` records these nodes. + + orig_wait = comm_block.wait_nodes[0] + nodes = collections.deque(list(orig_wait.users)) + while nodes: + user_node = nodes.popleft() + if not isinstance(user_node, fx.Node): + continue + if node_indices[user_node] < last_wait_node_idx: + incorrect_order_nodes.append(user_node) + nodes.extend(list(user_node.users)) + + orig_wait.replace_all_uses_with(fused_output) + + last_fused_result = fused_outputs[0] + fused_outputs_set = OrderedSet(fused_outputs) + for node in graph.nodes: + if node in fused_outputs_set: + last_fused_result = node + + # Move the incorrect_order_nodes to right after the last fused_result. + incorrect_order_nodes = sorted( + incorrect_order_nodes, key=lambda node: node_indices[node] + ) + move_block_after(incorrect_order_nodes, last_fused_result) + + +def _fuse_allreduce( + graph: fx.Graph, + comm_blocks: list[CommBlock], + node_indices: dict[fx.Node, int], + use_concat: bool, +) -> CommBlock: + """Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock.""" + + if len(comm_blocks) == 1: + return comm_blocks[0] + + # Find the last input node of all the CommBlocks. This node will be served + # as the inserting point of the new collective op. + last_input_node = comm_blocks[0].inputs[0] + last_input_index = -1 + all_input_nodes = [] + for comm_block in comm_blocks: + input_node = comm_block.inputs[0] + all_input_nodes.append(input_node) + index = node_indices[input_node] + if index >= last_input_index: + assert index != last_input_index + last_input_node = input_node + last_input_index = index + + if use_concat: + fused_comm_block = _fuse_allreduce_by_concat( + graph, last_input_node, all_input_nodes, comm_blocks[-1] + ) + else: + fused_comm_block = _fuse_with_coalesced_op( + graph, last_input_node, all_input_nodes, comm_blocks[-1] + ) + + _scatter_fused_allreduce_waits( + graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat + ) + + for comm_block in comm_blocks: + for wait in comm_block.wait_nodes: + graph.erase_node(wait) + graph.erase_node(comm_block.comm_node) + graph.eliminate_dead_code() + + return fused_comm_block + + +def _bucket_size_fusion( + graph: fx.Graph, comm_blocks: list[CommBlock], bucket_size_mb: int +) -> Generator[list[CommBlock], None, None]: + MB = 1024**2 + bucket_size = 1 * MB + bucket_cap_size = bucket_size_mb * MB + curr_size = 0 + curr_blocks = [] + + count = 0 + fuse_count = 0 + for i, block in enumerate(comm_blocks): + curr_blocks.append(block) + itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize + curr_size += cast(torch.Size, block.shape).numel() * itemsize + count += 1 + if curr_size < bucket_size and i != len(comm_blocks) - 1: + continue + + fuse_count += 1 + if torch.distributed.get_rank() == 0: + logger.info( + "DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d", + fuse_count, + count, + curr_size, + bucket_size, + ) + + # Set the debug counters + counters["inductor"]["ddp_buckets"] = fuse_count + yield curr_blocks + + bucket_size = bucket_cap_size + curr_blocks = [] + curr_size = 0 + count = 0 + + +def _fuse_ddp_communication( + graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any] +) -> None: + for output in reversed(graph.nodes): + if output.op == "output": + break + + def ddp_reducer_filter(block: CommBlock) -> bool: + if ( + not isinstance(block.comm_node.args[0], fx.Node) + or block.comm_node.args[0].target != aten.div.Tensor + ): + return False + + if len(block.wait_nodes[0].users) != 1: + # gradient/wait node should only be used by one user + return False + + # Two cases: + # 1. gradient/wait node should be directly used by the output + # if gradient is None before bwd. + # 2. gradient/wait node should be directly used by copy_. + if ( + output not in block.wait_nodes[0].users + and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default + ): + return False + + return True + + ops = ( + torch.ops._c10d_functional.all_reduce_.default, + torch.ops._c10d_functional.all_reduce.default, + ) + comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter) + node_indices = {node: i for i, node in enumerate(graph.nodes)} + + for block in algorithm_fn(graph, comm_blocks): + fusion_fn(graph, block, node_indices) + + +def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None: + _fuse_ddp_communication( + graph, + partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), + partial(_fuse_allreduce, use_concat=False), + ) + + +def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None: + _fuse_ddp_communication( + graph, + partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), + partial(_fuse_allreduce, use_concat=True), + ) + + +def schedule_comm_wait(graph: fx.Graph) -> None: + """ + Delay the execution of wait tensors of allreduce until its first user. + + This algorithm considers the intermediate users, like split, getitem, + of the wait node and schedule those intermediate users as well. + This will result in a better overlapping result. + """ + ops = ( + torch.ops._c10d_functional.all_reduce_.default, + torch.ops._c10d_functional.all_reduce.default, + torch.ops._c10d_functional.all_reduce_coalesced.default, + torch.ops._c10d_functional.all_reduce_coalesced_.default, + ) + comm_blocks = get_all_comm_blocks(graph, ops) + if not comm_blocks: + return + + # Find all the end users. + allreduce_users = OrderedSet[fx.Node]() + for allreduce in comm_blocks: + for output in allreduce.outputs: + allreduce_users.update(output.users) + + node_indices = {node: i for i, node in enumerate(graph.nodes)} + for allreduce in comm_blocks: + # Find the earliest/first user -- target_node. + assert len(allreduce.outputs) >= 1, ( + f"Found a allreduce that has zero outputs/users -- {allreduce}." + ) + # Initialize the target node to avoid typing issues. + target_node = next(iter(next(iter(allreduce.outputs)).users)) + target_node_index = 2**31 + for user in (user for output in allreduce.outputs for user in output.users): + index = node_indices[user] + if index < target_node_index: + target_node = user + target_node_index = index + + # Move wait nodes and all the subsequent nodes in the comm_block to + # before the first user -- target_node. + wait_idx = -1 + for wait_idx, node in enumerate(allreduce.node_list): + if node == allreduce.wait_nodes[0]: + break + assert wait_idx >= 0 + move_block_before(allreduce.node_list[wait_idx:], target_node) + + +def fuse_ddp_communication( + graph: fx.Graph, passes: list[Union[Callable[..., None], str]], bucket_size_mb: int +) -> None: + for i, pa in enumerate(passes): + with GraphTransformObserver( + graph.owning_module, f"fuse_ddp_communication_pass_{i}" + ): + if isinstance(pa, str): + func = globals()[pa] + else: + func = pa + if "bucket_size_mb" in OrderedSet( + v.name for v in inspect.signature(func).parameters.values() + ): + func(graph, bucket_size_mb=bucket_size_mb) + else: + func(graph) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..e6757c3ad9e31e89246d19859b449f0643ee9be3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -0,0 +1,156 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch import Tensor +from torch._dynamo.utils import counters, is_node_meta_valid +from torch.fx.experimental.symbolic_shapes import statically_known_true + +from .. import config +from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern +from .split_cat import construct_pattern_matcher_pass + + +aten = torch.ops.aten +log = logging.getLogger(__name__) + +# TODO: need a better strategy for decomposing mm +MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 +MAX_OTHER_DIMENSION_DECOMPOSITION = 32 + +min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION +max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION +if "decompose_mm_pass" in config.post_grad_fusion_options: + min_first_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION) + max_other_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) + + +def check_device(a: Tensor, b: Tensor, device="cuda") -> bool: + return (a.device.type == b.device.type) and (b.device.type == device) + + +def realize_inputs(inputs: list[torch.fx.Node]): + for inp in inputs: + if isinstance(inp, torch.fx.node.Node): + inp.meta["inductor_realize_to_strides"] = True + + +def should_decompose_bmm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if len(mat1.shape) != 3 or len(mat2.shape) != 3: + return False + if check_device(mat1, mat2, device="cuda"): + if mat1.shape[0] < min_first_dimension_decomposition: + return False + # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION + if (mat1.shape[1] < max_other_dimension_decomposition) + ( + mat1.shape[2] < max_other_dimension_decomposition + ) + (mat2.shape[2] < max_other_dimension_decomposition) < 2: + return False + return True + elif check_device(mat1, mat2, device="cpu"): + if mat1.shape[0] == 1 and mat2.shape[0] == 1: + return True + return False + + +def should_decompose_mm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if len(mat1.shape) != 2 or len(mat2.shape) != 2: + return False + return ( + check_device(mat1, mat2, device="cuda") + and statically_known_true(mat1.shape[0] >= min_first_dimension_decomposition) + and statically_known_true(mat2.shape[0] < max_other_dimension_decomposition) + and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition) + ) or ( + check_device(mat1, mat2, device="cpu") + and statically_known_true(mat1.shape[0] == 1) + and statically_known_true(mat2.shape[0] <= 128) + and statically_known_true(mat2.shape[1] <= 512) + ) + + +def print_decompose_pattern(match: Match, inputs: list[torch.fx.Node]): + node = match.nodes[-1] + log.debug( + "Decompose %s with input shape: %s", + node.target, + ", ".join( + str(input.meta["val"].shape) if "val" in input.meta else "None" + for input in inputs + ), + ) + + +@register_graph_pattern( + CallFunction(aten.bmm, Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2).to( + mat1.dtype + ) + + if should_decompose_bmm(mat1, mat2): + counters["inductor"]["decompose_bmm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + realize_inputs([mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.addmm, Arg(), Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_addmm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, + mat3: torch.fx.Node, +): + def repl(mat1, mat2, mat3): + return ( + torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2).to(mat2.dtype) + mat1 + ) + + if should_decompose_mm(mat2, mat3): + counters["inductor"]["decompose_addmm"] += 1 + match.replace_by_example(repl, [mat1, mat2, mat3]) + print_decompose_pattern(match, [mat1, mat2, mat3]) + realize_inputs([mat1, mat2, mat3]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_mm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2).to(mat1.dtype) + + if should_decompose_mm(mat1, mat2): + counters["inductor"]["decompose_mm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + realize_inputs([mat1, mat2]) + return diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py new file mode 100644 index 0000000000000000000000000000000000000000..713ed27aaa84ad24999fa9455cb97326c933a3a1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py @@ -0,0 +1,81 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import Any, Union + +import torch +from torch import SymBool, SymFloat, SymInt +from torch.types import py_sym_types +from torch.utils._ordered_set import OrderedSet + + +@dataclass +class _SymExprHash: + """ + Hash for a py_sym_types that will use the underlying sympy expression + """ + + sym_obj: Union[SymInt, SymFloat, SymBool] + + def __hash__(self) -> int: + return hash((type(self.sym_obj), self.sym_obj.node.expr)) + + def __eq__(self, value) -> bool: + if not isinstance(value, _SymExprHash): + return False + return self.sym_obj.node.expr == value.sym_obj.node.expr + + +class _SymHashingDict: + """ + Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse + existing sym proxies. + + SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail, + fallback to symnodes. + """ + + def __init__(self): + self.sym_hash_dict = {} + + def __setitem__(self, key, value): + self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value) + + def __getitem__(self, key): + return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)] + + def __contains__(self, key): + return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict + + def get(self, key, default=None): + return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default) + + def _wrap_to_sym_expr_hash(self, key): + return _SymExprHash(key) if isinstance(key, py_sym_types) else key + + +def dedupe_symints(graph: torch.fx.Graph): + """ + Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs. + + We only dedupe from graph inputs to avoid adding a potential dependency in the forward + from the backward. + + """ + + sym_dict = _SymHashingDict() + resolvable_from_input_symints = OrderedSet[Any]() + + for node in graph.nodes: + val = node.meta.get("val", None) + if val is None or not isinstance(val, py_sym_types): + continue + + if node.op == "placeholder": + resolvable_from_input_symints.add(node) + sym_dict[val] = node + elif existing_node := sym_dict.get(val): + node.replace_all_uses_with(existing_node) + graph.erase_node(node) + elif all(n in resolvable_from_input_symints for n in node.all_input_nodes): + sym_dict[val] = node + resolvable_from_input_symints.add(node) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..0e647e37cd346ff742f246760bdb3348e4842851 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -0,0 +1,409 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +from torch._dynamo.utils import counters +from torch._inductor import config as inductor_config +from torch.func import functional_call + +from ..pattern_matcher import ( + CallFunctionVarArgs, + CallModuleVarArgs, + Match, + register_graph_pattern, +) +from .pre_grad import efficient_conv_bn_eval_pass + + +def efficient_conv_bn_eval( + bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Efficient ConvBN Blocks for Transfer Learning and Beyond" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module. + conv (nn.modules.conv._ConvNd): a conv module + x (torch.Tensor): Input feature map. + """ + + assert bn.running_var is not None + assert bn.running_mean is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv.weight + if conv.bias is not None: + bias_on_the_fly = conv.bias + else: + bias_on_the_fly = torch.zeros_like(bn.running_var) + + if bn.weight is not None: + bn_weight = bn.weight + else: + bn_weight = torch.ones_like(bn.running_var) + + if bn.bias is not None: + bn_bias = bn.bias + else: + bn_bias = torch.zeros_like(bn.running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv.weight.ndim - 1) + if isinstance(conv, nn.modules.conv._ConvTransposeNd): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn.running_mean + ) + + input = x + params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly} + output = functional_call(conv, params, input) + return output + + +def efficient_conv_bn_eval_decomposed( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv: torch._ops.OpOverload, + conv_weight, + conv_bias, + x, + conv_remainging_args, +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Efficient ConvBN Blocks for Transfer Learning and Beyond" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + """ + assert bn_running_var is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv_weight + if conv_bias is not None: + bias_on_the_fly = conv_bias + else: + bias_on_the_fly = torch.zeros_like(bn_running_var) + + if bn_weight is not None: + bn_weight = bn_weight + else: + bn_weight = torch.ones_like(bn_running_var) + + if bn_bias is not None: + bn_bias = bn_bias + else: + bn_bias = torch.zeros_like(bn_running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv_weight.ndim - 1) + if "conv_transpose" in conv.__str__(): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn_running_mean + ) + + input = x + return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args)) + + +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.nn.functional.batch_norm, + ] + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 8 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-3]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch._C._nn.linear, + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_running_mean = bn_node.args[1] + bn_running_var = bn_node.args[2] + bn_weight = bn_node.args[3] + bn_bias = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, # type: ignore[arg-type] + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] + + return + + +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.ops.aten.batch_norm.default, + ] + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 9 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-4]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch.ops.aten.linear.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_weight = bn_node.args[1] + bn_bias = bn_node.args[2] + bn_running_mean = bn_node.args[3] + bn_running_var = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, # type: ignore[arg-type] + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] + + return + + +@register_graph_pattern( + CallModuleVarArgs( + [ + nn.modules.batchnorm._BatchNorm, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.SyncBatchNorm, + ], + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs): + # We matched a BN node + bn_node = match.nodes[0] + graph = match.graph + gm = graph.owning_module + bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type] + + # We can only use efficient conv-bn for eval mode with track_running_stats + if not bn_mod.track_running_stats or bn_mod.training: + return + + # Check if the input is Conv + if bn_node.args: + input_node = bn_node.args[0] + else: + input_node = bn_node.kwargs["input"] + if input_node.op != "call_module": # type: ignore[union-attr] + return + if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr] + return + input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr] + supported_convs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ] + if not any(isinstance(input_mod, cls) for cls in supported_convs): + return + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + # Find a pair of conv and bn computation nodes to optimize. + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(conv_node): # type: ignore[arg-type] + # create `get_attr` node to access modules + # note that we directly call `create_node` to fill the `name` + # argument. `graph.get_attr` and + # `graph.call_function` does not allow the `name` argument. + conv_get_node = graph.create_node( + op="get_attr", + target=conv_node.target, # type: ignore[union-attr] + name="get_conv", + ) + bn_get_node = graph.create_node( + op="get_attr", target=bn_node.target, name="get_bn" + ) + if conv_node.args: # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + else: + conv_input = conv_node.kwargs["input"] # type: ignore[union-attr] + # prepare args for the fused function + args = (bn_get_node, conv_get_node, conv_input) + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval, + args=args, + name="efficient_conv_bn_eval", + ) + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..f05048a85e0e722b9af68e21fea28f58cdc8e917 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/freezing_patterns.py @@ -0,0 +1,297 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch._inductor.compile_fx import fake_tensor_prop +from torch._inductor.utils import GPU_TYPES + +from ..._dynamo.utils import counters +from .. import config +from ..pattern_matcher import ( + _return_true, + CallFunction, + fwd_only, + Ignored, + init_once_fakemode, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) + + +aten = torch.ops.aten + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + +binary_folding_pass = PatternMatcherPass() + + +def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs): + """ + Passes that are applied to the graph to freeze pass. + """ + + from ..freezing import constant_fold + + lazy_init() + # We need a few rounds of binary folding to get rid of all the + # unnecessary nodes, but may need a good method to chose the rounds number. + # works like: conv+binary+binary. + binary_folding = counters["inductor"]["binary_folding"] + fake_tensor_prop(gm, aot_example_inputs, True) + + torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_computation_ops( + gm + ) + for _ in range(4): + constant_fold(gm) + # Make sure meta['val'] is properly set for all nodes + fake_tensor_prop(gm, aot_example_inputs, True) + binary_folding_pass.apply(gm.graph) # type: ignore[arg-type] + # If we don't have binary folding, we don't need to run the pass again. + # TODO: remove the need to run fake_tensor_prop on the whole model. + if counters["inductor"]["binary_folding"] == binary_folding: + break + binary_folding = counters["inductor"]["binary_folding"] + + torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_computation_ops( + gm + ) + + constant_fold(gm) + fake_tensor_prop(gm, aot_example_inputs, True) + + for pattern in pass_patterns: + pattern.apply(gm.graph) # type: ignore[arg-type] + + # The CPU weight packing always assume the conv's weight is channels last, + # So make sure the layout_optimization is on when doing it. + if ( + torch._C._has_mkldnn + and config.cpp.weight_prepack + and config.layout_optimization + ): + from .mkldnn_fusion import _eliminate_duplicate_packed_nodes + + _eliminate_duplicate_packed_nodes(gm) + + stable_topological_sort(gm.graph) + gm.recompile() + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn and config.cpp.weight_prepack: + from .mkldnn_fusion import _mkldnn_weight_pack_init + + _mkldnn_weight_pack_init() + + from .binary_folding import binary_folding_init + + addmm_patterns_init() + binary_folding_init() + + +def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0): + while pass_number > len(pass_patterns) - 1: + pass_patterns.append(PatternMatcherPass()) + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=pass_patterns[pass_number], + ) + + +def register_binary_folding_pattern(pattern, extra_check=_return_true): + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=binary_folding_pass, + ) + + +@functools.cache +def addmm_patterns_init(): + """ + addmm related patterns. + To avoid duplication, also includes int8 WoQ GEMM pattern without bias. + """ + device = next( + (gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu" + ) + val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False) + scale = functools.partial(torch.empty, (10,), device=device, requires_grad=False) + + def check_int8_woq_concat_linear_weights(match): + is_cpu = match.kwargs["inp"].meta["val"].is_cpu + if not is_cpu or not config.cpp.enable_concat_linear: + # Currently, this pattern is only supported on CPU + return False + + weight_inputs = ["w1", "w2"] + if "w3" in match.kwargs: + weight_inputs.append("w3") + + if not all( + match.kwargs[wgt].target == torch.ops.prims.convert_element_type.default + for wgt in weight_inputs + ): + return False + + if not all( + next(iter(match.kwargs[wgt]._input_nodes.keys())).meta["val"].dtype + is torch.int8 + for wgt in weight_inputs + ): + return False + + if not all( + match.kwargs[wgt].meta["val"].dtype is torch.bfloat16 + for wgt in weight_inputs + ): + return False + + equal_shape_inputs = [weight_inputs] + for equal_shape_group in equal_shape_inputs: + inps = [match.kwargs[name] for name in equal_shape_group] + if not all( + inp.meta["val"].shape == inps[0].meta["val"].shape for inp in inps + ): + return False + return True + + def check_concat_weights(match): + is_cpu = match.kwargs["inp"].meta["val"].is_cpu + if is_cpu and not config.cpp.enable_concat_linear: + return False + + weight_inputs = ["w1", "w2"] + if "w3" in match.kwargs: + weight_inputs.append("w3") + + equal_shape_inputs = [weight_inputs] + + if "b1" in match.kwargs: + bias_inputs = ["b1", "b2"] + if "b3" in match.kwargs: + bias_inputs.append("b3") + + equal_shape_inputs.append(bias_inputs) + + for equal_shape_group in equal_shape_inputs: + inps = [match.kwargs[name] for name in equal_shape_group] + + if not all( + inp.op == "get_attr" + and inp.meta["val"].shape == inps[0].meta["val"].shape + for inp in inps + ): + return False + return True + + def int8_woq_fusion_pattern(inp, w1, w2, w3, s1, s2, s3): + return ((inp @ w1) * s1, (inp @ w2) * s2, (inp @ w3) * s3) + + def int8_woq_fusion_replacement(inp, w1, w2, w3, s1, s2, s3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_s = torch.cat((s1, s2, s3), dim=0) + mm = (inp @ cat_w).mul(cat_s) + return mm.chunk(3, dim=1) + + register_replacement( + int8_woq_fusion_pattern, + int8_woq_fusion_replacement, + [val(), val(), val(), val(), scale(), scale(), scale()], + fwd_only, + pass_patterns[0], + extra_check=check_int8_woq_concat_linear_weights, + exclusive_arg_names=("w1", "w2", "w3", "s1", "s2", "s3"), + ) + + def matmul_fuse_pattern(inp, w1, w2, w3): + return (inp @ w1, inp @ w2, inp @ w3) + + def matmul_replacement(inp, w1, w2, w3): + cat_t = torch.cat((w1, w2, w3), dim=1) + mm = inp @ cat_t + return mm.chunk(3, dim=1) + + register_replacement( + matmul_fuse_pattern, + matmul_replacement, + [val(), val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3"), + ) + + def matmul_fuse_pattern_two(inp, w1, w2): + return (inp @ w1, inp @ w2) + + def matmul_replacement_two(inp, w1, w2): + cat_t = torch.cat((w1, w2), dim=1) + mm = inp @ cat_t + return mm.chunk(2, dim=1) + + register_replacement( + matmul_fuse_pattern_two, + matmul_replacement_two, + [val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2"), + ) + + def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3): + return ( + aten.addmm(b1, inp, w1), + aten.addmm(b2, inp, w2), + aten.addmm(b3, inp, w3), + ) + + def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_b = torch.cat((b1, b2, b3)) + return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) + + register_replacement( + addmm_fuse_pattern_second, + addmm_fuse_replacement_second, + [val() for _ in range(7)], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), + ) + + +def same_dtype(match): + return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"] + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + Ignored(), + KeywordArg("dtype"), + ), + pass_dict=pass_patterns[0], + extra_check=same_dtype, +) +def unnecessary_dtype_convert(match: Match, **kwargs): + """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding""" + graph = match.graph + node = match.output_node() + node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] + graph.erase_node(node) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed950afe9a186f306a047f84d7148f27178a849 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/fuse_attention.py @@ -0,0 +1,1089 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import logging +import math + +import torch + +from ..._dynamo.utils import counters +from ..pattern_matcher import ( + filter_nodes, + fwd_only, + gen_register_replacement, + joint_fwd_bwd, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +_scaled_dot_product_attention = aten.scaled_dot_product_attention + + +def _sfdp_pattern_1(query, key, value, inv_scale): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_1(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_2(query, key, value, scale_factor): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(scale_factor) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_2(query, key, value, scale_factor): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale_factor) + .softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_5(query, key, value, attn_mask): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + # attn_weight = torch.dropout(attn_weight, dropout_p) + return attn_weight @ value + + +def _sfdp_replacement_5(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_7(query, key, value, dropout_p): + # in real workloads inputs to matmul are permuted + # causing matmul to expand to a series of expand and clone calls + # we want the same to happen during pattern tracing + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_7(query, key, value, dropout_p): + # sdpa prefers inputs in permuted format + # it makes a copy to put them in this format + # if they aren't already + # to make replacement efficient ensure that inputs to sdpa + # are in required order + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_8(query, key, value): + # no dropout version of pattern 7 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_8(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_9(query, key, value, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_9(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_10(query, key, value): + # no dropout version of 9 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_10(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_11(query, key, value, inv_scale): + # Mainly for huggingface models + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v) + + +def _sfdp_replacement_11(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + +def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_13(query, key, value, dropout_p): + attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) + attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p) + return torch.bmm(attn_weight, value) + + +def _sfdp_replacement_13(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + dropout_p=dropout_p, + scale=1.0, + ).squeeze(0) + + +def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): + # for BertLarge + # Permutations are needed to create clones in graph. + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask) + .softmax(dim=-1) + .matmul(v) + ) + + +def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale): + # for DistilBert + # Permutations are needed to create clones in graph. + # Ref: https://github.com/pytorch/pytorch/issues/119911 + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v + + +def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): + # for BertLarge with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + torch.nn.functional.dropout( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax( + dim=-1 + ), + dropout_p, + ) + .to(dtype=query.dtype) + .matmul(v) + ) + + +def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p): + # for DistilBert with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return ( + torch.nn.functional.dropout( + torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p + ) + @ v + ) + + +def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p): + # for hf_GPT2 with dropout (introduces clone node) for inference + # it also returns permuted key & value + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + [], + value.size(-1) ** 0.5, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + return ( + ( + torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul( + value + ) + ), + key, + value, + ) + + +def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + permuted_key = key.transpose(1, 2) + permuted_value = value.transpose(1, 2) + return ( + _scaled_dot_product_attention( + query.transpose(1, 2), + permuted_key, + permuted_value, + attn_mask=causal_mask, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(value.size(-1)), + ), + permuted_key, + permuted_value, + ) + + +def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p): + # for token-classification+gpt2 / text-generation+gpt2 + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + [], + value.size(-1) ** 0.5, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + attn_weights = attn_weights + attn_mask + attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) + return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value) + + +def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = torch.where(causal_mask, attn_mask, fill_value) + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(value.size(-1)), + ) + + +def _sfdp_pattern_20(query, key, value, attn_mask, dropout_p): + # for DistilBert with dropout transformers==4.44.2 + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + q = q.div(math.sqrt(q.size(-1))) + scores = q @ k.transpose(-2, -1) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return ( + torch.nn.functional.dropout( + torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p + ) + @ v + ) + + +def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(query.size(-1)), + ) + + +def _sfdp_pattern_21(query, key, value, attn_mask): + # for T5 with inplace add + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + masked_score = score + attn_mask + score = masked_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value) + + +def _sfdp_replacement_21(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + scale=1.0, + ) + + +def _sfdp_pattern_22(query, key, value, attn_mask): + # for T5 with inplace add and return key and value + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + masked_score = score + attn_mask + score = masked_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value + + +def _sfdp_replacement_22(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return ( + _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + scale=1.0, + ), + key, + value, + ) + + +def _sfdp_pattern_23(query, key, value): + # for T5 with inplace add and + # return key and value and + # attn_mask is generated by atem.full(..., 0) + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + fp32_score = score.float() + score = fp32_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value + + +def _sfdp_replacement_23(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return ( + _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + is_causal=False, + scale=1.0, + ), + key, + value, + ) + + +def _sfdp_params_check(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + if not (query.dtype == key.dtype == value.dtype) or not ( + query.device == key.device == value.device + ): + return False + add_mask_node = filter_nodes(match.nodes, aten.add.Tensor) + # Has attn_mask add. + if len(add_mask_node) > 0: + attn_mask_node = add_mask_node[0].args[1] + # attn_mask_node may be a float/int number. + if not hasattr(attn_mask_node, "meta"): + return False + attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] + # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool + # attn_mask.dtype == torch.float for models like albert. + if ( + not isinstance(attn_mask, torch.Tensor) + or not ( + attn_mask.dtype == query.dtype + or attn_mask.dtype == torch.bool + or attn_mask.dtype == torch.float + ) + or query.device != attn_mask.device + # When we tensorify floats we end up turning floats + # into 0d scalar tensors. It doesn't make any sense + # to have a 0d scalar tensor attention mask so + # conveniently we can insert this check to get + # tests that erroneously passing in a float + # attention mask to fail as expected. + or attn_mask.dim() == 0 + ): + return False + return True + + +def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False): + def fn(match): + if ( + disable_cuda + and "query" in match.kwargs + and "cuda" in str(match.kwargs["query"].meta["val"].device) + ): + return False + if scale_factor_op is not None: + scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] + # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. + scale_factor = scale_factor_node.args[1] + # make sure the scale_factor a float/int. SymInt? + if not isinstance(scale_factor, (float, int)): + return False + return _sfdp_params_check(match) + + return fn + + +def partialize_and_update_signature(func, **kwargs): + """ + Equivalent to functools.partial but also updates the signature on returned function + """ + original_sig = inspect.signature(func) + parameters = original_sig.parameters + + new_parameters = { + key: value for key, value in parameters.items() if key not in kwargs + } + new_sig = inspect.Signature(parameters=list(new_parameters.values())) + + partial_func = functools.partial(func, **kwargs) + + def wrapper(*args, **kwargs): + return partial_func(*args, **kwargs) + + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ + + return wrapper + + +def _get_sfdp_patterns(): + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values don't actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + # attn_mask + b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device) + m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) + # need 2d attn_mask to generate patterns with view op + m_inp_2d = functools.partial(torch.empty, (2, 4), device=device) + # inv_scale + c_inp = functools.partial(torch.tensor, 2.0, device=device) + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + d = {"dropout_p": 0.113377} + + # we could also generate all these patterns in 3d.. TODO + g_3d_inp = functools.partial( + torch.empty, (1024, 128, 128), device=device, requires_grad=True + ) + + # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. + # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. + # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. + g_bs1_inp = functools.partial( + torch.empty, (1, 4, 8, 16), device=device, requires_grad=True + ) + m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) + + # softmax will generate a dtype conversion on inputs if they are in half, + # but will not in float, so we generate a pattern for both + for dtype in [torch.float, torch.half]: + g = functools.partial(g_inp, dtype=dtype) + b = functools.partial(b_inp, dtype=dtype) + b_float = functools.partial(b_inp, dtype=torch.float) + b_bool = functools.partial(b_inp, dtype=torch.bool) + m = functools.partial(m_inp, dtype=dtype) + m_float = functools.partial(m_inp, dtype=torch.float) + m_bool = functools.partial(m_inp, dtype=torch.bool) + m_2d = functools.partial(m_inp_2d, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + g_3d = functools.partial(g_3d_inp, dtype=dtype) + g_bs1 = functools.partial(g_bs1_inp, dtype=dtype) + m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) + m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float) + m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool) + + candidates = [ + ( + _sfdp_pattern_1, + _sfdp_replacement_1, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_2, + _sfdp_replacement_2, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_3, + _sfdp_replacement_3, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_4, + _sfdp_replacement_4, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_5, + _sfdp_replacement_5, + [g(), g(), g(), b()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_6, + _sfdp_replacement_6, + [g(), g(), g(), b()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_7, + _sfdp_replacement_7, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_8, + _sfdp_replacement_8, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_9, + _sfdp_replacement_9, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_10, + _sfdp_replacement_10, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_11, + _sfdp_replacement_11, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_12, + _sfdp_replacement_12, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_13, + _sfdp_replacement_13, + [g_3d(), g_3d(), g_3d()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_14, + _sfdp_replacement_14, + [g(), g(), g(), m(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_15, + _sfdp_replacement_15, + [g(), g(), g(), m_2d(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_17, + _sfdp_replacement_17, + [g(), g(), g(), m_2d(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_18, + _sfdp_replacement_18, + [g(), g(), g(), m_bool()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_18, + _sfdp_replacement_18, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_19, + _sfdp_replacement_19, + [g(), g(), g(), b_bool(), b_float()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_20, + _sfdp_replacement_20, + [g(), g(), g(), m_2d()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_21, + _sfdp_replacement_21, + [g(), g(), g(), m_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_22, + _sfdp_replacement_22, + [g(), g(), g(), m_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_23, + _sfdp_replacement_23, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ] + mask_fp32_patterns = ["pattern_16"] + if dtype == torch.half: + # Add inputs of bf16 q/k/v and fp32 mask, for models like albert. + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + + for pattern, replacement, args, workaround, extra_check in candidates: + # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern + # gets serialized to a python file and does not require tracing at runtime. + assert isinstance(workaround, dict) + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + if ( + any(p in name for p in mask_fp32_patterns) + and args[3].dtype == torch.float32 + ): + name += "_mask_fp32" + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield ( + training_name, + { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + }, + ) + + if workaround: + assert len(workaround) == 1 and "dropout_p" in workaround + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield ( + inference_name, + { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + # with dropout turned into clone, we end up with a number of + # semantically identical graphs + "skip_duplicates": True, + }, + ) + + +@functools.cache +def _sfdp_init(): + for key, register_replacement_kwargs in _get_sfdp_patterns(): + gen_register_replacement(key, **register_replacement_kwargs) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..357a9d66cdad744d99df3ce129236985cb2b7ddb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py @@ -0,0 +1,1426 @@ +# mypy: allow-untyped-defs +import collections +import logging +import operator +from collections import OrderedDict +from collections.abc import Iterable, Iterator +from typing import Any, Optional + +import torch +from torch._dynamo.utils import counters, is_node_meta_valid +from torch._logging import trace_structured +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..pattern_matcher import ( + CallFunctionVarArgs, + get_arg_value, + stable_topological_sort, +) +from ..utils import OPTIMUS_EXCLUDE_POST_GRAD + + +try: + # importing this will register fbgemm lowerings for inductor + import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401 + + has_fbgemm = True +except Exception: + has_fbgemm = False + +aten = torch.ops.aten + +log = logging.getLogger(__name__) + +DEFAULT_BETA = 1 +DEFAULT_ALPHA = 1 + +MIN_FUSE_SET_SIZE = 5 +MAX_FUSE_SET_SIZE = 300 +MAX_FUSE_SEARCH_DEPTH = 5 +# The maximum tensor size that can go into the fusion group +MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 +# Whether we only fuse nodes with same parent node +FUSE_NODES_WITH_SAME_PARENT = False +# Whether we enable the add broadcast in batch linear +SHAPE_BROADCAST_BATCH_LINEAR = False +# Whether we enable the fuse nodes with same users +Fuse_NODES_WITH_SAME_USERS = False + +# exclude these nodes from BFS +# excluding get item improves optimizer compilation time by 60s +SEARCH_EXCLUSIONS = OrderedSet([operator.getitem]) + + +default_graph_search_options = { + "min_fuse_set_size": MIN_FUSE_SET_SIZE, + "max_fuse_set_size": MAX_FUSE_SET_SIZE, + "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH, + "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR, + "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT, + "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR, + "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS, +} + +graph_search_options = default_graph_search_options + + +def update_stack_example_value(node, metadata, dim=0, op=torch.stack): + """ + Update the example value of the node in the graph to enable followup split cat opt. + """ + if node is not None and hasattr(node, "meta"): + if op == torch.stack: + example_value = torch.stack(metadata, dim=dim) + elif op == torch.unbind: + example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment] + else: + return + node.meta["example_value"] = example_value + + +def update_pointwise_example_value(pointwise_node, input, other, op): + """ + Update the example value of the add node in the graph to enable followup split cat opt. + """ + if pointwise_node is not None and hasattr(pointwise_node, "meta"): + if op == torch.add: + example_value = torch.add(input, other) + elif op == torch.mul: + example_value = torch.mul(input, other) + else: + return + pointwise_node.meta["example_value"] = example_value + + +class GroupBatchFusionBase: + def __init__(self, **kwargs) -> None: + self.graph_search_options = kwargs.pop( + "graph_search_options", default_graph_search_options + ) + + def match(self, node): + raise NotImplementedError("match called on base") + + def fuse(self, graph, subset): + raise NotImplementedError("fuse called on base") + + +PRE_GRAD_FUSIONS: dict[str, GroupBatchFusionBase] = {} +POST_GRAD_FUSIONS: dict[str, GroupBatchFusionBase] = {} + + +def register_fusion(name: str, pre_grad=True): + def decorator(fusion_cls: GroupBatchFusionBase): + if pre_grad: + PRE_GRAD_FUSIONS[name] = fusion_cls + else: + POST_GRAD_FUSIONS[name] = fusion_cls + return fusion_cls + + return decorator + + +def list_group_batch_fusions(pre_grad=True) -> list[str]: + if pre_grad: + return list(PRE_GRAD_FUSIONS.keys()) + else: + return list(POST_GRAD_FUSIONS.keys()) + + +def decompose_stack(graph: torch.fx.GraphModule, input_tensors: list[Any]) -> Any: + unsqueezed_inputs = [] + unsqueezed_inputs_meta = [] + for input_tensor in input_tensors: + unsqueezed_input = graph.call_function( # type: ignore[operator] + aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0} + ) + unsqueezed_inputs.append(unsqueezed_input) + unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment] + unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"]) + stacked_inputs = graph.call_function( # type: ignore[operator] + aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0} + ) + stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment] + return stacked_inputs + + +class GroupFusion(GroupBatchFusionBase): + """ + Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm. + """ + + +class BatchFusion(GroupBatchFusionBase): + """ + Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm. + """ + + +class BatchPointwiseOpsFusionFactory(BatchFusion): + def __init__(self, op, **kwargs) -> None: + super().__init__(**kwargs) + self.op = op + + +@register_fusion("batch_linear_post_grad", pre_grad=False) +class PostGradBatchLinearFusion(BatchFusion): + """ + Fuse ops in a batch way in post grad (aten level). + """ + + def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: + # pyre-fixme[7]: Incompatible return type + return ( + node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA + and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA # type: ignore[return-value] + ) + + def _is_input_2d(self, input: torch.fx.Node) -> bool: + input_shapes = input.meta["val"].shape + return ( + len(input_shapes) == 2 + and isinstance(input_shapes[0], int) + and isinstance(input_shapes[1], int) + ) + + def match( + self, node: torch.fx.Node + ) -> Optional[tuple[str, int, int, int, bool, str]]: + if CallFunctionVarArgs(aten.mm).match(node): + input_m, weight_m = node.args + bias_m = None + + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias_m, input_m, weight_m = node.args + else: + return None + # get the user of the node + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + # only handle the cases where inputs are 2D tensors + if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type] + return None + m, k = input_m.meta["val"].shape # type: ignore[union-attr] + n = weight_m.meta["val"].shape[1] # type: ignore[union-attr] + batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users)) + return batch_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_nodes = [] + batch_inputs_meta = [] + batch_weights_meta = [] + batch_biases_meta = [] + + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + elif CallFunctionVarArgs(aten.mm.default).match(node): + input, weight = node.args + bias = None + batch_nodes.append(node) + batch_inputs.append(input) # type: ignore[possibly-undefined] + batch_weights.append(weight) # type: ignore[possibly-undefined] + batch_biases.append(bias) # type: ignore[possibly-undefined] + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr] + if bias is not None: # type: ignore[possibly-undefined] + batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr] + else: + batch_biases_meta.append(None) + + with graph.inserting_before(subset[-1]): # type: ignore[operator] + fused_inputs = decompose_stack(graph, batch_inputs) + fused_weights = decompose_stack(graph, batch_weights) + fused_inputs_meta_val = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + fused_weights_meta_val = torch.stack( + [weight["val"] for weight in batch_weights_meta] + ) + fused_bmm = graph.call_function( # type: ignore[operator] + aten.bmm, + args=(fused_inputs, fused_weights), + ) + fused_bmm.meta["val"] = aten.bmm( + fused_inputs_meta_val, fused_weights_meta_val + ) + for i, original_mm in enumerate(batch_nodes): + has_bias = False + with graph.inserting_after(fused_bmm): # type: ignore[operator] + new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i))) # type: ignore[operator] + new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i) + if batch_biases[i]: + has_bias = True + # broadcast the bias to the same shape as the mm output + if self.graph_search_options.get( + "shape_broadcast_batch_linear", False + ): + broadcast_shape = torch.broadcast_shapes( + batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape + ) + broadcast_bias = graph.call_function( # type: ignore[operator] + aten.broadcast_to.default, + args=(batch_biases[i],), + kwargs={"size": broadcast_shape}, + ) + broadcast_bias.meta["val"] = aten.broadcast_to( + batch_biases_meta[i]["val"], broadcast_shape + ) # type: ignore[assignment] + new_bias_add = graph.call_function( # type: ignore[operator] + aten.add.Tensor, args=((broadcast_bias, new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + broadcast_bias.meta["val"], new_mm.meta["val"] + ) + else: + new_bias_add = graph.call_function( # type: ignore[operator] + aten.add, args=((batch_biases[i], new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + batch_biases_meta[i]["val"], new_mm.meta["val"] + ) + new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] + original_mm.replace_all_uses_with(new_mm_cont) + new_mm_cont.meta.update(original_mm.meta) + graph.erase_node(original_mm) # type: ignore[operator] + counters["inductor"]["batch_linear_post_grad"] += 1 + + +@register_fusion("group_linear", pre_grad=False) +class GroupLinearFusion(GroupFusion): + def _addmm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr] + return ( + node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA + and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA + and len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def _mm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + return ( + len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool]]: + if CallFunctionVarArgs(aten.mm.default).match( + node + ) and self._mm_node_can_be_fused(node): + group_key = ("group_linear", True) + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias = node.args[0] + group_key = ("group_linear", bias is None) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + group_inputs = [] + group_weights = [] + group_biases = [] + group_nodes = [] + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + else: + assert CallFunctionVarArgs(aten.mm.default).match(node) + input, weight = node.args + bias = None + + group_nodes.append(node) + group_inputs.append(input) + group_weights.append(weight) + group_biases.append(bias) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + + with graph.inserting_before(subset[0]): # type: ignore[operator] + fused_mm = graph.call_function( # type: ignore[operator] + torch.ops.fbgemm.gmm.default, + args=(group_inputs, group_weights, group_biases), + kwargs={"smart_fused": True}, + ) + + for i, original_mm in enumerate(group_nodes): + with graph.inserting_after(fused_mm): # type: ignore[operator] + new_mm = graph.call_function(operator.getitem, args=(fused_mm, i)) # type: ignore[operator] + original_mm.replace_all_uses_with(new_mm) + new_mm.meta.update(original_mm.meta) + graph.erase_node(original_mm) # type: ignore[operator] + counters["inductor"]["group_linear"] += 1 + + +class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise math operator (e.g., add, mul) in post grad pass. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def _pointwise_node_can_be_fused(self, node: torch.fx.Node): + # note: we only consider the case where the inputs are tensors + # for mixed precision training, we need to make sure the inputs + # of the aten.cat when do the stack should be the same dtype + # otherwise, the output of the aten.cat may be not the same as + # its inputs, and cause dtype not same error in mm or addmm + input, other = node.args + return ( + input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr] + # input and other can be scalars, where they have no attribute 'meta' + if hasattr(input, "meta") + and hasattr(other, "meta") + and is_node_meta_valid(input) # type: ignore[arg-type, union-attr] + and is_node_meta_valid(other) # type: ignore[arg-type, union-attr] + # torch.SymInt or torch.SymFloat object has no attribute 'shape' + and isinstance(input.meta["val"], torch.Tensor) # type: ignore[union-attr] + and isinstance(other.meta["val"], torch.Tensor) # type: ignore[union-attr] + else False + ) + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(self.op).match( + node + ) and self._pointwise_node_can_be_fused(node): + alpha = node.kwargs.get("alpha", DEFAULT_ALPHA) + rounding_mode = node.kwargs.get("rounding_mode", None) + input, other = node.args + shape = list(input.meta["val"].shape) # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # only consider the linear case so far + # pyre-fixme[16] + if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr] + parent = ( + # pyre-fixme[16] + input.args[0] # type: ignore[union-attr] + # pyre-fixme[16] + if input.target == aten.select # type: ignore[union-attr] + else other.args[0] # type: ignore[union-attr] + ) + else: + parent = "" + else: + parent = "" + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(shape), + str(input.meta["val"].dtype), # type: ignore[union-attr] + str(other.meta["val"].dtype), # type: ignore[union-attr] + str(alpha), + str(rounding_mode), + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_inputs, batch_others = [], [] + alpha = subset[0].kwargs.get("alpha", DEFAULT_ALPHA) + batch_inputs_meta, batch_others_meta = [], [] + + for node in subset: + input, other = node.args + batch_inputs.append(input) + batch_others.append(other) + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr] + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = decompose_stack(graph, batch_inputs) + stack_others = decompose_stack(graph, batch_others) + stack_inputs_meta = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + stack_others_meta = torch.stack( + [other["val"] for other in batch_others_meta] + ) + + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs, stack_others), + kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {}, + ) + batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta) + for i, original_add in enumerate(subset): + with graph.inserting_after(batch_op): # type: ignore[operator] + new_add = graph.call_function( # type: ignore[operator] + torch.ops.aten.select, args=((batch_op, 0, i)) + ) + original_add.replace_all_uses_with(new_add) + new_add.meta.update(original_add.meta) + graph.erase_node(original_add) # type: ignore[operator] + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + +@register_fusion("batch_linear_lhs") +class BatchLinearLHSFusion(BatchFusion): + """ + Batch linear left-hand side fusion. This pass tries to fuse the following patterns: + + torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn) + -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1)) + + We have a separate pass to eliminate contiguous transpose in a generic way. + """ + + def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool, Any]]: + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + bias = get_arg_value(node, 2, "bias") + group_key = ("batch_linear_lhs", bias is None, input) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_input = None + batch_weights, batch_weights_meta = [], [] + batch_biases, batch_biases_meta = [], [] + split_sections = [] + for node in subset: + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + batch_nodes.append(node) + if batch_input is None: + batch_input = input + else: + assert batch_input is input + batch_weights.append(weight) + batch_weights_meta.append(weight.meta["example_value"]) + if bias: + batch_biases.append(bias) + batch_biases_meta.append(bias.meta["example_value"]) + split_sections.append(weight.meta["example_value"].shape[0]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + cat_weights = graph.call_function( # type: ignore[operator] + torch.cat, args=(batch_weights,), kwargs={"dim": 0} + ) + cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0) + transposed_weights = graph.call_function( # type: ignore[operator] + torch.transpose, args=(cat_weights, 0, 1) + ) + transposed_weights.meta["example_value"] = torch.transpose( + cat_weights.meta["example_value"], 0, 1 + ) + if len(batch_biases) > 0: + cat_biases = graph.call_function( # type: ignore[operator] + torch.cat, args=(batch_biases,), kwargs={"dim": 0} + ) + cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0) + fused_lhs = graph.call_function( # type: ignore[operator] + torch.addmm, + args=(cat_biases, batch_input, transposed_weights), + ) + fused_lhs.meta["example_value"] = torch.addmm( + cat_biases.meta["example_value"], + batch_input.meta["example_value"], # type: ignore[union-attr] + transposed_weights.meta["example_value"], + ) + else: + fused_lhs = graph.call_function( # type: ignore[operator] + torch.mm, + args=(batch_input, transposed_weights), + ) + fused_lhs.meta["example_value"] = torch.mm( + batch_input.meta["example_value"], # type: ignore[union-attr] + transposed_weights.meta["example_value"], + ) + fused_lhs_list = graph.call_function( # type: ignore[operator] + torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1} + ) + + for i, node in enumerate(batch_nodes): + with graph.inserting_after(fused_lhs_list): # type: ignore[operator] + new_node = graph.call_function( # type: ignore[operator] + operator.getitem, args=(fused_lhs_list, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_linear_lhs"] += 1 + + +# Poor person's check for if a node in the graph mutates its input. +# (the graph is torch IR, so we will see torch fns and python operators) +def _is_mutable_node(tgt): + if str(tgt).endswith("_"): + # e.g. torch.mul_, torch.Tensor.mul_ + return True + if ( + hasattr(tgt, "__module__") + and tgt.__module__ == "_operator" + and tgt.__name__.startswith("i") + ): + # e.g. operator.iand, operator.imul + return True + return False + + +def is_linear_node_can_be_fused(node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + return ( + is_node_meta_valid(node) + and is_node_meta_valid(input) + and is_node_meta_valid(weight) + and len(input.meta["example_value"].shape) == 2 + and len(weight.meta["example_value"].shape) == 2 + # the mm -> bmm transform adds an unbind() op, + # which is not safe for autograd when the output of the mm is mutated. + # don't pattern match if any users of the mm mutate the input. + and not any(_is_mutable_node(user.target) for user in node.users) + ) + + +@register_fusion("batch_linear") +class PreGradBatchLinearFusion(BatchFusion): + """ + Batch linear fusion in pre grad pass. + Fuse linear with same size with torch.baddmm + """ + + def _getitem_args(self, getitem_node: torch.fx.Node): + if getitem_node.target != operator.__getitem__ or ( + getitem_node.op != "call_function" + ): + return None + return getitem_node.args[0] + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + group_key = ( + "batch_linear", + self._getitem_args(input), + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape), + bias is None, + str(users), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_inputs_metadata = [] + batch_weights_metadata = [] + batch_biases_metadata = [] + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + weight = get_arg_value(node, 1, "weight") + batch_weights.append(weight) + batch_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 2, "bias") + batch_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + batch_biases_metadata.append(bias.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + stack_weights = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weights, batch_weights_metadata) + transpose_weight = graph.call_function( # type: ignore[operator] + torch.transpose, args=(stack_weights, 1, 2) + ) + transpose_weight.meta["example_value"] = torch.transpose( + stack_weights.meta["example_value"], 1, 2 + ) + if all(bias is None for bias in batch_biases): + bmm = graph.call_function( # type: ignore[operator] + torch.bmm, + args=(stack_inputs, transpose_weight), + ) + bmm.meta["example_value"] = torch.bmm( + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + else: + stack_biases = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_biases, batch_biases_metadata) + unsqueeze_biases = graph.call_function( # type: ignore[operator] + torch.unsqueeze, args=(stack_biases, 1) + ) + unsqueeze_biases.meta["example_value"] = torch.unsqueeze( + stack_biases.meta["example_value"], 1 + ) + bmm = graph.call_function( # type: ignore[operator] + torch.baddbmm, + args=(unsqueeze_biases, stack_inputs, transpose_weight), + ) + try: + # it will have runtime error to broadcast when it has dynamic shape included + # in the meta data, so we need to skip the update meta data + bmm.meta["example_value"] = torch.baddbmm( + unsqueeze_biases.meta["example_value"], + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + except Exception as e: + log.debug( + f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004 + ) + bmm_meta = None + + bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0}) # type: ignore[operator] + if bmm_meta is not None: + bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) + for i, linear in enumerate(batch_nodes): + with graph.inserting_after(bmm): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(bmm, i)) # type: ignore[operator] + linear.replace_all_uses_with(getitem) + getitem.meta.update(linear.meta) + graph.erase_node(linear) # type: ignore[operator] + counters["inductor"]["batch_linear"] += 1 + + +@register_fusion("batch_layernorm") +class BatchLayernormFusion(BatchFusion): + """ + Batch layer norm fusion in pre grad pass + """ + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 2, "weight") + bias = get_arg_value(node, 3, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + group_key = ( + ( + "batch_layernorm", + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape) + if weight is not None + else "", + str(bias.meta["example_value"].shape) if bias is not None else "", + str(get_arg_value(node, 1, "normalized_shape")), + str(get_arg_value(node, 4, "eps")), + str(users), + ) + if "example_value" in input.meta + and is_node_meta_valid(weight) + and is_node_meta_valid(bias) + else None + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + group_inputs = [] + group_shapes = [] + group_weights = [] + group_biases = [] + group_epss = [] + group_nodes = [] + group_inputs_metadata = [] + group_biases_metadata = [] + group_weights_metadata = [] + for node in subset: + group_nodes.append(node) + input = get_arg_value(node, 0, "input") + group_inputs.append(input) + group_inputs_metadata.append(input.meta["example_value"]) + group_shapes.append(get_arg_value(node, 1, "normalized_shape")) + weight = get_arg_value(node, 2, "weight") + group_weights.append(weight) + if weight is not None and hasattr(weight, "meta"): + group_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 3, "bias") + group_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + group_biases_metadata.append(bias.meta["example_value"]) + eps = get_arg_value(node, 4, "eps") + if eps is None: + eps = 1e-5 + group_epss.append(eps) + stack_dim = -1 - len(group_shapes[-1]) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + if all(weight is None for weight in group_weights): + group_weights = None # type: ignore[assignment] + assert all(eps == group_epss[0] for eps in group_epss), ( + "all epsilon values must be equal" + ) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_input = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim} + ) + update_stack_example_value(stack_input, group_inputs_metadata, stack_dim) + if group_weights is not None: + stack_weight = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weight, group_weights_metadata) + else: + stack_weight = None + if group_biases is not None: + stack_bias = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_bias, group_biases_metadata) + else: + stack_bias = None + + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.nn.functional.layer_norm, + args=(stack_input, group_shapes[-1]), + kwargs={"eps": group_epss[-1]}, + ) + batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"] + + if group_weights is not None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + elif group_weights is not None and group_biases is None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + elif group_weights is None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + + batch_layer_norm_unbind = graph.call_function( # type: ignore[operator] + torch.unbind, + args=(batch_layer_norm,), + kwargs={"dim": stack_dim}, + ) + update_stack_example_value( + batch_layer_norm_unbind, + batch_layer_norm.meta["example_value"], + op=torch.unbind, + dim=stack_dim, + ) + + for i, node in enumerate(group_nodes): + with graph.inserting_after(batch_layer_norm_unbind): # type: ignore[operator] + new_node = graph.call_function( # type: ignore[operator] + operator.getitem, args=(batch_layer_norm_unbind, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_layernorm"] += 1 + + +class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass. + We fuse it in random place, and the introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # pyre-fixme[16] + parent = node.args[0] + parent = parent.target if parent is not None else "" # type: ignore[union-attr] + else: + parent = "" + # for relu op, we also use the inplace to construct the key + group_key = ( + "batch_" + self.op.__name__.lower().split(".")[0], + str(input.meta["example_value"].shape), + str(node.kwargs.get("inplace", False)), + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + if self.op == torch.nn.functional.relu: + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + kwargs={"inplace": subset[0].kwargs.get("inplace", False)}, + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"], + inplace=subset[0].kwargs.get("inplace", False), + ) + else: + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"] + ) + unbind_op = graph.call_function( # type: ignore[operator] + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + unbind_op.meta["example_value"] = torch.unbind( + batch_op.meta["example_value"], dim=0 + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 + + +class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass. + The introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # for relu op, we also use the inplace to construct the key + # we batch the ops with same parent to enable followup split cat + parent = node.args[0] + parent = ( + parent.target # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False) + else "" + ) + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(input.meta["val"].shape), + str(node.kwargs.get("inplace", False)), + # pyre-fixme[16] + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["val"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = decompose_stack(graph, batch_inputs) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(batch_op): # type: ignore[operator] + getitem = graph.call_function(aten.select, args=(batch_op, 0, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + +class BatchMathOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch simple match related ops such as nan_to_num in pre grad pass. + """ + + def __init__(self, op, **kwargs): + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # check the input has the same shape and its users have the same target + # check all clamp operators have the same min and max values, and + # nan_to_num operators use the same default value. + child = next(iter(node.users.keys())) + group_key = ( + str(input.meta["example_value"].shape) + + str(node.kwargs) + + str(child.target) + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + kwargs = subset[0].kwargs + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + kwargs=kwargs, + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"], **kwargs + ) + unbind_op = graph.call_function( # type: ignore[operator] + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + unbind_op.meta["example_value"] = torch.unbind( + batch_op.meta["example_value"], dim=0 + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 + + +@register_fusion("batch_tanh") +class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.tanh, **kwargs) + + +@register_fusion("batch_sigmoid") +class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.sigmoid, **kwargs) + + +@register_fusion("batch_relu") +class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.nn.functional.relu, **kwargs) + + +@register_fusion("batch_detach") +class BatchDetachPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.detach, **kwargs) + + +@register_fusion("batch_nan_to_num") +class BatchNanToNumPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.nan_to_num, **kwargs) + + +@register_fusion("batch_clamp") +class BatchClampPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.clamp, **kwargs) + + +@register_fusion("batch_aten_tanh", pre_grad=False) +class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.tanh.default, **kwargs) + + +@register_fusion("batch_aten_sigmoid", pre_grad=False) +class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.sigmoid.default, **kwargs) + + +@register_fusion("batch_aten_relu", pre_grad=False) +class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.relu.default, **kwargs) + + +@register_fusion("batch_aten_add", pre_grad=False) +class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.add.Tensor, **kwargs) + + +@register_fusion("batch_aten_sub", pre_grad=False) +class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.sub.Tensor, **kwargs) + + +@register_fusion("batch_aten_div", pre_grad=False) +class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.div.Tensor, **kwargs) + + +@register_fusion("batch_aten_mul", pre_grad=False) +class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.mul.Tensor, **kwargs) + + +class _OrderedSet: + def __init__(self, param=None) -> None: + if param: + self.rep = OrderedDict(dict.fromkeys(param)) + else: + self.rep = OrderedDict() + + def __contains__(self, o) -> bool: + return o in self.rep + + def __len__(self) -> int: + return self.rep.__len__() + + def append(self, o): + self.rep[o] = None + + def __iter__(self): + return self.rep.keys().__iter__() + + +def find_independent_subset_greedy( + node_list: Iterable[torch.fx.Node], + graph_search_options: dict[str, Any], +) -> Iterator[Iterable[torch.fx.Node]]: + """ + Yields a list of subsets of `node_list` where no element in the subset + depends on any other element in the subset. This results in a set of + independent nodes which can be fused together. + + The order of `node_list` is preserved within each subset so we can benefit + from split-cat elimination in later passes. + + During iteration it is only safe to mutate the graph by changing the nodes + that have been returned. + + graph_search_options: + - min_fuse_set_size: Minimum size of the subset to consider. Subsets below + this size will be ignored. + - max_fuse_set_size: Maximum size of the subset to consider. Subsets will + be broken to be at most this size. + """ + + # Compute all the children of `node` which are members of + # `interesting_nodes`. + def find_dependent_nodes(node, interesting_nodes): + visited_node_set = OrderedSet[torch.fx.Node]() + dep_set = OrderedSet[torch.fx.Node]() + + work = [node] + while work: + node = work.pop() + for input_node in node.all_input_nodes: + if input_node in interesting_nodes: + dep_set.add(input_node) + + if input_node not in visited_node_set: + visited_node_set.add(input_node) + work.append(input_node) + + return dep_set + + min_fuse_set_size = graph_search_options["min_fuse_set_size"] + max_fuse_set_size = graph_search_options["max_fuse_set_size"] + + # node_list needs to be a set because we only track the nodes that are left + # in it (and we want to do the `in` on a set, not a list). But we want to + # keep the correct order. + node_list = _OrderedSet(node_list) + + cache: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = {} + while node_list: + subset: list[torch.fx.Node] = [] + subset_deps = OrderedSet[torch.fx.Node]() + + next_round_node_list = _OrderedSet() + for node in node_list: + if len(subset) >= max_fuse_set_size or node in subset_deps: + next_round_node_list.append(node) + continue + + dep_set = cache.pop(node, None) + if dep_set is None: + dep_set = find_dependent_nodes(node, node_list) + + if not dep_set.intersection(subset): + subset.append(node) + subset_deps.update(dep_set) + else: + next_round_node_list.append(node) + cache[node] = dep_set + + if len(subset) >= min_fuse_set_size: + # Careful here - the caller uses the subsets to fuse nodes together + # so we need to clear any cache entry that contains one of the + # returned nodes because the dependency list could be different + # (larger) after the merge. + cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)} + yield subset + + node_list = next_round_node_list + + +def get_fusion_candidates( + rule: GroupBatchFusionBase, + root_node: torch.fx.Node, + fused_set: OrderedSet[torch.fx.Node], +) -> collections.defaultdict[Any, list[torch.fx.Node]]: + """ + Search fusion candidates for a specific rule using BFS starting from the root node. + We only search the subgraph within graph_search_options["max_fuse_search_depth"]. + """ + q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque() + + candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = ( + collections.defaultdict(list) + ) + + if root_node.target in SEARCH_EXCLUSIONS: + return candidate_dict + + visited_set = OrderedSet[torch.fx.Node]() + + for next_node in root_node.all_input_nodes: + q.append((1, next_node)) + visited_set.add(next_node) + + while len(q) > 0: + depth, node = q.popleft() + + if node in fused_set: + continue + + key = rule.match(node) + if key is not None: + candidate_nodes = candidate_dict[key] + if node not in candidate_nodes: + candidate_nodes.append(node) + else: + if depth < rule.graph_search_options["max_fuse_search_depth"]: + for next_node in node.all_input_nodes: + if next_node not in visited_set: + visited_set.add(next_node) + q.append((depth + 1, next_node)) + + return candidate_dict + + +def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase): + stable_topological_sort(graph) # type: ignore[arg-type] + fused_set = OrderedSet[torch.fx.Node]() + log_to_scuba = False + + for node in reversed(graph.nodes): # type: ignore[arg-type] + candidates = get_fusion_candidates(rule, node, fused_set) + + for key, candidate_nodes in candidates.items(): + if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]: + continue + + for subset in find_independent_subset_greedy( + candidate_nodes, rule.graph_search_options + ): + rule.fuse(graph, subset) + fused_set.update(subset) + log.debug( + f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004 + ) + log_to_scuba = True + if log_to_scuba: + from torch.fx._lazy_graph_module import _LazyGraphModule + + # Force graph to re-compile otherwise the output python code may be broken + gm = graph._owning_module + if isinstance(gm, _LazyGraphModule): + _LazyGraphModule.recompile() + else: + assert isinstance(gm, torch.fx.GraphModule) + gm.recompile() + graph_str = gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"optimus_{str(rule.__class__.__name__)}", + "encoding": "string", + }, + payload_fn=lambda: graph_str, + ) + + +def generate_fusion_from_config(config_options: dict[str, Any], pre_grad=True): + fusions: list[GroupBatchFusionBase] = [] + for name, options in config_options.items(): + # we skip all patterns from pattern_matcher passes (e.g., split_cat) + if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS: + continue + fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name] + _options = graph_search_options.copy() + _options.update(options) + fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator] + return fusions + + +def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): + fusions: list[GroupBatchFusionBase] = [] + # we keep all current pre grad fusions to keep + # current implementation, will remove this later + if pre_grad: + fusions += generate_fusion_from_config( + config.pre_grad_fusion_options, pre_grad=True + ) + else: + fbgemm_fusion_keys = [ + x + for x in config.post_grad_fusion_options + if ( + x not in OPTIMUS_EXCLUDE_POST_GRAD + and config.post_grad_fusion_options[x].get("require_fbgemm", False) + ) + ] + fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in fbgemm_fusion_keys + } + non_fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in config.post_grad_fusion_options.keys() + if fusion not in fbgemm_fusion_keys + } + fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False) + if has_fbgemm: + fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) + + for i, rule in enumerate(fusions): + with GraphTransformObserver( + graph.owning_module, + f"group_batch_fusion_{i}", + ): + apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d7187de0d9bcccc20196976dbe2795168a9357 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/joint_graph.py @@ -0,0 +1,943 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import operator +import typing +from collections import Counter +from collections.abc import Sequence +from typing import Any, Union + +import torch +import torch._guards +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters +from torch._inductor.constant_folding import ConstantFolder +from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict +from torch._inductor.utils import get_gpu_type +from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + statically_known_true, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..pattern_matcher import ( + Arg, + CallFunction, + init_once_fakemode, + KeywordArg, + Match, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, + stable_topological_sort, +) +from .decompose_mem_bound_mm import check_device +from .replace_random import replace_random_passes + + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten +prims = torch.ops.prims + +pass_patterns = [ + patterns, + PatternMatcherPass(), +] + + +@init_once_fakemode +def lazy_init(): + from .fuse_attention import _sfdp_init + from .misc_patterns import _misc_patterns_init + from .pad_mm import _pad_mm_init + + _pad_mm_init() + _sfdp_init() + _misc_patterns_init() + + +def remove_no_ops( + gm: torch.fx.GraphModule, + zeros: OrderedSet[torch.fx.Node], + ones: OrderedSet[torch.fx.Node], +): + with torch.utils._python_dispatch._disable_current_modes(): + "Removes no-ops: (+ 0, - 0, * 1, / 1)" + graph = gm.graph + + def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")): + if any(not isinstance(t, torch.Tensor) for t in (t1, t2)): + return False + for field in fields: + if getattr(t1, field) != getattr(t2, field): + return False + return True + + def replace_no_op(node, replace_input_index): + replacement = node.args[replace_input_index] + + # https://github.com/pytorch/pytorch/issues/86128 causes + # non-Tensor inputs even for ops with only Tensor inputs. + # TODO - decompose/type promote to avoid this + if not all(isinstance(arg, torch.fx.Node) for arg in node.args): + return + + if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]): + if fake_tensors_eq( + node.meta["val"], + replacement.meta["val"], + ("shape", "device"), + ): + with graph.inserting_after(node): + replacement = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(replacement, node.meta["val"].dtype), + ) + else: + return + + node.replace_all_uses_with(replacement) + replacement.meta.update(node.meta) + graph.erase_node(node) + + for node in graph.find_nodes(op="call_function", target=aten.add.Tensor): + # TODO handle Tensor-Scalar adds, it's a different schema + if len(node.args) == 2: + if ( + not any(e in zeros for e in node.args) + or node.kwargs.get("alpha", 1) != 1 + ): + continue + + replace_index = 1 if node.args[0] in zeros else 0 + replace_no_op(node, replace_index) + + for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor): + if len(node.args) == 2: + if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1: + continue + + replace_no_op(node, 0) + + for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor): + if len(node.args) == 2: + if not any(e in ones for e in node.args): + continue + + replace_input_index = 1 if node.args[0] in ones else 0 + replace_no_op(node, replace_input_index) + + for node in graph.find_nodes(op="call_function", target=aten.div.Tensor): + if len(node.args) == 2 and node.args[1] in ones: + replace_no_op(node, 0) + + # meta tensors returned from the graph have no data and can be replaced with empty_strided + for output_node in graph.find_nodes(op="output"): + had_meta_return = False + + def visit(n): + nonlocal had_meta_return + val = n.meta.get("val") + if isinstance(val, torch.Tensor) and val.device.type == "meta": + with graph.inserting_before(output_node): + n.replace_all_uses_with( + graph.call_function( + torch.ops.aten.empty_strided.default, + args=(val.size(), val.stride()), + kwargs={"dtype": val.dtype, "device": val.device}, + ) + ) + had_meta_return = True + + torch.fx.map_arg(output_node.args, visit) + if had_meta_return: + graph.eliminate_dead_code() + + +def remove_redundant_views(gm: torch.fx.GraphModule): + """ + Removes redundant views by reusing existing ones. + """ + with torch.utils._python_dispatch._disable_current_modes(): + # A dictionary mapping a tensor to all aliased views. + views: dict[torch.fx.Node, dict[torch.dtype, torch.fx.Node]] = {} + graph = gm.graph + + for node in graph.find_nodes( + op="call_function", target=torch.ops.aten.view.dtype + ): + src = node.args[0] + to_type = node.args[1] + existing_views = views.get(src) + is_needed = True + + if existing_views: + # Replace the view with the an existing view if available. + alias = existing_views.get(to_type) + if alias: + is_needed = False + node.replace_all_uses_with(alias) + alias.meta.update(node.meta) + graph.erase_node(node) + else: + from_type = src.meta["val"].dtype + existing_views = {from_type: src} + views[src] = existing_views + + if is_needed: + # Save the new alias but do not replace existing one. + existing_views.setdefault(to_type, node) + views[node] = existing_views + + # Clean up unused views. + while True: + unused_views = [alias for alias in views if not alias.users] + if len(unused_views) == 0: + break + for unused in unused_views: + views.pop(unused) + graph.erase_node(unused) + + +class UniformValueConstantFolder(ConstantFolder): + """ + Runs constant folding and replaces tensors that have a uniform value + with a tensor constructor call: aten.full([shape], value, ...) + """ + + def __init__(self, gm, skip_constructors=False) -> None: + super().__init__(gm, skip_constructors) + self.node_storages_ptrs: dict[torch.fx.Node, int] = {} + self.constant_data_ptrs: dict[torch.fx.Node, StorageWeakRef] = {} + # we may constant fold a tensor which in the graph has a sym size + # see: [constant folding refining of symints] + self.node_replacements_shapes: dict[torch.fx.Node, list[int]] = {} + + # initialize symint -> node mapping so that we can + # use symint nodes in full constructors + self.symint_nodes = _SymHashingDict() + for n in self.module.graph.nodes: # type: ignore[union-attr] + if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): + self.symint_nodes[n.meta["val"]] = n + + # reference from torch/_funtorch/partitioners.py:get_default_op_list + self.view_op_packets = [ + aten.squeeze, + aten.unsqueeze, + aten.alias, + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + ] + + self.indexing_op_packets = OrderedSet( + [ + aten.slice, + ] + ) + + self._add_peephole_patterns() + + def _add_peephole_patterns(self) -> None: + """ + Add peephole patterns for nodes where we can infer constant value even if some inputs + of the node are unknown. + """ + for op in itertools.chain( + self.module.graph.find_nodes( # type: ignore[operator, union-attr] + op="call_function", target=torch.ops.aten.mul.Tensor + ), + self.module.graph.find_nodes( # type: ignore[operator, union-attr] + op="call_function", target=torch.ops.aten.mul.Scalar + ), + ): + tensor_val = op.meta.get("val", None) + if not isinstance(tensor_val, torch.Tensor): + continue + + def is_zero_int(arg: Any) -> bool: + return isinstance(arg, int) and arg == 0 + + if not any(is_zero_int(a) for a in op.args): + continue + + t = torch.full( + [1], # shape + 0, # value + dtype=tensor_val.dtype, + device=tensor_val.device, + pin_memory=False, + ) + self.add_node_replacement(op, t) + + def _support_dynamic_shape(self): + return True + + def insertable_tensor_check(self, t: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor.flatten()[0].item() + self.node_replacements_shapes[node] = node.meta["val"].shape + self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage()) + + def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None: + for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] + if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): + env[n] = n.meta["val"] + else: + env[n] = self.unknown_value + + def _deduce_value(self, node: torch.fx.Node): + # deduce value for full-like nodes + # 1. for constructors, substitute value is a tensor of size [1] + # 2. for view ops/indexing, substitute value is the same as the input + # 3. for pointwise ops, run node to get the substitute value + # 4. deal with some special ops + # otherwise, stop deduce value and return unknown value + + # TODO: cat, more indexing + # TODO - do on cpu to avoid syncs + + # single-elem attrs + if node.op == "get_attr" or ( + node.op == "call_function" + and node.target == torch.ops.aten.lift_fresh_copy.default + ): + out = super(ConstantFolder, self).run_node(node) + if isinstance(out, torch.Tensor) and out.numel() == 1: + return out + + # handle device_put op + if node.target == prims.device_put.default: + return super(ConstantFolder, self).run_node(node) + + # constructors ops + if ( + node.op == "call_function" + and node.target == aten.full.default + and len(node.args) == 2 + ): + args, kwargs = self.fetch_args_kwargs_from_env(node) + value = args[1] + # Don't specialize symbolic value. + if not isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)): + new_args = [[1], value] + return aten.full.default(*new_args, **node.kwargs) + + # handle before view ops because this changes value + if node.target == aten.view.dtype: + return super(ConstantFolder, self).run_node(node) + + # view ops, return input tensor, the first argument + if hasattr(node.target, "overloadpacket") and ( + node.target.overloadpacket in self.view_op_packets + or node.target.overloadpacket in self.indexing_op_packets + ): + assert isinstance(node.args[0], torch.fx.Node) + return self.env[node.args[0]] + + # we don't want to return unknown value for symints so that we can + # still constant fold through their use in constructors or views + # if we see them in a pointwise node (e.g., tensor * symint) + # we will bail + if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt): + return node.meta["val"] + + # pointwise ops + if isinstance(node.target, torch._ops.OpOverload) and ( + torch.Tag.pointwise in node.target.tags + or node.target is torch.ops.aten.scalar_tensor.default + ): + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs): + return self.unknown_value + + # we run the ops with dim 1, so remove memory_format to avoid error + kwargs = dict(kwargs) + kwargs.pop("memory_format", None) + + return node.target(*args, **kwargs) + + return self.unknown_value + + +def constant_fold_uniform_value(gm: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops." + aten = torch.ops.aten + + # Constant folding can leak memory, especially with repeated compilation, so we are only going to + # remove constants which can be replaced with a constructor. + cf = UniformValueConstantFolder(gm) + cf.run() + + node_replacements = cf.node_replacements + + # note: [constant folding refining of symints] + # constant folding will partially evaluate a graph such that values which have dependencies which + # are entirely known at compile time may also become compile time constants. in some cases, + # this will include symints which we had not yet previously deduced are guaranteed a + # constant value and is then deduced in constant folding. an example is: + # unbacked_symint_eq_11 = torch.full((), 11).item() + # torch.full((unbacked_symint_eq_11,), 0) + node_replacements_shapes = cf.node_replacements_shapes + + graph = gm.graph + + zeros = OrderedSet[Any]() + ones = OrderedSet[Any]() + + # Got failures in `test_is_set_to_cuda` if we change aliasing on constants, + # so just constant-ify if a Tensor is unaliased + constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter() + + for node in cf.node_replacements: + constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1 + + for node, value in node_replacements.items(): + # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now + # hasn't shown up to be important yet + if "val" not in node.meta: + # This can only happen in AOTI + continue + + fake_tensor = node.meta["val"] + if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format): + continue + + # TODO - not sure about lossy uint->python value->uint conversions + if fake_tensor.dtype in ( + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ): + continue + + if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1: + continue + + with graph.inserting_after(node): + # the conversion from tensor and back to value can be lossy, just use the original full ctor value + if ( + node.op == "call_function" + and node.target == aten.full.default + and len(node.args) == 2 + ): + value = node.args[1] + + # refines symints, see [constant folding refining of symints] above + for runtime_size, compile_time_size in zip( + node_replacements_shapes[node], fake_tensor.shape + ): + torch._check(runtime_size == compile_time_size) + + # replace SymInt as Node before creating a new full node + # e.g. (1, s0) -> (1, arg0_1) + node_shape = node_replacements_shapes[node] + if not all( + not isinstance(s, torch.SymInt) or s in cf.symint_nodes + for s in node_shape + ): + continue + + shapes = [ + cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s + for s in node_replacements_shapes[node] + ] + + # zeros and ones just get traced into full, so we insert those + new_node = graph.call_function( + aten.full.default, + args=(shapes, value), + kwargs={ + "dtype": fake_tensor.dtype, + "layout": torch.strided, + "device": fake_tensor.device, + "pin_memory": False, + }, + ) + + new_node.meta.update(node.meta) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if value == 0: + zeros.add(new_node) + elif value == 1: + ones.add(new_node) + + remove_no_ops(gm, zeros, ones) + remove_redundant_views(gm) + + +def canonicalize_quant_mapping(gm: torch.fx.GraphModule): + """ + + + torch.ops.higher_order.invoke_quant_packed(repeated_subgraph0, 'quant_invoke_0_0', (arg0_1, arg1_1)); + -> + torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); + """ + graph = gm.graph + invoke_quant_invocations = graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_quant_packed + ) + for invoke_quant in invoke_quant_invocations: + kwargs = dict(invoke_quant.kwargs) + + quant_options_node = kwargs.pop("quant_options", None) + if quant_options_node is not None: + assert isinstance(quant_options_node, torch.fx.Node) + quant_options = torch._higher_order_ops.InvokeQuant( + *invoke_quant.kwargs["quant_options"].args, + **invoke_quant.kwargs["quant_options"].kwargs, + ) + else: + quant_options = torch._higher_order_ops.InvokeQuant() + + subgraph, *args = invoke_quant.args + with gm.graph.inserting_before(invoke_quant): + invoke_quant_replacement = graph.call_function( + torch._higher_order_ops.invoke_quant, + (subgraph, *args), + kwargs, + ) + invoke_quant_replacement.meta.update(subgraph.meta) + invoke_quant_replacement.meta["quant_options"] = quant_options + + invoke_quant.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(invoke_quant) + + if quant_options_node and len(quant_options_node.users) == 0: + graph.erase_node(quant_options_node) + + first_user = next(iter(invoke_quant_replacement.users)) + + if ( + len(invoke_quant_replacement.users) == 1 + and len(subgraph.users) == 1 + and first_user.target == operator.getitem + and first_user.args[1] == 0 + ): + subgraph_graph = getattr(gm, subgraph.target) + output_node = torch._inductor.utils.output_node(subgraph_graph) + assert ( + isinstance(output_node.args[0], (list, tuple)) + and len(output_node.args[0]) == 1 + ) + + unpacked_output = output_node.args[0][0] + output_node.args = (unpacked_output,) + if "val" in output_node.meta: + output_node.meta["val"] = output_node.meta["val"][0] + subgraph_graph.recompile() + + invoke_quant_replacement.meta.update(first_user.meta) + first_user.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(first_user) + + +def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule): + """ + Canonicalization passes that will run immediately after aot autograd + tracing. Thsis must be run before all other graph passes. + """ + canonicalize_quant_mapping(gm) + + +def joint_graph_passes(graph: torch.fx.GraphModule): + """ + Run FX transformations on the joint forwards+backwards graph. + """ + GraphTransformObserver = functools.partial( + torch.fx.passes.graph_transform_observer.GraphTransformObserver, + subsystem="joint_graph_passes", + ) + + lazy_init() + count = 0 + + # must occur before other passes + canonicalize_aten_ir_passes(graph) + + if config.joint_custom_pre_pass is not None: + GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass( + config.joint_custom_pre_pass + ) + count += 1 + + from .post_grad import remove_noop_ops + + GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops) + + if config.joint_graph_constant_folding: + GraphTransformObserver(graph, "constant_fold_uniform_value").apply_gm_pass( + constant_fold_uniform_value + ) + + if config.joint_custom_pre_pass is not None: + GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass( + config.joint_custom_pre_pass + ) + count += 1 + + if config.pattern_matcher: + for i, patterns in enumerate(pass_patterns): + maybe_count = GraphTransformObserver( + graph, f"pass_pattern_{i}" + ).apply_graph_pass(patterns.apply) + count += maybe_count if maybe_count is not None else 0 + + if not config.fallback_random: + # not trying into the bisector because decomps may have already affected rng reproducibility + # we'll instead explicitly turn off the config + count += replace_random_passes(graph) + + if config.joint_custom_post_pass is not None: + GraphTransformObserver(graph, "joint_custom_post_pass").apply_graph_pass( + config.joint_custom_post_pass + ) + count += 1 + + if count: + stable_topological_sort(graph.graph) + graph.graph.lint() + graph.recompile() + return graph + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.iota.default, + KeywordArg("length"), + start=KeywordArg("start"), + step=KeywordArg("step"), + dtype=KeywordArg("dtype"), + device=KeywordArg("device"), + requires_grad=KeywordArg("requires_grad"), + ), + pass_dict=patterns, +) +def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad): + """ + Eager supports: + + aten.index(cuda_tensor, torch.arange(..., device="cpu")) + + But this results in an implicit host-device-copy and breaks cudagraphs. + Rewrite the arange to use CUDA. + """ + (node,) = match.nodes + user_devices = OrderedSet[torch.device]() + for user in node.users: + if ( + user.op == "call_function" + and user.target in (aten.index.Tensor, aten.index_put.default) + and hasattr(user.meta.get("val"), "device") + ): + user_devices.add(user.meta["val"].device) # type: ignore[union-attr] + else: + return # bail out + + if len(user_devices) == 1 and "val" in node.meta: + (user_device,) = user_devices + if device.type != user_device.type: + repl = match.graph.call_function( + torch.ops.prims.iota.default, + (length,), + { + "start": start, + "step": step, + "dtype": dtype, + "device": user_device, + "requires_grad": requires_grad, + }, + ) + repl.meta.update(node.meta) + repl.meta["val"] = repl.meta["val"].to(user_device) + node.replace_all_uses_with(repl) + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + CallFunction( + torch.ops.prims.convert_element_type.default, + KeywordArg("arg"), + KeywordArg("dtype1"), + ), + KeywordArg("dtype2"), + ), + pass_dict=patterns, +) +def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype): + """Remove chain of dtype conversions often created by AMP""" + graph = match.graph + node = match.output_node() + allowed = torch.float16, torch.bfloat16, torch.float32, torch.float64 + if dtype1 in allowed and dtype2 in allowed: + repl = graph.call_function( + torch.ops.prims.convert_element_type.default, (arg, dtype2) + ) + repl.meta.update(node.meta) + node.replace_all_uses_with(repl) + match.erase_nodes() + + +def definitely_equal( + old_sizes: Sequence[Union[torch.SymInt, int]], + new_sizes: Sequence[Union[torch.SymInt, torch.fx.Node, int]], +) -> bool: + """ + Leverage guard_or_true/false to compare if two lists of int/symint are equal. + Useful to compare sizes, strides etc. + + Can handle -1 in new_sizes which happens in the size arguments of a + view op. old_sizes is supposed to be the tensor shape and should not + contain -1. + + new_sizes can contains fx.Node when dynamic shape is enabled. In that + case new_sizes[i].meta['val'] contains the real torch.SymInt. + """ + + num_neg1 = 0 + + if len(old_sizes) != len(new_sizes): + return False + + for lhs_item, rhs_item in zip(old_sizes, new_sizes): + if isinstance(rhs_item, torch.fx.Node): + rhs_item = rhs_item.meta["val"] + + assert isinstance(lhs_item, (int, torch.SymInt)), type(lhs_item) + assert isinstance(rhs_item, (int, torch.SymInt)), type(rhs_item) + + # It still makes sense to call guard_or_true/false since lhs_item + # rhs_item are torch.SymInt rather than sympy expressions when + # dynamic shape is enabled. + if guard_or_false(lhs_item == rhs_item): + continue + + if guard_or_true(rhs_item != -1): + return False + + num_neg1 += 1 + + if num_neg1 > 1: + return False + return True + + +@register_graph_pattern( + CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")), + pass_dict=patterns, +) +def pointless_view(match: Match, arg, size): + """Remove no-op view""" + node = match.output_node() + arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] + if definitely_equal(arg_size, size): + node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + aten.view.default, + CallFunction(aten.view.default, KeywordArg("arg"), KeywordArg("size1")), + KeywordArg("size2"), + ), + pass_dict=patterns, +) +def pointless_view_pair(match: Match, arg, size1, size2): + """ + Remove a pair of views that are pointless. + """ + node = match.output_node() + arg_size = list(arg.meta["val"].shape) + if definitely_equal(arg_size, size2): + node.replace_all_uses_with(arg) + match.erase_nodes() + counters["inductor"]["removed_pointless_view_pair"] += 1 + + +@register_graph_pattern( + CallFunction( + aten.permute.default, + CallFunction(aten.permute.default, KeywordArg("arg"), KeywordArg("perm1")), + KeywordArg("perm2"), + ), + pass_dict=patterns, +) +def pointless_permute_pair(match: Match, arg, perm1, perm2): + rank = len(perm1) + assert len(perm2) == rank + + for i in range(rank): + if perm1[perm2[i]] != i: + return # bail out + node = match.output_node() + node.replace_all_uses_with(arg) + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + aten.bmm, + Arg(), + Arg(), + ), + pass_dict=patterns, +) +def bmm_to_mm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + """Convert bmm to mm when batch size is 1""" + + def repl(a, b): + return torch.mm(a.squeeze(0), b.squeeze(0)).unsqueeze(0) + + if ( + check_device(mat1.meta["val"], mat2.meta["val"], get_gpu_type()) + and statically_known_true(mat1.meta["val"].shape[0] == 1) + and statically_known_true(mat2.meta["val"].shape[0] == 1) + ): + match.replace_by_example(repl, [mat1, mat2]) + + +# When softmax is used with temperature or other scaling, we get the pattern +# +# scale(x) - scale(x).amax(dim, keepdim=True) +# +# which is expected to be at most zero, but we may end up with numerical +# discrepancies # between the recomputed values of scale(x) inside and out +# of the reduction, # depending on compiler optimizations, e.g. use of fma +# instructions. +# +# Here we replace it with the mathematically equivalent, +# +# scale(x - x.amax(dim, keepdim=True)) +# +# which is more stable as we only compute the scaling once. +# +# NOTE: This pattern must come after fused attention matching! + + +def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False): + # Allow matching inp * other and other * input + if reverse: + scaled = CallFunction( + linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE + ) + else: + scaled = CallFunction( + linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE + ) + if to_dtype: + scaled = CallFunction( + prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE + ) + amax = CallFunction( + aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim") + ) + return CallFunction(aten.sub.Tensor, scaled, amax) + + +def _other_is_broadcasted_in_dim(match): + # Check that the scaling factor is constant across the reduction dim, + # so scaling doesn't change which index corresponds to the maximum value + other = match.kwargs["other"] + if isinstance(other, (int, float)): + return True + + inp = match.kwargs["inp"] + if not all(isinstance(x, torch.fx.Node) for x in (inp, other)): + return False + + inp_example = inp.meta["val"] + other_example = other.meta["val"] + if isinstance(other_example, (torch.SymInt, torch.SymFloat)): + return True + + if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)): + return False + + inp_ndim = inp_example.ndim + other_shape = other_example.shape + if inp_ndim < len(other_shape): + return False + + # Pad other_shape to the same ndim as inp + other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape) + + dim = match.kwargs["dim"] + if isinstance(dim, int): + dim = (dim,) + + return all(statically_known_true(other_shape[d] == 1) for d in dim) + + +def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): + def repl(inp, other): + if dtype is not None: + inp = inp.to(dtype) + + sign: Union[int, float, torch.Tensor] + if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): + sign = 1 if other >= 0 else -1 + else: + one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) + sign = torch.where(other >= 0, one, -one) + + inp = inp * sign + max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + return (inp - max_) * (sign * other) + + match.replace_by_example(repl, [inp, other]) + + +for reverse, to_dtype in itertools.product((False, True), repeat=2): + register_graph_pattern( + _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype), + pass_dict=pass_patterns[1], + extra_check=_other_is_broadcasted_in_dim, + )(mul_softmax_pattern) + + +def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): + def repl(inp, other): + if dtype is not None: + inp = inp.to(dtype) + + sign: Union[int, float, torch.Tensor] + if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): + sign = 1 if other >= 0 else -1 + else: + one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) + sign = torch.where(other >= 0, one, -one) + + inp = inp * sign + max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + return (inp - max_) / (sign * other) + + match.replace_by_example(repl, [inp, other]) + + +for to_dtype in (False, True): + register_graph_pattern( + _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype), + pass_dict=pass_patterns[1], + extra_check=_other_is_broadcasted_in_dim, + )(div_softmax_pattern) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d935a4f8bb47588f198b4ab563e6cd35cc77d0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -0,0 +1,1079 @@ +# mypy: allow-untyped-defs +import logging +import operator +from collections import defaultdict +from dataclasses import dataclass, field +from math import prod +from typing import Any, cast, Optional + +import torch +from torch.utils._ordered_set import OrderedSet + +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunction, + Ignored, + KeywordArg, + ListOf, + Match, + MULTIPLE, + PatternExpr, + PatternMatcherPass, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +patterns = PatternMatcherPass() + + +def _is_backward(graph: torch.fx.Graph) -> bool: + placeholders = [] + for node in graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return not all(node.name.startswith("primal") for node in placeholders) + + +def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float: + return M * N * K / (M * K + N * K + M * N) + + +def _filter_nodes_by_target(nodes: list[torch.fx.Node], target) -> list[torch.fx.Node]: + return [x for x in nodes if x.target == target] + + +def _find_ancestors(node: torch.fx.Node) -> OrderedSet[torch.fx.Node]: + ancestors = OrderedSet[torch.fx.Node]() + ancestors.add(node) + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.all_input_nodes: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return OrderedSet(node for node in ancestors if node.op != "placeholder") + + +def _get_tensor(node: torch.fx.Node) -> torch.Tensor: + val = node.meta["val"] + assert isinstance(val, torch.Tensor) + return val + + +@dataclass +class _AllGatherMatch: + match: Match + shard_node: torch.fx.Node + ag_node: torch.fx.Node + res_node: torch.fx.Node + gather_dim: int + group_name: str + + def replace_with(self, new_node: torch.fx.Node) -> None: + self.res_node.replace_all_uses_with(new_node) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_all_gather_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def make_zero_dim_all_gather_pattern(shard): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + shard, + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.all_gather_tensor with gather_dim == 0 + zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard")) + + def make_all_gather_split_pattern(shard): + return CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + make_zero_dim_all_gather_pattern(shard), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ) + + def make_cat_pattern(splits): + return CallFunction( + aten.cat.default, + ListOf(splits), + KeywordArg("gather_dim"), + ) + + # Matches funcol.all_gather_tensor with gather_dim > 0 + non_zero_dim_all_gather_pattern = make_cat_pattern( + make_all_gather_split_pattern(KeywordArg("shard")), + ) + + # Match a zero-dim all-gather in which the data is transferred as uint8 and + # viewed back as the original dtype. + zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_zero_dim_all_gather_pattern( + KeywordArg("shard"), + ), + Ignored(), + ) + + # Match a non-zero dim all-gather in which the data is transferred as uint8 + # and viewed back as the original dtype. + non_zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_cat_pattern( + CallFunction( + aten.view.dtype, + make_all_gather_split_pattern( + KeywordArg("shard"), + ), + Ignored(), + ), + ), + Ignored(), + ) + + # If two patterns with the same res_node_target have the same suffix, the + # longer pattern should appear first in the list. + # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1) + # should appear before (2) in the list. + res_node_target_to_patterns = { + aten.cat.default: [ + (non_zero_dim_all_gather_pattern, 0), + ], + aten.view.dtype: [ + (non_zero_dim_type_erased_all_gather_pattern, 0), + (zero_dim_type_erased_all_gather_pattern, 0), + ], + c10d.wait_tensor.default: [ + (zero_dim_all_gather_pattern, 0), + ], + } + + # Match in reverse to ensure longer patterns is prioritized + all_gathers = [] + visited_ag_nodes = OrderedSet[torch.fx.Node]() + for node in reversed(graph.nodes): + for target, patterns in res_node_target_to_patterns.items(): + if node.target != target: + continue + for pattern, ag_node_idx in patterns: + match = pattern.match(node) + if not match: + continue + + assert isinstance(match, Match) + ag_node = match.nodes[ag_node_idx] + assert ag_node.target == c10d.all_gather_into_tensor.default + + if ag_node in visited_ag_nodes: + continue + visited_ag_nodes.add(ag_node) + + ag_match = _AllGatherMatch( + match=match, + shard_node=match.kwargs["shard"], + ag_node=ag_node, + res_node=node, + gather_dim=match.kwargs.get("gather_dim", 0), + group_name=match.kwargs["group_name"], + ) + all_gathers.append(ag_match) + + return list(reversed(all_gathers)) + + +@dataclass +class _ReduceScatterMatch: + match: Match + input_node: torch.fx.Node + reduce_scatter_node: torch.fx.Node + wait_tensor_node: torch.fx.Node + reduce_op: str + scatter_dim: int + group_name: str + + def replace_with(self, new_node: torch.fx.Node) -> None: + # Replace all uses of the result node (wait_tensor) with the fused node. + self.wait_tensor_node.replace_all_uses_with(new_node) + + # If the reduce-scatter result is saved for backward, save the fused node for backward instead. + self._update_save_for_backward(new_node) + + def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: + """ + If the output node is a user of the reduce_scatter node (indicating the reduce_scatter + result is saved for backward), this method will update the output node to use the fused node instead. + """ + output_node = None + for user in self.reduce_scatter_node.users: + if user.target == "output": + output_node = user + break + if output_node is not None: + output_node.replace_input_with(self.reduce_scatter_node, new_node) + + # Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not + # saved for backward anymore. + assert len(self.reduce_scatter_node.users) == 1, ( + "Reduce scatter node has multiple users, this is not expected" + ) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_reduce_scatter_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def reduce_scatter_template(inp: PatternExpr, users: int): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + inp, + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + _users=users, + ), + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 + zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + KeywordArg("input"), users=1 + ) + + # Two users will occur when the reduce-scatter result is saved for backward + zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + KeywordArg("input"), users=2 + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 + non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=1, + ) + + # Two users will occur when the reduce-scatter result is saved for backward + non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=2, + ) + + reduce_scatters = [] + for node in reversed(graph.nodes): + if node.target == c10d.wait_tensor.default: + if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[-2], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[0], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[-2], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[0], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + return list(reversed(reduce_scatters)) + + +@dataclass +class _Matmul: + nodes: list[torch.fx.Node] + arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False) + A_node: torch.fx.Node + B_node: torch.fx.Node + pre_mm_reshape: Optional[torch.fx.Node] + post_mm_reshape: Optional[torch.fx.Node] + + def __post_init__(self): + assert len(self.nodes) in (1, 3) + if len(self.nodes) == 1: + assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default) + else: + assert self.nodes[0].target == aten.reshape.default + assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default) + assert self.nodes[2].target == aten.reshape.default + self.arg_ancestor_nodes = _find_ancestors(self.B_node) + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + """ + graph = new_node.graph + + # For 2D-matmuls, we simply replace the mm node with `new_node`. + if len(self.nodes) == 1: + mm_node = self.nodes[0] + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + mm_node.replace_all_uses_with(new_node) + graph.erase_node(mm_node) + return + + # An ND-matmul is reshape -> mm -> reshape sequence. We first replace + # the second reshape node with `new_node`. Then, we ensure that the + # original mm node in the sequence ends up with zero users by replacing + # it with a reverse reshape of `new_node`. + graph = new_node.graph + assert len(self.nodes) == 3 + mm_node = self.nodes[1] + output_reshape_node = self.nodes[2] + + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + assert output_reshape_node.target == aten.reshape.default + + output_reshape_node.replace_all_uses_with(new_node) + if len(mm_node.users) > 1: + with graph.inserting_after(new_node): + new_mm_node = graph.call_function( + aten.reshape.default, + args=(new_node, list(_get_tensor(mm_node).shape)), + ) + mm_node.replace_all_uses_with(new_mm_node) + + def erase(self) -> None: + for node in reversed(self.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + @classmethod + def from_match(cls, match: list[torch.fx.Node]) -> "_Matmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten.mm.default, + aten.reshape.default, + ) + mm_node = match[0] if len(match) == 1 else match[1] + return _Matmul( + nodes=match, + A_node=cast("torch.fx.Node", match[0].args[0]), + B_node=cast("torch.fx.Node", mm_node.args[1]), + # _Matmul handles reshapes via custom graph manipulation logic, see `replace_with()` method. + # TODO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes. + pre_mm_reshape=None, + post_mm_reshape=None, + ) + + +@dataclass +class _ScaledMatmul(_Matmul): + A_scale_node: torch.fx.Node + B_scale_node: torch.fx.Node + bias_node: Optional[torch.fx.Node] + result_scale_node: Optional[torch.fx.Node] + out_dtype: Optional[torch.dtype] + use_fast_accum: bool + pre_mm_reshape: Optional[torch.fx.Node] + post_mm_reshape: Optional[torch.fx.Node] + + def __post_init__(self): + super().__post_init__() + self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node) + self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node) + + @classmethod + def from_match(cls, match: list[torch.fx.Node]) -> "_ScaledMatmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten._scaled_mm.default, + aten.reshape.default, + ) + + def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any: + if idx >= len(node.args): + return default + return node.args[idx] + + # Use mm_node with 2D args for both A and B, even if this is a "reshape -> mm -> reshape" pattern. + # We will store the reshapes in pre_mm_reshape and post_mm_reshape, to be referenced later to + # produce the correct output shapes, reduce-scatter along the correct dimensions, etc. + is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default + mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0] + pre_mm_reshape = match[0] if is_reshape_mm_reshape_pattern else None + post_mm_reshape = match[-1] if is_reshape_mm_reshape_pattern else None + A_node = cast("torch.fx.Node", mm_node.args[0]) + B_node = cast("torch.fx.Node", mm_node.args[1]) + A_scale_node = cast("torch.fx.Node", mm_node.args[2]) + B_scale_node = cast("torch.fx.Node", mm_node.args[3]) + + return _ScaledMatmul( + nodes=match, + A_node=A_node, + B_node=B_node, + A_scale_node=A_scale_node, + B_scale_node=B_scale_node, + bias_node=get_arg(mm_node, 4, None), + result_scale_node=get_arg(mm_node, 5, None), + out_dtype=get_arg(mm_node, 6, None), + use_fast_accum=get_arg(mm_node, 7, False), + pre_mm_reshape=pre_mm_reshape, + post_mm_reshape=post_mm_reshape, + ) + + +def _find_reshape_mm_reshape(node: torch.fx.Node) -> list[_Matmul]: + if node.target != aten.reshape.default: + return [] + + matches = [] + for mm_node in node.users: + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + continue + for reshape_node in mm_node.users: + if reshape_node.target != aten.reshape.default: + continue + + # Since the reshape -> mm -> reshape pattern would be subsumed into + # the fused op, we only match the patterns where the shape of the + # second reshape is matches the mm result produced by the fused op. + matmul_input_node = cast("torch.fx.Node", node.args[0]) + B_node = cast("torch.fx.Node", mm_node.args[1]) + matmul_out_shape = torch.Size( + [ + *_get_tensor(matmul_input_node).shape[:-1], + _get_tensor(B_node).shape[-1], + ] + ) + if _get_tensor(reshape_node).shape != matmul_out_shape: + continue + matches.append([node, mm_node, reshape_node]) + # If for some rare reason mm_node is being reshaped by two + # different reshape nodes, we only include mm_node once in the + # parsing result. + break + + matmuls = [] + for match in matches: + mm_node = match[1] + if mm_node.target == aten.mm.default: + matmul = _Matmul.from_match(match) + matmuls.append(matmul) + elif mm_node.target == aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match(match) + matmuls.append(matmul) + else: + raise AssertionError( + "Expect the node's target to be either aten.mm.default or " + f"aten._scaled_mm.default. Got {mm_node.target}." + ) + return matmuls + + +def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]: + """ + Find the matmuls that use `node` as the lhs argument. + """ + matmuls = [] + for user in node.users: + # ND matmuls + if user.target == aten.reshape.default: + matmuls.extend(_find_reshape_mm_reshape(user)) + # 2D matmuls + elif user.target == aten.mm.default: + matmul = _Matmul.from_match(match=[user]) + matmuls.append(matmul) + elif user.target == aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match([user]) + matmuls.append(matmul) + return matmuls + + +def _insert_fused_all_gather_matmul( + graph: torch.fx.Graph, + matmuls: list[_Matmul], + shard_node: torch.fx.Node, + gather_dim: int, + group_name: str, +) -> torch.fx.Node: + mm_types = OrderedSet(map(type, matmuls)) + assert len(mm_types) == 1 + mm_type = next(iter(mm_types)) + if mm_type == _Matmul: + B_nodes = [matmul.B_node for matmul in matmuls] + return graph.call_function( + torch.ops.symm_mem.fused_all_gather_matmul.default, + args=(shard_node, B_nodes, gather_dim, group_name), + kwargs={"return_A": True}, + ) + elif mm_type == _ScaledMatmul: + scaled_matmuls = cast("list[_ScaledMatmul]", matmuls) + return graph.call_function( + torch.ops.symm_mem.fused_all_gather_scaled_matmul.default, + args=( + shard_node, + [matmul.B_node for matmul in scaled_matmuls], + scaled_matmuls[0].A_scale_node, + [matmul.B_scale_node for matmul in scaled_matmuls], + gather_dim, + group_name, + [matmul.bias_node for matmul in scaled_matmuls], + [matmul.result_scale_node for matmul in scaled_matmuls], + [matmul.out_dtype for matmul in scaled_matmuls], + [matmul.use_fast_accum for matmul in scaled_matmuls], + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {mm_type}") + + +def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: + """ + Fused the pattern + + A = all_gather_tensor(A_shard, gather_dim, group_name) + C_0 = torch.matmul(A, B_0) + C_1 = torch.matmul(A, B_1) + C_2 = torch.matmul(A, B_2) + ... + + into + + A, Cs = torch.ops.symm_mem.fused_all_gather_matmul( + A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_shard_for_fused_all_gather_matmul, + ) + + shard_node, ag_node, ag_res_node, gather_dim, group_name = ( + all_gather.shard_node, + all_gather.ag_node, + all_gather.res_node, + all_gather.gather_dim, + all_gather.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + if gather_dim >= len(_get_tensor(shard_node).shape) - 1: + # Decomposing the matmul on the K dimension is not supported + return + + # Find consumer matmuls + matmuls = _find_consumer_matmuls(ag_res_node) + + # The matmuls are only fusible if non-A args don't depend on the all-gather + # result node + matmuls = [ + matmul + for matmul in matmuls + if all_gather.res_node not in matmul.arg_ancestor_nodes + ] + + if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1: + return + + # Fuse the all_gather_tensor with the eligible matmuls + graph = ag_node.graph + with graph.inserting_before(ag_node): + if "val" in shard_node.meta: + restrided = restride_A_shard_for_fused_all_gather_matmul( + _get_tensor(shard_node), + gather_dim, + ) + shard_node = graph.call_function( + inductor_prims.force_stride_order, + args=(shard_node, restrided.stride()), + ) + + fused_node = _insert_fused_all_gather_matmul( + graph, matmuls, shard_node, gather_dim, group_name + ) + new_ag_node = graph.call_function( + operator.getitem, + args=(fused_node, 0), + ) + new_out_nodes = graph.call_function( + operator.getitem, + args=(fused_node, 1), + ) + for idx, matmul in enumerate(matmuls): + new_out_node = graph.call_function( + operator.getitem, + args=(new_out_nodes, idx), + ) + matmul.replace_with(new_out_node) + matmul.erase() + all_gather.replace_with(new_ag_node) + all_gather.erase() + + # If the new_ag_node has no users, we tell the fused op to not return + # it. This creates more optimization opportunities. + if len(new_ag_node.users) == 0: + graph.erase_node(new_ag_node) + kwargs = dict(fused_node.kwargs) + if "return_A" in kwargs: + kwargs["return_A"] = False + fused_node.kwargs = kwargs + + # Raise ancestors of non-A args that are topologically ordered between + # ag_res_node and the matmul above fused_node. + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + OrderedSet(x for matmul in matmuls for x in matmul.arg_ancestor_nodes), + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + +def _scatter_dim_after_reshape( + reshape_node: torch.fx.Node, orig_scatter_dim: int +) -> int: + """ + Given a reshape node and the original scatter dim for the target tensor, + returns the new scatter dim for the reshaped tensor. + """ + # if there was no pre-mm reshape, scatter dim will not change. + if not reshape_node: + return orig_scatter_dim + + reshape_op_output_tensor = _get_tensor(reshape_node) + assert reshape_op_output_tensor.ndim == 2, ( + "reshape must produce 2D tensor for scaled_mm" + ) + + assert len(reshape_node.args) >= 1, "reshape node must have at least 1 arg" + input_tensor_node = cast(torch.fx.Node, reshape_node.args[0]) + reshape_op_input_tensor = _get_tensor(input_tensor_node) + assert reshape_op_input_tensor.ndim > reshape_op_output_tensor.ndim, ( + "reshape must be from 3D+ to 2D" + ) + + # Note: for a N-D tensor to be reshaped into 2D, either the leading dims or ending dims must + # be collapsed to a single dim. First determine which of these happened. + input_shape = reshape_op_input_tensor.shape + output_shape = reshape_op_output_tensor.shape + leading_dims_collapsed = output_shape[0] == prod(input_shape[:-1]) + + # Case 1: scatter dim 0 always maps to 0 after any reshape from 3D+ to 2D, regardless if + # leading dims or ending dims were collapsed. + if orig_scatter_dim == 0: + return 0 + + # Case 2: scatter dim "ndim-1" always maps to 1 after any reshape from 3D+ to 2D, regardless if + # leading dims or ending dims were collapsed. + if orig_scatter_dim == reshape_op_input_tensor.ndim - 1: + return 1 + + # Case 3: scatter dim was one of the middle dims (between 0 and ndim-1). + # if the leading dims were collapsed, the new scatter dim will be 0. + # if the ending dims were collapsed, the new scatter dim will be 1. + return 0 if leading_dims_collapsed else 1 + + +def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]: + """ + Returns producer matmul node if found, otherwise returns None. + """ + if node.target == aten.mm.default: + return _Matmul.from_match(match=[node]) + elif node.target == aten._scaled_mm.default: + return _ScaledMatmul.from_match(match=[node]) + elif node.target == aten.reshape.default: + reshape_node_1 = node + + mm_node = reshape_node_1.args[0] + assert isinstance(mm_node, torch.fx.Node) + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + return None + + reshape_node_0 = mm_node.args[0] + assert isinstance(reshape_node_0, torch.fx.Node) + if reshape_node_0.target != aten.reshape.default: + return None + + if mm_node.target == aten.mm.default: + return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1]) + elif mm_node.target == aten._scaled_mm.default: + return _ScaledMatmul.from_match( + match=[reshape_node_0, mm_node, reshape_node_1] + ) + return None + + +def _insert_fused_matmul_reduce_scatter( + graph: torch.fx.Graph, + matmul: _Matmul, + reduce_op: str, + orig_scatter_dim: int, + group_name: str, + scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern + output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern +) -> torch.fx.Node: + if type(matmul) == _Matmul: + return graph.call_function( + torch.ops.symm_mem.fused_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + reduce_op, + orig_scatter_dim, + group_name, + ), + ) + elif type(matmul) == _ScaledMatmul: + return graph.call_function( + torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + matmul.A_scale_node, + matmul.B_scale_node, + reduce_op, + orig_scatter_dim, + scatter_dim_after_reshape, + group_name, + output_shape, + matmul.bias_node, + matmul.result_scale_node, + matmul.out_dtype, + matmul.use_fast_accum, + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") + + +def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: + """ + Fused the pattern + + reduce_scatter_tensor(A @ B, scatter_dim, group_name) + + into + + torch.ops.symm_mem.fused_matmul_reduce_scatter( + A, B, scatter_dim, group_name, + ) + + Returns boolean indicating if fusion was successful or not. + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_for_fused_matmul_reduce_scatter, + ) + + ( + input_node, + _reduce_scatter_node, + rs_wait_tensor_node, + reduce_op, + orig_scatter_dim, + group_name, + ) = ( + reduce_scatter.input_node, + reduce_scatter.reduce_scatter_node, + reduce_scatter.wait_tensor_node, + reduce_scatter.reduce_op, + reduce_scatter.scatter_dim, + reduce_scatter.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + # Currently fused_matmul_reduce_scatter doesn't return the matmul result, + # so we can't apply the fusion if the matmul result is used by multiple + # users. This is not a fundamental limitation of the fused op and can be + # addressed if needed. + if len(input_node.users) != 1: + log.warning( + "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion." + ) + return + + matmul = _find_producer_matmul(input_node) + if matmul is None: + log.warning( + "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion" + ) + return + + if rs_wait_tensor_node in matmul.arg_ancestor_nodes: + log.warning( + "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" + ) + return + + # We need to track 3 values for the fused scaled mm reduce scatter implementation: + # 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims. + # 2. The scatter dim after the reshape, to use when we are doing the 2D (a*b,c) @ (c,d) = (a,b,d) scaled mm op. + # 3. Store expected potentially 3D+ mm output shape, so we can reshape the 2D mm output to the intended + # 3D+ shape before applying reduce-scatter, and to prevent shape errors with subsequent ops. + + # If 'A' was reshaped from 3D+ -> 2D for the mm, we need to determine the new scattter dim after the reshape + # for the fused matmul reduce scatter implementation to use. + if matmul.pre_mm_reshape: + scatter_dim_after_maybe_reshape = _scatter_dim_after_reshape( + matmul.pre_mm_reshape, orig_scatter_dim + ) + else: + scatter_dim_after_maybe_reshape = orig_scatter_dim + + # If the 2D mm output was reshaped from 2D -> 3D+, we need to store the intended output shape for the + # fused matmul reduce scatter implementation to use. + if matmul.post_mm_reshape: + output_shape = list(_get_tensor(matmul.post_mm_reshape).shape) + else: + A_orig_shape = list(_get_tensor(matmul.A_node).shape) + B_shape = list(_get_tensor(matmul.B_node).shape) + output_shape = [*A_orig_shape[:-1], B_shape[-1]] + + graph = rs_wait_tensor_node.graph + with graph.inserting_before(rs_wait_tensor_node): + # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter + if "val" in matmul.A_node.meta: + restrided = restride_A_for_fused_matmul_reduce_scatter( + _get_tensor(matmul.A_node), + scatter_dim_after_maybe_reshape, + ) + matmul.A_node = graph.call_function( + inductor_prims.force_stride_order, + args=(matmul.A_node, restrided.stride()), + ) + + # Replace matched subgraph with fused matmul reduce scatter node + fused_node = _insert_fused_matmul_reduce_scatter( + graph, + matmul, + reduce_op, + orig_scatter_dim, + group_name, + scatter_dim_after_maybe_reshape, + output_shape, + ) + reduce_scatter.replace_with(fused_node) + reduce_scatter.erase() + matmul.erase() + + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + matmul.arg_ancestor_nodes, + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + log.debug("successfully fused matmul reduce scatter") + + +def _get_node_to_ancestors( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: + """ + Compute the ancestors for all nodes in a graph. + """ + node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated] + for node in graph.nodes: + node_to_ancestors[node] = OrderedSet(node.all_input_nodes) + for dep in node.all_input_nodes: + node_to_ancestors[node] |= node_to_ancestors[dep] + + return node_to_ancestors + + +def _get_collective_to_overlappable_nodes( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, list[torch.fx.Node]]: + """ + For each collective in the graph, find nodes that are neither ancestors nor + descendants of the collective. + """ + + def is_collective(node) -> bool: + # Only consider all-gather and reduce-scatter in the context of + # micro-pipeline TP. + return node.target in [ + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ] + + node_to_ancestors = _get_node_to_ancestors(graph) + collective_to_overlappable_nodes = defaultdict(list) + for node in graph.nodes: + if not is_collective(node): + continue + for x in graph.nodes: + if ( + node not in node_to_ancestors[x] + and x not in node_to_ancestors[node] + and x.op == "call_function" + ): + collective_to_overlappable_nodes[node].append(x) + + return collective_to_overlappable_nodes + + +def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]: + """ + Find all unexposed collectives in the graph. + + Because we don't have the runtime estimate, this function is a rough + estimation using the following strong/hand-wavy assumptions: + + - Only a predefined set of "compute intensive" operation can hide a collective. + - Any "compute intensive" operation can hide exactly one collective. + """ + + def _is_compute_intensive(node: torch.fx.Node) -> bool: + return node.target in [torch.ops.aten.mm.default] + + collective_to_overlapping_candidates = defaultdict(list) + available_nodes = OrderedSet[torch.fx.Node]() + collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph) + for collective, overlappable_nodes in collective_to_overlappable_nodes.items(): + candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)] + collective_to_overlapping_candidates[collective] = candidates + available_nodes.update(candidates) + + unexposed_collectives = [] + for ( + collective, + overlapping_candidates, + ) in collective_to_overlapping_candidates.items(): + # Each collective consumes exactly one overlapping candidate + for x in overlapping_candidates: + if x in available_nodes: + unexposed_collectives.append(collective) + available_nodes.remove(x) + break + return unexposed_collectives + + +def micro_pipeline_tp_pass(graph: torch.fx.Graph): + all_gathers = find_all_gather_patterns(graph) + reduce_scatters = find_reduce_scatter_patterns(graph) + + # When a collective can be hidden through either simple overlapping or + # micro-pipeline TP, we prefer simple overlapping to avoid the overhead + # associated with decomposition. If reorder_for_compute_comm_overlap is + # enabled, we identify collectives that can be hidden through simple + # overlapping and exclude them from micro-pipeline TP candidates. + if config.reorder_for_compute_comm_overlap: + unexposed_collectives = _get_unexposed_collectives(graph) + all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] + reduce_scatters = [ + x + for x in reduce_scatters + if x.reduce_scatter_node not in unexposed_collectives + ] + + if not all_gathers and not reduce_scatters: + log.warning( + "async TP found no matching all-gather/reduce-scatter patterns for fusion" + ) + + for all_gather in all_gathers: + fuse_all_gather_matmul(all_gather) + + for reduce_scatter in reduce_scatters: + fuse_matmul_reduce_scatter(reduce_scatter) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c8068f130c8b5e5b3149eb9c318c248db2d2a4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/misc_patterns.py @@ -0,0 +1,131 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch._dynamo.utils import counters +from torch._ops import OpOverload, OpOverloadPacket +from torch.utils._ordered_set import OrderedSet + +from ..pattern_matcher import fwd_only, register_replacement + + +aten = torch.ops.aten + + +@functools.cache +def _misc_patterns_init(): + from .joint_graph import patterns as joint_graph_patterns + from .post_grad import pass_patterns as post_grad_patterns_all + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # These patterns do 2 things + # 1. Since we know that index is completely unique, we can codegen it using + # stores instead of atomic adds, which is quite a bit faster. + # 2. Also, since we are guaranteed that they are completely within bounds, + # we can use unsafe indexing and skip debug asserts + def randperm_index_add_pattern(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return torch.index_add(x, dim=0, source=y, index=index), index + + def randperm_index_add_replacement(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return ( + torch.ops.aten._unsafe_index_put( + x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False + ), + index, + ) + + register_replacement( + randperm_index_add_pattern, + randperm_index_add_replacement, + [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + ) + + def randperm_index_pattern(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten.index(x, (index,)), index + + def randperm_index_replacement(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten._unsafe_index(x, (index,)), index + + register_replacement( + randperm_index_pattern, + randperm_index_replacement, + [torch.empty(4, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + scalar_workaround={"slice_shape": 42}, + ) + + +class NumpyCompatNormalization: + numpy_compat: dict[str, tuple[str, ...]] = { + "dim": ("axis",), + "keepdim": ("keepdims",), + "input": ("x", "a", "x1"), + "other": ("x2",), + } + inverse_mapping: dict[str, str] + cache: dict["torch.fx.graph.Target", OrderedSet[str]] + + def __init__(self) -> None: + self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"] + self.inverse_mapping = {} + for actual_kwarg, numpy_kwargs in self.numpy_compat.items(): + for numpy_kwarg in numpy_kwargs: + assert numpy_kwarg not in self.inverse_mapping + self.inverse_mapping[numpy_kwarg] = actual_kwarg + + def __call__(self, graph: torch.fx.Graph): + for node in graph.nodes: + if node.op != "call_function": + continue + if isinstance(node.target, (OpOverload, OpOverloadPacket)): + # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't. + continue + kwargs = node.kwargs + + if node.target in self.cache: + replaceable_kwargs = self.cache[node.target] + else: + signatures = torch.fx.operator_schemas.get_signature_for_torch_op( + node.target + ) + signatures = () if signatures is None else signatures + replaceable_kwargs = OrderedSet() + for sig in signatures: + for param_name in sig.parameters.keys(): + if param_name in self.numpy_compat: + replaceable_kwargs.update(self.numpy_compat[param_name]) + + self.cache[node.target] = replaceable_kwargs + + if not replaceable_kwargs: + continue + + new_kwargs = {} + kwargs_changed = False + for k, v in kwargs.items(): + if k in replaceable_kwargs: + kwargs_changed = True + new_kwargs[self.inverse_mapping[k]] = v + else: + new_kwargs[k] = v + + if kwargs_changed: + node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs) + counters["inductor"]["numpy_compat_normalization"] += 1 + + +numpy_compat_normalization = NumpyCompatNormalization() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..42e57bf9cebd6dc2eb7d89f9259072fbdf68dc5b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -0,0 +1,1527 @@ +# mypy: allow-untyped-defs +import functools +import operator +from functools import reduce +from typing import Any, Callable + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.utils._ordered_set import OrderedSet + +from .. import ir +from ..lowering import lowerings as L +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + get_arg_value, + KeywordArg, + MULTIPLE, +) +from ..utils import ( + is_mkldnn_bf16_supported, + is_mkldnn_fp16_supported, + SUPPORTED_MKLDNN_DEVICES, +) +from ..virtualized import ops, V +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern +from .quantization import ( + _register_int8_woq_concat_linear_pattern, + _register_quantization_lowerings, + _register_quantization_weight_pack_pass, + _register_woq_lowerings, +) + + +if torch._C._has_mkldnn: + aten = torch.ops.aten + mkldnn = torch.ops.mkldnn + prims = torch.ops.prims + + _conv_args = [Arg() for _ in range(10)] + _linear_args = [Arg() for _ in range(6)] + _conv_transpose_args = [Arg() for _ in range(11)] + + class MkldnnDeviceOpBase: + def get_linear_transpose_weight(self, weight_node): + raise NotImplementedError + + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + raise NotImplementedError + + def pack_linear_weight( + self, graph, is_lp_weight, transpose_weight_node, batch_size + ): + raise NotImplementedError + + def pack_linear( + self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ): + raise NotImplementedError + + class CpuMkldnnDeviceOp(MkldnnDeviceOpBase): + def get_linear_transpose_weight(self, weight_node): + packed_weight_node = weight_node + assert packed_weight_node.target == mkldnn._reorder_linear_weight + transpose_weight_node = packed_weight_node.args[0] + assert transpose_weight_node.target == aten.permute.default + return transpose_weight_node + + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + packed_weight_op = mkldnn._reorder_convolution_weight + if is_transposed: + packed_weight_op = mkldnn._reorder_convolution_transpose_weight + + # mkldnn_reorder_conv_weight(self, padding, stride, dilation, groups, input_size) + packed_weight_inputs = (weight,) + tuple(constant_args) + (input_size,) + return graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + def pack_linear_weight( + self, graph, is_lp_weight, transpose_weight_node, batch_size + ): + # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. + packed_weight_inputs = ( + transpose_weight_node, + batch_size.node.shape_env.size_hint(batch_size.node.expr) + if has_free_symbols(batch_size) + else batch_size, + ) + + # MKL packed matrix can't be copied to a different address because the internal implementation + # depends on the alignment of internally-stored metadata. + # In aot mode, we need to firstly save the packed weight, when loading it, + # it will be in a different address which doesn't work. + # Disable MKL prepack linear in AOT mode + packed_weight_op = ( + mkldnn._reorder_linear_weight + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation + ) + else torch.ops.mkl._mkl_reorder_linear_weight + ) + return graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + def pack_linear( + self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ): + packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node) + transpose_weight_node = packed_weight_node.args[0] + if is_lp_weight or mkldnn._is_mkldnn_acl_supported() or V.aot_compilation: + packed_linear_inputs += (bias, "none", [], "") + packed_linear_op: Callable[..., Any] = mkldnn._linear_pointwise.default + else: + packed_linear_inputs += (transpose_weight_node, bias, batch_size) + packed_linear_op = torch.ops.mkl._mkl_linear + + return graph.create_node( + "call_function", packed_linear_op, packed_linear_inputs + ) + + class XpuMkldnnDeviceOp(MkldnnDeviceOpBase): + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + assert not is_transposed, ( + "'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." + ) + return weight + + def _get_mkldnn_device_op(device_type: str) -> MkldnnDeviceOpBase: + """ + Returns the MKLDNN device operation class based on the current device type. + """ + if device_type == "cpu": + return CpuMkldnnDeviceOp() + elif device_type == "xpu": + return XpuMkldnnDeviceOp() + else: + raise RuntimeError(f"MKLDNN is not supported on {device_type} device.") + + def _is_valid_grouped_gemm_fusion(computation_nodes): + """ + Here we check: + 1. More than 1 GEMM nodes has been found. + 2. All the GEMM nodes share the same activation. + 3. All the GEMM nodes have same weight size but different wgt node. + """ + computation_op = mkldnn._linear_pointwise.default + act = computation_nodes[0].args[0] + wgt = computation_nodes[0].args[1] + wgt_size = wgt.meta.get("val").size() # type: ignore[union-attr] + return len(computation_nodes) >= 2 and all( + ( + node.target == computation_op + and node.args[0] == act + and (node.args[1].meta.get("val").size() == wgt_size) + and (node.args[1] != wgt or gemm_idx == 0) + ) + for gemm_idx, node in enumerate(computation_nodes) + ) + + def grouped_gemm_pass(graph: torch.fx.Graph): + """ + Group GEMM has multi output nodes which is complicated to define a Pattern. + Use below way to connect the pattern to the lowering. + TODO: Use MultiOutputPattern, current limitation is the pattern requires + fixed number of output nodes. Extend to support Group GEMM for pattern matcher. + """ + computation_op = mkldnn._linear_pointwise.default + from ..mkldnn_lowerings import grouped_gemm_lowering + + for node in graph.find_nodes(op="call_function", target=computation_op): + if ( + not node._erased + and isinstance(node.meta.get("val"), torch.Tensor) + and node.meta["val"].device.type == "cpu" + ): + act = node.args[0] + users = list(act.users) + if _is_valid_grouped_gemm_fusion(users): + with graph.inserting_before(node): + grouped_gemm_node = graph.create_node( + "call_function", + grouped_gemm_lowering, + ( + act, + [user.args[1] for user in users], + [user.args[2] for user in users], + ), + ) + grouped_gemm_node.meta["val"] = [ + user.meta["val"] for user in users + ] + with graph.inserting_after(grouped_gemm_node): + for gemm_idx, user in enumerate(users): + assert user.target == computation_op + get_item = graph.create_node( + "call_function", + operator.getitem, + ( + grouped_gemm_node, + gemm_idx, + ), + ) + user.replace_all_uses_with(get_item) + graph.erase_node(user) + return + + def _conv_call(users=1): + return CallFunction( + mkldnn._convolution_pointwise.default, *_conv_args, _users=users + ) + + def _linear_call(users=1): + return CallFunction( + mkldnn._linear_pointwise.default, *_linear_args, _users=users + ) + + def _conv_transpose_call(users=1): + return CallFunction( + mkldnn._convolution_transpose_pointwise.default, + *_conv_transpose_args, + _users=users, + ) + + def _to_float(input_call, users=1): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_float"), + _users=users, + ) + + def _to_bf16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_bf16"), + _users=1, + ) + + def _to_fp16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_fp16"), + _users=1, + ) + + def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype): + # only insert to_dtype if lowp_dtype is True + computation_call = ( + _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users) + ) + out = unary_fusion(computation_call) + if lowp_dtype == torch.bfloat16: + return _to_bf16(out) + elif lowp_dtype == torch.float16: + return _to_fp16(out) + else: + return out + + def _gelu_fusion_1(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.erf, + CallFunction(aten.mul, computation_call, 0.7071067811865476), + ), + 1, + ), + ) + + def _gelu_fusion_2(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.tanh, + CallFunction( + aten.mul, + CallFunction( + aten.add, + computation_call, + CallFunction( + aten.mul, + CallFunction( + aten.mul, + CallFunction( + aten.mul, computation_call, computation_call + ), + computation_call, + ), + 0.044715, + ), + ), + 0.7978845608028654, + ), + ), + 1, + ), + ) + + def _hardswish_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.mul, + computation_call, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + ), + 6, + ) + + def _silu_fusion(computation_call): + return CallFunction( + aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call) + ) + + def _hardsigmoid_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + 6, + ) + + def _leaky_relu_fusion(computation_call): + return CallFunction( + aten.where, + CallFunction(aten.gt, computation_call, 0), + computation_call, + CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")), + ) + + def _hardtanh_fusion(computation_call): + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + + def _combined_fusion(computation_call, elementwise_op): + return CallFunction(elementwise_op, computation_call) + + # binary_op(other, computation_op) + def _binary_fusion_v1(computation_call, binary_fn): + return CallFunction(binary_fn, KeywordArg("other"), computation_call) + + # binary_op(computation_op, other) + def _binary_fusion_v2(computation_call, binary_fn): + return CallFunction(binary_fn, computation_call, KeywordArg("other")) + + def _is_single_computation_op(computation_op, lowp_dtype=None): + def fn(match): + computation_nodes = filter_nodes(match.nodes, computation_op) + + if lowp_dtype: + output_node_meta = match.output_node().meta.get("val") + if output_node_meta.dtype != lowp_dtype: + return False + + if len(computation_nodes) < 1: + return False + if any(n.args[-3] != "none" for n in computation_nodes): + return False + return True + + return fn + + def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): + def fn(match): + matched = _is_single_computation_op(computation_op, lowp_dtype)(match) + computation_node = filter_nodes(match.nodes, computation_op)[0] + if lowp_dtype: + conversion_dtype_nodes = filter_nodes( + match.nodes, prims.convert_element_type.default + ) + if len(conversion_dtype_nodes) != 2: + return False + # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16 + if computation_node == conversion_dtype_nodes[0].args[0]: + to_float = conversion_dtype_nodes[0].args[1] + to_lp = conversion_dtype_nodes[1].args[1] + else: + to_float = conversion_dtype_nodes[1].args[1] + to_lp = conversion_dtype_nodes[0].args[1] + matched = matched and to_float == torch.float and to_lp == lowp_dtype + return matched + + return fn + + def _register_unary_fusion_lowering( + pattern, unary_attr, computation_op, lowp_dtype=None + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype), + ) + def fn(match, *args, **kwargs): + computation_args = list(args)[:-3] + [ + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + return L[computation_op](*computation_args) + + return fn + + def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + ) + def fn(match, *args, **kwargs): + negative_slope = kwargs.get("negative_slope") + if isinstance(negative_slope, ir.TensorBox): + matched = False + else: # inp is a Number + matched = True + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + if matched: + computation_args = computation_args[:-3] + [ + "leaky_relu", + [negative_slope], + "", + ] + return L[computation_op](*computation_args) + else: + # computation_args += ["none", [], ""] + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.where]( + L[aten.gt](out, 0), + out, + L[aten.mul](out, negative_slope), + ) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + ) + def fn(match, *args, **kwargs): + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + if isinstance(min_value, ir.TensorBox) or isinstance( + max_value, ir.TensorBox + ): + matched = False + else: # inp is a Number + assert max_value is not None + matched = min_value <= max_value + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + if matched: + computation_args = computation_args[:-3] + [ + "hardtanh", + [min_value, max_value], + "", + ] + return L[computation_op](*computation_args) + else: + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + _binary_attr = { + aten.add: "add", + ops.add: "add", + aten.sub: "sub", + ops.sub: "sub", + } + + def _is_valid_binary(match, computation_op, binary_op): + binary_nodes = filter_nodes(match.nodes, binary_op) + if len(binary_nodes) < 1: + return False + + def get_meta_value(argument: torch.fx.node.Argument): + # Only torch.fx.Node is expected to have meta. + if isinstance(argument, torch.fx.Node): + return argument.meta.get("val", None) + return None + + if any( + not isinstance(get_meta_value(n.args[0]), torch.Tensor) + or not isinstance(get_meta_value(n.args[1]), torch.Tensor) + for n in binary_nodes + ): + return False + # check alpha is one. + if any( + get_arg_value(n, 2, kwarg_name="alpha") != 1.0 + and get_arg_value(n, 2, kwarg_name="alpha") is not None + for n in binary_nodes + ): + return False + + def _check_input_sizes(n, computation_op): + # Check if the tensor shape of the 'other' node is the same as or + # can be broadcasted to the tensor shape of the computation node. + computation_node = ( + n.args[0] if n.args[1] is match.kwargs["other"] else n.args[1] + ) + assert computation_node.target == computation_op + computation_node_size = get_meta_value(computation_node).size() + if computation_op is mkldnn._linear_pointwise.default: + broadcast_sizes = [] + if len(computation_node_size) >= 2: + broadcast_sizes = [ + torch.Size( + [1 for _ in range(len(computation_node_size) - 1)] + + [computation_node_size[-1]] + ), + ] + else: + assert len(computation_node_size) > 2 + broadcast_sizes = [ + torch.Size( + [computation_node_size[0], computation_node_size[1]] + + [1 for _ in range(len(computation_node_size) - 2)] + ), + torch.Size( + [1, computation_node_size[1]] + + [1 for _ in range(len(computation_node_size) - 2)] + ), + torch.Size([1 for _ in range(len(computation_node_size))]), + ] + return ( + get_meta_value(match.kwargs["other"]).size() + in [ + computation_node_size, + ] + + broadcast_sizes + ) + + if any( + not _check_input_sizes(n, computation_op) + or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device + or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype + for n in binary_nodes + ): + return False + # check args[0] and args[1] is not same + if any(n.args[0] == n.args[1] for n in binary_nodes): + return False + return True + + def _is_valid_computation_binary(computation_op, binary_op, other_index=None): + def fn(match): + if not _is_single_computation_op(computation_op)(match): + return False + if not _is_valid_binary(match, computation_op, binary_op): + return False + return True + + return fn + + def _get_remaining_users(extra_input_node, compute_node): + # Think about this pattern: + # ReLU + # / \ + # Conv1 + # / \ + # Conv2 + # \ / + # Add + # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add. + # The Conv1 is the ancestor node of the current compute node (Conv2). + # This indicates that the buffer of ReLU has completed all its usage, + # So we can safely make changes to it now by doing Conv2->Add inplace fusion. + # Take above case as example: + # * extra_input_node: ReLU + # * compute_node: Conv2 + # _get_remaining_users will return the users of extra_input_node which are not + # ancestor node of compute_node. + def _is_ancestor_node(_current_node, _ancestor_node): + # Check whether _ancestor_node is the ancestor node of _current_node + _node_list = [_current_node] + _visited_nodes = OrderedSet[torch.fx.Node]() + while len(_node_list) != 0: + _current_node = _node_list.pop(0) + if _current_node not in _visited_nodes: + _visited_nodes.add(_current_node) + if _current_node == _ancestor_node: + return True + elif isinstance( + _current_node, torch.fx.Node + ) and _current_node.op not in ["placeholder", "output", "get_attr"]: + for input in _current_node.all_input_nodes: + _node_list.append(input) # noqa: PERF402 + return False + + return [ + user + for user in list(extra_input_node.users) + if not _is_ancestor_node(compute_node, user) + ] + + def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index): + def fn(match): + if not _is_valid_computation_binary(computation_op, binary_op)(match): + return False + binary_nodes = filter_nodes(match.nodes, binary_op) + + def _get_compute_node(_binary_node, _other_index): + assert len(_binary_node.all_input_nodes) == 2, ( + "Binary node should have 2 input nodes." + ) + _compute_index = 1 if (_other_index == 0) else 0 + return _binary_node.args[_compute_index] + + def _other_input_not_inplaceable(_binary_node, _other_index): + _compute_node = _get_compute_node(_binary_node, _other_index) + return ( + len( + _get_remaining_users( + _binary_node.args[_other_index], _compute_node + ) + ) + > 1 + or _binary_node.args[_other_index] == _compute_node.args[0] + ) + + if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes): + return False + if any( + n.args[other_index].op in ["placeholder", "output"] + for n in binary_nodes + ): + return False + return True + + return fn + + def _register_binary_unary_fusion_lowering( + pattern, + computation_op, + binary_op, + fusion_op, + unary_attr=None, + ): + @register_lowering_pattern( + pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op) + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += ( + len(match.nodes) + ) + return L[fusion_op](*computation_args) + + return fn + + def _can_be_inplace(_other): + return not ( + isinstance(_other.data, ir.BaseView) + or len(_other.get_inputs_that_alias_output()) > 0 + ) + + def _register_binary_unary_maybe_inplace_fusion_lowering( + pattern, + computation_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + unary_attr=None, + other_index=None, + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_binary_inplace( + computation_op, binary_op, other_index + ), + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += ( + len(match.nodes) + ) + # Make sure the other is not an alias or mutation(fx side doesn't has such info). + other.realize() + if not _can_be_inplace(other) or other.data.shape != list( + match.nodes[0].meta["val"].size() + ): + return L[outplace_fusion_op](*computation_args) + return L[inplace_fusion_op](*computation_args) + + return fn + + computation_ops = [ + mkldnn._convolution_pointwise.default, + mkldnn._linear_pointwise.default, + mkldnn._convolution_transpose_pointwise.default, + ] + + class UnaryAttr: + def __init__( + self, op_name: str, scalars_attr=None, algorithm_attr=None + ) -> None: + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + def _register_unary_fusion(): + computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call] + + def _unary_fusion_patterns(lowp_dtype): + replacement_unary_fusion_patterns = { + UnaryAttr("gelu", algorithm_attr="tanh"): [ + _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("gelu", algorithm_attr="none"): [ + _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardswish"): [ + _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardsigmoid"): [ + _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("swish"): [ + _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + } + if not lowp_dtype: + call_user1 = [call_fn(users=1) for call_fn in computation_call_fns] + replacement_unary_fusion_patterns.update( + { + UnaryAttr("relu"): [ + _combined_fusion(u, aten.relu) for u in call_user1 + ], + UnaryAttr("sigmoid"): [ + _combined_fusion(u, aten.sigmoid) for u in call_user1 + ], + UnaryAttr("tanh"): [ + _combined_fusion(u, aten.tanh) for u in call_user1 + ], + } + ) + + return replacement_unary_fusion_patterns + + for lowp_dtype in [torch.bfloat16, torch.float16, None]: + replace_patterns = _unary_fusion_patterns(lowp_dtype) + for unary_attr, patterns in replace_patterns.items(): + _register_unary_fusion_lowering( + patterns[0], unary_attr, computation_ops[0], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[1], unary_attr, computation_ops[1], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[2], unary_attr, computation_ops[2], lowp_dtype + ) + _leaky_relu_patterns = [ + _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops): + _register_leaky_relu_fusion_lowering( + pattern, computation_op, lowp_dtype + ) + hardtanh_patterns = [ + _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(hardtanh_patterns, computation_ops): + _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype) + + def _register_inplace_fusion(): + binary_ops = [aten.add, ops.add] + inplace_fusion_op = mkldnn._convolution_pointwise_.binary + outplace_fusion_op = mkldnn._convolution_pointwise.binary + conv_call = _conv_call(users=1) + conv_op = computation_ops[0] + for binary_op in binary_ops: + binary_v1 = _binary_fusion_v1(conv_call, binary_op) + binary_unary_v1 = _combined_fusion(binary_v1, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + ) + binary_v2 = _binary_fusion_v2(conv_call, binary_op) + binary_unary_v2 = _combined_fusion(binary_v2, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + ) + + def _register_binary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [ + mkldnn._convolution_pointwise.binary, + mkldnn._linear_pointwise.binary, + ] + _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern = _binary_fusion_v2(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + for binary_op in [aten.add, ops.add]: + pattern = _binary_fusion_v1(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + def _register_binary_unary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [mkldnn._convolution_pointwise.binary] + _computation_user_1 = [_conv_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern_v1 = _combined_fusion( + _binary_fusion_v2(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v1, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + for binary_op in [aten.add, ops.add]: + pattern_v2 = _combined_fusion( + _binary_fusion_v1(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v2, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + + def _recover_linear(): + # convert reshape+linear+reshape to a single linear for applying fusion path. + # concat_linear (pass_number=0) -> mkldnn_linear_pack (pass_numer=1) -> _recover_linear(pass_number=2) + @register_freezing_graph_pattern( + CallFunction( + aten.reshape.default, + CallFunction( + mkldnn._linear_pointwise.default, + CallFunction( + aten.reshape.default, + Arg(), + KeywordArg("reshape_1"), + _users=MULTIPLE, + ), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("reshape_2"), + ), + pass_number=2, + ) + def reshape_linear_reshape_pattern(match, *args, **kwargs): + def get_val(val): + return val if isinstance(val, int) else val.meta.get("val") + + reshape_1 = kwargs.get("reshape_1") + reshape_2 = kwargs.get("reshape_2") + assert isinstance(reshape_1, list) + assert isinstance(reshape_2, list) + assert len(reshape_1) == 2 + + graph = match.graph + reshape_2_node = match.output_node() + linear_input_node = reshape_2_node.args[0].args[0].args[0] + # check linear's input's shape[:-1] == reshape_2[:-1] + # and check product(reshape_2[:-1]) == reshape_1[0] + can_remove_reshape = linear_input_node.meta.get("val").shape[ + :-1 + ] == torch.Size([get_val(val) for val in reshape_2[:-1]]) + can_remove_reshape = can_remove_reshape and ( + reduce( + operator.mul, + [get_val(val) for val in reshape_2[:-1]], + ) + == get_val(reshape_1[0]) + ) + + if can_remove_reshape: + repl = graph.call_function(mkldnn._linear_pointwise.default, args) + repl.meta.update(reshape_2_node.meta) + reshape_2_node.replace_all_uses_with(repl) + old_linear_node = reshape_2_node.args[0] + reshape_1_node = old_linear_node.args[0] + graph.erase_node(reshape_2_node) + graph.erase_node(old_linear_node) + if len(reshape_1_node.users) == 0: + graph.erase_node(reshape_1_node) + counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"] += 1 + counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"] += len( + match.nodes + ) + + def is_linear_add_bias(match): + add_node = match.output_node() + linear_node = add_node.args[0] + device_type = add_node.meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + transpose_weight_node = mkldnn_device_op.get_linear_transpose_weight( + linear_node.args[1] + ) + weight_meta = transpose_weight_node.args[0].meta.get("val") + bias_node = add_node.args[1] + if isinstance(bias_node, int): + # we only folding bias if it is a constant + return False + bias_meta = add_node.args[1].meta.get("val") + if weight_meta is None or bias_meta is None: + return False + + if bias_meta.dtype != weight_meta.dtype: + return False + return ( + linear_node.args[2] is None + and bias_meta.dim() == 1 + and bias_meta.size(0) == weight_meta.size(1) + ) + + # convert linear+bias to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.add.Tensor, + CallFunction(mkldnn._linear_pointwise.default, *_linear_args), + Arg(), + ), + pass_number=2, + extra_check=is_linear_add_bias, + ) + def linear_bias_pattern(match, *args): + graph = match.graph + add_node = match.output_node() + linear_node = add_node.args[0] + new_args = list(linear_node.args) + new_args[2] = add_node.args[1] + repl = graph.call_function( + mkldnn._linear_pointwise.default, tuple(new_args) + ) + repl.meta.update(add_node.meta) + add_node.replace_all_uses_with(repl) + match.erase_nodes() + counters["inductor"]["mkldnn_linear_bias_matcher_count"] += 1 + counters["inductor"]["mkldnn_linear_bias_matcher_nodes"] += len(match.nodes) + + def _is_packable_mkldnn_rnn_layer(match): + lstm_node = match.output_node() + POS_WEIGHTS = [1, 2] + POS_INPUTS = [0, 5, 6] + POS_ARGS = POS_WEIGHTS + POS_INPUTS + # Weights should be Constant + if any( + lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS + ): + return False + + # Meta info for weights and inputs should be available + if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS): + return False + + # Check device + if any( + lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu" + for POS_ARG in POS_ARGS + ): + return False + + # Check dtype + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16 + and not is_mkldnn_bf16_supported("cpu") + for POS_ARG in POS_ARGS + ): + return False + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16 + and not is_mkldnn_fp16_supported("cpu") + for POS_ARG in POS_ARGS + ): + return False + + return True + + def _is_packable_convolution(match): + """ + Check if the node is supported for MKLDNN convolution. + """ + conv_node = match.output_node() + device_type = conv_node.meta.get("val").device.type + # The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device. + if match.kwargs["is_transposed"] and device_type == "xpu": + return False + + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + input_size = input_meta_value.shape + if conv_node.args[1].op != "get_attr": + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type not in SUPPORTED_MKLDNN_DEVICES + or (meta_value.dim() != 4 and meta_value.dim() != 5) + ): + return False + + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not is_mkldnn_bf16_supported(device_type): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not is_mkldnn_fp16_supported(device_type): + return False + is_transposed = conv_node.args[-3] + if is_transposed: + # TODO: Support dynamic shape case for MKLDNN conv transpose. + if has_free_symbols(input_size): + return False + groups = conv_node.args[-1] + in_channels = weight_meta_value.size(0) + # doesn't support group_depthwise_conv_transpose. + if groups > 1 and groups == in_channels: + return False + # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big + output_paddings = conv_node.args[-2] + strides = conv_node.args[3] + if any( + output_padding >= stride + for output_padding, stride in zip(output_paddings, strides) + ): + return False + return True + + def _is_packable_linear(match): + """ + Check if the node is supported for MKLDNN linear. + """ + + def is_const_or_cat_by_const(weight): + if weight.op == "get_attr": + return True + if weight.target != aten.cat.default: + return False + return all(arg.op == "get_attr" for arg in weight.args[0]) + + linear_node = match.output_node() + # mkldnn linear only supports beta=1or0 and alpha=1 + if linear_node.target == aten.addmm.default: + alpha = linear_node.kwargs.get("alpha", 1.0) + beta = linear_node.kwargs.get("beta", 1.0) + if (beta != 0.0 and beta != 1.0) or alpha != 1.0: + return False + # weight_idx is 1 for aten.mm and is 2 for aten.addmm + weight_idx = 2 if linear_node.target == aten.addmm.default else 1 + if not is_const_or_cat_by_const(linear_node.args[weight_idx]): + return False + input_meta_value = linear_node.args[weight_idx - 1].meta.get("val") + weight_meta_value = linear_node.args[weight_idx].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + batch_size = input_meta_value.shape[0] + if ( + input_meta_value.dtype == torch.float64 + or weight_meta_value.dtype == torch.float64 + ): + return False + is_lp_weight = weight_meta_value.dtype in ( + torch.bfloat16, + torch.float16, + ) + # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. + # on aarch64, use mkldnn op for fp32 as well if acl is enabled + if ( + not is_lp_weight + and not mkldnn._is_mkldnn_acl_supported() + and ((not torch._C.has_mkl) or has_free_symbols(batch_size)) + ): + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 2 + ): + return False + if weight_idx == 2: + bias_meta_value = linear_node.args[0].meta.get("val") + if ( + bias_meta_value is None + or meta_value.device.type != "cpu" + or bias_meta_value.dim() != 1 + or bias_meta_value.size(0) != weight_meta_value.size(1) + ): + return False + + device_type = input_meta_value.device.type + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not is_mkldnn_bf16_supported(device_type): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not is_mkldnn_fp16_supported(device_type): + return False + return True + + _aten_conv_args = ( + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + KeywordArg("is_transposed"), + Arg(), + Arg(), + ) + + _aten_mkldnn_rnn_layer_args = ( + Arg(), # input + Arg(), # weight0 + Arg(), # weight1 + Arg(), # weight2 + Arg(), # weight3 + Arg(), # hx_ + Arg(), # cx_ + KeywordArg("reverse"), # reverse + Arg(), # batch_sizes + Arg(), # mode + Arg(), # hidden_size + Arg(), # num_layers + Arg(), # has_biases + Arg(), # bidirectional + Arg(), # batch_first + Arg(), # train + ) + + def _register_weight_pack_pass(): + @register_freezing_graph_pattern( + CallFunction(aten.convolution.default, *_aten_conv_args), + extra_check=_is_packable_convolution, + ) + def convolution(match, *args, **kwargs): + is_transposed = kwargs.get("is_transposed") + assert isinstance(is_transposed, bool) + graph = match.graph + conv_node = match.output_node() + device_type = conv_node.args[0].meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + input_size = conv_node.args[0].meta.get("val").shape + with graph.inserting_before(conv_node): + constant_args = [args[4], args[3], args[5], args[-1]] + packed_conv_op = mkldnn._convolution_pointwise.default + if is_transposed: + constant_args.insert(1, args[-2]) # output_padding + packed_conv_op = mkldnn._convolution_transpose_pointwise.default + + if not has_free_symbols(input_size): + packed_weight_node = mkldnn_device_op.pack_conv_weight( + graph, + is_transposed, + args[1], + constant_args, + input_size, + ) + else: + assert not is_transposed + # For dynamic shape case, we need to pack weight in runtime. + packed_weight_node = args[1] + + packed_conv_inputs = ( + (args[0], packed_weight_node, args[2]) + + tuple(constant_args) + + ("none", [], "") + ) + packed_conv_node = graph.create_node( + "call_function", packed_conv_op, tuple(packed_conv_inputs) + ) + conv_node.replace_all_uses_with(packed_conv_node) + packed_conv_node.meta.update(conv_node.meta) + graph.erase_node(conv_node) + counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + @register_freezing_graph_pattern( + CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args), + extra_check=_is_packable_mkldnn_rnn_layer, + ) + def mkldnn_rnn_layer(match, *args, **kwargs): + def get_item(graph, node, index): + return graph.call_function(operator.getitem, (node, index)) + + graph = match.graph + lstm_node = match.output_node() + weight0, weight1 = args[1:3] + reverse = kwargs.get("reverse") + packed_lstm_op = aten.mkldnn_rnn_layer.default + hidden_size = args[9] + has_biases = args[11] + batch_first = args[13] + with graph.inserting_before(lstm_node): + packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default + packed_weight_inputs = ( + weight0, + weight1, + hidden_size, + reverse, + has_biases, + batch_first, + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, packed_weight_inputs, {}, "name" + ) + packed_weight_items = [ + get_item(graph, packed_weight_node, i) for i in range(2) + ] + pack_lstm_inputs = ( + args[0], + *packed_weight_items, + args[3], + args[4], + args[5], + args[6], + reverse, + *args[7:], + ) + + packed_lstm_node = graph.create_node( + "call_function", packed_lstm_op, args=pack_lstm_inputs + ) + lstm_node.replace_all_uses_with(packed_lstm_node) + packed_lstm_node.meta.update(lstm_node.meta) + graph.erase_node(lstm_node) + counters["inductor"]["mkldnn_rnn_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_rnn_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + @register_freezing_graph_pattern( + CallFunction( + aten.addmm.default, + Arg(), + Arg(), + Arg(), + beta=KeywordArg("beta"), + alpha=KeywordArg("alpha"), + ), + extra_check=_is_packable_linear, + pass_number=1, + ) + @register_freezing_graph_pattern( + CallFunction(aten.mm.default, Arg(), Arg()), + extra_check=_is_packable_linear, + pass_number=1, + ) + def linear(match, *args, **kwargs): + graph = match.graph + linear_node = match.output_node() + input = args[0] if linear_node.target == aten.mm.default else args[1] + bias = ( + None + if linear_node.target == aten.mm.default + or ( + linear_node.target == aten.addmm.default + and linear_node.kwargs.get("beta", 1.0) == 0.0 + ) + else args[0] + ) + weight = args[1] if linear_node.target == aten.mm.default else args[2] + device_type = input.meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + with graph.inserting_before(linear_node): + transpose_weight_node = graph.create_node( + "call_function", aten.permute.default, (weight, (1, 0)) + ) + weight_dtype = weight.meta.get("val").dtype + is_lp_weight = weight_dtype in ( + torch.bfloat16, + torch.float16, + ) + batch_size = input.meta.get("val").shape[0] + if has_free_symbols(batch_size): + assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), ( + f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" + ) + packed_weight_node = mkldnn_device_op.pack_linear_weight( + graph, is_lp_weight, transpose_weight_node, batch_size + ) + packed_linear_node = mkldnn_device_op.pack_linear( + graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ) + + linear_node.replace_all_uses_with(packed_linear_node) + packed_linear_node.meta.update(linear_node.meta) + graph.erase_node(linear_node) + counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_linear_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + def _eliminate_duplicate_packed_nodes(gm): + """ + Combine packed weight nodes with the same inputs to reduce memory usage. + for example: + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(32, 32, bias=True) + + def forward(self, x): + return self.linear(self.linear(x)) + + the above's packed weight nodes are duplicate if two linear calls have same input size. + """ + if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): + return gm + + packed_weight_ops = [ + torch._C._nn.mkldnn_reorder_conv2d_weight, + torch._C._nn.mkldnn_reorder_conv3d_weight, + mkldnn._reorder_convolution_transpose_weight, + mkldnn._reorder_linear_weight, + mkldnn._reorder_mkldnn_rnn_layer_weight, + ] + if torch._C.has_mkl: + packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight) + + for node in gm.graph.nodes: + if node.target in packed_weight_ops and len(node.args[0].users) > 1: + for user_node in list(node.args[0].users.keys()): + if ( + user_node.target == node.target + and user_node != node + and user_node.args == node.args + ): + user_node.replace_all_uses_with(node) + gm.graph.erase_node(user_node) + + @functools.cache + def _mkldnn_fusion_init(): + # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. + # Otherwise even the matmul or innerproduct can not be accelerated with acl + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and not torch.ops.mkldnn._is_mkldnn_acl_supported() + ): + _register_unary_fusion() + _register_inplace_fusion() + _register_binary_unary_fusion() + _register_binary_fusion() + _register_quantization_lowerings() + _register_woq_lowerings() + + @functools.cache + def _mkldnn_weight_pack_init(): + if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): + _register_weight_pack_pass() + _recover_linear() + _register_quantization_weight_pack_pass() + _register_int8_woq_concat_linear_pattern() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b140b49d20d41b39ece2b68f2c9d4a4622c793 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/numeric_utils.py @@ -0,0 +1,213 @@ +# mypy: allow-untyped-defs +import gc +import logging +import os +import random +import traceback + +import numpy + +import torch +import torch.optim as optim +from torch.utils._ordered_set import OrderedSet + +from .. import config + + +logger: logging.Logger = logging.getLogger(__name__) + +MAIN_RANDOM_SEED = 1337 + +# Set the CUBLAS_WORKSPACE_CONFIG environment variable +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +# If the two forward functions involve any non-deterministic operations, +# such as certain types of parallelism or asynchronous execution, +# this can also lead to different outputs. +def set_deterministic() -> None: + """Make torch manual seed deterministic.""" + + torch.manual_seed(MAIN_RANDOM_SEED) + random.seed(MAIN_RANDOM_SEED) + numpy.random.seed(MAIN_RANDOM_SEED) + torch.use_deterministic_algorithms(True) + + +def clean_memory() -> None: + """Clean memory to avoid OOM.""" + gc.collect() + torch.cuda.empty_cache() + + +# We compare the numerical results before and after pre/post grad fx passes +# transformation to make sure the numerical results are the same. +def compare_dict_tensors(dict_base, dict_control, precision): + if len(OrderedSet(dict_base.keys())) != len(OrderedSet(dict_control.keys())): + logger.warning("Mismatch keys found before and after pre/post grad fx passes.") + logger.debug("keys before pre/post grad fx passes %s", dict_base.keys()) + logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) + return False + is_allclose = True + for key in dict_base.keys(): + if key not in dict_control: + logger.warning( + "Mismatch parameter name %s does not exist after pre/post grad fx passes", + key, + ) + # Some parameters have `None`, and not every param has a valid .grad field, we skip them + if dict_base[key] is None or dict_control[key] is None: + continue + if not torch.allclose( + dict_base[key], + dict_control[key], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.warning( + "Mismatch parameter values found before and after pre/post grad fx passes." + ) + logger.debug("value before pre/post grad fx passes %s", dict_base[key]) + logger.debug("value after pre/post grad fx passes %s", dict_control[key]) + is_allclose = False + return is_allclose + + +def compare_tuple_tensors(tuple_base, tuple_control, precision): + if len(tuple_base) != len(tuple_control): + logger.warning( + "Mismatch fw output length. before transformation: %s, after transformation: %s", + len(tuple_base), + len(tuple_control), + ) + return False + is_allclose = True + for i in range(len(tuple_base)): + # Some parameters have `None`, we skip them + if tuple_base[i] is None or tuple_control[i] is None: + continue + if not torch.allclose( + tuple_base[i], + tuple_control[i], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.debug( + "forward output before pre/post grad fx passes %s", tuple_base[i] + ) + logger.debug( + "forward output after pre/post grad fx passes %s", tuple_control[i] + ) + is_allclose = False + return is_allclose + + +def compare_parameters(model_base, model_control, precision): + return compare_dict_tensors( + dict(model_base.named_parameters()), + dict(model_control.named_parameters()), + precision, + ) + + +def compare_forward_output(pred_base, pred_control, precision): + return compare_tuple_tensors( + pred_base, + pred_control, + precision, + ) + + +def compare_gradients(model_base, model_control, precision): + grad_base = {key: param.grad for key, param in model_base.named_parameters()} + grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()} + return compare_dict_tensors( + grad_base, + grad_pt2, + precision, + ) + + +def run_model( + model_base, model_control, model_input, num_iterations=10, precision=1e-4 +): + clean_memory() + for i in range(num_iterations): + logger.info("start %s iteration", i) + set_deterministic() + pred_base = model_base(*model_input) + set_deterministic() + pred_control = model_control(*model_input) + + res = compare_parameters(model_base, model_control, precision) + logger.info("compare parameters. Numerical result : %s", res) + + res = compare_forward_output(pred_base, pred_control, precision) + logger.info("compare loss/predict. Numerical result : %s", res) + # tensor may not have a grad_fn + try: + _ = pred_base[0].sum().backward(retain_graph=True) + _ = pred_control[0].sum().backward(retain_graph=True) + res = compare_gradients(model_base, model_control, precision) + logger.info("compare param grad. Numerical result : %s", res) + except Exception: + logger.exception("Exception when comparing gradients") + traceback.print_exc() + + if config.fx_passes_numeric_check["requires_optimizer"]: + try: + optimizer_base = optim.SGD( + [param for name, param in model_base.named_parameters()], lr=0.01 + ) + optimizer_base.step() + + optimizer_control = optim.SGD( + [param for name, param in model_control.named_parameters()], lr=0.01 + ) + optimizer_control.step() + + res = compare_parameters(model_base, model_control, precision) + logger.info( + "compare parameters with optimizer added. Numerical result : %s", + res, + ) + except Exception: + logger.exception( + "Exception when optimizer is added to check parameter names" + ) + traceback.print_exc() + else: + logger.warning( + "no parameter with optimizer to compare with length %s before transformation" + " and the length %s after transformation", + len(dict(model_base.named_parameters())), + len(dict(model_control.named_parameters())), + ) + + +def numeric_check_if_enabled( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations, + precision, +): + # need to topo-sort graphmodule before we run the model, + # otherwise it may fail as refer before def + # fail silently in order not to block the model run + try: + with torch.autograd.set_detect_anomaly(True): + run_model( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations=num_iterations, + precision=precision, + ) + except Exception as e: + logger.warning( + "Runtime numeric check failed in pre grad fx passes with error: %s", e + ) + traceback.print_exc() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..d2dfc3d9e4d0d93574d3a5eaeaac492c18abf8b3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pad_mm.py @@ -0,0 +1,925 @@ +import functools +import itertools +import operator +import typing +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import torch +import torch._inductor.runtime.runtime_utils +from torch import Tensor +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor import utils +from torch._inductor.autoheuristic.autoheuristic import ( + AHContext, + AutoHeuristic, + LocalFeedback, +) +from torch._inductor.autoheuristic.autoheuristic_utils import ( + context_add_strides, + context_add_using_tf32, + pad_mm_operations, + pad_mm_precondition, +) +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._mode_utils import no_dispatch + +from ...utils._triton import has_triton +from ..pattern_matcher import ( + fwd_only, + gen_register_replacement, + joint_fwd_bwd, + Match, + ReplaceFn, + SearchFn, +) + + +aten = torch.ops.aten + + +# This flag is only used for testing purpose. +# Changing it to True will ignore comparing do_bench times +# between original pattern and padded one. +_skip_do_bench_times = False + + +def fetch_fake_tensors(match: Match, kwarg_names: Sequence[str]) -> list[Tensor]: + kwargs = match.kwargs + return [kwargs[name].meta["val"] for name in kwarg_names] + + +def unwrap_fake_args( + *arg_names: str, +) -> Callable[[Callable[..., Any]], Callable[[Match], Any]]: + def decorator(func: Callable[..., Any]) -> Callable[[Match], Any]: + def wrapper(match: Match) -> Any: + fake_tensors = fetch_fake_tensors(match, arg_names) + return func(*fake_tensors) + + return wrapper + + return decorator + + +def get_alignment_size(x: Tensor) -> int: + return get_alignment_size_dtype(x.dtype) + + +def get_alignment_size_dtype(dtype: torch.dtype) -> int: + if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16: + return 8 + elif dtype == torch.float32 or dtype == torch.float: + return 4 + else: + return 0 + + +def check_device(a: Tensor, b: Tensor) -> bool: + return a.is_cuda and b.is_cuda + + +def check_dtype(a: Tensor, b: Tensor) -> bool: + return a.is_floating_point() and b.is_floating_point() + + +def should_pad_common( + mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None +) -> bool: + # It's fine we have symbolic shapes or strides as long as they + # have hints. Later, we will make sure we only pad non-symbolic dimensions. + def valid_shape_and_stride(t: Optional[Tensor]) -> bool: + if t is None: + return True + + symbolic_cnt = 0 + for x in t.size(): + if isinstance(x, int): + continue + elif utils.is_symbolic(x): + if not x.node.has_hint(): + return False + symbolic_cnt += 1 + else: + return False + # filter out cases where all dimensions are symbolic + if symbolic_cnt == len(t.size()): + return False + return all( + isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) + for x in t.stride() + ) + + return ( + torch._inductor.config.shape_padding + and check_device(mat1, mat2) + and check_dtype(mat1, mat2) + and all(valid_shape_and_stride(t) for t in (mat1, mat2, input)) + ) + + +def get_padded_length(x: Union[int, torch.SymInt], alignment_size: int) -> int: + # we don't pad x if it is symbolic + if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0: + return 0 + + # ignore dim that can be squeezed away + if x == 1: + return 0 + + return int((x // alignment_size + 1) * alignment_size) - x + + +def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor: + if padded_length == 0: + return x + pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) + return torch.cat([x, pad], dim=dim) + + +def addmm_pattern( + input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float +) -> Tensor: + return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + +def should_pad_addmm(match: Match) -> bool: + mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input")) + return should_pad_common(mat1, mat2, input) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.addmm, input=input + ) + + +def pad_addmm( + input: Optional[Tensor], + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + beta: float = 1.0, + alpha: float = 1.0, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + # for paddings, dim order is reversed for some reasons + # and for every dim, we need to specify left and right padding + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length + ) + + # the add broadcasts, so we only pad if the dimension != 1 + if input is not None: + if n_padded_length != 0: + if input.dim() == 2 and input.shape[1] != 1: + input = pad_dim(input, n_padded_length, 1) + elif input.dim() == 1 and input.shape[0] != 1: + input = pad_dim(input, n_padded_length, 0) + if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1: + input = pad_dim(input, m_padded_length, 0) + + res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + if m_padded_length != 0: + res = res[:-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :-n_padded_length] + return res + + +def addmm_replace( + input: Optional[Tensor], + mat1: Tensor, + mat2: Tensor, + beta: float = 1.0, + alpha: float = 1.0, +) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + return pad_addmm( + input, + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + beta, + alpha, + ) + + +def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: + denominator = M * K + N * K + M * N + if denominator == 0: + return False + arithmetic_intensity = (M * N * K) / denominator + + # we have experienced some large perf hits in this case, even in bandwidth bound regimes + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and torch.cuda.get_device_capability() < (9, 0) + ): # doesn't repro on h100s: + return True + + # Fails with AMD + try: + machine_balance = ( + 1000 * utils.get_device_tflops(dtype) + ) / utils.get_gpu_dram_gbps() + except Exception: + return True + + # dram_gbps might be underestimating bandwidth because of cache. + # if we estimate machine balance too low we might miss some speedups, + # if we estimate too high there will be unnecessary compilation time increase. + # TODO - finetune coefficient here. As a reference point, Triton mm model assumes + # 80% of reads are in cache and cache is 4x faster than dram_gbps + machine_balance = machine_balance * 0.5 + + return arithmetic_intensity > machine_balance + + +@functools.cache +def get_pad_cache() -> torch._inductor.codecache.LocalCache: + return torch._inductor.codecache.LocalCache() + + +def get_cached_should_pad(key: str) -> bool: + return get_pad_cache().lookup(key) # type: ignore[return-value] + + +def set_cached_should_pad(key: str, value: bool) -> None: + return get_pad_cache().set_value(key, value=value) + + +def get_cached_base_mm_benchmark_time(key: str) -> float: + return get_pad_cache().lookup(key) # type: ignore[return-value] + + +def set_cached_base_mm_benchmark_time(key: str, value: float) -> None: + return get_pad_cache().set_value(key, value=value) + + +def should_pad_bench_key( + match: Match, + mat1: Tensor, + mat2: Tensor, + op: torch._ops.OpOverloadPacket, + input: Optional[Tensor] = None, + is_base_time_key: bool = False, +) -> str: + def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]: + return (t.shape, t.stride(), t.dtype) + + tf32_key = ( + None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32 + ) + + def fmt_pad(name: str) -> Optional[str]: + if is_base_time_key: + return None + return f"exclude_pad:{should_exclude_padding_time(match, name)}" + + key = ( + tensor_key(mat1), + tensor_key(mat2), + fmt_pad("mat1"), + fmt_pad("mat2"), + op, + input if input is None else tensor_key(input), + tf32_key, + ) + + key = str(key) + if is_base_time_key: + key = f"base mm time: {key}" + return key + + +def get_non_view_def(node: torch.fx.Node) -> torch.fx.Node: + if node.op == operator.getitem: + return get_non_view_def(node.args[0]) # type: ignore[arg-type] + + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and utils.is_view(node.target) + ): + return get_non_view_def(node.all_input_nodes[0]) + + return node + + +def should_exclude_padding_time(match: Match, arg_name: str) -> bool: + node_def = get_non_view_def(match.kwargs[arg_name]) + + # constant padding converts tensors to contiguous so even if the input tensor + # can be planned layout transform is not free. TODO - way to pad and preserve layout ? + if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous(): + return False + + # TODO - see issue https://github.com/pytorch/pytorch/issues/128889 + # We would only able to completely plan these out if we were only doing + # first dimension padding. non-first we would still need a copy + # because these outputs are fixed dense. + cannot_plan_output = [ + aten.mm.default, + aten.convolution.default, + aten.convolution_backward.default, + aten.bmm.default, + aten.addmm.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_dot_product_efficient_attention.default, + ] + + if node_def.target in cannot_plan_output: + return False + + if ( + node_def.target == aten.cat.default + and len(node_def.all_input_nodes) + > torch._inductor.config.max_pointwise_cat_inputs + ): + return False + + # optimistically assume we should be able to memory plan away + # all non inputs + return node_def.op != "placeholder" + + +def should_pad(key: str, ori_time: float, pad_time: float) -> bool: + multiplier = 1.1 + # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable + # tradeoff between performance improvement from shape padding and overhead from additional memory ops + # TODO: Build a learned model which would be better than this heuristic + if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options: + multiplier = torch._inductor.config.post_grad_fusion_options[ + "shape_padding_multiplier" + ].get("value", 1.1) + counters["inductor"]["shape_padding_multiplier"] += 1 + should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier + set_cached_should_pad(key, should_pad) + return should_pad + + +def should_pad_mm_bf16(dtype: torch.dtype, M: int, N: int, K: int) -> bool: + # always force pad for mm with bf16 when the following are satisfied to avoid perf regression + large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ + "pad_aten_mm_pass" + ].get("k_threshold_to_pad", 8388608) + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and N % 2 == 1 + and K >= large_k_threshold_to_pad + and torch.cuda.get_device_capability() < (9, 0) + ): # doesn't repro on h100s: + return True + return False + + +def should_pad_bench(*args: Any, **kwargs: Any) -> bool: + with dynamo_timed( + "pad_mm_benchmark", + log_pt2_compile_event=False, + dynamo_compile_column_us="compile_time_autotune_time_us", + ): + return _should_pad_bench(*args, **kwargs) + + +def get_do_bench() -> Callable[[Callable[[], Any]], float]: + with dynamo_timed("pad_mm_benchmark_get_do_bench"): + return functools.partial( + torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu, + warmup=5, + ) + + +def _should_pad_bench( + match: Match, + mat1: Tensor, + mat2: Tensor, + op: torch._ops.OpOverloadPacket, + input: Optional[Tensor] = None, +) -> bool: + do_bench = get_do_bench() + + m_padded_length = 0 + n_padded_length = 0 + with no_dispatch(): + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + m = mat1.shape[0] + k = mat1.shape[1] + n = mat2.shape[1] + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + elif op is torch.ops.aten.bmm: + m = mat1.shape[1] + k = mat1.shape[2] + n = mat2.shape[2] + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + else: + return False + + if m_padded_length == k_padded_length == n_padded_length == 0: + return False + + def realize_symbols( + ds: Union[torch.Size, tuple[torch.SymInt, ...]], + ) -> list[int]: + return [d if isinstance(d, int) else d.node.hint for d in ds] + + if any( + dim == 0 + for dim in itertools.chain( + realize_symbols(mat1.shape), realize_symbols(mat2.shape) + ) + ): + return False + + if torch._inductor.config.force_shape_pad: + return True + + if ( + "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options + and should_pad_mm_bf16(mat1.dtype, m, n, k) + ): + return True + + if not has_triton(): + return False + + if not is_mm_compute_bound(m, k, n, mat1.dtype): + return False + + # We don't want to look up the cache for cases that are trivially false + # since it does file io + key = should_pad_bench_key(match, mat1, mat2, op, input) + + cached_pad = get_cached_should_pad(key) + if cached_pad is not None: + return cached_pad + + def realize_tensor(t): + if isinstance(t, FakeTensor): + size_hints = realize_symbols(t.size()) + stride_hint = realize_symbols(t.stride()) + real_size = ( + sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 + ) + real_t = torch.randn(real_size, dtype=t.dtype, device=t.device) + return torch.as_strided(real_t, size_hints, stride_hint) + else: + return torch.randn_like(t) + + mat1 = realize_tensor(mat1) + mat2 = realize_tensor(mat2) + + # since we key on whether or not the inputs can be memory planned, set cache for the + # original time which is unaffected by whether or not the input can be planned + ori_time_key = should_pad_bench_key( + match, mat1, mat2, op, input, is_base_time_key=True + ) + ori_time = get_cached_base_mm_benchmark_time(ori_time_key) + if ori_time is None and op is torch.ops.aten.addmm and input is not None: + # realize bias for addmm + input = realize_tensor(input) + + mat1_pad = mat1 + mat2_pad = mat2 + + is_bmm = op is torch.ops.aten.bmm + + mat1_pre_padded = should_exclude_padding_time(match, "mat1") + fns = [] + if mat1_pre_padded and (m_padded_length or k_padded_length): + mat1_pad = pad_mat1( + mat1_pad, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + is_bmm=is_bmm, + ) + + def write_pad(): + if is_bmm: + mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0) + else: + mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0) + + fns.append(write_pad) + + mat2_pre_padded = should_exclude_padding_time(match, "mat2") + if mat2_pre_padded and (k_padded_length or n_padded_length): + mat2_pad = pad_mat2( + mat2_pad, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + is_bmm=is_bmm, + ) + + def write_pad(): + if is_bmm: + mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0) + else: + mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0) + + fns.append(write_pad) + + if op is torch.ops.aten.addmm: + input_pad = None + if input is not None and input.is_cuda: + input_pad = torch.randn_like(input) + fns.append( + lambda: pad_addmm( + input_pad, + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + elif op is torch.ops.aten.mm: + fns.append( + lambda: pad_mm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + else: + fns.append( + lambda: pad_bmm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + + def orig_bench_fn(): + if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: + op(mat1, mat2) + else: + op(input, mat1, mat2) + + def pad_bench_fn(): + for fn in fns: + fn() + + if ( + torch._inductor.config.run_autoheuristic("pad_mm") + and op is torch.ops.aten.mm + ): + ah_should_pad = run_autoheuristic( + mat1, + mat2, + orig_bench_fn, + pad_bench_fn, + m_padded_length, + k_padded_length, + n_padded_length, + do_bench, + mat1_pre_padded, + mat2_pre_padded, + ori_time, + ori_time_key, + key, + ) + if ah_should_pad is not None: + return ah_should_pad + + if ori_time is None: + ori_time = do_bench(orig_bench_fn) + set_cached_base_mm_benchmark_time(ori_time_key, ori_time) + + pad_time = do_bench(pad_bench_fn) + return should_pad(key, ori_time, pad_time) + + +def get_context( + mat1: Tensor, + mat2: Tensor, + mat1_pre_padded: bool, + mat2_pre_padded: bool, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, +) -> AHContext: + context = AHContext() + + context.add_feature("m", mat1.shape[0]) + context.add_feature("k", mat1.shape[1]) + context.add_feature("n", mat2.shape[1]) + + context_add_strides(context, "mat1", mat1.stride()) + context_add_strides(context, "mat2", mat2.stride()) + + context.add_feature("m_padded_length", m_padded_length) + context.add_feature("k_padded_length", k_padded_length) + context.add_feature("n_padded_length", n_padded_length) + + context.add_feature("mat1_align_size", get_alignment_size(mat1)) + context.add_feature("mat2_align_size", get_alignment_size(mat2)) + + context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True) + context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True) + + context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True) + context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True) + + context_add_using_tf32(context, mat1.dtype) + return context + + +def run_autoheuristic( + mat1: Tensor, + mat2: Tensor, + orig_bench_fn: Callable[[], None], + pad_bench_fn: Callable[[], None], + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + do_bench: Callable[[Callable[[], Any]], float], + mat1_pre_padded: bool, + mat2_pre_padded: bool, + ori_time: float, + ori_time_key: str, + key: str, +) -> Optional[bool]: + def feedback_fn( + choice: str, + ) -> Optional[float]: + if choice == orig_choice: + return do_bench(orig_bench_fn) + elif choice == pad_choice: + return do_bench(pad_bench_fn) + return None + + def fallback() -> str: + return "autotune" + + orig_choice = "orig" + pad_choice = "pad" + choices = [orig_choice, pad_choice] + feedback = LocalFeedback(feedback_fn) # type: ignore[arg-type] + context = get_context( + mat1, + mat2, + mat1_pre_padded, + mat2_pre_padded, + m_padded_length, + k_padded_length, + n_padded_length, + ) + name = "pad_mm" + autoheuristic = AutoHeuristic( + fallback=fallback, + choices=choices, + feedback=feedback, + context=context, + name=name, + augment_context=pad_mm_operations(), + precondition=pad_mm_precondition, + ) + choice = autoheuristic.get_choice() + choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None} + ah_should_pad = choice2should_pad.get(choice, None) + + if torch._inductor.config.collect_autoheuristic(name): + ah_ori_time = autoheuristic.get_collected_feedback(orig_choice) + ah_pad_time = autoheuristic.get_collected_feedback(pad_choice) + + # if precondition is not satisfied, autoheuristic does not collect data + if ah_ori_time is not None and ah_pad_time is not None: + if ori_time is None: + set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time) + return should_pad(key, ah_ori_time, ah_pad_time) + if ah_should_pad is not None: + set_cached_should_pad(key, ah_should_pad) + return ah_should_pad + + +def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.mm(mat1, mat2) + + +def should_pad_mm(match: Match) -> bool: + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.mm + ) + + +def pad_mat1( + mat1: Tensor, *, m_padded_length: int, k_padded_length: int, is_bmm: bool = False +) -> Tensor: + if k_padded_length != 0 or m_padded_length != 0: + # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding + pad_arg = [0, k_padded_length, 0, m_padded_length] + if is_bmm: + pad_arg.extend((0, 0)) + return aten.constant_pad_nd(mat1, pad_arg) + else: + return mat1 + + +def pad_mat2( + mat2: Tensor, *, k_padded_length: int, n_padded_length: int, is_bmm: bool = False +) -> Tensor: + if k_padded_length != 0 or n_padded_length != 0: + # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding + pad_arg = [0, n_padded_length, 0, k_padded_length] + if is_bmm: + pad_arg.extend((0, 0)) + return aten.constant_pad_nd(mat2, pad_arg) + else: + return mat2 + + +def pad_mm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length + ) + res = aten.mm(mat1, mat2) + if m_padded_length != 0: + res = res[:-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :-n_padded_length] + return res + + +def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + return pad_mm( + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + ) + + +def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.bmm(mat1, mat2) + + +def should_pad_bmm(match: Match) -> bool: + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.bmm + ) + + +def pad_bmm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + is_bmm=True, + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + is_bmm=True, + ) + res = aten.bmm(mat1, mat2) + if m_padded_length != 0: + res = res[:, :-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :, :-n_padded_length] + return res + + +def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2)) + m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + return pad_bmm( + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + ) + + +@functools.cache +def _pad_mm_init() -> None: + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values dont actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + + dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + + dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + + dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True) + + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + rep = {"beta": 0.213377, "alpha": 0.113377} + + for pattern, replacement, args, workaround, extra_check in [ + ( + typing.cast(SearchFn, mm_pattern), + typing.cast(ReplaceFn, mm_replace), + [dim2a(), dim2b()], + {}, + should_pad_mm, + ), + ( + typing.cast(SearchFn, bmm_pattern), + typing.cast(ReplaceFn, bmm_replace), + [dim3a(), dim3b()], + {}, + should_pad_bmm, + ), + ( + typing.cast(SearchFn, addmm_pattern), + typing.cast(ReplaceFn, addmm_replace), + [dim1a(), dim2a(), dim2b()], + rep, + should_pad_addmm, + ), + ]: + assert isinstance(workaround, dict) # mypy is unable to infer the type properly + name = pattern.__name__ + + gen_register_replacement( + f"{name}_training", + pattern, + replacement, + args, + joint_fwd_bwd, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) + + gen_register_replacement( + f"{name}_inference", + pattern, + replacement, + args, + fwd_only, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..2f651a8d0e487b1a4821b160c4149a80d454bd23 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py @@ -0,0 +1,1793 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import operator +from collections import Counter, defaultdict +from typing import Any, Callable, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._inductor as inductor +import torch.utils._pytree as pytree +from torch import fx +from torch._decomp import register_decomposition +from torch._dynamo.utils import counters +from torch._inductor import comms +from torch._inductor.virtualized import ops +from torch._logging import trace_structured +from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype +from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq +from torch.utils._ordered_set import OrderedSet + +from .. import config, ir, pattern_matcher +from ..codegen.common import custom_backend_passes +from ..comms import remove_fsdp2_unsharded_param_graph_input_usage +from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage +from ..lowering import lowerings as L +from ..pattern_matcher import ( + _return_true, + Arg, + CallFunction, + CallFunctionVarArgs, + filter_nodes, + fwd_only, + get_arg_value, + get_mutation_region_id, + Ignored, + init_once_fakemode, + KeywordArg, + ListOf, + Match, + MultiOutputPattern, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) +from ..utils import ( + decode_device, + get_all_devices, + get_gpu_type, + is_gpu, + is_pointwise_use, + OPTIMUS_EXCLUDE_POST_GRAD, +) +from ..virtualized import V +from .b2b_gemm import B2B_GEMM_PASS +from .ddp_fusion import fuse_ddp_communication +from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS +from .micro_pipeline_tp import micro_pipeline_tp_pass +from .pre_grad import is_same_dict, save_inductor_dict +from .reinplace import reinplace_inplaceable_ops +from .split_cat import POST_GRAD_PATTERNS + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + + +def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): + """ + Passes that run on after grad. This is called once on the forwards + graph and once on the backwards graph. + + The IR here has been normalized and functionalized. + """ + GraphTransformObserver = functools.partial( + torch.fx.passes.graph_transform_observer.GraphTransformObserver, + subsystem="post_grad_passes", + ) + + if not torch._dynamo.config.skip_fsdp_hooks: + remove_fsdp2_unsharded_param_graph_input_usage(gm.graph) + + if config.dce: + # has some issues with mutation in inference mode + gm.graph.eliminate_dead_code() + + if is_inference and config.reorder_for_locality: + GraphTransformObserver(gm, "reorder_for_locality").apply_graph_pass( + reorder_for_locality + ) + + fake_tensor_updater = FakeTensorUpdater(gm.graph) + + if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass: + GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass( + post_grad_custom_pre_pass + ) + + if torch._C._has_mkldnn: + if ( + config.cpp.enable_grouped_gemm_template + and config.max_autotune + and "CPP" in config.max_autotune_gemm_backends + ): + from .mkldnn_fusion import grouped_gemm_pass + + grouped_gemm_pass(gm.graph) + + if config.cpp.enable_concat_linear: + from .quantization import concat_linear_woq_int4 + + # Concat linear optimization for WOQ int4 + concat_linear_woq_int4(gm) + + if config.pattern_matcher: + lazy_init() + GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass( + functools.partial(group_batch_fusion_passes, pre_grad=False) + ) + GraphTransformObserver(gm, "remove_noop_ops").apply_graph_pass(remove_noop_ops) + GraphTransformObserver(gm, "remove_assert_ops").apply_graph_pass( + remove_assert_ops + ) + for i, patterns in enumerate(pass_patterns): + GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass( + patterns.apply + ) + for pass_name in config.post_grad_fusion_options: + # skip all patterns for group batch fusions or quantization patterns + if pass_name in POST_GRAD_FUSIONS or pass_name in OPTIMUS_EXCLUDE_POST_GRAD: + continue + pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name] + inductor_before_change = save_inductor_dict( + [pattern_matcher_pass.pass_name] + ) + GraphTransformObserver(gm, pass_name).apply_graph_pass( + pattern_matcher_pass.apply + ) + if not is_same_dict(counters["inductor"], inductor_before_change): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"{pattern_matcher_pass.pass_name}_post_grad", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + if config.b2b_gemm_pass: + B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type] + + if config._micro_pipeline_tp: + micro_pipeline_tp_pass(gm.graph) + + if config._fuse_ddp_communication: + GraphTransformObserver(gm, "fuse_ddp_communication").apply_graph_pass( + lambda graph: fuse_ddp_communication( + graph, + config._fuse_ddp_communication_passes, + config._fuse_ddp_bucket_size, + ) + ) + + if post_grad_custom_post_pass := config.post_grad_custom_post_pass: + GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass( + post_grad_custom_post_pass + ) + + GraphTransformObserver(gm, "stable_sort").apply_graph_pass(stable_topological_sort) + + GraphTransformObserver(gm, "move_constructors_to_cuda").apply_graph_pass( + move_constructors_to_gpu + ) + + fake_tensor_updater.incremental_update() + + for device, custom_backend_pass in custom_backend_passes.items(): + if custom_backend_pass is not None: + gm_devices = [d.type for d in get_all_devices(gm)] + if device in gm_devices: + pass_name = "custom_backend_passes_" + device + GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass) + + # Keep these last, since they introduces mutation. Look at + # ./fx_passes/README.md for a discussion of mutation invariants. + GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( + reinplace_inplaceable_ops + ) + GraphTransformObserver( + gm, "decompose_triton_kernel_wrapper_functional" + ).apply_graph_pass(decompose_triton_kernel_wrapper_functional) + GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass( + decompose_auto_functionalized + ) + if not torch._dynamo.config.skip_fsdp_hooks: + GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass( + comms.reinplace_fsdp_all_gather + ) + GraphTransformObserver(gm, "decompose_scan_to_while_loop").apply_gm_pass( + decompose_scan_to_while_loop + ) + GraphTransformObserver(gm, "decompose_map_to_while_loop").apply_gm_pass( + decompose_map_to_while_loop + ) + + gm.recompile() + gm.graph.lint() + + +def prepare_softmax_pattern(x, dim): + xmax = x.amax(dim=dim, keepdim=True) + xsub = x - xmax + xexp = xsub.exp() + xsum = xexp.sum(dim=dim, keepdim=True) + return xmax, xsum, xsub, xexp + + +def prepare_softmax_replacement(x, dim): + """ + Return xsub since otherwise log-softmax can not be matched + due to a use of this intermediate node. Same reason to return + xsub.exp() for softmax. + """ + from torch._inductor.inductor_prims import prepare_softmax_online + + xmax, xsum = prepare_softmax_online(x, dim) + xsub = x - xmax + return xmax, xsum, xsub, xsub.exp() + + +def prepare_softmax_extra_check(match): + """ + We only have triton online softmax kernels currently. + """ + return ( + config.online_softmax + and match.kwargs["x"].meta["val"].device.type == "cuda" + and config.cuda_backend == "triton" + ) + + +def decompose_map_to_while_loop(gm: torch.fx.GraphModule): + """This is similar to decompose_scan_to_while_loop.""" + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.map_impl), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + assert len(kwargs) == 0, ( + "kwargs of map are not merged into args before entering decompose_map_to_while_loop_pass" + ) + subgraph, fx_xs, fx_additional_inputs = args + sub_gm: torch.fx.GraphModule = getattr(gm, subgraph.target) + cur_node = match.nodes[0] + mapped_outputs = cur_node.meta["val"] + + def lower_to_while_loop(*args, **kwargs): + assert len(kwargs) == 0 + xs, additional_inputs = pytree.tree_unflatten(args, tree_spec) + assert isinstance(xs, (tuple, list)) and isinstance( + additional_inputs, (tuple, list) + ), (xs, additional_inputs) + map_length = xs[0].size(0) + loop_idx = torch.zeros([], dtype=torch.int64, device=torch.device("cpu")) + + # Similar to NOTE [Pre-allocate scan's output buffer] + bound_symbols = { + arg.node.expr: arg + for arg in pytree.tree_leaves((args, map_length)) + if isinstance(arg, torch.SymInt) + } + out_buffers = [ + torch.empty_strided( + resolve_shape_to_proxy(out.size(), bound_symbols), + resolve_shape_to_proxy(out.stride(), bound_symbols), + device=out.device, + dtype=out.dtype, + layout=out.layout, + requires_grad=out.requires_grad, + ) + for out in mapped_outputs + ] + + while_loop_operands = (loop_idx, out_buffers, xs) + while_loop_flat_operands, operands_spec = pytree.tree_flatten( + while_loop_operands + ) + while_loop_additional_inputs = additional_inputs + _, operands_and_additional_inputs_spec = pytree.tree_flatten( + (*while_loop_operands, additional_inputs) + ) + + def cond_fn(*flat_args): + loop_idx, _, _, _ = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, + ) + return loop_idx < map_length + + def body_fn(*flat_args): + loop_idx, out_bufs, xs, additional_inputs = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, + ) + + idx_int = loop_idx.item() + torch.ops.aten._assert_scalar.default(idx_int >= 0, "") + torch.ops.aten._assert_scalar.default(idx_int < map_length, "") + sub_xs = [torch.ops.aten.select.int(x, 0, idx_int) for x in xs] + outs = sub_gm(*sub_xs, *additional_inputs) + + for out, buffer in zip(outs, out_bufs): + buffer_slice = torch.ops.aten.select.int(buffer, 0, idx_int) + buffer_slice.copy_(out) + return loop_idx + 1, *out_bufs, *xs + + _, final_out, _ = pytree.tree_unflatten( + torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + tuple(while_loop_flat_operands), + tuple(while_loop_additional_inputs), + ), + operands_spec, + ) + return (final_out,) + + lower_to_while_loop_args, tree_spec = pytree.tree_flatten( + (fx_xs, fx_additional_inputs) + ) + match.replace_by_example( + lower_to_while_loop, lower_to_while_loop_args, run_functional_passes=False + ) + + graph_pass.apply(gm) + + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.map_impl + ): + raise AssertionError("map is not lowered to while_loop") + + +def resolve_shape_to_proxy( + shape: list[Union[int, torch.SymInt]], bound_symbols: dict[Any, Any] +): + """ + Given a list of symints/ints, this function returns a calculated expression of bound_symbols' values. + When we trace this function, we'll get a graph with call_function nodes that describes how the shape expr is + computed from bound_symbols' values. + + Suppose shape = (s1*s2, s1+s2) and bound_symbols = {s1: arg0, s2: arg1}, the result will be + (arg0 * arg1, arg0 + arg1). + """ + from torch.utils._sympy.interp import sympy_interp + from torch.utils._sympy.reference import PythonReferenceAnalysis + + ret = [] + for s in shape: + if isinstance(s, torch.SymInt): + ret.append( + sympy_interp( + PythonReferenceAnalysis, + bound_symbols, + s.node.expr, + ), + ) + else: + assert isinstance(s, int) + ret.append(s) + return ret + + +def decompose_scan_to_while_loop(gm: torch.fx.GraphModule): + """ + NOTE [decompose scan to while_loop] + This pass decomposes `scan` to `while_loop` by replacing the scan fx_node with a while_loop hop. + + Suppose we have a function f: + + def f(): + init = torch.zeros([]) + xs = torch.arange(4) + ys = [] + for i in range(xs.size(0)): + init = xs[i] + init + ys.append(init) + + # Return the final carry and stack the intermediates + return init, torch.stack(ys) + + We could rewrite it with a scan with the benefits of reducing compilation time/binary size, reducing + memory usage, supporting loops over unbacked shapes and cudagraph etc. + + def g(): + def step_fn(init: torch.Tensor, x: torch.Tensor): + next_init = x + init + return next_init, next_init + + init = torch.zeros([]) + xs = torch.arange(4) + final_carry, ys = torch._higher_order.scan(step_fn, init, xs) + return final_carry, ys + + This pass will rewrite scan into: + + def k(): + init = torch.zeros([]) + xs = torch.arange(4) + + # we create a loop_idx and loop through xs.shape[0] + loop_idx = torch.zeros([]) + ys = torch.empty_strided(_shape_stride_of_ys) + def cond_fn(loop_idx, ys, init, xs): + return loop_idx < xs.shape[0] + + # we pre-allocate the output buffer ys and inplace + # copy the y of each intermediate into a slice. + # NOTE [Pre-allocate scan's output buffer]. + def body_fn(loop_idx, ys, init, xs): + int_idx = loop_idx.item() + next_init, y = step_fn(init, xs[int_idx]) + ys[int_idx].copy_(y) + return loop_idx + 1, ys, next_init, xs + + final_carry, _, _, ys = torch._higher_order.while_loop(cond_fn, body_fn, (loop_idx, ys, init, xs)) + return final_carry, ys + """ + + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.scan), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.scan import _extract_carry_and_out + + assert len(kwargs) == 0, ( + "kwargs of scan are not merged into args before entering decompose_scan_to_while_loop_pass" + ) + + combine_subgraph, fx_init, fx_xs, fx_additional_inputs = args + assert combine_subgraph.op == "get_attr", "first arg is not combine_subgraph" + sub_gm: torch.fx.GraphModule = getattr(gm, combine_subgraph.target) + cur_node = match.nodes[0] + num_init_leaves = len(fx_init) + _, ys_outputs = _extract_carry_and_out(cur_node.meta["val"], num_init_leaves) + + def lower_to_while_loop(*args, **kwargs): + """ + The traced graph of this function will be used to replace the original scan fx_node. + """ + assert len(kwargs) == 0 + + # Step 1: construct necessary inputs to while_loop based on scan's input. + ( + init, + xs, + additional_inputs, + ) = pytree.tree_unflatten(args, tree_spec) + scan_length = xs[0].size(0) + loop_idx = torch.zeros([], dtype=torch.int64, device=torch.device("cpu")) + + # NOTE [Pre-allocate scan's output buffer] + # In order to pre-allocate the output buffer for ys, we rely on the meta of scan's fx_node. + # However, the meta consists of concrete symints, we need to bind those symints with + # proxies in order to trace the torch.empyt_strided call correctly. + # + # Also note that basic free symbols of tensor's shapes are guaranteed to be lifted as subgraph inputs + # in dynamo so we can always re-construct the sym expression from placeholders. + # See Note [Auto lift basic free symbols when create_graph_input] for how this is done. + bound_symbols = { + arg.node.expr: arg + for arg in pytree.tree_leaves((args, scan_length)) + if isinstance(arg, torch.SymInt) + } + ys_outs = [ + torch.empty_strided( + resolve_shape_to_proxy(ys_out.size(), bound_symbols), + resolve_shape_to_proxy(ys_out.stride(), bound_symbols), + device=ys_out.device, + dtype=ys_out.dtype, + layout=ys_out.layout, + requires_grad=ys_out.requires_grad, + ) + for ys_out in ys_outputs + ] + + while_loop_operands = (loop_idx, ys_outs, init, xs) + flat_operands, operands_spec = pytree.tree_flatten(while_loop_operands) + _, operands_and_additional_inputs_spec = pytree.tree_flatten( + (*while_loop_operands, additional_inputs) + ) + + # Step 2: create the cond_fn and body_fn for while_loop + def cond_fn(*flat_args): + loop_idx, _, _, _, _ = pytree.tree_unflatten( + flat_args, operands_and_additional_inputs_spec + ) # type: ignore[has-type] + return loop_idx < scan_length # type: ignore[has-type] + + def body_fn(*flat_args): + loop_idx, ys_outs, carry, xs, additional_inputs = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, # type: ignore[has-type] + ) + + idx_int = loop_idx.item() + torch.ops.aten._assert_scalar.default(idx_int >= 0, "") + torch.ops.aten._assert_scalar.default(idx_int < scan_length, "") + sub_xs = [torch.ops.aten.select.int(x, 0, idx_int) for x in xs] + next_carry, ys = _extract_carry_and_out( + sub_gm(*(list(carry) + sub_xs + list(additional_inputs))), + num_init_leaves, + ) + for y, y_out in zip(ys, ys_outs): + y_out_slice = torch.ops.aten.select.int(y_out, 0, idx_int) + y_out_slice.copy_(y) + return loop_idx + 1, *ys_outs, *next_carry, *xs + + # Step 3: call the while_loop operator + _, ys_outs, last_carry, _ = pytree.tree_unflatten( + torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + tuple(flat_operands), + tuple(additional_inputs), + ), + operands_spec, + ) + return list(last_carry) + list(ys_outs) + + lower_to_while_loop_args, tree_spec = pytree.tree_flatten( + ( + fx_init, + fx_xs, + fx_additional_inputs, + ) + ) + match.replace_by_example( + lower_to_while_loop, + lower_to_while_loop_args, + run_functional_passes=False, + ) + + graph_pass.apply(gm) + + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.scan + ): + raise AssertionError("scan is not lowered to while_loop") + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn: + from . import decompose_mem_bound_mm # noqa: F401 + from .mkldnn_fusion import _mkldnn_fusion_init + + _mkldnn_fusion_init() + + # Put this patterns in post-grad pass rather than joint-graph + # pass since otherwise there will be perf/peak-memory regression: + # https://github.com/pytorch/pytorch/issues/148141 + register_replacement( + prepare_softmax_pattern, + prepare_softmax_replacement, + [torch.empty(4, 8)], + scalar_workaround=dict(dim=-1), + trace_fn=fwd_only, + pass_dicts=pass_patterns[1], + extra_check=prepare_softmax_extra_check, + ) + + +def reorder_for_locality(graph: torch.fx.Graph): + if torch.distributed.is_available(): + + def check(): + # This is a wait node, and `other_node`` is some collective node. + # Eager semantics allow waits to be issued in a different order than + # the collectives. Reordering this wait node might reorder collectives + # which cause hangs. Once we have SPMD mode, we can safely reorder them. + # However, increasing the locality between a collective and its wait node + # is generally worse for performance. + return node.target != torch.ops._c10d_functional.wait_tensor.default + else: + + def check(): + return True + + def visit(other_node): + if ( + other_node.op == "call_function" + and other_node.target != operator.getitem + and all((n in seen_nodes) for n in other_node.users) + and get_mutation_region_id(graph, node) + == get_mutation_region_id(graph, other_node) + and check() + ): + # move node's producers right before it + node.prepend(other_node) + + seen_nodes = OrderedSet[torch.fx.Node]() + + # only reorder nodes before the first copy_ in the graph. + # copy_ will appear at the end of functionalized graphs when there is mutation on inputs, + # and this reordering doesn't work well with mutation + first_copy = next( + iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)), + None, + ) + past_mutating_epilogue = True if first_copy is None else False + + for node in reversed(graph.nodes): + seen_nodes.add(node) + if not past_mutating_epilogue: + past_mutating_epilogue = node is first_copy + continue + + torch.fx.map_arg((node.args, node.kwargs), visit) + + +def register_lowering_pattern( + pattern, extra_check=_return_true, pass_number=1 +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + Register an aten to inductor IR replacement pattern + """ + return pattern_matcher.register_lowering_pattern( + pattern, extra_check, pass_dict=pass_patterns[pass_number] + ) + + +################################################################################ +# Actual patterns below this point. +# Priority of patterns is: +# - later output nodes first +# - order patterns are defined in +################################################################################ + + +def is_valid_mm_plus_mm(match: Match): + if not (config.max_autotune or config.max_autotune_gemm): + return False + + *_b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape + *_b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape + if k1 != k2: + return False + + *_b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape + *_b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape + if k3 != k4: + return False + + if m1 != m2 or n1 != n2: + return False + + return True + + +def scatter_upon_const_tensor_extra_check(m): + if not config.optimize_scatter_upon_const_tensor: + return False + full_shape = m.kwargs["shape"] + selector = m.kwargs["selector"] + dim = m.kwargs["dim"] + if dim < 0: + dim += len(full_shape) + + selector_ft = selector.meta["val"] + assert selector_ft.dim() == len(full_shape) + + for idx, select_sz, full_sz in zip( + itertools.count(), selector_ft.shape, full_shape + ): + if idx == dim: + continue + + # TODO: the pattern can be updated to support the case that index tensor + # is shorter. But that will need a more complex condition expression + # especially for multi-dimensional tensors. + # Skip it for now. + if isinstance(full_sz, fx.Node): + full_sz = full_sz.meta["val"] + if select_sz < full_sz: + return False + + # Actually we can support small size larger than 1. It would be a bit + # tedius. E.g., we load all the index values (not many) and compare + # them with the position in tensor to decide what value to return. + return selector_ft.size(dim) == 1 + + +@register_lowering_pattern( + CallFunction( + aten.scatter.value, + CallFunction( + aten.full, + KeywordArg("shape"), + KeywordArg("background_val"), + dtype=KeywordArg("dtype"), + ), + KeywordArg("dim"), + KeywordArg("selector"), + KeywordArg("val"), # scalar value + ), + extra_check=scatter_upon_const_tensor_extra_check, +) +def scatter_upon_const_tensor( + match: Match, shape, background_val, dtype, dim, selector, val +): + """ + Match the pattern of full+scatter into a pointwise. + + TODO: Right now the scatter value must be a scalar. But we could support it + when it is a tensor as well. + """ + from torch._inductor import metrics + + metrics.num_matches_for_scatter_upon_const_tensor += 1 + + selector_loader = selector.make_loader() + + def inner_fn(idx): + selector_idx = list(idx) + selector_idx[dim] = 0 + + selector = selector_loader(selector_idx) + return ops.where( + selector == ops.index_expr(idx[dim], torch.int64), + ops.constant(val, dtype), + ops.constant(background_val, dtype), + ) + + return ir.Pointwise.create( + device=selector.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=shape, + ) + + +@register_lowering_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")), + CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")), + ), + extra_check=is_valid_mm_plus_mm, +) +def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): + return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) + + +@register_graph_pattern( + CallFunction( + aten.cumsum.default, + CallFunction( + torch.ops.aten.full.default, + KeywordArg("shape"), + KeywordArg("fill_value"), + dtype=KeywordArg("dtype"), + layout=Ignored(), + device=KeywordArg("device"), + pin_memory=False, + _users=MULTIPLE, + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + pass_dict=pass_patterns[1], +) +def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim): + """Based on a pattern in OPTForCausalLM""" + + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + # cumsum promotes all integral types to int64 + dtype = torch.int64 + + def repl(*shape): + dim_size = shape[dim] + idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype) + + inter_shape = [1] * len(shape) + inter_shape[dim] = dim_size + return (idx * fill_value).view(inter_shape).expand(shape) + + # only replace the output node, not all nodes + match.nodes = [match.output_node()] + match.replace_by_example(repl, list(shape)) + + +_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) + + +@register_lowering_pattern( + CallFunction( + aten.cat, + [ + _cat_1, + CallFunction( + aten.slice, + _cat_1, + 1, + 0, + KeywordArg("size"), + ), + ], + 1, + ) +) +def cat_slice_cat(match, cat_input, size, dim=1): + """ + This is an example of a more complex pattern where cat_1 is used + multiple times inside the pattern. We fold 2 calls to cat into one. + + Matches: + cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1) + slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) + slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) + cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1) + + + Rewrite to: + slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19) + cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1) + """ + first, *rest = cat_input + # Optimization is optional, because we can just not fold the cat + # size should be within first.get_size()[dim] such that the optimization is valid. + # For negative `end`, we currently fallback to not optimizing. + if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]): + # fold 2 cats into 1 cat + return L[aten.cat]( + [ + first, + *rest, + L[aten.slice](first, dim, 0, size), + ], + dim, + ) + else: + # don't expect to hit this case, just fall back + tmp = L[aten.cat](cat_input, dim) + return L[aten.cat]( + [ + tmp, + L[aten.slice](tmp, dim, 0, size), + ], + dim, + ) + + +def is_valid_splitwithsizes_cat(match): + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + cat_nodes = filter_nodes(match.nodes, aten.cat) + get_item_nodes = filter_nodes(match.nodes, operator.getitem) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + # The dim of split and cat should match for passthrough + if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"): + return False + get_item_args = OrderedSet( + get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes + ) + assert None not in get_item_args + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # All parts of split should be included in the cat + if get_item_args != OrderedSet(range(len(split_sizes))): + return False + # The order of get_item_args should same with cat_node used. + # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1), + # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1). + cat_items_args_order = [ + get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0) + ] + if cat_items_args_order != list(range(len(split_sizes))): + return False + + return True + + +def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): + """True if two nodes have the same metadata""" + val1 = node1.meta.get("val") + val2 = node2.meta.get("val") + return ( + val1 is not None + and val2 is not None + and statically_known_true(sym_eq(val1.size(), val2.size())) + and val1.layout == val2.layout + and val1.dtype == val2.dtype + and val1.device == val2.device + and ( + val1.layout != torch.strided + or statically_known_true(sym_eq(val1.stride(), val2.stride())) + ) + ) + + +noop_registry: dict[Any, Any] = {} + + +def register_noop_decomp(targets, nop_arg=0): + def register_fun(cond): + register_decomposition(targets, registry=noop_registry, unsafe=True)( + (cond, nop_arg) # type: ignore[arg-type] + ) + return cond + + return register_fun + + +@register_noop_decomp(aten.slice) +def slice_noop(self, dim=0, start=None, end=None, step=1): + if start is None or end is None: + return False + + slice_dim_size = self.shape[dim] + if ( + statically_known_true(sym_eq(start, 0)) + and ( + statically_known_true(end >= 2**63 - 1) + or statically_known_true(end >= slice_dim_size) + ) + and statically_known_true(sym_eq(step, 1)) + ): + return True + return False + + +@register_noop_decomp(aten.slice_scatter, 1) +def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1): + if start is None: + start = 0 + if end is None: + end = 2**63 - 1 + slice_scatter_dim_size = self.shape[dim] + if ( + self.shape == src.shape + and start == 0 + and ( + statically_known_true(end >= 2**63 - 1) + or statically_known_true(end >= slice_scatter_dim_size) + ) + and step == 1 + ): + return True + return False + + +@register_noop_decomp(aten.repeat) +def repeat_noop(self, repeats): + return all(r == 1 for r in repeats) + + +@register_noop_decomp(aten.constant_pad_nd) +def constant_pad_nd(x, padding, fill_value=0): + return all(p == 0 for p in padding) + + +@register_noop_decomp(torch.ops.prims.convert_element_type) +def convert_element_type_noop(x, dtype: torch.dtype): + return x.dtype == dtype + + +@register_noop_decomp(torch.ops.prims.device_put) +def device_put_noop(x, device, non_blocking=True): + return x.device == decode_device(device) + + +@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc]) +def int_noop(x): + return is_integer_dtype(x.dtype) + + +@register_noop_decomp([aten.pow]) +def pow_noop(a, b): + return isinstance(b, int) and b == 1 + + +@register_noop_decomp([aten.cat], lambda args: args[0][0]) +def cat_noop(inputs, dim=0): + return len(inputs) == 1 + + +@register_noop_decomp(aten.view.default) +def view_default_noop(arg, size): + return statically_known_true(sym_eq(arg.shape, tuple(size))) + + +@register_noop_decomp(aten.view.dtype) +def view_dtype_noop(arg, dtype): + return arg.dtype == dtype + + +# Note, we also always have a check for identical metadata, which is why these +# are safe +@register_noop_decomp([aten.copy], nop_arg=1) +@register_noop_decomp([aten.alias, aten.clone]) +def true_noop(*args, **kwargs): + return True + + +def remove_noop_ops(graph: torch.fx.Graph): + """ + Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph. + """ + inputs = OrderedSet[torch.fx.Node]() + input_storages = OrderedSet[Union[int, None]]() + output_storages = OrderedSet[Union[int, None]]() + + for node in graph.find_nodes(op="placeholder"): + inputs.add(node) + input_storages.add(get_node_storage(node)) + + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" + outputs = output_node.args[0] + if not isinstance(outputs, (list, tuple)): + # nested subgraphs can have singleton outputs + outputs = (outputs,) + for out in outputs: + if isinstance(out, torch.fx.Node): + output_storages.add(get_node_storage(out)) + + for node in graph.nodes: + if node.target in noop_registry: + cond, src_index = noop_registry[node.target] + if isinstance(src_index, int): + src = node.args[src_index] + else: + src = src_index(node.args) + if not isinstance(src, torch.fx.Node): + continue + # Don't introduce new aliasing between inputs and outputs. + # See fx_passes/README.md for a discussion of why this is + # necessary. + node_storage = get_node_storage(node) + src_storage = get_node_storage(src) + node_is_view = node_storage == src_storage + if ( + not node_is_view + and node_storage in output_storages + and (src_storage in input_storages or src_storage in output_storages) + ): + continue + + # Even if input and outputs are expected to alias, + # don't make "node is src" True + if ( + node_is_view + and node in output_node.args + and (src in inputs or src in output_node.args) + ): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + if same_meta(node, src) and cond(*args, **kwargs): + node.replace_all_uses_with(src) + graph.erase_node(node) + + +def remove_assert_ops(graph: torch.fx.Graph): + """ + Removes aten._assert_tensor_metadata.default op because + 1) it will be lowered to a no-op in inductor + 2) it can block fusion, such as unfuse_bias_add_to_pointwise fusion. + + This op could come from aten.to functionalization in export. + + For example, if we have a graph like below + + %addmm = aten.addmm.default(%linear_bias, %arg3_1, %permute) + %_assert_tensor_metadata = aten._assert_tensor_metadata.default(%addmm, None, None, torch.float16) + %convert_element_type_3 = prims.convert_element_type.default(%addmm, torch.float32) + %pow_1 = aten.pow.Tensor_Scalar(%convert_element_type_3, 2) + + We still want to fuse add from addmm with pow, instead of fusing add with mm, according to unfuse_bias_add_to_pointwise fusion. + + However, aten._assert_tensor_metadata.default is not a pointwise op, and would fail the should_prefer_unfused_addmm check. + + We remove this op so it doesn't block fusion decisions. It's safe because this op is lowered to a no-op with @register_lowering. + + """ + for node in graph.find_nodes( + op="call_function", target=torch.ops.aten._assert_tensor_metadata.default + ): + graph.erase_node(node) + + +def decompose_triton_kernel_wrapper_functional(graph): + """Decomposes triton_kernel_wrapper_functional nodes into clones and the underlying + mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_functional_dense, + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return (triton_kernel_wrapper_functional_dense(*args, **kwargs),) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + graph_pass.apply(graph) + + for node in graph.find_nodes( + op="call_function", + target=torch.ops.higher_order.triton_kernel_wrapper_functional, + ): + raise AssertionError("triton_kernel_wrapper_functional was not removed") + + +def decompose_auto_functionalized(graph): + """Decomposes auto_functionalized nodes into clones and the underlying + mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense + + only_clone_these_tensors = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + assert len(args) == 1 + mode = args[0] + return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized_v2_dense, + ) + + only_clone_these_bases = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def _maybe_resolve_constant_get_attr(node): + # Resolve getattr node to its value because they don't always have meta["val"] + if ( + isinstance(node, torch.fx.Node) + and node.op == "get_attr" + and "val" not in node.meta + ): + const_attr = getattr(graph.owning_module, node.target) # type: ignore[arg-type] + assert isinstance( + const_attr, (torch.fx.GraphModule, pytree.TreeSpec) + ), (type(const_attr), const_attr) + return const_attr + return node + + flat_args = [_maybe_resolve_constant_get_attr(arg) for arg in flat_args] + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + assert len(args) == 1 + mutable_op = args[0] + return auto_functionalized_v2_dense( + mutable_op, only_clone_these_bases, **kwargs + ) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + graph_pass.apply(graph) + + # We need to remove the get_attr registered for _constant_schema and the + # auto_functioanlized's graph module (it's replaced with original ) when auto_functionalize a hop. + _to_remove = [] + for node in graph.nodes: + if node.op == "get_attr" and len(node.users) == 0: + _to_remove.append(node) + if hasattr(graph.owning_module, node.target) and isinstance( + getattr(graph.owning_module, node.target), torch.fx.GraphModule + ): + delattr(graph.owning_module, node.target) + for node in _to_remove: + graph.erase_node(node) + + graph.lint() + + for _ in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized + ): + raise AssertionError("auto_functionalized was not removed") + + for _ in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized_v2 + ): + raise AssertionError("auto_functionalized_v2 was not removed") + + +@register_lowering_pattern( + CallFunction( + aten.cat, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split_with_sizes, + KeywordArg("input_"), + Ignored(), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ), + ), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_splitwithsizes_cat, +) +def splitwithsizes_cat_replace(match, input_): + return input_ + + +def is_valid_cat_splitwithsizes(match): + cat_nodes = filter_nodes(match.nodes, aten.cat) + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + + # the cat node has other users: can't eliminate + if len(cat_node.users) > 1: + return False + + # the dim of the cat and split should match + dim = get_arg_value(split_node, 2, "dim") + if dim != get_arg_value(cat_node, 1, "dim"): + return False + + cat_inputs = list(get_arg_value(cat_node, 0)) + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # the number of input tensors in cat and the + # length of the split sizes should match + if len(cat_inputs) != len(split_sizes): + return False + + for cat_input, split_size in zip(cat_inputs, split_sizes): + # each cat input tensor's size along dim + # should match the corresponding split size + if "val" not in cat_input.meta: + return False + cat_input_size = cat_input.meta["val"].size(dim) + if cat_input_size != split_size: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.split_with_sizes, + CallFunction( + aten.cat, + KeywordArg("input_"), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_cat_splitwithsizes, +) +def cat_splitwithsizes_replace(match, input_): + return input_ + + +def view_to_reshape(gm): + """ + Replace view ops in the GraphModule to reshape ops. + """ + subgraph_names: OrderedSet[str] = OrderedSet( + x.target for x in gm.graph.find_nodes(op="get_attr") + ) + + for child_name, child_mod in gm.named_children(): + if child_name in subgraph_names and isinstance(child_mod, torch.fx.GraphModule): + view_to_reshape(child_mod) + + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.view.default + ): + nd.target = torch.ops.aten.reshape.default + + +def should_prefer_unfused_addmm(match): + inp = match.kwargs["inp"] + if not is_gpu(inp.meta["val"].device.type): + return False + + output = match.output_node() + return all(is_pointwise_use(use) for use in output.users) + + +@register_graph_pattern( + CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), + pass_dict=pass_patterns[2], + extra_check=should_prefer_unfused_addmm, +) +def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): + def repl(inp, x1, x2): + return x1 @ x2 + inp + + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def is_valid_addmm_fusion(match): + mat1, mat2 = match.args + inp = match.kwargs["inp"] + + if not ( + isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor) + ): + return False # Input is a number + + in_shape = inp.meta["val"].shape + mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1] + matched = is_expandable_to(in_shape, mm_shape) + if not matched: + return False # Shape mismatch + + return not should_prefer_unfused_addmm(match) + + +@register_graph_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, Arg(), Arg()), + KeywordArg("inp"), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +@register_graph_pattern( + CallFunction( + aten.add, + KeywordArg("inp"), + CallFunction(aten.mm, Arg(), Arg()), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +def addmm(match, mat1, mat2, *, inp): + def repl(inp, mat1, mat2): + return aten.addmm(inp, mat1, mat2) + + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def register_partial_reduction_pattern(): + "Reuse partial reductions in complete reductions" + + # post grad equivalents + equiv_red = { + aten.amax.default: aten.max.default, + aten.amin.default: aten.min.default, + } + + # TODO: to support other reductions like sum, would need to skip + # lower precision reductions since partial output would need to be kept at fp32. + for red_op in (aten.amax.default, aten.amin.default): + inp = KeywordArg("input") + partial_reduc = CallFunction( + red_op, inp, KeywordArg("reduced_dims"), KeywordArg("keepdim") + ) + full_reduc = CallFunction([red_op, equiv_red[red_op]], inp) + + @register_graph_pattern( + MultiOutputPattern([partial_reduc, full_reduc]), pass_dict=pass_patterns[2] + ) + def reuse_partial(match, input, reduced_dims, keepdim): + partial_red, full_red = match.output_nodes() + + # if they're small, reuse not worth it + if not statically_known_true(input.meta["val"].numel() >= 4096): + return True + + def replacement(inp: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + partial = partial_red.target(inp, reduced_dims, keepdim) + complete = full_red.target(partial) + return (partial, complete) + + counters["inductor"]["partial_reduction_reuse"] += 1 + match.replace_by_example(replacement, [input]) + + +register_partial_reduction_pattern() + + +def check_shape_cuda_and_fused_int_mm_mul_enabled(match): + return ( + config.force_fuse_int_mm_with_mul + and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2 + and getattr(match.args[2].meta.get("val"), "is_cuda", False) + ) + + +def is_index_put_and_requires_h2d_sync_for_gpu_value(node): + from torch.fx.operator_schemas import normalize_function + + if node.target not in [ + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + ]: + return False + # Inductor falls back to aten.index_put_. + # index_put_ will will call nonzero() and perform a H2D sync if + # any of its indices are bool/byte tensors + # However, it will short-circuit this H2D sync and run mask_fill_ + # if the value we are putting is a cpu scalar. + # Therefore, when inductor sees an index_put_ with byte tensor indices, + # it should *not* convert the cpu scalar value into a gpu tensor. + args_, _kwargs = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc] + any_byte_bool_indices = False + indices = args_[1] + for i in indices: + if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]: + any_byte_bool_indices = True + + val = args_[2].meta["val"] + val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1 + # If both these conditions hold, then converting the val + # to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_ + return any_byte_bool_indices and val_is_cpu_scalar + + +class ConstructorMoverPass: + def __init__( + self, target: str, allow_outputs: bool = False, allow_inputs: bool = False + ) -> None: + """ + Move constructors from cpu to the target_device. + + Sweeps through the module, looking for constructor nodes that can be moved + to the target_device. + + A constructor node can be moved to the target_device iff all of its users + can also be moved (tested by cannot_be_moved). Otherwise, all dependent + constructor nodes won't be moved. + + - target: target device type + - allow_outputs: allow outputs to be moved + - allow_inputs: allow inputs to be moved + """ + + self.target = target + self.allow_inputs = allow_inputs + self.allow_outputs = allow_outputs + + assert isinstance(target, str), ( + "target should be a string representing the device type. " + f"Got: {type(target).__name__}" + ) + + def allow_cpu_device(self, node: fx.Node) -> bool: + """ + Returns whether a node that returns a tensor on the target device may have + cpu tensors as input. + """ + return node.target in ( + torch.ops.aten.index.Tensor, + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + torch.ops.aten.copy.default, + torch.ops.aten.copy_.default, + torch.ops.aten.slice_scatter.default, + ) + + def is_on_target_device(self, node: fx.Node) -> bool: + """ + Returns whether a node is on the target device. + """ + node_device = self.get_node_device(node) + return node_device is not None and node_device.type == self.target + + def is_cpu_scalar_tensor(self, node: fx.Node) -> bool: + """ + Returns whether a node is a cpu scalar tensor. + """ + device = self.get_node_device(node) + is_cpu = device is not None and device.type == "cpu" + ten = node.meta.get("val") + is_scalar = isinstance(ten, torch.Tensor) and len(ten.size()) == 0 + return is_cpu and is_scalar + + def all_inputs_are_cpu_scalar_or_on_target_device(self, node: fx.Node) -> bool: + """ + Returns whether a node's inputs are either cpu scalar tensors or + on the target device. + """ + inputs = ( + inp + for inp in itertools.chain(node.args, node.kwargs.values()) + if isinstance(inp, fx.Node) + ) + return all( + self.is_cpu_scalar_tensor(inp) or self.is_on_target_device(inp) + for inp in inputs + ) + + def cannot_be_moved(self, node: fx.Node) -> bool: + """ + Returns whether a node can be moved to the target device. + + If this function returns False, it means that this node and all of its users + won't be moved into the target device. + """ + if node.target == "output": + return not self.allow_outputs + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + return True + + if is_index_put_and_requires_h2d_sync_for_gpu_value(node): + return True + + return False + + def get_node_device(self, node: fx.Node) -> Optional[torch.device]: + """ + Get the device of a node. + """ + ten = node.meta.get("val") + return None if not isinstance(ten, torch.Tensor) else ten.device + + def get_cpu_indeg_count(self, graph: fx.Graph) -> dict[fx.Node, int]: + """ + Get the number of cpu inputs to a node + """ + cpu_indeg: dict[fx.Node, int] = Counter() + + for node in graph.nodes: + cpu_count = 0 + + def add_cpu_inp(node): + nonlocal cpu_count + device = self.get_node_device(node) + cpu_count += device is not None and device.type == "cpu" + + pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs)) + + if cpu_count: + cpu_indeg[node] = cpu_count + + return cpu_indeg + + def __call__(self, graph: fx.Graph) -> None: + target_devices = OrderedSet[torch.device]() + constructors = [] + cpu_placeholders: OrderedSet[fx.Node] = OrderedSet() + + for node in graph.nodes: + device = self.get_node_device(node) + if device and device.type == self.target: + target_devices.add(device) + + if ( + self.allow_inputs + and node.op == "placeholder" + and self.is_cpu_scalar_tensor(node) + ): + cpu_placeholders.add(node) + constructors.append(node) + continue + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + continue + + if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target): + continue + + if not node.kwargs.get("device") == torch.device("cpu"): + continue + + constructors.append(node) + + # not handling multiple target devices initially + if not constructors or len(target_devices) != 1: + return + + movable_constructors = self.find_movable_constructors(graph, constructors) + + target_device = next(iter(target_devices)) + for node in movable_constructors: + if node in cpu_placeholders: + with graph.inserting_after(node): + gpu_node = graph.call_function( + torch.ops.prims.device_put.default, (node, target_device) + ) + node.replace_all_uses_with( + gpu_node, + lambda x: x != gpu_node + and x.target != torch.ops.aten.copy_.default, + ) + + # noop elimination if there are other device_put for gpu_node to + # target device. Alternatively, we could just move the other device_put + # earlier in the graph, but that is not supported in fx graph yet. + noop_device_puts = [ + user + for user in gpu_node.users + if user.target == torch.ops.prims.device_put.default + and user.args[1] == target_device + ] + for noop in noop_device_puts: + noop.replace_all_uses_with(gpu_node) + graph.erase_node(noop) + else: + kwargs = node.kwargs.copy() + kwargs["device"] = target_device + node.kwargs = kwargs + + def find_movable_constructors( + self, graph: fx.Graph, constructors: list[fx.Node] + ) -> OrderedSet[fx.Node]: + """ + Starting from the cpu constructors, iterate through the graph and test that all of their + downstream uses can safely be moved to cpu. + """ + cpu_indeg: dict[fx.Node, int] = self.get_cpu_indeg_count(graph) + + # which constructors cannot be moved to gpu + cannot_move_to_gpu = OrderedSet[fx.Node]() + + # For any node in the graph, which constructors does it have a dependency on + constructor_dependencies: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict( + OrderedSet + ) + + # if a cpu node has a dependency on two different cpu constructors, + # then if either constructor cannot be moved to gpu, the other cannot as well. + # In this case any node with a dependency on one will have a dependency on the other + equal_constructor_sets: dict[fx.Node, OrderedSet[fx.Node]] = { + c: OrderedSet([c]) for c in constructors + } + + def make_dependencies_equivalent( + set1: OrderedSet[fx.Node], set2: OrderedSet[fx.Node] + ) -> OrderedSet[fx.Node]: + # could use union find but not worth complexity here + set1.update(set2) + for obj in set1: + equal_constructor_sets[obj] = set1 + return set1 + + queue: list[fx.Node] = list(constructors) + + for c in queue: + constructor_dependencies[c].add(c) + + while queue: + node = queue.pop() + dependencies = constructor_dependencies[node] + + for user in node.users: + if self.cannot_be_moved(user): + cannot_move_to_gpu.update(dependencies) + break + + # this node was used on a op which takes in multiple devices and output a gpu + # tensor. we can convert its cpu input to gpu without making further changes + if self.allow_cpu_device(user) and self.is_on_target_device(user): + del cpu_indeg[user] + elif ( + self.allow_inputs + and self.all_inputs_are_cpu_scalar_or_on_target_device(user) + ): + # this node takes only cpu scalar tensors or gpu tensors as inputs + # and outputs a gpu tensor. we can convert its cpu scalar inputs to gpu + # without making further changes + del cpu_indeg[user] + else: + # otherwise, we should continue look at its downstream uses + cpu_indeg[user] -= 1 + if cpu_indeg[user] == 0: + del cpu_indeg[user] + queue.append(user) + + unioned_set = make_dependencies_equivalent( + dependencies, constructor_dependencies[user] + ) + constructor_dependencies[user] = unioned_set + + for node in cpu_indeg: + if constructor_dependencies[node]: + cannot_move_to_gpu.update(constructor_dependencies[node]) + + all_cannot_move_to_gpu = cannot_move_to_gpu.copy() + for constructor in cannot_move_to_gpu: + all_cannot_move_to_gpu.update(equal_constructor_sets[constructor]) + + return OrderedSet(constructors) - all_cannot_move_to_gpu + + +def move_constructors_to_gpu(graph: fx.Graph) -> None: + """ + Moves intermediary tensors which are constructed on the cpu to gpu when safe + """ + + # cudagraph does not support cpu tensors. In this pass, we update the graph + # by explicitly moving cpu scalar tensors to gpu when profitable, relying on + # graph partition to split off this data copy, and cudagraphifying + # the remaining gpu ops. + allow_inputs_outputs = ( + True + if ( + torch._inductor.config.triton.cudagraphs + and torch._inductor.config.graph_partition + ) + else False + ) + ConstructorMoverPass( + get_gpu_type(), + allow_inputs=allow_inputs_outputs, + allow_outputs=allow_inputs_outputs, + )(graph) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1709962e64bb7f68bf265320ee1a6c35b8bc10 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/pre_grad.py @@ -0,0 +1,862 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import logging +import types +from collections.abc import Sequence +from typing import Optional + +import torch +import torch.nn as nn +from torch._dynamo.utils import counters, detect_fake_mode +from torch._logging import trace_structured +from torch.fx.experimental.optimization import ( + matches_module_pattern, + replace_node_module, +) +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn import functional as F +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights + +from .. import config +from ..fx_utils import matches_module_function_pattern +from ..pattern_matcher import ( + init_once_fakemode, + PatternMatcherPass, + stable_topological_sort, +) +from ..utils import is_cpu_device, pass_execution_and_save +from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS +from .misc_patterns import numpy_compat_normalization +from .split_cat import PRE_GRAD_PATTERNS + + +log = logging.getLogger(__name__) + +efficient_conv_bn_eval_pass = PatternMatcherPass( + pass_name="efficient_conv_bn_eval_pass" +) + +fuse_split_linear_add_pass = PatternMatcherPass( + pass_name="fuse_split_linear_add_pass", +) +fuse_chunk_squeeze_cat_pass = PatternMatcherPass( + pass_name="fuse_chunk_squeeze_cat_pass", +) +remove_reshape_pass = PatternMatcherPass( + pass_name="remove_reshape_pass", +) + +# based on predispatch aten IR +normalization_pass_aten = PatternMatcherPass(pass_name="normalization_pass_aten") +merge_splits_pass_aten = PatternMatcherPass(pass_name="merge_splits_pass_aten") +split_cat_pass_aten = PatternMatcherPass(pass_name="split_cat_pass_aten") +unbind_stack_pass_aten = PatternMatcherPass(pass_name="unbind_stack_pass_aten") +merge_getitem_cat_pass_aten = PatternMatcherPass( + pass_name="merge_getitem_cat_pass_aten" +) +merge_stack_tahn_unbind_pass_aten = PatternMatcherPass( + pass_name="merge_stack_tahn_unbind_pass_aten" +) +mutate_cat_pass_aten = PatternMatcherPass(pass_name="mutate_cat_pass_aten") +remove_split_with_size_one_pass_aten = PatternMatcherPass( + pass_name="remove_split_with_size_one_pass_aten" +) + + +def save_inductor_dict(pass_to_compare=None): + if not pass_to_compare: + pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list( + config.post_grad_fusion_options.keys() + ) + return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare} + + +def is_same_dict(inductor_dict, optimus_dict): + for pass_name, count in optimus_dict.items(): + if count != dict(inductor_dict).get(pass_name, 0): + return False + return True + + +def shape_prop(mod) -> None: + return None + + +def normalize_node_kwargs_pass(graph): + return None + + +def fuse_parallel_linear_pass(graph): + return None + + +def remove_split_ops(graph, shape_prop): + return None + + +def remove_split_ops_pass(graph): + remove_split_ops(graph.owning_module, shape_prop) + + +def fuse_chunk_reshape_unsqueeze_concat_pass(graph): + return None + + +def fuse_chunk_reshape_concat_pass(graph): + return None + + +def remove_noop_pass(graph): + return None + + +def stack_to_unsqueeze_pass(graph): + return None + + +def merge_concats_pass(graph): + return None + + +def relu_nan_to_num(graph): + return None + + +def fuse_split_getitem_squeeze_cat(graph): + return None + + +def use_triton_dot_compress(graph): + return None + + +def use_triton_lce_replace_simple_LCE_helper(gm, shape_prop): + return None + + +def use_triton_lce_replace_simple_LCE(graph): + return use_triton_lce_replace_simple_LCE_helper(graph.owning_module, shape_prop) + + +def use_triton_lce_replace_normal_LCE_helper(gm, shape_prop): + return None + + +def use_triton_lce_replace_normal_LCE(graph): + return use_triton_lce_replace_simple_LCE_helper(graph.owning_module, shape_prop) + + +def use_matmul_lce_replace_normal_LCE(graph): + return None + + +def use_matmul_fuse_lce_replace_first_LCE(graph): + return None + + +@init_once_fakemode +def lazy_init(): + from . import efficient_conv_bn_eval, split_cat # noqa: F401 + + if config.is_fbcode(): + from . import fb # type: ignore[attr-defined] # noqa: F401 + + +def _get_pass_name_func(p): + if isinstance(p, PatternMatcherPass): + pass_name = p.pass_name + pass_func = p.apply + elif isinstance(p, types.FunctionType): + pass_name = p.__name__.lstrip("_") + pass_func = p + else: + pass_name = None + pass_func = None + + return pass_name, pass_func + + +def _run_pre_dispatch_passes( + gm: torch.fx.GraphModule, + example_inputs: Sequence[object] = (), + add_passes: Optional[str] = None, + remove_passes: Optional[str] = None, +) -> None: + # order matters + default_pass_list = [ + # normalize passes, must be called as the first passes + normalization_pass_aten, + normalize_node_kwargs_pass, + remove_noop_pass, + relu_nan_to_num, + fuse_chunk_reshape_concat_pass, + group_batch_fusion_passes, + normalize_node_kwargs_pass, + fuse_chunk_squeeze_cat_pass, + merge_concats_pass, + fuse_split_linear_add_pass, + remove_reshape_pass, + fuse_parallel_linear_pass, + remove_split_ops_pass, + stack_to_unsqueeze_pass, # run before fuse_chunk_reshape_unsqueeze_concat_pass + fuse_chunk_reshape_unsqueeze_concat_pass, + ] + + full_pass_list = default_pass_list + [ + fuse_split_getitem_squeeze_cat, + use_triton_dot_compress, + use_triton_lce_replace_simple_LCE, + use_triton_lce_replace_normal_LCE, + use_matmul_fuse_lce_replace_first_LCE, + use_matmul_lce_replace_normal_LCE, + ] + + log.info( + f"pre_grad_passes: add_passes: {add_passes}, remove_pass: {remove_passes}" # noqa: G004 + ) + add_passes_list = [] + remove_passes_list = [] + if add_passes: + add_passes_list = add_passes.split(",") + if remove_passes: + remove_passes_list = remove_passes.split(",") + + shape_prop = lambda mod: ShapeProp( # noqa: E731 + gm=mod, + # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` + fake_mode=detect_fake_mode(example_inputs), + ).propagate(*tuple(example_inputs)) + + for p in default_pass_list: + pass_name, pass_func = _get_pass_name_func(p) + # should not happen + if pass_name is None or pass_func is None: + continue + if pass_name in remove_passes_list: + continue + pass_execution_and_save( + pass_func, + gm, + example_inputs, + f"[Pre grad(predispatch IR)] Apply {pass_name} pass", + ) + + for p in full_pass_list: + pass_name, pass_func = _get_pass_name_func(p) + if pass_name is None or pass_func is None: + continue + if pass_name in add_passes_list: + pass_execution_and_save( + pass_func, + gm, + example_inputs, + f"[Pre grad(predispatch IR)] Apply {pass_name} pass", + ) + + # Remove noops at the end, which may be generated other passes. + pass_execution_and_save( + remove_noop_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)]Apply remove_noop pass", + ) + shape_prop(gm) + + +def pre_grad_passes( + gm: torch.fx.GraphModule, + example_inputs: Sequence[object] = (), + add_passes: Optional[str] = None, + remove_passes: Optional[str] = None, +) -> torch.fx.GraphModule: + """ + Apply passes on the input FX graph using Torch IR. + + WARNING: + The IR before grad is not functional or normalized, so it is harder + to write passes on this IR. Passes must be safe with respect to + aliasing and mutation and need to handle all possible arg schemas. + + Consider adding a new pass to post_grad.py or joint_graph.py which + are after functionalization and normalization. + """ + if config.pattern_matcher: + lazy_init() + if hasattr( + config, "fx_passes_numeric_check" + ) and config.fx_passes_numeric_check.get("pre_grad", False): + gm_before_fx_passes = gm.__copy__() + # explicitly run with predispatch atenIR based passes + if config.is_predispatch: + _run_pre_dispatch_passes(gm, example_inputs, add_passes, remove_passes) + else: + # We only log the graph with changes to avoid the excessive compilation time + # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/ + if example_inputs is not None: + gm = fuse_fx(gm, example_inputs) + numpy_compat_normalization(gm.graph) + # We should always do the normalization_pass first + if "normalization_pass" in config.pre_grad_fusion_options: + pattern_matcher_pass = PRE_GRAD_PATTERNS["normalization_pass"] + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + group_batch_fusion_passes(gm.graph, pre_grad=True) + for pass_name in config.pre_grad_fusion_options: + # skip all patterns for group batch fusions + if pass_name in PRE_GRAD_FUSIONS or pass_name == "normalization_pass": + continue + pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name] + inductor_before_change = save_inductor_dict( + [pattern_matcher_pass.pass_name] + ) + # we support run same pattern multiple times, the default is to run only once + counter = config.pre_grad_fusion_options[pass_name].get("counter", 1) + for _ in range(counter): + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + if not is_same_dict(counters["inductor"], inductor_before_change): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"{pattern_matcher_pass.pass_name}_pre_grad", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + # TODO: move efficient_conv_bn_eval_pass to the fusions dict too. + efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] + + if config.pre_grad_custom_pass is not None: + with GraphTransformObserver(gm, "pre_grad_custom_pass"): + config.pre_grad_custom_pass(gm.graph) + stable_topological_sort(gm.graph) + + from .quantization import quant_lift_up + + quant_lift_up(gm) + + gm.graph.lint() + gm.recompile() + + if ( + config.pattern_matcher + and hasattr(config, "fx_passes_numeric_check") + and config.fx_passes_numeric_check.get("pre_grad", False) + and example_inputs is not None + ): + from .numeric_utils import numeric_check_if_enabled + + gm_after_fx_passes = gm.__copy__() + numeric_check_if_enabled( + gm_before_fx_passes, # type: ignore[possibly-undefined] + gm_after_fx_passes, + example_inputs, + config.fx_passes_numeric_check.get("num_iterations", 1), + config.fx_passes_numeric_check.get("precision", 1e-4), + ) + + return gm + + +def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule: + is_cpu = is_cpu_device(example_inputs) + # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` + fake_mode = detect_fake_mode(example_inputs) + + gm = sink_cat_after_pointwise(gm) + if config.permute_fusion and not is_cpu: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) + with GraphTransformObserver(gm, "linear_permute_fusion"): + gm = linear_permute_fusion(gm) + with GraphTransformObserver(gm, "permute_linear_fusion"): + gm = permute_linear_fusion(gm) + with GraphTransformObserver(gm, "permute_matmul_fusion"): + gm = permute_matmul_fusion(gm) + + # make sure the autograd is disabled. + if torch.is_grad_enabled() or not is_cpu: + return gm + if config.freezing: + with GraphTransformObserver(gm, "remove_identity"): + gm = remove_identity(gm) + with GraphTransformObserver(gm, "fuse_conv_bn"): + gm = fuse_conv_bn(gm) + return gm + + +def fetch_attr(target: str, mod): + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Removes all identity layers from the module. + """ + + class IdentityRemover(torch.fx.Transformer): + def call_module(self, target, args, kwargs): + if isinstance(self.submodules[target], nn.Identity): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return IdentityRemover(gm).transform() + + +def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule: + """ + Fuses Convolution/BN layers for inference purposes. + """ + modules_patterns = [ + (torch.nn.Conv1d, torch.nn.BatchNorm1d), + (torch.nn.Conv2d, torch.nn.BatchNorm2d), + (torch.nn.Conv3d, torch.nn.BatchNorm3d), + ] + module_function_patterns = [ + (torch.nn.Conv1d, F.batch_norm), + (torch.nn.Conv2d, F.batch_norm), + (torch.nn.Conv3d, F.batch_norm), + ] + modules = dict(gm.named_modules()) + + class ConvBNFusion: + def __init__( + self, + bn_node, + conv_module, + bn_module=None, # For BN Module + bn_running_mean=None, # For Functional BN + bn_running_var=None, + bn_eps=None, + bn_weight=None, + bn_bias=None, + ) -> None: + self.bn_nodes = [ + bn_node, + ] + self.conv_module = conv_module + self.bn_module = bn_module + self.bn_running_mean = bn_running_mean + self.bn_running_var = bn_running_var + self.bn_eps = bn_eps + self.bn_weight = bn_weight + self.bn_bias = bn_bias + self.fusion_enabled = True + + def add_bn_node(self, bn_node): + self.bn_nodes.append(bn_node) + + def disable_fusion(self): + self.fusion_enabled = False + + def is_fusion_enabled(self): + return self.fusion_enabled + + conv_bn_to_fuse: dict[int, ConvBNFusion] = {} + for pattern in modules_patterns: + conv_bn_to_fuse.clear() + for node in gm.graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + eval_mode = all(not n.training for n in [conv, bn]) + if not eval_mode: + continue + if not bn.track_running_stats: + continue + + # Do hash based on the module name of conv + hash_id = hash(node.args[0].target) + if hash_id not in conv_bn_to_fuse: + conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn) + else: + if bn == conv_bn_to_fuse[hash_id].bn_module: + # Do fusion if same bn module + conv_bn_to_fuse[hash_id].add_bn_node(node) + else: + # Disable the conv bn folding if conv shared by different bn + conv_bn_to_fuse[hash_id].disable_fusion() + + for conv_bn_fusion in conv_bn_to_fuse.values(): + if conv_bn_fusion.is_fusion_enabled(): + bn_nodes = conv_bn_fusion.bn_nodes + conv = conv_bn_fusion.conv_module + bn = conv_bn_fusion.bn_module + + fused_conv = fuse_conv_bn_eval(conv, bn) + for bn_node in bn_nodes: + replace_node_module(bn_node.args[0], modules, fused_conv) + bn_node.replace_all_uses_with(bn_node.args[0]) + gm.graph.erase_node(bn_node) + + gm.graph.lint() + for pattern in module_function_patterns: + conv_bn_to_fuse.clear() + for node in gm.graph.nodes: + if matches_module_function_pattern(pattern, node, modules): + # TODO: support kwargs. + if len(node.args) != 8: + continue + conv = modules[node.args[0].target] + bn_training = node.args[5] + bn_eps = node.args[7] + if conv.training or bn_training: + continue + if type(bn_eps) is not float: + continue + + def _used_by_same_conv_module(users): + conv_module_name = users[0].args[0].target + return all( + conv_module_name == user.args[0].target for user in users + ) + + bn_args_is_constant = all( + n.op == "get_attr" + and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users))) + for n in node.args[1:5] + ) + if not bn_args_is_constant: + continue + bn_running_mean = fetch_attr(node.args[1].target, gm) + bn_running_var = fetch_attr(node.args[2].target, gm) + bn_weight = fetch_attr(node.args[3].target, gm) + bn_bias = fetch_attr(node.args[4].target, gm) + if bn_running_mean is None or bn_running_var is None: + continue + + # Do hash based on the module name of conv + hash_id = hash(node.args[0].target) + if hash_id not in conv_bn_to_fuse: + conv_bn_to_fuse[hash_id] = ConvBNFusion( + node, + conv, + bn_running_mean=bn_running_mean, + bn_running_var=bn_running_var, + bn_eps=bn_eps, + bn_weight=bn_weight, + bn_bias=bn_bias, + ) + else: + if ( + hash(bn_running_mean) + == hash(conv_bn_to_fuse[hash_id].bn_running_mean) + and hash(bn_running_var) + == hash(conv_bn_to_fuse[hash_id].bn_running_var) + and torch.allclose( + torch.tensor(bn_eps), + torch.tensor(conv_bn_to_fuse[hash_id].bn_eps), + ) + and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight) + and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias) + ): + # Do fusion if same functional bn + conv_bn_to_fuse[hash_id].add_bn_node(node) + else: + # Disable the conv bn folding if conv shared by different bn + conv_bn_to_fuse[hash_id].disable_fusion() + + for conv_bn_fusion in conv_bn_to_fuse.values(): + if conv_bn_fusion.is_fusion_enabled(): + bn_nodes = conv_bn_fusion.bn_nodes + conv = conv_bn_fusion.conv_module + bn_running_mean = conv_bn_fusion.bn_running_mean + bn_running_var = conv_bn_fusion.bn_running_var + bn_eps = conv_bn_fusion.bn_eps + bn_weight = conv_bn_fusion.bn_weight + bn_bias = conv_bn_fusion.bn_bias + + fused_conv = copy.deepcopy(conv) + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn_running_mean, + bn_running_var, + bn_eps, + bn_weight, + bn_bias, + ) + for bn_node in bn_nodes: + replace_node_module(bn_node.args[0], modules, fused_conv) + bn_node.replace_all_uses_with(bn_node.args[0]) + gm.graph.erase_node(bn_node) + gm.graph.lint() + gm.recompile() + + return gm + + +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.nn.functional.linear] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["weight"] # type: ignore[return-value] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] # type: ignore[return-value] + else: + return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["other"] # type: ignore[return-value] + + +def check_permute(node: torch.fx.Node) -> bool: + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type] + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[operator, union-attr] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + def one_user(node): + users = list(node.users) + return users[0] if len(users) == 1 else None + + def is_view(node): + return node.op == "call_method" and node.target == "view" + + def is_pointwise_unary(node): + ops = "call_function", "call_method" + pointwise = torch.relu, torch.tanh, "relu", "tanh" + return node.op in ops and node.target in pointwise + + g = module.graph + for node in g.nodes: + if node.op != "call_function" or node.target != torch.cat: + continue + + cat_or_view = node + while True: + user = one_user(cat_or_view) + if not user or not is_view(user): + break + cat_or_view = user + + if user and is_pointwise_unary(user): + with g.inserting_before(node): + + def cat_args(tensors, dim=0): + return tensors, dim + + tensors, dim = cat_args(*node.args, **node.kwargs) + new_kwargs = { + name: val for name, val in user.kwargs.items() if name != "input" + } + new_tensors = [ + g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs) + for arg in tensors + ] + new_cat = g.create_node( + "call_function", torch.cat, args=(new_tensors, dim) + ) + user.replace_all_uses_with(cat_or_view) + node.replace_all_uses_with(new_cat) + g.erase_node(user) + g.erase_node(node) + g.lint() + module.recompile() + return module + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.find_nodes(op="call_method", target="permute"): + if check_permute(node): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target == torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + if bias is None: + return torch.matmul(weight, input.transpose(-1, -2)) + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.find_nodes( + op="call_function", target=torch.nn.functional.linear + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in itertools.chain( + module.graph.find_nodes(op="call_function", target=torch.bmm), + module.graph.find_nodes(op="call_function", target=torch.matmul), + ): + normalized = NormalizedMatmulNode(node) + input_A_node = normalized.get_input() + input_B_node = normalized.get_other() + input_A = input_A_node + input_B = input_B_node + Atrans = Btrans = False + if ( + input_A_node.op == "call_method" + and input_A_node.target == "permute" + and check_permute(input_A_node) + ): + Atrans = True + if len(input_A_node.args) > 0: + input_A = input_A_node.args[0] # type: ignore[assignment] + else: + input_A = input_A_node.kwargs["input"] # type: ignore[assignment] + + if ( + input_B_node.op == "call_method" + and input_B_node.target == "permute" + and check_permute(input_B_node) + ): + Btrans = True + if len(input_B_node.args) > 0: + input_B = input_B_node.args[0] # type: ignore[assignment] + else: + input_B = input_B_node.kwargs["input"] # type: ignore[assignment] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(input_A, input_B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if Atrans and len(input_A_node.users) == 0: + module.graph.erase_node(input_A_node) + if Btrans and len(input_B_node.users) == 0: + module.graph.erase_node(input_B_node) + + module.graph.lint() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + if bias is None: + return torch.matmul(input.transpose(-1, -2), weight.t()) + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul( + A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool +) -> torch.Tensor: + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..862df99a41e50d6217012cebab4fb66da365ec50 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/quantization.py @@ -0,0 +1,3891 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import copy +import functools +import itertools +import math +import operator +from typing import Any + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.fx.node import map_arg + +from .. import config +from ..lowering import lowerings as L, require_channels_last +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + KeywordArg, + ListOf, + Match, + stable_topological_sort, +) +from ..utils import pad_listlike +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims +quantized_decomposed = torch.ops.quantized_decomposed +quantized = torch.ops.quantized + +# Only for per tensor quant since permute may changes the channel idx +_PER_TENSOR_QUANTIZE_OPS = [ + quantized_decomposed.quantize_per_tensor.default, + quantized_decomposed.quantize_per_tensor.tensor, +] + +_VIEW_OPS = [ + aten.transpose.int, + aten.permute.default, + aten.view.default, +] + +""" +The quantization.py file primarily incorporates passes related to quantization fusion +in inductor, includes: +1. Dequant Promotion; +2. Conv/GEMM weight prepack with oneDNN Library; +3. Conv/GEMM quantization fusion with output quant node (if have); +4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; + +It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference +of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is +1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. +2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. +Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 +quantization. +""" + + +def _get_pattern_output_dtype(match: Match): + """ + Get the pattern's output dtype from node's meta + Assume only 1 output node in this matched pattern. + """ + pattern_output_nodes = match.output_nodes() + assert len(pattern_output_nodes) == 1 + output_node = pattern_output_nodes[0] + assert isinstance(output_node, torch.fx.Node) + output_dtype = output_node.meta["val"].dtype + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + return output_dtype + + +def _may_generate_pattern_with_dtype_convert( + pattern, dtype=Arg(), with_dtype_convert=True, users=1 +): + if with_dtype_convert: + return CallFunction( + prims.convert_element_type.default, + pattern, + dtype, + _users=users, + ) + else: + return pattern + + +def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): + if with_reshape: + return CallFunction( + torch.ops.aten.reshape.default, + pattern, + reshape_size, + ) + else: + return pattern + + +def _generate_linear_t_pattern( + _dequant_per_channel_pattern, + dtype, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + _dequant_per_channel_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + return t_pattern + + +def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): + # only insert to_dtype if is_bf16 is True + computation_call = _may_generate_pattern_with_dtype_convert( + call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users + ) + return unary_fusion(computation_call) + + +def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): + dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.tensor + if is_tensor_overload + else quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("x_quant_min"), + KeywordArg("x_quant_max"), + KeywordArg("x_dq_dtype"), + ) + return dequantize_per_tensor_activation_pattern + + +dequantize_per_channel_weight_pattern = CallFunction( + quantized_decomposed.dequantize_per_channel.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("w_axis"), + KeywordArg("w_quant_min"), + KeywordArg("w_quant_max"), + KeywordArg("w_dtype"), +) + +dequantize_per_channel_to_bf16_weight_pattern = ( + _may_generate_pattern_with_dtype_convert( + dequantize_per_channel_weight_pattern, + KeywordArg("autocast_wgt_dtype"), + ) +) + +dequantize_per_channel_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + +dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_to_bf16_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + + +def get_qconv_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv_pointwise.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +def get_qconv2d_binary_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv2d_pointwise.binary, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("accum"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("accum_scale"), + KeywordArg("accum_zero_point"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + +def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +def get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("x_2"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("x2_scale"), + KeywordArg("x2_zp"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + +dequantize_accum_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("accum"), + KeywordArg("accum_scale"), + KeywordArg("accum_zp"), + Arg(), + Arg(), + KeywordArg("accum_dq_dtype"), +) + + +def generate_pattern_with_binary( + binary_post_op, + computation_call, + extra_input_pattern, + dtype_convert=False, + swap_inputs=False, +): + binary_pattern = ( + CallFunction( + binary_post_op, + extra_input_pattern, + computation_call, + ) + if swap_inputs + else CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) + ) + return _may_generate_pattern_with_dtype_convert( + binary_pattern, + KeywordArg("convert_dtype_after_inplace_add"), + dtype_convert, + ) + + +def generate_pattern_with_unary(computation_call, unary_post_op): + if unary_post_op is not None: + return CallFunction( + unary_post_op, + computation_call, + ) + return computation_call + + +def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): + quantized_op_output_pattern_pt2e = CallFunction( + quantized_decomposed.quantize_per_tensor.default, + _may_generate_pattern_with_dtype_convert( + computation_call, + Arg(), + with_dtype_convert, + ), + KeywordArg("o_inv_scale"), + KeywordArg("o_zp"), + KeywordArg("o_qmin"), + KeywordArg("o_qmax"), + KeywordArg("o_dtype"), + ) + return quantized_op_output_pattern_pt2e + + +def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): + if kwarg_name in check_node.kwargs: + actual_value = check_node.kwargs[kwarg_name] + return actual_value == expected_value + else: + assert len(check_node.args) >= (args_index + 1) + actual_value = check_node.args[args_index] + return actual_value == expected_value + + +def _is_valid_quantized_conv_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qconv_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qconv_pointwise + )[0] + return _check_node_kwarg_arg_value( + qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype + ) + return True + + return fn + + +def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): + return ( + _is_valid_qconv_binary_optimization_pattern() + if has_binary_post_op + else _is_valid_quantized_conv_optimization_pattern() + ) + + +def _is_valid_qconv_lowering_pattern(): + def fn(match): + if len(match.nodes) != 1: + return False + return match.nodes[0].target in ( + torch.ops.onednn.qconv_pointwise.default, + torch.ops.onednn.qconv_pointwise.tensor, + torch.ops.onednn.qconv2d_pointwise.binary, + torch.ops.onednn.qconv2d_pointwise.binary_tensor, + ) + + return fn + + +def _register_quantized_conv_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qconv_lowering_pattern(), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + output_dtype = kwargs["output_dtype"] + # post op + postop_name = kwargs["postop_name"] + postop_args = kwargs["postop_args"] + postop_algorithm = kwargs["postop_algorithm"] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + postop_name, + postop_args, + postop_algorithm, + ) + counters["inductor"]["qconv_unary_lower_count"] += 1 + counters["inductor"]["qconv_unary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv + + +def _is_valid_quantized_linear_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qlinear_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qlinear_pointwise + )[0] + return _check_node_kwarg_arg_value( + qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype + ) + return True + + return fn + + +def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False): + return ( + _is_valid_qlinear_binary_optimization_pattern() + if has_binary_post_op + else _is_valid_quantized_linear_optimization_pattern() + ) + + +def _is_valid_qlinear_lowering_pattern(): + def fn(match): + if len(match.nodes) != 1: + return False + return match.nodes[0].target in ( + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, + ) + + return fn + + +def _register_quantized_linear_unary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_lowering_pattern(), + pass_number=pass_number, + ) + def qlinear(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs["b"] if "b" in kwargs else None + + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + + # post op + postop_name = kwargs["postop_name"] + postop_args = kwargs["postop_args"] + postop_algorithm = kwargs["postop_algorithm"] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + postop_name, + postop_args, + postop_algorithm, + ) + counters["inductor"]["qlinear_unary_lower_count"] += 1 + counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear + + +def _register_quantized_linear_binary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_lowering_pattern(), + pass_number=pass_number, + ) + def qlinear_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype is not None + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + x2 = kwargs["x_2"] + x2_scale = kwargs["x2_scale"] + x2_zp = kwargs["x2_zp"] + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # bias + b = kwargs["b"] if "b" in kwargs else None + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + + x2.realize() + from .mkldnn_fusion import _can_be_inplace + + binary_op_name = kwargs["binary_op_name"] + alpha = kwargs["alpha"] + unary_op_name = kwargs["unary_op_name"] + unary_op_args = kwargs["unary_op_args"] + unary_op_algorithm = kwargs["unary_op_algorithm"] + + if binary_op_name == "sum" and not _can_be_inplace(x2): + # When we enable the GEMM Template, the output of QLinear + # will be reshaped from 2D back to 3D if the input is 3D. + # This causes _can_be_inplace(x2) to return False if x2 happens + # to be the output of QLinear in this scenario. + # Change the post op from sum to binary add for this case. + # Refer to test case: + # test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2 + binary_op_name = "add" + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + x2, + b, + o_inv_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_op_name, + alpha, + unary_op_name, + unary_op_args, + unary_op_algorithm, + ) + counters["inductor"]["qlinear_binary_lower_count"] += 1 + counters["inductor"]["qlinear_binary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear_binary + + +def _is_valid_qconv_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qconv_pointwise + ) + + +def _is_valid_qlinear_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qlinear_pointwise, + # we don't insert q-dq for extra input due to accuracy issues + extra_input_from_dequant=False, + ) + + +def _is_valid_quantized_op_binary_optimization_pattern( + qop, extra_input_from_dequant=True +): + # Check if it's a valid Binary Pattern for qconv2d and qlinear: + # * qop_pointwise should only has one users + # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern + # * the two inputs of binary node should have attribute "meta" and should be tensors + # * the two inputs of binary node should have the same shape + # * All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + compute_node = filter_nodes(match.nodes, qop)[0] + # qop_pointwise should only have one user + if len(compute_node.users) != 1: + return False + binary_node_inputs = next(iter(compute_node.users)).args + assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + if output_dtype in [torch.float32, torch.bfloat16]: + extra_input_of_binary_node = None + for arg in binary_node_inputs: + if arg != compute_node: + extra_input_of_binary_node = arg + break + assert extra_input_of_binary_node is not None + # Extra input of binary node comes from dequant pattern + if extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + != quantized_decomposed.dequantize_per_tensor.default + ) + ): + return False + + # the two inputs of binary node should have attribute "meta" and should be tensors + if not ( + hasattr(binary_node_inputs[0], "meta") + and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ) or not ( + hasattr(binary_node_inputs[1], "meta") + and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ): + return False + # the two inputs of binary node should have the same shape + if ( + binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] + != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] + ): + return False + + # All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + + from .mkldnn_fusion import _get_remaining_users + + extra_input_of_pattern = ( + match.kwargs["other"] + if "other" in match.kwargs + else ( + match.kwargs["accum"] + if (output_dtype in [torch.uint8, torch.int8]) + or (not extra_input_from_dequant) + else match.kwargs["accum_after_dequant"] + ) + ) + if ( + len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1 + or extra_input_of_pattern == compute_node.args[0] + ): + return False + return True + + return fn + + +def _register_quantized_conv_binary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qconv_lowering_pattern(), + pass_number=pass_number, + ) + def qconv_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype is not None + x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] + accum = kwargs["accum"] + accum_scale = kwargs["accum_scale"] + accum_zp = kwargs["accum_zero_point"] + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + # Output QParams + output_scale = kwargs["output_scale"] + output_zero_point = kwargs["output_zero_point"] + + # post ops + binary_op_name = kwargs["binary_op_name"] + alpha = kwargs["alpha"] + unary_op_name = kwargs["unary_op_name"] + unary_op_args = kwargs["unary_op_args"] + unary_op_algorithm = kwargs["unary_op_algorithm"] + + accum.realize() + from .mkldnn_fusion import _can_be_inplace + + assert _can_be_inplace(accum), ( + "QConv Binary Inplace Fusion requires accum is not an alias or mutation." + ) + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + accum, + b, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point, + output_dtype, + accum_scale, + accum_zp, + binary_op_name, + alpha, + unary_op_name, + unary_op_args, + unary_op_algorithm, + ) + counters["inductor"]["qconv2d_binary_lower_count"] += 1 + counters["inductor"]["qconv2d_binary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv_binary + + +def _register_quantization_unary_lowering(): + # QConv2d + for users in [1, 2]: + qconv_pattern = get_qconv_pt2e_pattern(users) + _register_quantized_conv_lowering( + qconv_pattern, + 2, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + _register_quantized_linear_unary_lowering( + qlinear_pattern, + 2, # pass_number + computation_op, + ) + + +def _register_quantization_binary_lowering(): + # QConv2d + for users in (1, 2): + qconv_pattern = get_qconv2d_binary_pt2e_pattern(users) + _register_quantized_conv_binary_lowering( + qconv_pattern, + 2, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + _register_quantized_linear_binary_lowering( + qlinear_pattern, + 2, # pass_number + computation_op, + ) + + +def _is_valid_quantized_maxpool2d_optimization_pattern(): + def fn(match): + # Only match the pattern which max_pool2d_with_indices returns value + # instead of indices. + get_item_node = filter_nodes(match.nodes, operator.getitem)[0] + return get_item_node.args[1] == 0 + + return fn + + +def _register_quantized_maxpool2d_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(), + ) + def qmaxpool2d(match: Match, *args, **kwargs): + x = kwargs["x"] + kernel_size = kwargs["kernel_size"] + stride = kwargs["stride"] if ("stride" in kwargs) else None + padding = kwargs["padding"] if ("padding" in kwargs) else 0 + dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1 + ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False + + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + + computation_args = ( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + computation_args, _ = require_channels_last(computation_op, *computation_args) + counters["inductor"]["qmaxpool2d_matcher_count"] += 1 + counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qmaxpool2d + + +def _register_quantization_maxpool2d(): + # Currently, the default parameters are not in FX Graph generated by Dynamo export. + # So, if user defines nn.MaxPool2d with different assignment of default parameter, + # it will generate graph with different number of input nodes and hence + # different pattern to be matched. + # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901 + max_pool2d_args_list = [ + [ + KeywordArg("stride"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("ceil_mode"), + ], + ] + for max_pool2d_args in max_pool2d_args_list: + dequantize_maxpool2d_pattern = CallFunction( + aten.max_pool2d_with_indices.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("kernel_size"), + *max_pool2d_args, + ) + dequantize_lowmem_maxpool2d_pattern = CallFunction( + prims._low_memory_max_pool_with_offsets.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("kernel_size"), + *max_pool2d_args, + KeywordArg("offset_dtype"), + ) + dequantize_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_maxpool2d_pattern, + Arg(), + ) + dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_lowmem_maxpool2d_pattern, + Arg(), + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern), + quantized.max_pool2d.default, + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant( + dequantize_lowmem_maxpool2d_get_item_pattern + ), + quantized.max_pool2d.default, + ) + + +def _is_input_output_same_scale_zp(check_node): + def fn(match): + # Ensure all the inputs and output has same scale and zero point + # Step 1: Check inputs/output zero point + # Get dequant nodes at input + dequant_nodes = filter_nodes( + match.nodes, quantized_decomposed.dequantize_per_tensor.default + ) + zero_points = [node.args[2] for node in dequant_nodes] + # Get quant nodes at output + quant_nodes = filter_nodes( + match.nodes, quantized_decomposed.quantize_per_tensor.default + ) + assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern" + zero_points.append(quant_nodes[0].args[2]) + if not all(zero_point == zero_points[0] for zero_point in zero_points): + return False + + # Step 2: Check inputs/output scale + scales = [node.args[1] for node in dequant_nodes] + scales.append(quant_nodes[0].args[1]) + if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] + return False + + return True + + return fn + + +def _register_quantized_cat_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.cat.default), + ) + def qcat(match: Match, inputs, dim, **kwargs): + # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] + uint8_inputs = [input[0] for input in inputs] + counters["inductor"]["qcat_matcher_count"] += 1 + counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes) + return L[computation_op](uint8_inputs, dim) + + return qcat + + +_raw_dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), +) + + +def _register_quantization_cat(): + dequantize_cat_pattern = CallFunction( + aten.cat.default, + ListOf(_raw_dequantize_per_tensor_activation_pattern), + KeywordArg("dim"), + ) + _register_quantized_cat_lowering( + generate_pattern_with_output_quant(dequantize_cat_pattern), + aten.cat, + ) + + +def _register_quantized_reshape_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.reshape.default), + ) + def qreshape(match: Match, *args, **kwargs): + qx = kwargs["x"] + shape = kwargs["shape"] + counters["inductor"]["qreshape_matcher_count"] += 1 + counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) + return L[computation_op](qx, shape) + + return qreshape + + +def _register_quantization_reshape(): + dequantize_reshape_pattern = CallFunction( + torch.ops.aten.reshape.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("shape"), + ) + _register_quantized_reshape_lowering( + generate_pattern_with_output_quant(dequantize_reshape_pattern), + aten.reshape, + ) + + +def _is_valid_concat_linear_int8_woq_optimization_pattern(): + def fn(match): + if not config.cpp.enable_concat_linear: + return False + assert all(k in match.kwargs for k in ("x", "w1", "w2", "w3", "scales")) + if not all( + hasattr(match.kwargs[key], "meta") + for key in ["x", "w1", "w2", "w3", "scales"] + ): + return False + x = match.kwargs["x"].meta["val"] + w1 = match.kwargs["w1"].meta["val"] + w2 = match.kwargs["w2"].meta["val"] + w3 = match.kwargs["w3"].meta["val"] + scales = match.kwargs["scales"].meta["val"] + if len(match.kwargs["scales"].meta["val"].size()) > 1: + return False + num_scales = match.kwargs["scales"].meta["val"].numel() + w1_cols = match.kwargs["w1"].meta["val"].size()[0] + w2_cols = match.kwargs["w2"].meta["val"].size()[0] + w3_cols = match.kwargs["w3"].meta["val"].size()[0] + # Technically, the shapes of the three weights need not be equal. + # But currently, we only enable replacement in this case. + if w1_cols != w2_cols or w2_cols != w3_cols: + return False + if 3 * w1_cols != num_scales: + return False + return ( + # For now, we only support woq mm kernels + # with x.type=bfloat16 and w.type=int8 + x.dtype == torch.bfloat16 + and w1.dtype == torch.int8 + and w2.dtype == torch.int8 + and w3.dtype == torch.int8 + and scales.dtype == torch.bfloat16 + # _weight_int8pack_mm kernel only supports cpu now + # TODO: add cuda kernel support instead of calling mul+sum + and x.device.type == "cpu" + and x.device == w1.device + and w1.device == w2.device + and w2.device == w3.device + and x.device == scales.device + ) + + return fn + + +def _is_valid_woq_optimization_pattern(): + def fn(match): + assert all(k in match.kwargs for k in ("x", "weight", "scales")) + if not all( + hasattr(match.kwargs[key], "meta") for key in ["x", "weight", "scales"] + ): + return False + x = match.kwargs["x"].meta["val"] + weight = match.kwargs["weight"].meta["val"] + scales = match.kwargs["scales"].meta["val"] + return ( + # For now, we only support woq mm kernels + # with x.type=bfloat16 and w.type=int8 + x.dtype == torch.bfloat16 + and weight.dtype == torch.int8 + and scales.dtype == torch.bfloat16 + # _weight_int8pack_mm kernel only supports cpu now + # TODO: add cuda kernel support instead of calling mul+sum + and x.device.type == "cpu" + and x.device == weight.device + and x.device == scales.device + ) + + return fn + + +def _register_concat_linear_int8_woq_lowering( + pattern, computation_woq, computation_reshape +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(), + pass_number=4, + ) + def woq(match: Match, *args, **kwargs): + x = kwargs["x"] + w1 = kwargs["w1"] + w2 = kwargs["w2"] + w3 = kwargs["w3"] + scales = kwargs["scales"] + counters["inductor"]["woq_matcher_count"] += 1 + counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) + out_features = ( + w1.meta["val"].size()[0] + + w2.meta["val"].size()[0] + + w3.meta["val"].size()[0] + ) + origin_x_size = tuple(x.meta["val"].size()) + x_shape = [-1, origin_x_size[-1]] + out_shape = list(origin_x_size[:-1] + (out_features,)) + mm_node_of_x = None + for candidate in iter(x.users.keys()): + if ( + candidate.target == aten.mm.default + and list(candidate._input_nodes)[1].target == aten.cat.default + ): + mm_node_of_x = candidate + break + assert mm_node_of_x is not None, "unable to find mm node" + _, cat_wgt_node = mm_node_of_x._input_nodes + scaling_node = next(iter(mm_node_of_x.users.keys())) + user_of_scaling_node = next(iter(scaling_node.users.keys())) + # Some other pass is making some changes that entails + # adding a node before it's used, but it can only be found when + # lint is run. stable_topological_sort() is being run before lint, + # so that error was not being being discovered. + # We call stable_topological_sort here as a workaround. + stable_topological_sort(match.graph) + with match.graph.inserting_before(user_of_scaling_node): + new_cat_node = match.graph.call_function( + aten.cat.default, + args=([w1, w2, w3], 0), + ) + x_reshape_node = match.graph.call_function( + computation_reshape, args=(x, x_shape) + ) + new_woq_node = match.graph.call_function( + computation_woq, + args=(x_reshape_node, new_cat_node, scales), + ) + new_woq_node.meta = copy.copy(x.meta) + output_reshape_node = match.graph.call_function( + computation_reshape, args=(new_woq_node, out_shape) + ) + scaling_node.replace_all_uses_with(output_reshape_node) + match.graph.erase_node(scaling_node) + match.graph.erase_node(mm_node_of_x) + match.graph.erase_node(cat_wgt_node) + match.graph.lint() + + return woq + + +def _register_woq_lowering(pattern, computation_woq, computation_reshape): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_woq_optimization_pattern(), + ) + def woq(match: Match, *args, **kwargs): + x = kwargs["x"] + weight = kwargs["weight"] + scales = kwargs["scales"] + counters["inductor"]["woq_matcher_count"] += 1 + counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) + out_features = weight.get_size()[0] + origin_x_size = x.get_size() + x_shape = [-1, origin_x_size[-1]] + out_shape = origin_x_size[:-1] + [ + out_features, + ] + func1 = L[computation_reshape](x, x_shape) + func2 = L[computation_woq](func1, weight, scales) + return L[computation_reshape](func2, out_shape) + + return woq + + +def _register_woq_mm_int8_pattern1(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to mm, with x reshape + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.reshape.default, + CallFunction( + aten.mm.default, + CallFunction(aten.reshape.default, KeywordArg("x"), Arg()), + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + ), + Arg(), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern2(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to mm, w/o x reshape + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.reshape.default, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + ), + Arg(), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern3(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to bmm + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.bmm.default, + CallFunction(aten.expand.default, KeywordArg("x"), Arg()), + CallFunction( + aten.expand.default, + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern4(): + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg("weight"), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_int8_woq_concat_linear_pattern(): + def _create_wgt_node(wgt_node_name: str): + return CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg(wgt_node_name), + Arg(), + ), + Arg(), + ) + + cat_wgt = CallFunction( + aten.cat.default, [_create_wgt_node(wgt) for wgt in ["w1", "w2", "w3"]], 1 + ) + + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction(aten.mm.default, KeywordArg("x"), cat_wgt), + KeywordArg("scales"), + ) + _register_concat_linear_int8_woq_lowering( + _woq_pattern, aten._weight_int8pack_mm.default, aten.reshape + ) + + +def _register_quantization_lowerings(): + _register_quantization_unary_lowering() + _register_quantization_binary_lowering() + _register_quantization_maxpool2d() + _register_quantization_cat() + _register_quantization_reshape() + + +def _register_woq_lowerings(): + _register_woq_mm_int8_pattern1() + _register_woq_mm_int8_pattern2() + _register_woq_mm_int8_pattern3() + _register_woq_mm_int8_pattern4() + + +def _is_valid_dequant_promotion_pattern(dtype=torch.float32): + def _inner(match): + assert dtype in [torch.float32, torch.bfloat16] + dequant_pattern_end_node = match.output_node() + if dequant_pattern_end_node.target not in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ]: + return False + + if dequant_pattern_end_node.target is aten.reshape.default: + dequant_node = ( + dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- reshape <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[0].args[ + 0 + ] # pattern: linear <- reshape <- to_bf16 <- dequant + ) + else: + dequant_node = ( + dequant_pattern_end_node # pattern: linear <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- to_bf16 <- dequant + ) + + if ( + dequant_node.target + in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + and len(list(dequant_pattern_end_node.users)) > 1 + ): + # If dequant pattern has more than 1 users, then do dequant promoted + return True + return False + + return _inner + + +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_promotion_pattern(dtype), + pass_number=pass_number, + ) + def dequant_promotion(match: Match, *args, **kwargs): + # Dequant_promotion will transform + # graph 1: + # quant + # + - - - | - - - + + # | dequant | + # | / \ | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # into: + # graph 2: + # quant + # + - - / - \ - - + + # |dequant dequant| + # | | | | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # In graph 1, the dequant node is shared by node1 and node2, + # as a result, neither node1 nor node2 could form an int8 + # fusion pattern. + # After this transformation, the graph 2 could hit the int8 + # fusion pattern: dequant-node-quant, respectively for + # node1 and node2. + assert dtype in [torch.float32, torch.bfloat16] + + def clone_to_new_node(graph, source_node, user_node): + # Clone the source_node to a new node + # Replace user_node's input from source_node to new_node + assert source_node.op == "call_function", ( + "clone_to_new_node only support node.op call_function" + ) + with graph.inserting_before(user_node): + new_node = graph.call_function( + source_node.target, + args=source_node.args, + kwargs=source_node.kwargs, + ) + new_node.meta = copy.copy(source_node.meta) + user_node.replace_input_with(source_node, new_node) + return new_node + + # Find the start node and end node of a dequant pattern + # * End node should be the match.output_node() + # * Start node should be the node of dequantize_per_tensor + dequant_pattern_end_node = match.output_node() + assert dequant_pattern_end_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ] + + # For a dequant pattern, we should expect see the node list as: + # * OPT(aten.reshape.default) + # * OPT(prims.convert_element_type.default) (to_bf16) + # * dequantize_per_tensor + def _find_first_node_in_dequant_pattern(_node): + if _node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For a dequant pattern, we expect the start node is a dequantize_per_tensor node + return _node + else: + assert len(_node.args) >= 1, ( + "In in dequant pattern, each node should have more than 1 arg." + ) + return _find_first_node_in_dequant_pattern(_node.args[0]) + + dequant_pattern_start_node = _find_first_node_in_dequant_pattern( + dequant_pattern_end_node + ) + + assert dequant_pattern_start_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + # Clone the dequant pattern for each user node + graph = match.graph + user_node_list = list(dequant_pattern_end_node.users) + for user_node in user_node_list[1:]: + _source_node = dequant_pattern_end_node + _user_node = user_node + while _source_node != dequant_pattern_start_node.args[0]: + _user_node = clone_to_new_node(graph, _source_node, _user_node) + _source_node = _source_node.args[0] # type: ignore[assignment] + + counters["inductor"]["dequant_promotion_matcher_count"] += 1 + counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) + + +def _is_valid_dequant_conv_pattern(dtype): + def _inner(match): + # Here we do some further check to ensure: + # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. + # 2. The dequant pattern has only 1 user of conv2d node. + # If these conditions don't meet, we will not + # insert weight prepack node into the matched pattern. + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu") + or meta_value.dim() not in [3, 4] + ): + # Only support conv1d/2d now + return False + + assert dtype in [torch.float32, torch.bfloat16] + + if dtype == torch.float32: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + return True + + return _inner + + +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_conv_pattern(dtype), + pass_number=pass_number, + ) + def qconv_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + if dtype == torch.float32: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr] + has_clone_to_channel_last_node_in_pattern = ( + conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] + ) + clone_node = ( + conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None + ) + + if dtype == torch.float32: + dequant_per_channel = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + else: + weight_to_bf16_node = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + + assert ( + dequant_per_channel.target # type: ignore[union-attr] + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Conv Params + bias, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(conv_node): + # Insert weight prepack node and the QConv node + packed_weight_inputs = ( + qw, + w_scale, + x_scale, + x_zp, + stride, + padding, + dilation, + groups, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qconv_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # attr + [], # scalars + "", # algorithm + ) + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) + conv_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(conv_node.meta) + + # Erase the original conv node + graph.erase_node(conv_node) + # Erase the dequant pattern + if dtype == torch.bfloat16: + graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_node) # type: ignore[arg-type] + # Erase the dequant per channel pattern + if clone_node is not None: + graph.erase_node(clone_node) # type: ignore[arg-type] + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_convolution_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32 +): + assert dtype in [torch.float32, torch.bfloat16] + dequant_convolution_node_pattern = CallFunction( + aten.convolution.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + _dequant_per_channel_pattern, + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("is_transposed"), + KeywordArg("out_padding"), + KeywordArg("groups"), + ) + return dequant_convolution_node_pattern + + +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): + assert dtype in [torch.float32, torch.bfloat16] + return ( + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_weight_pattern, + dtype, + ), + # There is another pattern due to the pass of convert_conv_weights_to_channels_last + # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. + # Depend on some heuristics, it may or may not insert to(channel_last) node + # between convolution and dequant_per_channel node + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_clone_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_clone_weight_pattern, + dtype, + ), + ) + + +def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): + output_reshape_node = None + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node = match.output_node() + assert output_reshape_node.target is aten.reshape.default + linear_node = output_reshape_node.args[0] + else: + linear_nodes = filter_nodes(match.nodes, aten.bmm.default) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + else: + linear_node = match.output_node() + + assert linear_node.target in ( + aten.addmm.default, + aten.mm.default, + aten.bmm.default, + ) + return linear_node, output_reshape_node + + +def _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +): + act_reshape_node = None + activation_to_bf16_node = None + act_expand_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + assert act_reshape_node.target is aten.reshape.default + if dtype == torch.float32: + # pattern: linear -> reshape -> dequant + dequant_node = act_reshape_node.args[0] + else: + # pattern: linear -> reshape -> to_bf16 -> dequant + activation_to_bf16_node = act_reshape_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous + act_expand_node = linear_node.args[input_index] + assert act_expand_node.target is aten.expand.default + if dtype == torch.float32: + dequant_node = act_expand_node.args[0] + else: + activation_to_bf16_node = act_expand_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + if dtype == torch.float32: + # pattern: linear -> dequant + dequant_node = linear_node.args[input_index] + else: + # pattern: linear -> to_bf16 -> dequant + activation_to_bf16_node = linear_node.args[input_index] + dequant_node = activation_to_bf16_node.args[0] + return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node + + +def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): + def _inner(match): + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + ( + dequant_node, + _, + _, + _, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + assert dequant_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + + # Extra check for bmm pattern + if input_dim_exceeds_two and not input_contiguous: + # Check for act + # Act expand size should be exactly same as act size + act_expand_size = match.kwargs["act_expand_size"] + act_node = match.kwargs["x"] + if not ( + hasattr(act_node, "meta") + and isinstance(act_node.meta.get("val", None), torch.Tensor) + and (act_node.meta["val"].size() == torch.Size(act_expand_size)) + ): + return False + + # Check for wgt + # wgt permute dims should be [1, 0] + wgt_permute_dims = match.kwargs["permute_axes"] + if wgt_permute_dims != [1, 0]: + return False + + # Check below wgt size items: + # wgt before expand should with dim 2 + # Expand size should with dim 3 + # Expand size[0] should same as act size[0] + # Expand size[1] should same as wgt size[1] + # Expand size[2] should same as wgt size[0] + qweight_node = match.kwargs["q_weight"] + wgt_expand_size = match.kwargs["wgt_expand_size"] + if not ( + hasattr(qweight_node, "meta") + and isinstance(qweight_node.meta.get("val", None), torch.Tensor) + and len(qweight_node.meta["val"].size()) == 2 + and len(wgt_expand_size) == 3 + and wgt_expand_size[0] == act_node.meta["val"].size()[0] + and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] + and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] + ): + return False + + return True + + return _inner + + +def _register_qlinear_weight_prepack_pass( + pattern, + pass_number, + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_linear_pattern( + dtype, input_dim_exceeds_two, input_contiguous + ), + pass_number=pass_number, + ) + def qlinear_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + mm/addmm <- t <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + dequant_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_channel = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_channel = weight_to_bf16_node.args[0] + assert ( + dequant_per_channel.target + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + qw, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + Node = torch.fx.node.Node + if isinstance(x_scale, Node) and isinstance(x_zp, Node): + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + else: + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.default, args=new_args + ) + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_reshape_node.meta) + else: + if bias: + output_add_node_for_bias = match.output_node() + assert output_add_node_for_bias.target is aten.add.Tensor + output_add_node_for_bias.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_add_node_for_bias.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + # Erase the original linear node + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(output_reshape_node) + elif not input_contiguous and bias: + graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] + graph.erase_node(linear_node) + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(act_reshape_node) + else: + graph.erase_node(act_expand_node) + graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(dequant_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_linear_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + input_dim_exceeds_two=False, + is_tensor_overload=False, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _generate_dequant_bmm_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + with_bias=False, + is_tensor_overload=False, +): + # When activation of linear dim exceed 2 and not contiguous + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + + assert dtype in [torch.float32, torch.bfloat16] + dequant_bmm_pattern = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + + def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): + if _with_bias: + return CallFunction( + aten.add.Tensor, + _dequant_bmm_pattern, + KeywordArg("b"), + ) + else: + return _dequant_bmm_pattern + + return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) + + +def _generate_qlinear_weight_prepack_patterns( + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, + with_bias=False, + is_tensor_overload=False, +): + if input_dim_exceeds_two and not input_contiguous: + return _generate_dequant_bmm_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + with_bias, + is_tensor_overload, + ) + else: + return _generate_dequant_linear_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + input_dim_exceeds_two, + is_tensor_overload, + ) + + +def _generate_linear_dynamic_fp16_pattern( + _dequant_weight_pattern, + input_dim_exceeds_two=False, + input_contiguous=True, + relu_fused=False, +): + dtype = torch.float32 + t_pattern = _generate_linear_t_pattern(_dequant_weight_pattern, dtype) + + if input_dim_exceeds_two and not input_contiguous: + # pattern is + # x -> expand -> bmm (-> add) (-> relu) + # w -> dequant -> permute -> expand / + pattern_no_bias = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + KeywordArg("x"), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + pattern_with_bias = CallFunction( + aten.add.Tensor, + pattern_no_bias, + KeywordArg("b"), + ) + if relu_fused: + pattern_with_bias = CallFunction(aten.relu.default, pattern_with_bias) + pattern_no_bias = CallFunction(aten.relu.default, pattern_no_bias) + return pattern_with_bias, pattern_no_bias + + x_pattern_with_reshape = _may_generate_pattern_with_reshape( + KeywordArg("x"), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_bias_pattern = generate_pattern_with_unary( + _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + x_pattern_with_reshape, + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ), + aten.relu.default if relu_fused else None, + ) + dequant_linear_no_bias_pattern = generate_pattern_with_unary( + _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + x_pattern_with_reshape, + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ), + aten.relu.default if relu_fused else None, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _register_dequant_promotion(): + dequant_pattern_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + # 4 dequantization patterns will be matched based on the dtype and input dimension size. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 + # quant + # + - - - - | - - - - + + # | dequant | + # | | | + # | OPT(to_bf16) | + # | | | + # | OPT(reshape) | + # | / \ | + # | node1 node2 | + # + - - | - - - | - - + + # OPT(reshape) OPT(reshape) + # + - - | - - - | - - + + # OPT(to_fp32) OPT(to_fp32) + # + - - | - - - | - - + + # quant quant + _register_dequant_promotion_pass( + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=is_tensor_overload + ), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + with_reshape=input_dim_exceeds_two, + ), + pass_number=0, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + +def _register_qconv_weight_prepack(): + for dtype in [torch.float32, torch.bfloat16]: + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qconv_weight_prepack_pass( + weight_prepack_pattern, pass_number=1, dtype=dtype + ) + + +def _register_qlinear_weight_prepack(): + # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. + # Then convert the pattern into a QLinear node with int8_fp32/bf16. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(reshape) | + + # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous + # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | expand permute | + # | \ | | + # | expand | + # | / | + # | bmm | + # | | | + # | OPT(add) | + + linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + + # Step 1: register patterns from mm and addmm + for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: + weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( + dtype, + input_dim_exceeds_two, + is_tensor_overload=is_tensor_overload, + ) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qlinear_weight_prepack_pass( + weight_prepack_pattern, + pass_number=1, + dtype=dtype, + input_dim_exceeds_two=input_dim_exceeds_two, + ) + + # Step 2: register patterns from bmm + # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous + # refer to: + # https://github.com/pytorch/pytorch/blob/80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 + # in this case, we can convert it back to qlinear + for dtype, with_bias, is_tensor_overload in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ): + bmm_pattern = _generate_qlinear_weight_prepack_patterns( + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + with_bias=with_bias, + is_tensor_overload=is_tensor_overload, + ) + _register_qlinear_weight_prepack_pass( + bmm_pattern, + pass_number=1 + if with_bias + else 2, # if with_bias, there is an output add, so we should try to match it firstly + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + ) + + +def _register_linear_dynamic_fp16_weight_prepack_pass( + pattern, + pass_number, + input_dim_exceeds_two=False, + input_contiguous=True, + relu_fused=False, +): + def _extra_check_fn(match: Match): + return match.kwargs["dtype_fp16"] == torch.float16 + + @register_freezing_graph_pattern( + pattern, + extra_check=_extra_check_fn, + pass_number=pass_number, + ) + def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + fp32 activation + | + mm/addmm <- t <- to_fp32 <- to_fp16 <- weight + | + (reshape) <- (relu) + + OR + + fp32 activation + | + expand + | + bmm <- expand <- t <- to_fp32 <- to_fp16 <- weight + | + (add) <- (relu) + + Insert weight prepack node and change the pattern to: + fp32 activation + | + onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight + (or onednn.linear_relu_dynamic_fp16) + """ + # find params + x = kwargs["x"] + w = kwargs["w"] + bias = kwargs["b"] if "b" in kwargs else None + + # find linear node + nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default] + linear_nodes = [] + for node in nodes_to_find: + linear_nodes.extend(filter_nodes(match.nodes, node)) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + assert isinstance(linear_node, torch.fx.node.Node) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + # find relu node + relu_node = None + if relu_fused: + relu_node = match.output_node() + assert isinstance(relu_node, torch.fx.node.Node) + + # find reshape node, expand node and add node + ( + act_reshape_node, + output_reshape_node, + expand_x_node, + expand_w_node, + add_bias_node, + ) = (None, None, None, None, None) + t_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + t_node = linear_node.args[weight_index] + output_reshape_node = next(iter(linear_node.users)) + assert output_reshape_node.target is aten.reshape.default + else: + expand_x_node = linear_node.args[input_index] + expand_w_node = linear_node.args[weight_index] + assert isinstance(expand_w_node, torch.fx.node.Node) + t_node = expand_w_node.args[0] + if bias: + add_bias_node = next(iter(linear_node.users)) + assert add_bias_node.target is aten.add.Tensor + else: + t_node = linear_node.args[weight_index] + assert isinstance(t_node, torch.fx.node.Node) + + w_to_fp32_node = t_node.args[0] + assert ( + isinstance(w_to_fp32_node, torch.fx.node.Node) + and w_to_fp32_node.target + is quantized_decomposed.convert_element_type.no_fuse + ) + w_to_fp16_node = w_to_fp32_node.args[0] + assert ( + isinstance(w_to_fp16_node, torch.fx.node.Node) + and w_to_fp16_node.target + is quantized_decomposed.convert_element_type.no_fuse + ) + + x_shape = x.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + w, + x_shape, + ) + packed_weight_op = torch.ops.onednn.linear_prepack_fp16 + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + # create new linear node and insert on graph + new_args: tuple[Any, ...] = ( + x, + prepack_weight_node, + bias, + ) + linear_op = ( + torch.ops.onednn.linear_relu_dynamic_fp16.default + if relu_fused + else torch.ops.onednn.linear_dynamic_fp16.default + ) + new_linear_node = graph.call_function(linear_op, args=new_args) + out_node = match.output_node() + out_node.replace_all_uses_with(new_linear_node) + + # Erase the original nodes in the reverse order + new_linear_node.meta.update(out_node.meta) + if relu_node is not None: + graph.erase_node(relu_node) + if output_reshape_node is not None: + graph.erase_node(output_reshape_node) + if add_bias_node is not None: + graph.erase_node(add_bias_node) + graph.erase_node(linear_node) + if act_reshape_node is not None: + assert isinstance(act_reshape_node, torch.fx.node.Node) + graph.erase_node(act_reshape_node) + if expand_x_node is not None: + assert isinstance(expand_x_node, torch.fx.node.Node) + graph.erase_node(expand_x_node) + if expand_w_node is not None: + assert isinstance(expand_w_node, torch.fx.node.Node) + graph.erase_node(expand_w_node) + graph.erase_node(t_node) + graph.erase_node(w_to_fp32_node) + graph.erase_node(w_to_fp16_node) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _register_linear_dynamic_fp16_weight_prepack(): + to_dtype_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse + weight_pattern = CallFunction( + to_dtype_op, + CallFunction( + to_dtype_op, + KeywordArg("w"), + KeywordArg("dtype_fp16"), + ), + KeywordArg("dtype_fp32"), + ) + cases = itertools.product( + [False, True], # input_dim_exceeds_two + [True, False], # input_contiguous + [False, True], # relu fused + ) + for input_dim_exceeds_two, input_contiguous, relu_fused in cases: + patterns = _generate_linear_dynamic_fp16_pattern( + weight_pattern, + input_dim_exceeds_two, + input_contiguous, + relu_fused, + ) + for pattern in patterns: + _register_linear_dynamic_fp16_weight_prepack_pass( + pattern, + pass_number=0 if relu_fused else 1, + input_dim_exceeds_two=input_dim_exceeds_two, + input_contiguous=input_contiguous, + relu_fused=relu_fused, + ) + + +def _register_smooth_quant_int_mm_pattern(): + """ + The pattern is: + (no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape + or + (with bias) pattern_no_bias -> add (-> reshape -> reshape) + """ + + # When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist + # When torch.compile'ing with dynamic=False, they don't exist + def get_pattern_no_bias(expand_a_scale: bool, reshape_a: bool = True): + return CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mul.Tensor, + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten._int_mm.default, + CallFunction( + aten.reshape.default, + KeywordArg("a"), + KeywordArg("in_shape"), + ) + if reshape_a + else KeywordArg("a"), + KeywordArg("b"), + ), + KeywordArg("dtype"), + ), + ( + CallFunction( + aten.expand.default, + KeywordArg("x_scale"), + Arg(), + ) + if expand_a_scale + else KeywordArg("x_scale") + ), + ), + KeywordArg("w_scale"), + ) + + def _with_outer_reshape(pattern): + return CallFunction( + aten.reshape.default, pattern, KeywordArg("out_shape_no_bias") + ) + + # for torch.compile(dynamic=False) + pattern_no_bias_1 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=False)) + pattern_with_bias_1 = CallFunction( + aten.add.Tensor, + pattern_no_bias_1, + KeywordArg("bias"), + ) + # for torch.compile(dynamic=True) + pattern_no_bias_2 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=True)) + pattern_with_bias_2 = CallFunction( + aten.reshape.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.add.Tensor, + pattern_no_bias_2, + KeywordArg("bias"), + ), + Arg(), + ), + KeywordArg("out_shape_with_bias"), + ) + + # The following patterns are for torchao int8_dynamic_activation_int8_weight linear, + # when both activation and weights are symmetrically quantized. + # In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used. + # Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias. + # Ideally, we should add mul + add post-op support in ATen int8 oneDNN linear op. + pattern1_with_no_outer_or_act_reshape = get_pattern_no_bias( + expand_a_scale=False, reshape_a=False + ) + pattern2_with_no_outer_or_act_reshape = get_pattern_no_bias( + expand_a_scale=True, reshape_a=False + ) + + def _validate_pattern(match: Match): + if len(match.nodes) not in [4, 5, 6, 7, 10]: + return False + # Make sure weight is a constant + aten_int_mm_node = filter_nodes(match.nodes, aten._int_mm.default)[0] + if not isinstance(aten_int_mm_node.args[1], torch.fx.node.Node): + return False + if aten_int_mm_node.args[1].op != "get_attr": + return False + + if len(match.nodes) == 10: + # Check the two tailing reshape nodes can be fused + if match.nodes[9].args[1] != match.nodes[6].args[1]: + return False + if len(match.nodes) == 10 or ( + len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor + ): + bias_idx = 7 if len(match.nodes) == 10 else 6 + # Check bias shape + bias_node = match.nodes[bias_idx].args[1] + if not isinstance(bias_node, torch.fx.node.Node): + return False + if len(bias_node.meta.get("tensor_meta").shape) != 1: # type: ignore[union-attr] + return False + return True + + pattern_to_pass_number = { + pattern_no_bias_2: 0, + pattern_with_bias_2: 0, + pattern_no_bias_1: 1, + pattern_with_bias_1: 1, + pattern1_with_no_outer_or_act_reshape: 2, + pattern2_with_no_outer_or_act_reshape: 2, + } + for pattern, pass_number in pattern_to_pass_number.items(): + + @register_freezing_graph_pattern( + pattern, + extra_check=_validate_pattern, + pass_number=pass_number, + ) + def _int_mm_weight_prepack(match: Match, *args, **kwargs): + bias = kwargs.get("bias", None) + x = kwargs["a"] + weight = kwargs["b"] + dtype = kwargs["dtype"] + x_scale = kwargs["x_scale"] + w_scale = kwargs["w_scale"] + x_shape = x.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + transpose_node = match.graph.call_function( + aten.permute.default, args=(weight, [1, 0]) + ) + contig_node = match.graph.call_function( + aten.contiguous.default, args=(transpose_node,) + ) + packed_weight_inputs = ( + contig_node, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = match.graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + dummy_zp = None + w_scale = match.graph.call_function( + prims.convert_element_type.default, args=(w_scale, torch.float32) + ) + + x_scale_shape = x_scale.meta.get("tensor_meta").shape + x_scale_is_scalar = False + if not has_free_symbols(x_scale_shape): + prod = 1 + for d in x_scale_shape: + prod *= d + x_scale_is_scalar = prod == 1 + + new_args: tuple[Any, ...] + if x_scale_is_scalar: + # in this case, we can call onednn.qlinear directly + new_args = ( + x, + x_scale, + dummy_zp, # x_zp + prepack_weight_node, + w_scale, + dummy_zp, # w_zp + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + new_linear_node = match.graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + out_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(out_node.meta) + else: + # onednn.qlinear does not support per-channel quantization of x + # so in this case, we have to apply x scale and add bias ourselves after qlinear + in_shape = kwargs.get("in_shape", None) + if in_shape is None: + x_reshaped = x + else: + x_reshaped = match.graph.call_function( + aten.reshape.default, args=(x, in_shape) + ) + new_args = ( + x_reshaped, + 1.0, # x_scale + 0, # x_zp + prepack_weight_node, + w_scale, + dummy_zp, # w_zp + None, # bias + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + new_linear_node = match.graph.call_function( + torch.ops.onednn.qlinear_pointwise, args=new_args + ) + # apply x scale + new_out_node = match.graph.call_function( + aten.mul.Tensor, args=(new_linear_node, x_scale) + ) + + # Add bias and reshape + has_outer_reshape = ( + kwargs.get("out_shape_with_bias", None) is not None + or kwargs.get("out_shape_no_bias", None) is not None + ) + + if has_outer_reshape: + out_shape = kwargs.get( + "out_shape_with_bias", kwargs["out_shape_no_bias"] + ) + if bias is not None: + new_out_node = match.graph.call_function( + aten.add.Tensor, args=(new_out_node, bias) + ) + if has_outer_reshape: + new_out_node = match.graph.call_function( + aten.reshape.default, + args=(new_out_node, out_shape), # type: ignore[possibly-undefined] + ) + else: + if has_outer_reshape: + new_out_node = match.graph.call_function( + aten.reshape.default, + args=(new_out_node, out_shape), # type: ignore[possibly-undefined] + ) + out_node.replace_all_uses_with(new_out_node) + new_out_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +class PostOpAttr: + def __init__( + self, + binary_op_name: str = "none", + alpha=None, + unary_op_name: str = "none", + scalars_attr=None, + algorithm_attr=None, + ) -> None: + self.binary_op_name = binary_op_name + self.alpha = alpha if alpha else 1.0 + self.unary_op_name = unary_op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + +def _register_qconv_post_op_fusion_pass( + pattern, + pass_number, + computation_op, + post_op_attr, +): + has_binary_post_op = post_op_attr.binary_op_name != "none" + + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) + o_zero_point = ( + kwargs["o_zp"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 0 + ) + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + if post_op_attr.unary_op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + post_op_attr.scalars_attr = [min_value, max_value] + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + if not has_binary_post_op: + computation_args: tuple[Any, ...] = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + else: + accum = ( + kwargs["accum"] + if output_dtype in [torch.uint8, torch.int8] + else kwargs["accum_after_dequant"] + ) + accum_scale = ( + kwargs["accum_scale"] + if output_dtype in [torch.uint8, torch.int8] + else 1.0 + ) + accum_zp = ( + kwargs["accum_zp"] + if output_dtype in [torch.uint8, torch.int8] + else 0 + ) + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + accum, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + accum_scale, + accum_zp, + post_op_attr.binary_op_name, + post_op_attr.alpha, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + new_conv_node = match.graph.call_function( + computation_op, args=computation_args + ) + out_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + count_key = ( + "qconv2d_binary_matcher_count" + if has_binary_post_op + else "qconv_unary_matcher_count" + ) + nodes_key = ( + "qconv2d_binary_matcher_nodes" + if has_binary_post_op + else "qconv_unary_matcher_nodes" + ) + counters["inductor"][count_key] += 1 + counters["inductor"][nodes_key] += len(match.nodes) + + return qconv + + +def _register_qconv_unary_fusion(): + from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + # Priority 1 to match: QConv2d Unary pattern with int8 output + # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. + # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + conv_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + get_qconv_pt2e_pattern(1), + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_qconv_pt2e_pattern(1), aten.relu.default + ), + ), + PostOpAttr( + "none", None, "hardtanh", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardtanh_fusion, + get_qconv_pt2e_pattern(1), + 1, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "hardswish", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardswish_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "swish", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _silu_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + unary_attr, # unary_attr + ) + + # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output + conv_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + get_qconv_pt2e_pattern(1), aten.relu.default + ), + PostOpAttr( + "none", None, "hardtanh", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardtanh_fusion, + get_qconv_pt2e_pattern(1), + 1, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "hardswish", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardswish_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "swish", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _silu_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_qconv_post_op_fusion_pass( + patterns, + 4, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + unary_attr, # unary_attr + ) + + +def _register_qconv_binary_fusion(): + for int8_mixed_bf16_with_inplace_add in [False, True]: + # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output + swap_binary_inputs_list = [False, True] + binary_replace_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + ), + PostOpAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ), + ), + } + ) + + for binary_unary_attr, patterns in binary_replace_patterns.items(): + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ) + } + ) + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + if int8_mixed_bf16_with_inplace_add: + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + else: + _register_qconv_post_op_fusion_pass( + patterns, + 4, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + } + ) + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qconv_post_op_fusion_pass( + patterns, + 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + +def _register_qlinear_post_op_fusion_pass( + pattern, + pass_number, + computation_op, + post_op_attr, +): + has_binary_post_op = post_op_attr.binary_op_name != "none" + + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op), + pass_number=pass_number, + ) + def qlinear_post_op_fusion(match: Match, *args, **kwargs): + """ + Match the pattern: + qlinear - post op + """ + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs["b"] if "b" in kwargs else None + + # Output QParams + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype in [torch.uint8, torch.int8]) + else 1.0 + ) + o_zero_point = ( + kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0 + ) + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + if not has_binary_post_op: + computation_args: tuple[Any, ...] = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + else: + other = kwargs["other"] if "other" in kwargs else kwargs["accum"] + x2_scale = 1.0 + x2_zp = 0 + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + other, + b, + o_inv_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + post_op_attr.binary_op_name, + post_op_attr.alpha, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + new_linear_node = match.graph.call_function( + computation_op, args=computation_args + ) + out_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + count_key = ( + "qlinear_binary_matcher_count" + if has_binary_post_op + else "qlinear_unary_matcher_count" + ) + nodes_key = ( + "qlinear_binary_matcher_nodes" + if has_binary_post_op + else "qlinear_unary_matcher_nodes" + ) + counters["inductor"][count_key] += 1 + counters["inductor"][nodes_key] += len(match.nodes) + + +def _register_qlinear_unary_fusion(): + from .mkldnn_fusion import ( + _gelu_fusion_1 as _gelu_fusion_erf, + _gelu_fusion_2 as _gelu_fusion_tanh, + ) + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + qlinear_pattern, + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + computation_op, + unary_attr, # unary_attr + ) + + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + computation_op, + unary_attr, # unary_attr + ) + + +def _register_qlinear_binary_fusion(): + r""" + Supported linear-binary(-unary) patterns + + linear(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + 1. int8-mixed-fp32 + +---+---------------+-----------+------------------------------+---------+ + | # | Add type | Quant out | Pattern | Post op | + +---+---------------+-----------+------------------------------+---------+ + | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | + +---+---------------+-----------+------------------------------+---------+ + | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | + +---+---------------+-----------+------------------------------+---------+ + + 2. int8-mixed-bf16 + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | # | X2 dtype | Add type | Quant out | Pattern | Post op | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + + Note + (1) The positions of linear and the extra input can be swapped. + (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the + extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. + """ + for x_scale_zp_are_tensors in (False, True): + qlinear_binary_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + unary_postop_list = ["none", "relu"] + unary_postop_dict = { + "none": None, + "relu": aten.relu.default, + } + convert_dtype_after_binary_list = [False, True] + + # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output + # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, + # totally 3 patterns (2 are identical) + swap_binary_inputs_list = [False, True] + int8_mixed_bf16_list = [False, True] + combinations = itertools.product( + unary_postop_list, + int8_mixed_bf16_list, + swap_binary_inputs_list, + convert_dtype_after_binary_list, + ) + qlinear_binary_replace_patterns = {} + for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + if not int8_mixed_bf16 and cvt_dtype_binary: + # No convert node after binary node if dtypes are all fp32 + continue + qlinear_binary_replace_patterns.update( + { + PostOpAttr( + "add", 1.0, unary_op, [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + # If fp32 extra input is inplace added to bf16 linear output, + # a to_bf16 node is inserted after binary + dtype_convert=cvt_dtype_binary, + swap_inputs=swap_inputs, + ), + unary_postop_dict[unary_op], + ), + ) + } + ) + for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("add", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 5, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output + # Covers (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "add", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 5, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + +@functools.cache +def _register_quantization_weight_pack_pass(): + # Step 1: Dequant promotion for int8-mixed-fp32/bf16 + _register_dequant_promotion() + + # Step 2: QConv weight prepack + _register_qconv_weight_prepack() + + # Step 3: QLinear weight prepack + _register_qlinear_weight_prepack() + _register_linear_dynamic_fp16_weight_prepack() + + # Step 4: weight prepack for SmoothQuant from Torchao + _register_smooth_quant_int_mm_pattern() + + # Step 5: QLinear post op Fusion + if not torch.ops.mkldnn._is_mkldnn_acl_supported(): + # skip fusion on ARM + _register_qconv_unary_fusion() + _register_qconv_binary_fusion() + _register_qlinear_unary_fusion() + _register_qlinear_binary_fusion() + + +def _is_valid_concat_linear_woq_int4_fusion(computation_nodes): + computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default + act = computation_nodes[0].args[0] + wgt = computation_nodes[0].args[1] + in_feature_size = wgt.meta.get("val").size(1) # type: ignore[union-attr] + group_size = computation_nodes[0].args[2] + return len(computation_nodes) >= 2 and all( + ( + node.target == computation_op + and node.args[0] == act # share same activation + and ( + node.args[1].meta.get("val").size(1) == in_feature_size + ) # same in feature size + and (node.args[1] != wgt or gemm_idx == 0) + and node.args[1].op == "get_attr" # wgt are all constants + and node.args[2] == group_size # same group size + ) + for gemm_idx, node in enumerate(computation_nodes) + ) + + +def concat_linear_woq_int4(gm: torch.fx.GraphModule): + """ + Concat Linear optimization pass for WOQ int4 + This pass fuses the original pattern: + def ... + return (woq_int4(x, w1, group_size, scale_zp1), woq_int4(x, w2, group_size, scale_zp1) ...) + into a single operation: + def ... + concat_res = woq_int4(x, concat_w, group_size, concat_scale_zp) + return split(concat_res, split_size_list) + """ + + def concat_wgt(packed_wgts, scale_zps, group_size, act_dtype): + # Concat the wgts and scale_zps, and repack the wgt + unpacked_wgts = [] + for packed_wgt in packed_wgts: + # Get the unpacked weight list + # Same as https://github.com/pytorch/pytorch/pull/156174 + K = packed_wgt.size(1) * 2 + N = packed_wgt.size(0) + x = torch.eye(K).to(dtype=act_dtype) + qscales_and_zeros = ( + torch.tensor([1.0, 8.0]) + .to(dtype=act_dtype) + .expand(K // group_size, N, 2) + .contiguous() + ) + unpacked_wgts.append( + torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + packed_wgt, + group_size, + qscales_and_zeros, + ) + .t() + .contiguous() + .to(torch.int32) # N, K + ) + concat_unpacked_wgt = torch.cat(unpacked_wgts, dim=0) + repack_w = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + concat_unpacked_wgt, 1 + ) + concat_scale_zp = torch.cat(scale_zps, dim=1).contiguous() + return repack_w, concat_scale_zp + + graph = gm.graph + computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default + for node in graph.find_nodes(op="call_function", target=computation_op): + if ( + not node._erased + and isinstance(node.meta.get("val"), torch.Tensor) + and node.meta["val"].device.type == "cpu" + ): + act = node.args[0] + users = list(act.users) + if _is_valid_concat_linear_woq_int4_fusion(users): + with graph.inserting_before(node): + assert all(user.args[1].op == "get_attr" for user in users) + computation_node_0 = users[0] + packed_wgts = [getattr(gm, user.args[1].target) for user in users] + group_size = computation_node_0.args[2] + scale_zps = [getattr(gm, user.args[3].target) for user in users] + out_feature_size_list = [ + packed_wgt.size(0) for packed_wgt in packed_wgts + ] + repack_w, concat_scale_zp = concat_wgt( + packed_wgts, scale_zps, group_size, act.meta.get("val").dtype + ) + repack_w_node_name = computation_node_0.args[1].target + "_concat" + concat_scale_zp_node_name = ( + computation_node_0.args[3].target + "_concat" + ) + gm.register_buffer(repack_w_node_name, repack_w) + setattr(gm, repack_w_node_name, repack_w) + gm.register_buffer(concat_scale_zp_node_name, concat_scale_zp) + setattr(gm, concat_scale_zp_node_name, concat_scale_zp) + + repack_w_node = graph.create_node( + "get_attr", repack_w_node_name, (), {} + ) + with graph.inserting_after(repack_w_node): + concat_scale_zp_node = graph.create_node( + "get_attr", concat_scale_zp_node_name, (), {} + ) + + with graph.inserting_after(concat_scale_zp_node): + concat_int4_gemm_node = graph.create_node( + "call_function", + computation_op, + ( + act, + repack_w_node, + group_size, + concat_scale_zp_node, + ), + ) + with graph.inserting_after(concat_int4_gemm_node): + split_node = graph.create_node( + "call_function", + torch.ops.aten.split_with_sizes.default, + ( + concat_int4_gemm_node, + out_feature_size_list, + 1, # split dim + ), + ) + with graph.inserting_after(split_node): + for gemm_idx, user in enumerate(users): + assert user.target == computation_op + get_item = graph.create_node( + "call_function", + operator.getitem, + ( + split_node, + gemm_idx, + ), + ) + with graph.inserting_after(get_item): + clone_node = graph.create_node( + "call_function", + torch.ops.aten.clone.default, + (get_item,), + {"memory_format": torch.contiguous_format}, + ) + user.replace_all_uses_with(clone_node) + graph.erase_node(user) + + +def quant_lift_up(graph_module: torch.fx.GraphModule): + """ + Lift up the quant node before view like nodes. It can benefit performance + of Attention like block. For example, we have the pattern as: + + DQ + DQ LINEAR + LINEAR VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + Q Q + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + We want to lift up the the quant nodes from matmul before view like nodes + as the output of Linear node. + + DQ + DQ LINEAR + LINEAR Q + Q VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + It produces a DQ->LINEAR->Q pattern which can be fused by backend. + """ + + def is_view_op(node): + return node.op == "call_function" and node.target in _VIEW_OPS + + for node in graph_module.graph.nodes: + # Leslie: Here we verify that the quant node has exactly + # one input FX node, with constant scalar value for scale and zero point. + # For the case input of quant node has more than one input FX nodes, + # extend the implementation to lift up all the connected nodes + # before the view nodes to keep the topological order. + if ( + node.op == "call_function" + and node.target in _PER_TENSOR_QUANTIZE_OPS + and len(node.all_input_nodes) == 1 + and is_view_op(node.all_input_nodes[0]) + ): + quant_node = node + input_node_of_quant = quant_node.args[0] + + # Check the nodes along lift up path has only 1 user node + # Propagate view like node to find where to insert the new quant node + could_lift_up = True + current_node = quant_node + input_node = current_node.args[0] + while is_view_op(input_node): + if len(input_node.users) != 1: + could_lift_up = False + break + current_node = input_node + input_node = current_node.args[0] + + # Further check the input node of the first view node has only 1 user node + if could_lift_up and len(input_node.users) == 1: + # Replace dequant's input from quant to quant's input + quant_node.replace_all_uses_with(input_node_of_quant) + # Insert the new quant node + with graph_module.graph.inserting_before(current_node): + new_quant_node = graph_module.graph.node_copy(quant_node) + input_node.replace_all_uses_with(new_quant_node) + + # Update inputs of new_quant_node + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == input_node_of_quant: + return input_node + else: + return n + + new_args = map_arg(new_quant_node.args, maybe_replace_node) + new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node) + new_quant_node.args = new_args # type: ignore[assignment] + new_quant_node.kwargs = new_kwargs # type: ignore[assignment] + graph_module.graph.erase_node(quant_node) + + graph_module.graph.lint() + graph_module.recompile() diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..64c388d3a9d1871eef78050f71facdeae87b4822 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/reinplace.py @@ -0,0 +1,766 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +from collections import defaultdict +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Callable, cast, Union + +import torch +import torch.fx.node +from torch._C._dynamo.guards import compute_overlapping_tensors +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger +from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_functional, +) +from torch._inductor import config, inductor_prims +from torch._inductor.fx_utils import get_node_storage, is_node_realized +from torch._inductor.lowering import ( + inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, +) +from torch._inductor.virtualized import V +from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.reinplace import _is_view_op +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass(frozen=True) +class InplaceableOp: + inplace_op: Callable[..., Any] + mutated_arg: int + extra_check: Callable[[torch.fx.Node], bool] = lambda node: True + + +_SCATTER_OP_TO_VIEW = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} +_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()} + + +def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs): + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (args, kwargs), + ) + with V.fake_mode: + fake_result = fn(*fake_args, **fake_kwargs) + + node = graph.call_function(fn, args, kwargs) + node.meta["val"] = fake_result + return node + + +@dataclass +class ViewOp: + target: torch._ops.OpOverload + args: tuple[Any, ...] + kwargs: dict[str, Any] + + +def _inplace_generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp] +) -> torch.Tensor: + tmp = inp + for view in view_ops: + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (view.args, view.kwargs), + ) + tmp = view.target(tmp, *fake_args, **fake_kwargs) + try: + tmp.copy_(src) + except RuntimeError as e: + raise RuntimeError( + f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}" + ) from e + return inp + + +def _generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp] +) -> torch.Tensor: + out = inp.clone() + return _inplace_generalized_scatter(out, src, view_ops) + + +def _decompose_scatter_functional_helper( + graph: torch.fx.Graph, + inp: torch.Tensor, + src: torch.Tensor, + view_ops: list[ViewOp], +) -> torch.fx.Node: + view_op, view_ops_tail = view_ops[0], view_ops[1:] + + if view_ops_tail: + view = graph_call_function( + graph, view_op.target, inp, *view_op.args, **view_op.kwargs + ) + src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment] + + return graph_call_function( + graph, + _VIEW_OP_TO_SCATTER[view_op.target], + inp, + src, + *view_op.args, + **view_op.kwargs, + ) + + +def _decompose_scatter_functional( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter to a sequence of view_scatter operations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + view = aten.slice(inp, 0, 0, 10) + view_updated = aten.slice_scatter(view, src, 1, 10, -10) + inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10) + """ + assert node.target is _generalized_scatter + return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type] + + +def _decompose_scatter_mutating( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter using mutations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + inp_updated = aten.clone(inp) + slice1 = aten.slice(inp_updated, 0, 0, 10) + slice2 = aten.slice(slice1, 1, 10, -10) + slice2.copy_(src) + + """ + assert node.target in (_generalized_scatter, _inplace_generalized_scatter) + inp, src, view_ops = node.args + assert not node.kwargs + + if node.target is _generalized_scatter: + inp = graph_call_function(graph, aten.clone, inp) + + tmp = inp + for view in view_ops: # type: ignore[union-attr] + tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr] + + graph_call_function(graph, aten.copy_.default, tmp, src) + return inp # type: ignore[return-value] + + +# View ops whose view_scatter op is lowered into mutations anyway, +# so is never a pessimisation to decompose. +_ALWAYS_MUTATING_SCATTER_OPS = OrderedSet( + [ + aten.as_strided.default, + aten.diagonal.default, + ] +) + + +def scatter_always_uses_mutation(node: torch.fx.Node) -> bool: + _, _, view_ops = node.args + view_ops = cast(Sequence[torch.fx.node.Argument], view_ops) + return any( + target in _ALWAYS_MUTATING_SCATTER_OPS + for view in view_ops + if isinstance(target := getattr(view, "target", None), torch._ops.OpOverload) + ) + + +def should_reinplace_scatter(node: torch.fx.Node) -> bool: + """Choose between mutating and functional scatter decompositions + + Reinplacing view scatter ops can be pessimising as it blocks fusion with the + input or output tensor computations. However, it is still profitable if the + input and output would have been realized anyway. + + """ + inp, _src, _view_ops = node.args + + # Mutating scatter ops unconditionally realize input and output + if scatter_always_uses_mutation(node): + return True + + if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type] + return True + + # If the output is copied back into the input, this forces both to be + # realized as the output is a user of the input + if inp.op in ("placeholder", "get_attr") and any( # type: ignore[union-attr] + user.target is aten.copy_.default and user.args[0] is inp for user in node.users + ): + return True + + # Otherwise, assume fusions will make functional variants profitable + return False + + +def decompose_generalized_scatter(graph: torch.fx.Graph) -> None: + """Replace _generalized_scatter with normal aten ops""" + for node in itertools.chain( + graph.find_nodes(op="call_function", target=_generalized_scatter), + graph.find_nodes(op="call_function", target=_inplace_generalized_scatter), + ): + use_mutation = ( + node.target is _inplace_generalized_scatter + or scatter_always_uses_mutation(node) + ) + + with graph.inserting_before(node): + if use_mutation: + new_node = _decompose_scatter_mutating(graph, node) + else: + new_node = _decompose_scatter_functional(graph, node) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + +def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None: + """ + This canonicalizes view scatter ops into a generalized form, defined as: + def scatter(inp, src, views): + tmp = inp.clone() + for view in views: + tmp = view(tmp) + tmp.copy_(src) + + We also fuse consecutive view scatter ops of the form + a = scatter(view2(self), src, [view1]) + b = scatter(self, a, [view2]) + which can be rewritten as + b = scatter(self, src, [view2, view1]) + a = view2(b) + + This is both more efficient as we only do a single scatter, and also + easier to reinplace since there is only one use of `self` + """ + + node_to_view_base: dict[torch.fx.Node, torch.fx.Node] = {} + node_to_view_op: dict[torch.fx.Node, list[ViewOp]] = defaultdict(list) + + def handle_views(node: torch.fx.Node): + inp = node.args[0] + node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type, assignment] + node_to_view_op[node] = [ + *node_to_view_op[inp], # type: ignore[index] + ViewOp( + node.target, # type: ignore[arg-type] + args=node.args[1:], + kwargs=node.kwargs, + ), + ] + + def handle_view_scatter(node: torch.fx.Node): + assert len(node.args) >= 2 + inp, src = node.args[:2] + + assert isinstance(node.target, torch._ops.OpOverload) + scatter_view_op = ViewOp( + _SCATTER_OP_TO_VIEW[node.target], + args=node.args[2:], + kwargs=node.kwargs, + ) + + def can_fuse(): + if src.target is not _generalized_scatter: # type: ignore[union-attr] + return False + src_inp, _src_src, _src_scatter_view_op = src.args # type: ignore[union-attr] + + inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type] + src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type] + return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index] + *node_to_view_op[inp], # type: ignore[index] + scatter_view_op, + ] + + if not can_fuse(): + with graph.inserting_before(node): + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src, + [scatter_view_op], + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + return + + _src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] + with graph.inserting_before(src): # type: ignore[arg-type] + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src_src, + [scatter_view_op, *src_scatter_view_op], # type: ignore[misc] + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if src.users: # type: ignore[union-attr] + new_src = graph_call_function( + graph, + _SCATTER_OP_TO_VIEW[node.target], + new_node, + *node.args[2:], + **node.kwargs, + ) + + handle_views(new_src) + src.replace_all_uses_with(new_src) # type: ignore[union-attr] + + graph.erase_node(src) # type: ignore[arg-type] + + for node in graph.nodes: + if _is_view_op(node.target): + handle_views(node) + elif node.target in _SCATTER_OP_TO_VIEW: + handle_view_scatter(node) + + +inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = { + aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), + aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0), + _generalized_scatter: InplaceableOp( + _inplace_generalized_scatter, + 0, + extra_check=should_reinplace_scatter, + ), +} + +try: + c10d_functional = torch.ops._c10d_functional + inplaceable_collective_ops: dict[Callable[..., Any], InplaceableOp] = { + c10d_functional.all_reduce.default: InplaceableOp( + c10d_functional.all_reduce_.default, 0 + ), + c10d_functional.all_reduce_coalesced.default: InplaceableOp( + c10d_functional.all_reduce_coalesced_.default, 0 + ), + } + inplaceable_ops.update(inplaceable_collective_ops) +except AttributeError: + # _c10d_functional ops are only available when torch + # is built with USE_DISTRIBUTED=1. + pass + +inplaceable_foreach_ops: dict[torch._ops.OpOverload, InplaceableOp] = {} +for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items(): + inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0) + + +inplaceable_triton_ops = OrderedSet([triton_kernel_wrapper_functional]) + + +# Operators that don't depend on the tensor data +META_ONLY_OPS = OrderedSet( + [ + aten.sym_size.int, + aten.sym_stride.int, + aten.sym_numel.default, + aten.sym_storage_offset.default, + ] +) + + +def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: + """ + Reinplaces in-placeable operations. + If there are no uses of a view of the mutated arg after the current node, + it is possible to inplace the op. + This above algorithm could be justified by observing side effects. While + we traverse the graph in forwards direction, only latter nodes could view + side effects of the current node. If the current node is not used later as + well as no view of this node is used later in the graph, then it is safe to + inplace as there would be no way to observe the side effects. + This condition is slightly different for graph inputs where they can only + be inplaced if the above condition is true and there's a copy_ in the + epilogue that signals that the caller wants to observe the mutation. + + Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from + input args, so instead of checking mutation on placeholder, AOTInductor + checks mutation on get_attr. This is subject to change in future. + """ + + copy_args_to_copy_nodes = {} + # maps argument to the first copy_ node that mutates it. + copy_nodes = {} + mutated_inputs = OrderedSet[Any]() + storage_to_nodes = defaultdict(list) + node_order: dict[Any, int] = {} + for i, node in enumerate(reversed(graph.nodes)): + node_order[node] = len(graph.nodes) - i - 1 + storage_to_nodes[get_node_storage(node)].append(node) + if node.target == aten.copy_.default and node.args[0].op in ( + "placeholder", + "get_attr", + ): + dst = node.args[0] + src = node.args[1] + # If the target is a getitem and it indexes a possible clone, + # then skip over it + if src.target == operator.getitem and ( + ( + src.args[0].target == triton_kernel_wrapper_functional + and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] + ) + or (src.args[0].target in inplaceable_foreach_ops) + or (src.args[0].target == torch.ops.higher_order.auto_functionalized) + ): + src = src.args[0] + + copy_args_to_copy_nodes[(dst, src)] = node + copy_nodes[dst] = node + + mutated_inputs.add(node.args[0]) + + def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg): + node_loc = node_order[node] + copy_node_loc = node_order[copy_node] if copy_node is not None else None + + def is_meta_only_user(node): + if _is_view_op(node.target): + return all(is_meta_only_user(u) for u in node.users) + return node.target in META_ONLY_OPS + + for view in shared_view_nodes: + for user in view.users: + user_loc = node_order[user] + # Skip all users before node + if user_loc <= node_loc: + continue + # Ignore uses after the copy_ epilogue node, where the input + # has already been mutated anyway + if copy_node_loc is not None and copy_node_loc <= user_loc: + continue + # Reinplacing does not change shape metadata + if is_meta_only_user(user): + continue + # If our graph looks like: + # foo(mutated_arg) + # mutated_arg.copy_(other) + # then it's safe for us to reinplace foo because mutated_arg + # will get overwritten anyways. + if ( + user.target is torch.ops.aten.copy_.default + and mutated_arg is user.args[0] + ): + continue + return True + return False + + def can_inplace(node, mutated_arg): + # ls should be a list of tensors that all shares the same storage. + def _overlap(ls) -> bool: + try: + return len(compute_overlapping_tensors(ls)) != 0 + except GuardOnDataDependentSymNode: + # If we fail with data dependent error we assume they all overlap. + return True + + if isinstance(mutated_arg, (list, tuple)): + # TODO Using _overlap here causes a several issues. + unique_storages = OrderedSet(get_node_storage(arg) for arg in mutated_arg) + if len(unique_storages) != len(mutated_arg): + # At least two Tensors in mutated_arg alias each other, so we can't reinplace it. + # We can probably do better (that is, reinplace one of them and clone the other) + # but that requires more work and mutable List[Tensor] are not that common. + return False + return all(can_inplace(node, arg) for arg in mutated_arg) + + if get_node_storage(mutated_arg) is None: + return False + + shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] + + # Only keep tensor that might overlap with mutated_arg. + shared_view_nodes = [ + v + for v in shared_view_nodes + if _overlap([mutated_arg.meta["val"], v.meta["val"]]) + ] + + if mutated_arg.op in ("placeholder", "get_attr"): + # Get the first copy_ node that mutates the mutated_arg. + copy_node = copy_nodes.get(mutated_arg, None) + if copy_node is None: + # There is no copy_ back to the candidate mutated_arg (which is a graph input). + # Therefore the semantics of the program are that it does not mutate + # mutated_arg, so we cannot re-inplace it. + return False + if any_use_of_views_after_node( + node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg + ): + return False + + return True + elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes): + # This should never happen in auto_functionalize_v2 non-inference mode, + # since all mutated_arg are bases. + + # If mutated arg is view of any of the inputs of the graph, + # do not allow for inplacing. + # This would require more sophisticated algorithm to handle + return False + else: + return not any_use_of_views_after_node( + node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg + ) + + def log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_nodes, + trigger, + ): + # Total size of possibly_missed_reinplacing_opportunities for tensors with static shapes. + missed_bytes = 0 + + def bytes(node): + t = node.meta.get("val", None) + if ( + t is not None + and isinstance(t.element_size(), int) + and isinstance(t.numel(), int) + ): + return t.element_size() * t.numel() + else: + return 0 + + for node in missed_nodes: + if isinstance(node, (list, tuple)): + for n in node: + missed_bytes += bytes(n) + else: + missed_bytes += bytes(node) + + log.info( + "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " + "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " + "memory usage and performance. Total size of missed opportunities with static shapes is" + " : %s bytes.", + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_bytes, + ) + + ReinplaceCounters.add_missed_opportunities(trigger, len(missed_args)) + ReinplaceCounters.add_missed_bytes(trigger, missed_bytes) + + replace_dict: dict[torch.fx.Node, torch.fx.Node] = {} + + def reinplace_and_refine_tensors_to_clone( + old_tensors_to_clone, kwargs, node_name, trigger + ): + tensors_to_clone: list[str] = [] + storage_of_reinplaced_args = OrderedSet[Union[int, None]]() + + # Those used to count possibly_missed_reinplacing_opportunities + missed_nodes = [] + missed_args = [] + + # TODO this logic can be made more precise using _overlap + def tensor_with_same_storage_already_reinplaced(arg): + if isinstance(arg, (list, tuple)): + return any( + get_node_storage(a) in storage_of_reinplaced_args for a in arg + ) + return get_node_storage(mutated_arg) in storage_of_reinplaced_args + + for arg in old_tensors_to_clone: + assert arg in kwargs + + mutated_arg = kwargs[arg] + + # Let's say we have: + # - op(x, y) that mutates both x and y + # - new_x, new_y = functional_op(x, y) is the functional variant + # If we are presented with functional_op(x, x), we must not reinplace + # this into op(x, x), because then it would be writing to the same Tensor. + # Instead, it's OK to reinplace one of them and to clone the other: + # >>> y = x.clone() + # >>> op(x, y) + # This also applies if we have views: functional_op(x, x[0]) + # should not reinplace into op(x, x[0]). + should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced( + mutated_arg + ) + if should_attempt_reinplace and can_inplace(node, mutated_arg): + # In general, we probably do not need those optimizations. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + if not trigger == ReInplaceTrigger.AUTO_FUNC_V2: + for user in node.users: + # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to + # output atindex size(out)+i. + # This used to compare string with integers before for auto_functionalize_v2. Not sure + # if it was needed for inplaceable_triton_ops? + if user.target == operator.getitem and user.args[1] == arg: + replace_dict[user] = mutated_arg + + if isinstance(mutated_arg, (list, tuple)): + for a in mutated_arg: + storage_of_reinplaced_args.add(get_node_storage(a)) + else: + storage_of_reinplaced_args.add(get_node_storage(mutated_arg)) + else: + if should_attempt_reinplace: + missed_args.append(arg) + missed_nodes.append(mutated_arg) + + tensors_to_clone.append(arg) + + log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_nodes, + trigger, + ) + return tensors_to_clone + + for node in graph.nodes: + if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: + mutated_arg = node.args[inplaceable_op.mutated_arg] + if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node): + # TODO(yifu): this doesn't properly remove copy epilogues for + # ops that mutate multiple inputs. Need to revise the copy + # node tracking logic to support the case. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + node.target = inplaceable_op.inplace_op + elif node.target == torch.ops.higher_order.auto_functionalized_v2: + _mutable_op = node.args[0] + kwargs = node.kwargs + + all_bases = kwargs["_all_bases"] + bases_to_clone = range(len(all_bases)) + base_tensors_dct = dict(enumerate(all_bases)) + new_bases_to_clone: list[int] = reinplace_and_refine_tensors_to_clone( + bases_to_clone, + base_tensors_dct, + node.target, + ReInplaceTrigger.AUTO_FUNC_V2, + ) + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = new_bases_to_clone + elif node.target == torch.ops.higher_order.auto_functionalized: + _mutable_op = node.args[0] + from torch._higher_order_ops.auto_functionalize import get_mutable_args + + tensors_to_clone, _ = get_mutable_args(_mutable_op) + # Don't try to reinplace Optional[Tensor] args that are None. + tensors_to_clone = [ + t for t in tensors_to_clone if node.kwargs[t] is not None + ] + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + tensors_to_clone, + node.kwargs, + _mutable_op._name, + ReInplaceTrigger.AUTO_FUNC_V1, + ) + + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = tensors_to_clone + elif node.target in inplaceable_triton_ops: + kernel_idx = node.kwargs["kernel_idx"] + kernel = kernel_side_table.get_kernel(kernel_idx) + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + if isinstance(kernel, JITFunction): + kernel_name = kernel.fn.__name__ + elif isinstance(kernel, Autotuner): + if config.is_fbcode(): + # Autotuner has different implementations for AMD and NV + if torch.version.hip is None: + kernel_name = kernel.base_fn.__name__ + else: + kernel_name = kernel.fn.__name__ + else: + kernel_name = kernel.base_fn.__name__ + else: + raise AssertionError("Unknown triton kernel type") + + # inplaceable_triton_ops take an additional argument called + # tensors_to_clone which contain a list of tensors to clone + # This pass iterates over them and sees which ones are safe + # to eliminate (i.e. no longer need the clones) + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + node.kwargs["tensors_to_clone"], + node.kwargs["kwargs"], + kernel_name, + ReInplaceTrigger.TRITON_OPS, + ) + + kwargs = dict(node.kwargs) + kwargs["tensors_to_clone"] = tensors_to_clone + node.kwargs = immutable_dict(kwargs) + if "eager_input_vals" in node.meta: + # We changed the kwargs, so we need to update eager_input_vals + # to something sane. + args, kwargs = node.meta["eager_input_vals"] + new_kwargs = {**kwargs} + new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone) + new_kwargs = immutable_dict(new_kwargs) + node.meta["eager_input_vals"] = (args, new_kwargs) + elif ( + inplaceable_op := inplaceable_foreach_ops.get(node.target, None) + ) is not None: + mutated_args = node.args[inplaceable_op.mutated_arg] + + if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args): + continue + + if can_inplace(node, mutated_args): + for arg in mutated_args: + copy_node = copy_args_to_copy_nodes[(arg, node)] + replace_dict[copy_node] = copy_node.args[0] + + node.target = inplaceable_op.inplace_op + for node, replacement in replace_dict.items(): + while replacement in replace_dict: + replacement = replace_dict[replacement] + replace_dict[node] = replacement + + node.replace_all_uses_with(replacement) + graph.erase_node(node) + + +def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None: + with enable_python_dispatcher(): + canonicalize_view_scatter_ops(graph) + reinplace_inplaceable_ops_core(graph) + decompose_generalized_scatter(graph) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py new file mode 100644 index 0000000000000000000000000000000000000000..27e97eaa553251f2b5ca197a47f7473dd65eba29 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/replace_random.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +import collections +import logging + +import torch +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import _extract_tensor_metadata + +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunctionVarArgs, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..virtualized import V + + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten + + +def replace_random_passes(gm: torch.fx.GraphModule): + """Modify the given FX graph to use backend-native random ops""" + if config.fallback_random: + return 0 + + count = patterns.apply(gm) + with GraphTransformObserver(gm, "fuse_seed_creation_pass"): + count += fuse_seed_creation_pass(gm.graph) + + return count + + +def fuse_seed_creation_pass(graph: torch.fx.Graph): + """ + Horizontally fuse all the seed generation on each device + + a = inductor_seed(dev) + b = inductor_seed(dev) + + Becomes: + seeds = inductor_seeds(2, dev) + a = inductor_lookup_seed(seeds, 0) + b = inductor_lookup_seed(seeds, 1) + + We do this because seed creation is entirely launch overhead bound. + """ + device_seeds = collections.defaultdict(list) + for node in graph.nodes: + if CallFunctionVarArgs(inductor_prims.seed).match(node): + device_seeds[node.args[0]].append(node) + + if not device_seeds: + return 0 + + for device, seeds in device_seeds.items(): + with graph.inserting_before(seeds[0]): + combined = graph.call_function(inductor_prims.seeds, (len(seeds), device)) + with V.fake_mode: + combined.meta["val"] = torch.empty( + [len(seeds)], device=device, dtype=torch.int64 + ) + combined.meta["tensor_meta"] = _extract_tensor_metadata( + combined.meta["val"] + ) + + for idx, seed in enumerate(seeds): + with graph.inserting_before(seed): + new_seed = graph.call_function( + inductor_prims.lookup_seed, (combined, idx) + ) + seed.replace_all_uses_with(new_seed) + new_seed.meta.update(seed.meta) + graph.erase_node(seed) + + return len(device_seeds) + + +def default_kwargs(device): + return {} + + +def get_device(device): + if device is not None: + return device + return torch.empty([]).device # default device + + +@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns) +def replace_random( + match: Match, + size, + *, + generator=None, + dtype=None, + device=None, + layout=None, + pin_memory=None, +): + if generator is not None: + return + + def replacement(size): + result = inductor_prims.random( + size, inductor_prims.seed(device), mode, **default_kwargs(device) + ) + if dtype is not None: + result = result.to(dtype) + return result + + mode = { + aten.rand: "rand", + aten.randn: "randn", + }[ + match.output_node().target.overloadpacket # type: ignore[union-attr] + ] # type: ignore[union-attr] + device = get_device(device) + match.replace_by_example(replacement, [size]) + + +@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns) +def replace_randint( + match: Match, + low, + high, + size, + *, + dtype=torch.int64, + device=None, + layout=None, + pin_memory=None, +): + def replacement(low, high, size): + result = inductor_prims.randint(low, high, size, inductor_prims.seed(device)) + return result.to(dtype) + + device = get_device(device) + match.replace_by_example(replacement, [low, high, size]) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8e6de3ff3cba5f0ebcec729c33061b04319d6a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -0,0 +1,174 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py new file mode 100644 index 0000000000000000000000000000000000000000..567390838ede7dc4d4181f601f020e8066cb07b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -0,0 +1,205 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa39474c67dd677008c8e7e9266cc875a153196 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -0,0 +1,204 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py new file mode 100644 index 0000000000000000000000000000000000000000..87302d1bab3694a33eac14e263dca86c9f702c75 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -0,0 +1,220 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py new file mode 100644 index 0000000000000000000000000000000000000000..d465c1cb4e22b14bfbbdd35d5ed28a43af891523 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -0,0 +1,130 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +neg_default = CallFunction(aten.neg.default, div_Tensor) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, fma_default, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, fma_default) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py new file mode 100644 index 0000000000000000000000000000000000000000..f102038e82c6d5858b8b334e956d87c7e86a9d22 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -0,0 +1,210 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py new file mode 100644 index 0000000000000000000000000000000000000000..e1cbb0df340bab14188ec9d5f04a29035dba8d84 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_8, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_15_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_8, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py new file mode 100644 index 0000000000000000000000000000000000000000..3a15abb9088ff5ffe8cd9af43df11ccc0d5bc143 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -0,0 +1,599 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py new file mode 100644 index 0000000000000000000000000000000000000000..812708907b3414e2c864ed36b98f2199e63ae5d2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -0,0 +1,246 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_17_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py new file mode 100644 index 0000000000000000000000000000000000000000..567d898ed204257e23a2002479afc5d26cba623b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -0,0 +1,453 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6d316351b8595f75fdfd262a4cb2171a8a6b1e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -0,0 +1,209 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_3, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py new file mode 100644 index 0000000000000000000000000000000000000000..f28da434ef0c85ca3d80095e68c052e8dc19dd2d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -0,0 +1,174 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py new file mode 100644 index 0000000000000000000000000000000000000000..9185aa3b1e3305cfa28f8080be04350beb17c065 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py @@ -0,0 +1,244 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_20_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_20_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py new file mode 100644 index 0000000000000000000000000000000000000000..ad27e6eb6bb8eba84a7c435b41cdbafae621e271 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py @@ -0,0 +1,217 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_half_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py new file mode 100644 index 0000000000000000000000000000000000000000..41a433e405433a1434c10700830ea8c81747a72b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py @@ -0,0 +1,229 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6f27cd284924601b6119ae5ea245bed4188962 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py @@ -0,0 +1,225 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7f7519ad0570d2c2f700d4081c9b7253d16657 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9cfd506f950415f4f90b49edf83815432a641c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py new file mode 100644 index 0000000000000000000000000000000000000000..f211e56b17a0a19c05bcb0efc681ed2623f4edf7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -0,0 +1,178 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py new file mode 100644 index 0000000000000000000000000000000000000000..01304bf415163909c5ec5b03064ce064697e1de9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -0,0 +1,194 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py new file mode 100644 index 0000000000000000000000000000000000000000..b463c7e64a6130dd85063f5fb88c2317c392c8f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -0,0 +1,221 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py new file mode 100644 index 0000000000000000000000000000000000000000..3faff67089b17ad370d4d7642539c7ce3fd5d235 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -0,0 +1,205 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf77120e836a5b577ea8a335f00bd63fd27163a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -0,0 +1,221 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..70d672442170905a411de63187a5b579b286bf73 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py @@ -0,0 +1,53 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +addmm_default = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha')) +mul_Scalar = CallFunction(aten.mul.Scalar, KeywordArg('tangents_1'), KeywordArg('beta')) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, mul_Scalar, Ignored(), True) +view_default = CallFunction(aten.view.default, sum_dim_IntList, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +mul_Scalar_1 = CallFunction(aten.mul.Scalar, mm_default, KeywordArg('alpha')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mul_Scalar_2 = CallFunction(aten.mul.Scalar, mm_default_1, KeywordArg('alpha')) +addmm_pattern_training = MultiOutputPattern([addmm_default, + view_default, + mul_Scalar_1, + mul_Scalar_2, + None, + None +]) + + +addmm_pattern_inference = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5ac59d6f06c97523e071e9b3ea78516ff09c0e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py @@ -0,0 +1,45 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, permute_default_1, KeywordArg('tangents_1')) +bmm_pattern_training = MultiOutputPattern([bmm_default, + bmm_default_1, + bmm_default_2 +]) + + +bmm_pattern_inference = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..058a2f881e3a52cb147cfd3fa0ef2bbd0a25945a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py @@ -0,0 +1,45 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_2 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mm_pattern_training = MultiOutputPattern([mm_default, + mm_default_1, + mm_default_2 +]) + + +mm_pattern_inference = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..ea379b0115d6d6c050820ead26e6ccba0970a0e6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/split_cat.py @@ -0,0 +1,2960 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +from collections import defaultdict +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union +from typing_extensions import TypeAlias + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import free_symbols +from torch.utils._ordered_set import OrderedSet + +from ..pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethodVarArgs, + FailedMatch, + get_arg_value, + Ignored, + KeywordArg, + ListOf, + Match, + MatchContext, + MULTIPLE, + PatternExpr, + PatternMatcherPass, + register_graph_pattern, + RepeatedExpr, +) +from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS + + +log = logging.getLogger(__name__) + +_Arguments: TypeAlias = tuple[torch.fx.node.Argument, ...] +_TransformParam: TypeAlias = tuple[ + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], +] +_Range: TypeAlias = tuple[int, int] + + +PRE_GRAD_PATTERNS: dict[str, PatternMatcherPass] = {} +POST_GRAD_PATTERNS: dict[str, PatternMatcherPass] = {} + +pre_grad_pass_names = [ + "normalization_pass", + "remove_split_with_size_one_pass", + "merge_getitem_cat_pass", + "merge_stack_tahn_unbind_pass", + "merge_splits_pass", + "mutate_cat_pass", + "split_cat_pass", + "unbind_stack_pass", + "split_cat_to_slices_pass", + "unbind_cat_to_view_pass", + "split_stack_to_cats_pass", + "unbind_stack_to_slices_pass", + "move_reshape_out_of_split_stack_pass", +] + +post_grad_pass_names = [ + "normalization_aten_pass", + "decompose_mm_pass", + "unbind_stack_aten_pass", + "shape_padding_multiplier", + "pad_aten_mm_pass", + "split_cat_aten_pass", + "select_cat_aten_pass", + "move_view_after_cat_aten_pass", +] + +for pass_name in pre_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in PRE_GRAD_FUSIONS: + continue + PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + +for pass_name in post_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in POST_GRAD_FUSIONS: + continue + POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + + +def construct_pattern_matcher_pass(pass_name: str): + """ + Return the specific pattern_matcher_pass given the pass name. + """ + if pass_name in PRE_GRAD_PATTERNS: + return PRE_GRAD_PATTERNS[pass_name] + else: + return POST_GRAD_PATTERNS[pass_name] + + +def _get_split_args_default(split_node): + input_kwarg = "tensor" + split_size_kwarg = "split_size_or_sections" + dim_kwarg = "dim" + default_dim_value = 0 + if split_node.op == "call_method": + split_size_kwarg = "split_size" + return ( + get_arg_value(split_node, 0, input_kwarg), + get_arg_value(split_node, 1, split_size_kwarg), + get_arg_value(split_node, 2, dim_kwarg) or default_dim_value, + ) + + +def _get_dim(node: Any): + assert isinstance(node, torch.fx.Node) + if "dim" in node.kwargs: + assert isinstance(node.kwargs["dim"], int) + return node.kwargs["dim"] + if node.target == torch.unbind: + if len(node.args) == 2: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + if node.target == torch.split: + if len(node.args) == 3: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + raise AssertionError( + f"Can't extract `dim` from {node.target} {node.args} {node.kwargs}" + ) + + +# noqa: W605 +# ############The pattern to be optimized is######### +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# | | +# split split -> dim=1, user=1, split_section_size=1 +# | | +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + +# ################After transformation############# +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + + +def normalize_split_base( + match: Match, + _get_split_args: Callable[ + [torch.fx.Node], tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]] + ], +): + """ + Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in + subsequent optimizations + """ + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["example_value"].dim() + + new_args = (split_input, split_sections) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_split_default(match: Match, *args, **kwargs): + return normalize_split_base(match, _get_split_args_default) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +def remove_split_with_size_one(match: Match, *args, **kwargs): + graph = match.graph + split_node = match.nodes[0] + split_input, split_size, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + # remove the dummy split whose split sections size is one + # theoretically nodes with no users should be removed, but we have seen the corner case + # thus we add its users check to walk around the StopIteration error. + if len(split_sections) == 1 and len(split_node.users.keys()) > 0: + # find the grand children of the split_node + next_users = find_next_users(split_node) + user = next(iter(split_node.users.keys())) + # replace the users of grand child node with the input node + for next_user in next_users: + next_user.replace_input_with(user, split_input) + # erase the split node and its child + graph.erase_node(user) + graph.erase_node(split_node) + counters["inductor"]["remove_split_with_size_one_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.unbind, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("unbind", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_unbind_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + input = get_arg_value(node, 0, "input") + dim = get_arg_value(node, 1, "dim") + if dim is None: + axis = node.kwargs.get("axis") + if axis is not None: + dim = axis + else: + dim = 0 + if input is None: + log.debug("couldn't find unbind args") + return + if not is_node_meta_valid(input): + log.debug("example value absent for node: %s", input) + return + ndim = input.meta["example_value"].ndim + if dim < 0: # Normalize unbind dim + dim += ndim + with graph.inserting_after(node): + new_node = graph.call_function( + torch.unbind, + args=(input,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_cat_default(match: Match, *args, **kwargs): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = cat_node.meta["example_value"].dim() + + def is_empty_tensor(x): + # special case where torch.cat supports cat'ing with an empty tensor + x_shape = x.meta["example_value"].shape + return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) + + assert all( + ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors + ) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + new_args = (tensors,) + new_kwargs = {"dim": cat_dim} + if ( + cat_node.args == new_args + and cat_node.kwargs == new_kwargs + and cat_node.op == "call_function" + ): + return + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + kwargs=new_kwargs, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.stack, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_stack_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(node, 0, "tensors") + dim = get_arg_value(node, 1, "dim") or 0 + if tensors is None or dim is None: + log.debug("couldn't find stack args") + return + assert isinstance(tensors, (list, tuple)) + + # A bug in pytorch, some nodes miss the example_value metadata + for tensor in itertools.chain([node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = node.meta["example_value"].dim() + if dim < 0: # Normalize dim + dim += ndim + + with graph.inserting_after(node): + new_node = graph.call_function( + node.target, # type: ignore[arg-type] + args=(tensors,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["normalization_pass"] += 1 + + +def find_next_users(split_node: torch.fx.Node) -> list[torch.fx.Node]: + next_users = [] + for getitem_node in split_node.users.keys(): + for getitem_user in getitem_node.users.keys(): + if getitem_user not in next_users: + next_users.append(getitem_user) + return next_users + + +@register_graph_pattern( + CallMethodVarArgs("squeeze", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_squeeze_default(match: Match, *args, **kwargs): + squeeze_node = match.nodes[0] + squeeze_input = get_arg_value(squeeze_node, 0) + + if "dim" in squeeze_node.kwargs: + assert len(squeeze_node.args) == 1 + dim = squeeze_node.kwargs["dim"] + elif len(squeeze_node.args) == 1: + # squeeze(Tensor) + dim = None + elif len(squeeze_node.args) == 2: + # squeeze(Tensor self, int dim) + # squeeze(Tensor self, int[] dim) + dim = squeeze_node.args[1] + else: + # squeeze(Tensor self, int[] dim) (called with varargs) + dim = squeeze_node.args[1:] + + if isinstance(dim, Sequence) and len(dim) == 1: + dim = dim[0] + + with match.graph.inserting_after(squeeze_node): + if dim is None: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,) + ) + else: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim} + ) + squeeze_node.replace_all_uses_with(new_squeeze_node) + new_squeeze_node.meta.update(squeeze_node.meta) + match.graph.erase_node(squeeze_node) + + +@register_graph_pattern( + CallMethodVarArgs("reshape", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_reshape_default(match: Match, *args, **kwargs): + reshape_node = match.nodes[0] + if not is_node_meta_valid(reshape_node): + log.debug("example value absent for node: %s", reshape_node) + return + reshape_input = get_arg_value(reshape_node, 0) + + if free_symbols(reshape_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", reshape_node) + return + + with match.graph.inserting_after(reshape_node): + new_reshape_node = match.graph.call_function( + torch.reshape, + args=(reshape_input, tuple(reshape_node.meta["example_value"].shape)), + ) + reshape_node.replace_all_uses_with(new_reshape_node) + new_reshape_node.meta.update(reshape_node.meta) + match.graph.erase_node(reshape_node) + + +@register_graph_pattern( + CallMethodVarArgs("clamp", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallFunctionVarArgs(torch.clamp, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_clamp_default(match: Match, *args, **kwargs): + clamp_node = match.nodes[0] + if not is_node_meta_valid(clamp_node): + log.debug("example value absent for node: %s", clamp_node) + return + + if free_symbols(clamp_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", clamp_node) + return + if len(clamp_node.args) > 1: + args = (get_arg_value(clamp_node, 0),) + kwargs = { + "min": get_arg_value(clamp_node, 1, kwarg_name="min"), + "max": get_arg_value(clamp_node, 2, kwarg_name="max"), + } + else: + args = clamp_node.args + kwargs = clamp_node.kwargs + with match.graph.inserting_after(clamp_node): + new_clamp_node = match.graph.call_function( + torch.clamp, + args=args, + kwargs=kwargs, + ) + clamp_node.replace_all_uses_with(new_clamp_node) + new_clamp_node.meta.update(clamp_node.meta) + match.graph.erase_node(clamp_node) + + +@register_graph_pattern( + CallMethodVarArgs("detach", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_detach_default(match: Match, *args, **kwargs): + detach_node = match.nodes[0] + if not is_node_meta_valid(detach_node): + log.debug("example value absent for node: %s", detach_node) + return + + if free_symbols(detach_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", detach_node) + return + + with match.graph.inserting_after(detach_node): + new_detach_node = match.graph.call_function( + torch.detach, + args=detach_node.args, + ) + detach_node.replace_all_uses_with(new_detach_node) + new_detach_node.meta.update(detach_node.meta) + match.graph.erase_node(detach_node) + + +class TorchSplit(CallFunction): + """ + Matches a call to torch.split if it is in a normalized form. Ensures that all users of + splits are unique getitems. + """ + + def __init__(self, arg, sizes, func=torch.split) -> None: + # using KeywordArg("dim") for `dim` checks they all match + super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")) + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = super()._match(node, ctx) + if not m: + return m + split_sections = node.args[1] + if not isinstance(split_sections, (list, tuple)): + return FailedMatch("split not normalized") + # check users are all unique getitems + seen_idxs = OrderedSet[int]() + for user in node.users: + if not CallFunction(operator.getitem, Arg(), Arg()).match(user): + # This should ideally never happen. Split user should always be a getitem + return FailedMatch(f"user of split not a getitem: {user}") + if not isinstance(user.args[1], int): + return FailedMatch("only integer getitems are handled") + if user.args[1] in seen_idxs: + return FailedMatch(f"duplicate getitem {user.args[1]}") + if user.args[-1] < 0: # type: ignore[operator] + # This shouldn't ideally happen as dynamo normalizes indexes to positive + return FailedMatch("negative index") + seen_idxs.add(user.args[1]) + return m + + +@register_graph_pattern( + TorchSplit( + CallFunction( + operator.getitem, + TorchSplit( + KeywordArg("first_split_input"), + KeywordArg("first_split_sections"), + ), + Ignored(), + ), + KeywordArg("next_split_sections"), + ), + pass_dict=construct_pattern_matcher_pass("merge_splits_pass"), +) +def merge_splits( + match: Match, + first_split_input: torch.fx.Node, + first_split_sections: list[int], + next_split_sections: list[int], + # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim + dim: int, +): + node = match.output_node() + # it is possible that the split has no users, + # we check the corner case and skip the pattern + if len(node.users.keys()) == 0: + return + graph = match.graph + first_split = node.args[0].args[0] # type: ignore[union-attr] + next_split_index = node.args[0].args[1] # type: ignore[union-attr] + + new_split_sections = list(first_split_sections) + new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc] + + first_split_dim = _get_dim(first_split) + + to_remove = [] + + with graph.inserting_before(first_split): # type: ignore[arg-type] + # Add the new split node + new_split = graph.call_function( + torch.split, + args=(first_split_input, new_split_sections), + kwargs={"dim": first_split_dim}, + ) + if is_node_meta_valid(first_split_input): + new_split.meta["example_value"] = torch.split( + first_split_input.meta["example_value"], + new_split_sections, + dim=first_split_dim, + ) + first_split_num_to_user = { + user.args[1]: user + for user in first_split.users.keys() # type: ignore[union-attr] + } + + new_split_num = 0 + for split_num in range(len(first_split_sections)): + if split_num not in first_split_num_to_user: + new_split_num += 1 + continue + old_getitem = first_split_num_to_user[split_num] + if split_num != next_split_index: + old_getitem.update_arg(0, new_split) + old_getitem.update_arg(1, new_split_num) + new_split_num += 1 + else: + next_split_num_to_user = { + user.args[1]: user for user in node.users.keys() + } + # It is not necessary all getitems from the split node are used. + for next_split_num in range(len(next_split_sections)): + with graph.inserting_after(new_split): + new_getitem = graph.call_function( + operator.getitem, args=(new_split, new_split_num) + ) + new_split_num += 1 + if next_split_num not in next_split_num_to_user: + continue + next_getitem = next_split_num_to_user[next_split_num] + new_getitem.meta.update(next_getitem.meta) + next_getitem.replace_all_uses_with(new_getitem) + to_remove.append(next_getitem) + to_remove.append(node) + to_remove.append(old_getitem) + + to_remove.append(first_split) # type: ignore[arg-type] + for node in to_remove: + graph.erase_node(node) + + counters["inductor"]["merge_splits_pass"] += 1 + + +class SplitCatSimplifier: + """ + Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat" + pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat. + Some such cases are: + 1. Final node has additional args (not coming from the initial split) + 2. Shuffling of args between split/cat + 3. Some final nodes are non-(cat/stack) + 4. Split-dim != cat-dim (but equal split) + + Note that any combination of the above cases can happen. + + To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged. + Then, we simplify the split accordingly. In the best case, split can be entirely removed. + + To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`). + + Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added. + + """ + + def simplify( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: list[int], + ): + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by + # a tuple indicating the split ranges. See `get_user_input_list` for more details + user_inputs_list = self.get_user_input_list(split_node, next_users) + # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and + # we can simply replace the split node. Otherwise, we simplify it. + simplified_split_ranges = self.get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges: # Simplification not possible + return + transform_params_list = self.get_transform_params( + split_node, next_users, user_inputs_list + ) + if not transform_params_list: + return + + # Start actual replacement + user_inputs_list_new = self.replace_split( + graph, split_node, split_sections, user_inputs_list, simplified_split_ranges + ) + self.replace_cat( + graph, + split_node, + next_users, + user_inputs_list_new, + transform_params_list, # type: ignore[arg-type] + ) + self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] + counters["inductor"]["unbind_stack_pass"] += 1 + + def get_user_input_list( + self, split_node: torch.fx.Node, next_users: list[torch.fx.Node] + ) -> list[list[Union[torch.fx.Node, _Range]]]: + """ + Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner + list represents the inputs to that particular node. This list can either contain + - a tuple representing the ranges of get_items that should go into the cat (closed interval) + - torch.fx.Node representing "other" inputs (which are not coming from our split) + """ + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]] = [] + for user in next_users: + if user.target in (torch.cat, torch.stack): + user_inputs_list.append(self.get_merged_user_inputs(split_node, user)) + else: + user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type] + return user_inputs_list + + def get_merged_user_inputs( + self, split_node: torch.fx.Node, cat_node: torch.fx.Node + ) -> list[Union[torch.fx.Node, _Range]]: + user_inputs = get_arg_value(cat_node, 0, "tensors") + simplified_user_inputs = [] + split_users = OrderedSet(split_node.users.keys()) + for user_input in user_inputs: + if user_input not in split_users: + simplified_user_inputs.append(user_input) + else: + # Add which "getitem" cat depends on + simplified_user_inputs.append(user_input.args[1]) + return self.merge_consecutive_inputs(simplified_user_inputs) + + def get_non_cat_node_input( + self, split_node: torch.fx.Node, node: torch.fx.Node + ) -> list[_Range]: + """ + Get input for a non cat node in the same format as `get_merged_user_inputs` + """ + node_input = [] + split_users = OrderedSet(split_node.users.keys()) + for node_arg in node.all_input_nodes: + if node_arg in split_users: + getitem_num = get_arg_value(node_arg, 1) + node_input.append((getitem_num, getitem_num)) + return node_input + + def merge_consecutive_inputs( + self, inputs: list[Union[torch.fx.Node, int]] + ) -> list[Union[torch.fx.Node, _Range]]: + """ + Merge consecutive inputs going into a user node. + + For e.g. + [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1] + """ + merged_ranges = [] + cur_range = None + for input_ in inputs: + if isinstance(input_, int): + if not cur_range: + cur_range = [input_, input_] + elif input_ == cur_range[1] + 1: + cur_range[1] += 1 + else: + merged_ranges.append(tuple(cur_range)) + cur_range = [input_, input_] + else: + if cur_range: + merged_ranges.append(tuple(cur_range)) + cur_range = None + merged_ranges.append(input_) # type: ignore[arg-type] + if cur_range: + merged_ranges.append(tuple(cur_range)) + return merged_ranges # type: ignore[return-value] + + def get_simplified_split_ranges( + self, + split_sections, + next_users, + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + ) -> Optional[list[_Range]]: + ranges = OrderedSet[Any]() + for user_inputs in user_inputs_list: + ranges.update(u for u in user_inputs if isinstance(u, tuple)) + + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + split_ranges = sorted( + [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges] + ) + + if not self.has_non_overlapping_ranges( + split_ranges, + ): # This need not be a strict condition + # However, we keep it now for simplicity. + return None + split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) + if len(split_sections) == len(split_ranges): # Simplification not possible + return None + counters["inductor"]["scmerge_split_sections_removed"] = len( + split_sections + ) - len(split_ranges) + return split_ranges + + def has_non_overlapping_ranges(self, ranges: list[_Range]) -> bool: + for range_, next_range in zip(ranges, ranges[1:]): + if range_[1] > next_range[0]: + return False + return True + + def fill_gaps(self, ranges: list[_Range], min_: int, max_: int) -> list[_Range]: + cur = min_ + filled_ranges = [] + for a, b in ranges: + if cur < a: + filled_ranges.append((cur, a)) + filled_ranges.append((a, b)) + cur = b + if filled_ranges[-1][1] < max_: + filled_ranges.append((filled_ranges[-1][1], max_)) + return filled_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + ) -> Optional[list[list[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + We replace a split node with an unflatten followed by a movedim + """ + split_dim = _get_dim(split_node) + split_sections = split_node.args[1] + transform_params_list: list[list[_TransformParam]] = [] + + for user_node, user_inputs in zip(next_users, user_inputs_list): + if user_node.target not in (torch.cat, torch.stack): + transform_params_list.append([]) + continue + + cat_dim = get_arg_value(user_node, 1, "dim") + transform_params: list[_TransformParam] = [] + for user_input in user_inputs: + if split_dim == cat_dim and user_node.target == torch.cat: + # No transform needed + transform_params.append((None, None, None, None)) + elif isinstance(user_input, tuple): # Split being simplified + # Verify equal split + subset_split_sections = split_sections[ # type: ignore[index] + user_input[0] : user_input[1] + + 1 # type: ignore[index] + ] + # All sections should be equal + if len(OrderedSet(subset_split_sections)) != 1: # type: ignore[arg-type] + return None + + num_splits = len(subset_split_sections) # type: ignore[arg-type] + unflatten_params = (split_dim, (num_splits, -1)) + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + transform_params.append( + (unflatten_params, movedim_params, None, None) + ) + elif ( + user_node.target == torch.stack or split_dim != cat_dim + ): # We need to unsqueeze inputs not coming through split + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-split inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + def replace_split( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: list[int], + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + split_ranges: list[_Range], + ) -> list[list[torch.fx.Node]]: + """ + Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it + into a split with lesser sections if len(split_ranges) > 1. + + Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node. + """ + split_input = split_node.args[0] + split_dim = _get_dim(split_node) + if len(split_ranges) == 1: # We can completely eliminate the split node + split_items = [split_input] + else: + with graph.inserting_after(split_node): + new_split = graph.call_function( + torch.split, + args=( + split_input, + [r[1] - r[0] for r in split_ranges], + ), + kwargs={"dim": split_dim}, + ) + if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr] + new_split.meta["example_value"] = torch.split( + split_input.meta["example_value"], # type: ignore[union-attr] + [r[1] - r[0] for r in split_ranges], + dim=split_dim, + ) + counters["inductor"]["scmerge_split_added"] += 1 + split_items = [] + with graph.inserting_after(new_split): + for i in range(len(split_ranges)): + getitem = graph.call_function(operator.getitem, args=(new_split, i)) + if is_node_meta_valid(new_split): + getitem.meta["example_value"] = new_split.meta["example_value"][ + i + ] + split_items.append(getitem) + # Now assign the right getitem to the right input + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + new_user_inputs_list = [] + for user_inputs in user_inputs_list: + new_user_inputs = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # Find the correct new getitem (present in split_items) + new_user_inputs.append( + split_items[ + split_ranges.index( + ( + cumulative_sizes[user_input[0]], + cumulative_sizes[user_input[1] + 1], + ) + ) + ] + ) + else: + new_user_inputs.append(user_input) + new_user_inputs_list.append(new_user_inputs) + return new_user_inputs_list # type: ignore[return-value] + + def replace_cat( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list_new, + transform_params_list: list[list[_TransformParam]], + ): + split_dim = _get_dim(split_node) + split_users = split_node.users.keys() + new_cats = [] + for user_node, user_inputs_new, transform_params in zip( + next_users, user_inputs_list_new, transform_params_list + ): + if user_node.target not in (torch.cat, torch.stack): + # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to + # the original split node) with the newer getitems + next_cat_input = 0 + for input_node in user_node.all_input_nodes: + if input_node in split_users: + user_node.replace_input_with( + input_node, user_inputs_new[next_cat_input] + ) + next_cat_input += 1 + continue + + # Handle cat/stack user nodes + cat_dim = get_arg_value(user_node, 1, "dim") + user_inputs_new_transformed, user_inputs_new_transformed_meta = [], [] + # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them + to_stack, to_stack_meta = [], [] + stack_dim = None + with graph.inserting_before(user_node): + for user_input_new, transform_param in zip( + user_inputs_new, transform_params + ): + if not is_node_meta_valid(user_input_new): + log.debug("example value absent for node: %s", user_input_new) + return + # Apply transforms + ( + unflatten_params, + movedim_params, + unsqueeze_params, + flatten_params, + ) = transform_param + if unsqueeze_params and ( + stack_dim is None or stack_dim == unsqueeze_params[0] + ): + to_stack.append(user_input_new) + to_stack_meta.append(user_input_new.meta["example_value"]) + stack_dim = unsqueeze_params[0] + continue + elif to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type] + to_stack_meta, + dim=stack_dim, # type: ignore[arg-type] + ) + to_stack, to_stack_meta = [], [] + stack_dim = None + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + if unsqueeze_params: + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + to_stack_meta.append(user_input_new.meta["example_value"]) + continue + + if unflatten_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.unflatten, args=(user_input_new, *unflatten_params) + ) + user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type] + user_input_new_meta, # type: ignore[arg-type] + *unflatten_params, # type: ignore[arg-type] + ) + if movedim_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.movedim, args=(user_input_new, *movedim_params) + ) + user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type] + user_input_new_meta, # type: ignore[arg-type] + *movedim_params, # type: ignore[arg-type] + ) + if flatten_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.flatten, args=(user_input_new, *flatten_params) + ) + user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type] + user_input_new_meta, + *flatten_params, # type: ignore[arg-type] + ) + user_inputs_new_transformed.append(user_input_new) + user_inputs_new_transformed_meta.append( + user_input_new.meta["example_value"] + ) + if to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type] + to_stack_meta, + dim=stack_dim, # type: ignore[arg-type] + ) + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + + with graph.inserting_after(user_node): + if len(user_inputs_new_transformed) > 1: + new_cat_node = graph.call_function( + torch.cat, + args=(user_inputs_new_transformed,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + user_inputs_new_transformed_meta, + dim=cat_dim, + ) + counters["inductor"]["scmerge_cat_added"] += 1 + else: + new_cat_node = user_inputs_new_transformed[-1] + new_cat_node.meta["example_value"] = ( + user_inputs_new_transformed_meta[-1] + ) + + if ( + user_node.target == torch.cat + and split_dim != cat_dim + and split_node.target == torch.split + ): + with graph.inserting_after(new_cat_node): + new_cat_node_meta = new_cat_node.meta["example_value"] + new_cat_node = graph.call_function( + torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) + ) + new_cat_node.meta["example_value"] = torch.flatten( + new_cat_node_meta, + cat_dim, + cat_dim + 1, + ) + user_node.replace_all_uses_with(new_cat_node) + new_cats.append(new_cat_node) + + def erase_old_nodes( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + ): + to_remove = [split_node] + counters["inductor"]["scmerge_split_removed"] += 1 + to_remove.extend(split_node.users.keys()) + for next_user in next_users: + if next_user.target not in (torch.cat, torch.stack): + continue + counters["inductor"]["scmerge_cat_removed"] += 1 + to_remove.append(next_user) + for node in reversed(to_remove): + if len(node.users.keys()) == 0: + graph.erase_node(node) + + +class UnbindCatRemover(SplitCatSimplifier): + """ + Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier. + + Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this, + other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`, + hence we extend that class. + """ + + def remove_unbind( + self, + graph: torch.fx.Graph, + unbind_node: torch.fx.Node, + ): + if not is_node_meta_valid(unbind_node): + return + # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node + # before we do the unbind remove, otherwise it will hit the error when we unbind part of them + getitem_indices = [ + getitem_node.args[1] for getitem_node in unbind_node.users.keys() + ] + if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] + getitem_indices + ) != len(unbind_node.meta["example_value"]): + return + num_unbind = len(getitem_indices) + split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] + + super().simplify(graph, unbind_node, split_sections) + + def get_simplified_split_ranges( + self, + split_sections: list[int], + next_users: list[torch.fx.Node], + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + ) -> Optional[list[_Range]]: + simplified_split_ranges = super().get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges or len(simplified_split_ranges) != 1: + return None + return simplified_split_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + ) -> Optional[list[list[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + Here is the rough transforms we apply: + + x -> unbind -> stack => x -> movedim + + x -> unbind -> cat => x -> movedim -> flatten + + When cat/stack nodes have additional args: + + addn ---| addn -> unsqueeze ---| + x -> unbind -> stack => x -> movedim -> cat + + addn ---| addn ---| + x -> unbind -> cat => x -> movedim -> flatten -> cat + + (Note application of these depends on the dims as well) + + + """ + split_dim = _get_dim(split_node) + transform_params_list: list[list[_TransformParam]] = [] + for user_node, user_inputs in zip(next_users, user_inputs_list): + cat_dim = get_arg_value(user_node, 1, "dim") or 0 + transform_params: list[_TransformParam] = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # User input is coming from unbind + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + flatten_params = None + if user_node.target == torch.cat: + flatten_params = (cat_dim, cat_dim + 1) + transform_params.append( + (None, movedim_params, None, flatten_params) + ) + elif ( + user_node.target == torch.stack + ): # We need to unsqueeze inputs not coming through unbind into cat + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-unbind inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + +class GetItem(CallFunction): + def __init__(self, arg, index, _users=1) -> None: + super().__init__(operator.getitem, arg, index, _users=_users) + + def find_anchor_nodes(self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node]): + # We generally match GetItem with arg being an Arg(). So, we never return the anchor + # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes + # to not use ctx.pattern_to_node + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ) + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def merge_split_squeeze( + match: Match, split_input: torch.fx.Node, split_sizes: list[int], dim: int +): + graph = match.graph + split = next(node for node in match.nodes if node.target == torch.split) + if not all(s == 1 for s in split_sizes): + return + if isinstance(dim, Sequence): + return + next_users = find_next_users(split) + if not all(node.target == torch.squeeze for node in next_users): + return + with graph.inserting_before(match.output_node()): + unbind = graph.call_function( + torch.unbind, args=(split_input,), kwargs={"dim": dim} + ) + if is_node_meta_valid(split_input): + unbind.meta["example_value"] = torch.unbind( + split_input.meta["example_value"], dim=dim + ) + for item_index, getitem_node in sorted( + [ + (getitem_node.args[1], getitem_node) + for getitem_node in split.users.keys() + ] + ): + squeeze = next(iter(getitem_node.users.keys())) + new_get_item = graph.call_function( + operator.getitem, args=(unbind, item_index) + ) + squeeze.replace_all_uses_with(new_get_item) + new_get_item.meta.update(squeeze.meta) + graph.erase_node(squeeze) + graph.erase_node(getitem_node) + graph.erase_node(split) + counters["inductor"]["split_cat_pass"] += 1 + + +getitem_unbind = ListOf( + GetItem( + CallFunction( + torch.unbind, + KeywordArg("unbind_input"), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + UnbindCatRemover().remove_unbind(match.graph, unbind_node) + + +getitem_split = ListOf( + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +reshape_getitem_split = ListOf( + CallFunction( + torch.reshape, + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + Arg(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + tensors=getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def simplify_split_cat(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + SplitCatSimplifier().simplify(match.graph, split_node, split_sections) + + +# noqa: W605 +# ############pattern to be optimized is######### + +# split_node(dim=1) +# / \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / \ / +# cat (user=mul, dim=1) cat(user=mul, dim=1) +# | \ | \ + +# ################after transformation############# + +# split_node(dim=1) +# / ... \ +# getitem getitem +# | \ | \ + + +def has_same_parent_node(node: torch.fx.Node): + # the input nodes of the node should come from the same parent + prev_node = None + for getitem in node.args[0]: # type: ignore[union-attr] + if getitem.target != operator.getitem: # type: ignore[union-attr] + return False + if prev_node is None: + prev_node = getitem.args[0] # type: ignore[union-attr] + else: + if getitem.args[0] != prev_node: # type: ignore[union-attr] + return False + return True + + +def remove_zeros(split_sections: list[int]): + """ + Remove zeros from the list and get the index mapping dict from getitem + in split node to getitem in new split node + """ + new_split_sections, index_mapping = [], {} + idx = 0 + for i in range(len(split_sections)): + if split_sections[i] > 0: + new_split_sections.append(split_sections[i]) + index_mapping[i] = idx + idx += 1 + + return new_split_sections, index_mapping + + +def is_sorted_and_consecutive(arr: list[int]) -> bool: + # check if the array is sorted + if arr == sorted(arr): + # check if the differences between adjacent elements are all 1 + return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:])) + else: + return False + + +def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) -> int: + """ + Calculate the fused tensor size in the indices + """ + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + return fused_tensor_size + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"), +) +def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, _split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") + # check the all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + # check all getitem only has one single user + if ( + split_dim != cat_dim + or not has_same_parent_node(cat_user) + or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be cated/stacked + # type: ignore[union-attr] + indices = [arg.args[1] for arg in cat_user.args[0]] # type: ignore[union-attr] + # the getitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): # type: ignore[arg-type] + continue + # update the arg of cat user, only keep the first getitem + cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index] + split_node, + indices, # type: ignore[arg-type] + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 # type: ignore[index] + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [cat_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + cat_user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(cat_user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters["inductor"]["merge_getitem_cat_pass"] += 1 + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op /cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) -> -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"), +) +def mutate_cat_node(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + _split_input, _split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") or 0 + # check that all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + if split_dim != cat_dim or not has_same_parent_node(cat_user): + continue + # find the index of getitems to be cat + indices, idx_to_getitem = [], {} + for getitem in cat_user.args[0]: # type: ignore[union-attr] + indices.append(getitem.args[1]) # type: ignore[union-attr] + idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] + # the getitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): # type: ignore[arg-type] + continue + # case 1: the cat uses all getitems from the split + if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type] + # replace the users of the cat node to be the input of the split node + cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type] + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["mutate_cat_pass"] += 1 + # case 2: the cat uses some getitems from the split + elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] + # check the split dim, and construct the slice tuple + start_fused_size = calculate_fused_tensor_size( + split_node, + list(range(indices[0])), # type: ignore[arg-type] + ) + end_fused_size = start_fused_size + calculate_fused_tensor_size( + split_node, + indices, # type: ignore[arg-type] + ) + slice_list = [] + for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append(slice(start_fused_size, end_fused_size, None)) + with graph.inserting_after(split_node): + slice_node = graph.call_function( + operator.getitem, + args=(split_node.args[0], tuple(slice_list)), + ) + cat_user.replace_all_uses_with(slice_node) + slice_node.meta.update(cat_user.meta) + + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["mutate_cat_pass"] += 1 + + +getitem_split_aten = ListOf( + CallFunction( + operator.getitem, + CallFunctionVarArgs([torch.ops.aten.split_with_sizes.default], users=MULTIPLE), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.split.Tensor, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_split_default_aten(match: Match, *args, **kwargs): + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("val absent for node: %s", split_node) + return + assert isinstance(split_node.meta["val"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["val"]] + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["val"].dim() + split_section_list = [split_size] * (len(split_node.meta["val"])) + new_args = (split_input, split_section_list) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.ops.aten.split_with_sizes.default, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.split_with_sizes.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_split_with_size_default_aten(match: Match, *args, **kwargs): + split_node = match.nodes[0] + graph = match.graph + split_input, split_sections, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_sections is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("val absent for node: %s", split_node) + return + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["val"].dim() + + new_args = (split_input, split_sections) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.ops.aten.split_with_sizes.default, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + getitem_split_aten, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_aten_pass"), +) +def merge_split_cat_aten(match: Match, *args, **kwargs): + graph = match.graph + split_node = match.nodes[0] + threshold_to_cat = torch._inductor.config.post_grad_fusion_options[ + "split_cat_aten_pass" + ].get("threshold_to_cat", 10) + # get the getitem nodes from the split node + getitem_nodes = list(split_node.users.keys()) + for cat_node in list(getitem_nodes[0].users.keys()): + cat_dim = get_arg_value(cat_node, 1, "dim") + cat_inputs = get_arg_value(cat_node, 0, "tensors") + if len(cat_inputs) < threshold_to_cat: + continue + # check split node and cat node has same dim, and all getitem nodes have same parent node + parent_to_indices = defaultdict(list) # type: ignore[var-annotated] + parent_to_getitems = defaultdict(list) # type: ignore[var-annotated] + for cat_input in cat_inputs: + # skip all non-getitem cat input + if cat_input.target != operator.getitem: + continue + current_getitem_parent = cat_input.args[0] + split_dim = get_arg_value(current_getitem_parent, 2, "dim") + if split_dim != cat_dim: + break + getitem_idx = cat_input.args[1] + if ( + current_getitem_parent not in parent_to_indices + ) or getitem_idx != parent_to_indices[current_getitem_parent][-1][-1] + 1: + parent_to_indices[current_getitem_parent].append([getitem_idx]) + parent_to_getitems[current_getitem_parent].append([cat_input]) + else: + parent_to_getitems[current_getitem_parent][-1].append(cat_input) + parent_to_indices[current_getitem_parent][-1].append(getitem_idx) + + cat_inputs_list = list(cat_inputs) + update_cat_arg = [] + # iterate through the indices to construct the slice nodes + for parent, indices in parent_to_indices.items(): + for idx, indice in enumerate(indices): + start, end = indice[0], indice[-1] + split_sections = list(parent.args[1]) + input_of_current_getitem_parent = parent.args[0] + if len(indice) >= threshold_to_cat or len(indice) == len( + split_sections + ): + if len(indice) != len(split_sections): + # get the start and end slicing indices + slice_node = graph.call_function( + torch.ops.aten.slice.Tensor, + args=( + input_of_current_getitem_parent, + split_dim, # type: ignore[possibly-undefined] + sum(split_sections[:start]), + sum(split_sections[: end + 1]), + ), + ) + else: + slice_node = input_of_current_getitem_parent + # find the index in the cat_inputs_list given the getitem node + update_cat_arg.append( + ( + slice_node, + cat_inputs_list.index(parent_to_getitems[parent][idx][0]), + cat_inputs_list.index(parent_to_getitems[parent][idx][-1]), + ) + ) + + result = [] + i = 0 + for slice_tensor, start, end in update_cat_arg: + while i < start: + result.append(cat_inputs_list[i]) + i += 1 + result.append(slice_tensor) + i = end + 1 + while i < len(cat_inputs_list): + result.append(cat_inputs_list[i]) + i += 1 + + cat_node.update_arg(0, result) + for getitem_node in getitem_nodes: + if len(getitem_node.users) == 0: + graph.erase_node(getitem_node) + if len(split_node.users) == 0: + graph.erase_node(split_node) + counters["inductor"]["split_cat_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + ListOf( + CallFunctionVarArgs(torch.ops.aten.select.int, users=MULTIPLE), + partial=True, + ), + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("select_cat_aten_pass"), +) +def merge_select_cat_aten(match: Match, *args, **kwargs): + graph = match.graph + node = match.nodes[0] + node_input = get_arg_value(node, 0, "tensors") + # get the select nodes from the node + select_nodes = list(node_input.users.keys()) + for cat_node in list(node.users.keys()): + if cat_node.target == torch.ops.aten.cat.default: + cat_dim = get_arg_value(cat_node, 1, "dim") + cat_inputs = get_arg_value(cat_node, 0, "tensors") + # check all select nodes has same slice dim + if not all( + select_node.args[1] == select_nodes[0].args[1] + for select_node in select_nodes + ): + continue + # We only consider the case where selece slice dim and cat node has same dim + if select_nodes[0].args[1] != cat_dim: + continue + if not is_node_meta_valid(cat_node): + continue + # check the cat node has consecutive indices + indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr] + if ( + not is_sorted_and_consecutive(indices) # type: ignore[arg-type] + or len(select_nodes) != len(cat_inputs) + ): + continue + # check all the select nodes can be merged to the cat node input + if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr] + continue + # reshape the node input to be the same shape as the cat node + with graph.inserting_before(node): + view_node = graph.call_function( + torch.ops.aten.view.default, + args=(node_input, cat_node.meta["val"].shape), + ) + # replace the node input with the new node + cat_node.replace_all_uses_with(view_node) + view_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["select_cat_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_cat_default_aten(match: Match, *args, **kwargs): + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "val" not in tensor.meta: + log.debug("val absent for node: %s", tensor) + return + + ndim = cat_node.meta["val"].dim() + + def is_empty_tensor(x: torch.fx.Node) -> bool: + # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor + x_shape = x.meta["val"].shape + return len(x_shape) == 1 and x_shape[0] == 0 + + assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.ops.aten.cat.default, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat, + ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"), +) +def merge_unbind_stack_aten(match: Match, *args, **kwargs): + node = match.nodes[-1] + graph = match.graph + # pyre-fixme[6] + unsqueeze_nodes = list(node.args[0]) # type: ignore[arg-type] + cat_dim = get_arg_value(node, 1, "dim") + # check the unsqueeze nodes come from the select nodes + if not all( + get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select + for unsqueeze_node in unsqueeze_nodes + ): + return + select_nodes = [ + get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes + ] + parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") + # check the target of select_nodes are the same + if not all( + select_node.target == torch.ops.aten.select for select_node in select_nodes + ): + return + # check the select nodes come from the same parent node + if not all( + get_arg_value(select_node, 0, "input") == parent_of_select_node + for select_node in select_nodes + ): + return + if len(unsqueeze_nodes) != len(select_nodes): + return + # check the select nodes have the same dim + if not all( + get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes + ): + return + # check the select nodes have consecutive indices starting from 0 + if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive( + [get_arg_value(select_node, 2, "index") for select_node in select_nodes] + ): + return + # check the users of parent of select node only from unsqueeze nodes that go to the cat node + # we simply check the number of users of the parent of select node + if len(parent_of_select_node.users.keys()) != len(node.args[0]): # type: ignore[arg-type] + return + node.replace_all_uses_with(parent_of_select_node) + graph.erase_node(node) + for unsqueeze_node in unsqueeze_nodes: + graph.erase_node(unsqueeze_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["unbind_stack_aten_pass"] += 1 + + +def divide_into_consecutive_sublists(indices: list[int]) -> list[list[int]]: + n = len(indices) + if n <= 1: + return [indices] + + # Initialize the list of sublists + sublists = [] + + # Iterate over the indices + i = 0 + while i < n: + # Initialize the current sublist + sublist = [indices[i]] + + # Iterate over the remaining indices + j = i + 1 + while j < n and indices[j] == indices[j - 1] + 1: + # Add the next index to the current sublist + sublist.append(indices[j]) + j += 1 + + # Add the current sublist to the list of sublists + sublists.append(sublist) + # Move to the next index + i = j + + return sublists + + +def update_args_from_split_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, + getitem_indices: list[int], + parents_seen: list[torch.fx.Node], + new_cat_args: list[torch.fx.Node], + new_cat_args_meta: list[torch.fx.Node], + idx_to_getitems: dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1]) + # case 1: the number of getitems is the same as the split size, eliminate the split + if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive( + getitem_indices + ): + # we can merge the getitems from the previous parent + new_cat_args.append(split_input) + new_cat_args_meta.append(split_input.meta["example_value"]) + else: + if len(getitem_indices) > 0: + # case 2: the number of getitems is smaller than the split size but larger than the threshold, and + # the indices of getitems are not all consecutive, we need to divide the indices into multiple groups + geitem_indices_sublist = divide_into_consecutive_sublists(getitem_indices) + for sublist in geitem_indices_sublist: + if len(sublist) >= threshold_to_cat: + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + start_fused_size = sum(split_size[: sublist[0]]) + end_fused_size = sum(split_size[: sublist[-1] + 1]) + slice_list = [] + for i in range(len(split_input.meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append( + slice(start_fused_size, end_fused_size, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(split_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = split_input.meta[ + "example_value" + ][tuple(slice_list)] + new_cat_args.append(slice_node) + new_cat_args_meta.append(slice_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in sublist: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append( + idx_to_getitems[i].meta["example_value"] + ) + + +def reshape_cat_node( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + unbind_input: torch.fx.Node, + cat_dim: int, + unbind_dim: int, + cat_shape: torch.Size, +) -> torch.fx.Node: + if cat_dim != unbind_dim: + # construct the permute node args, which has the same shape as the slice node + # then it has the same dim as the unbind_input, i.e., shape of cat + 1 + with graph.inserting_after(cat_node): + permute_list = list(range(len(cat_shape) + 1)) + permute_list[unbind_dim], permute_list[cat_dim] = ( + permute_list[cat_dim], + permute_list[unbind_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(unbind_input, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + unbind_input.meta["example_value"], permute_list + ) # type: ignore[arg-type] + else: + permute_node = unbind_input + with graph.inserting_after(permute_node): + reshape_node = graph.call_function( + torch.reshape, args=(permute_node, tuple(cat_shape)) + ) + reshape_node.meta["example_value"] = torch.reshape( + permute_node.meta["example_value"], tuple(cat_shape) + ) # type: ignore[arg-type] + return reshape_node + + +def update_args_from_unbind_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, # cat or stack node + getitem_indices: list[int], + parents_seen: list[torch.fx.Node], + new_cat_args: list[torch.fx.Node], + new_cat_args_meta: list[torch.fx.Node], + idx_to_getitems: dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + unbind_input = get_arg_value(parents_seen[-1], 0, "input") # split or unbind input + unbind_dim = get_arg_value(parents_seen[-1], 1, "dim") # split or unbind dim + cat_dim = get_arg_value(node, 1, "dim") # cat or stack dim + # case 1: the number of getitems is the same as the split size, eliminate the split + size = list(unbind_input.meta["example_value"].shape)[unbind_dim] + if size == len(getitem_indices): + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + # we can merge the getitems from the previous parent + reshape_node = reshape_cat_node( + graph, node, unbind_input, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + elif len(getitem_indices) >= threshold_to_cat and is_sorted_and_consecutive( + getitem_indices + ): + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + slice_list = [] + for i in range(len(cat_shape) + 1): + if i != unbind_dim: + slice_list.append(slice(None, None, None)) # start, end, step + else: + slice_list.append( + slice(getitem_indices[0], getitem_indices[-1] + 1, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(unbind_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = torch.narrow( + unbind_input.meta["example_value"], + unbind_dim, + getitem_indices[0], + getitem_indices[-1] - getitem_indices[0] + 1, + ) + reshape_node = reshape_cat_node( + graph, node, slice_node, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in getitem_indices: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append(idx_to_getitems[i].meta["example_value"]) + + +def construct_cat_args( + graph: torch.fx.Graph, + cat_or_stack_node: torch.fx.Node, + inputs: list[torch.fx.Node], + split_or_unbind_node: torch.fx.Node, + threshold_to_cat: int = 2, + run_update_func: Callable = update_args_from_split_getitem, # type: ignore[type-arg] +) -> tuple[list[torch.fx.Node], list[torch.Tensor]]: + new_cat_args, parents_seen, getitem_indices, idx_to_getitems = [], [], [], {} # type: ignore[var-annotated] + new_cat_args_meta = [] # type: ignore[var-annotated] + for input in inputs: + if input.target != operator.getitem: + # update the last arg based on getitem_indices and parents_seens + if len(parents_seen) > 0: + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + # reset the indices array + getitem_indices, idx_to_getitems = [], {} + else: + # get the parent node of the getitem input + parent, idx = input.args[0], input.args[1] # type: ignore[union-attr] + if parent.target != split_or_unbind_node.target: # type: ignore[union-attr] + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # cannot use parents_seen to check since the first item could be non getitem node + if len(parents_seen) == 0: + parents_seen.append(parent) + idx_to_getitems[idx] = input + getitem_indices.append(idx) + # case: we only have one getitem input, and it is in the last position + if input == inputs[-1]: + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # if it is the last input in the tensors, we also check if it can be optimized + if parent != parents_seen[-1] or input == inputs[-1]: + if input == inputs[-1]: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + # reset the indices array for the next parent + # remember to add the last element since it is the first + # item in this round of parent + # add the parent to the list of seen parents + parents_seen.append(parent) + getitem_indices, idx_to_getitems = [idx], {idx: input} + else: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + return new_cat_args, new_cat_args_meta + + +def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.Node]): + nodes = OrderedSet[Any]() + for input in inputs: + if input.target == operator.getitem: + nodes.add(input.args[0]) # type: ignore[union-attr] + if len(input.users.keys()) == 0: + graph.erase_node(input) + # check the split node to remove if it has no users + for node in nodes: + if len(node.users.keys()) == 0: # type: ignore[union-attr] + graph.erase_node(node) # type: ignore[arg-type] + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# other inputs getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) other inputs -> -> user=multiple +# / \ +# cat (user=mul, dim=1, split_node) + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), +) +def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_nodes = [node for node in match.nodes if node.target == torch.split] + if split_nodes: + split_node = next(node for node in split_nodes) + else: + # Handle the case where there are no nodes with a target of torch.split + return + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_cat_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + next_users = find_next_users(split_node) + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, _ = construct_cat_args( + graph, + cat_node, + cat_inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: if new cat args has length 1, we can remove the cat node + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["split_cat_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs): + new_args = (new_cat_args,) + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + # split and cat have the same dim + kwargs={"dim": split_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) + counters["inductor"]["split_cat_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=0) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"), +) +def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_cat_to_view_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + cat_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + # get the view shape + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["unbind_cat_to_view_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(cat_node, 1, "dim") + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) # type: ignore[arg-type] + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["unbind_cat_to_view_pass"] += 1 + + +def reshape_cat_node_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + split_or_unbind_dim: int, +) -> None: + # reshape the cat node to the stack node shape + stack_shape = stack_node.meta["example_value"].shape + stack_dim = _get_dim(stack_node) + if stack_dim != split_or_unbind_dim: + # case 1: the stack dim is not the same as the split dim + # we need to reshape the split input before we do the reshape + reshape_list = list(stack_shape) + reshape_list[stack_dim], reshape_list[split_or_unbind_dim] = ( + reshape_list[split_or_unbind_dim], + reshape_list[stack_dim], + ) + reshape_node = graph.call_function( + torch.reshape, + args=(cat_node, tuple(reshape_list)), + ) + reshape_node.meta["example_value"] = torch.reshape( + cat_node.meta["example_value"], tuple(reshape_list) + ) + permute_list = list(range(len(stack_shape))) + permute_list[stack_dim], permute_list[split_or_unbind_dim] = ( + permute_list[split_or_unbind_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(reshape_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + reshape_node.meta["example_value"], permute_list + ) + else: + # case 2: the stack dim is the same as the split dim + # we can directly reshape the split input + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, *stack_shape), # type: ignore[arg-type] + ) + stack_node.replace_all_uses_with(reshape_node) + reshape_node.meta.update(stack_node.meta) + stack_inputs = stack_node.args[0] # type: ignore[union-attr] + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + + +def convert_reshape_cat_arg_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + stack_node_shape: torch.Size, + stack_dim: int, + split_dim: int, +) -> torch.fx.Node: + # reshape the cat node to the stack node shape + cat_shape = cat_node.meta["example_value"].shape + if stack_dim != split_dim: + permute_list = list(range(len(cat_shape))) + permute_list[stack_dim], permute_list[split_dim] = ( + permute_list[split_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(cat_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + cat_node.meta["example_value"], permute_list + ) + else: + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type] + ) + reshape_node.meta["example_value"] = torch.Tensor.view( + permute_node.meta["example_value"], + tuple(stack_node_shape), # type: ignore[arg-type] + ) + return reshape_node + + +# ############pattern to be optimized is######### +# | | +# split split (dim=1) +# / \ / \ +# getitem ... getitem other ops +# \ | / / +# stack(user=mul, dim=1 or 2) -> can be different dim +# | + +# ################after transformation############# + +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / +# cat(user=mul, dim=1) cat_other_opts +# \ / +# cat +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_stack_to_cats_pass"), +) +def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_stack_to_cats_pass" + ].get("threshold_to_cat", 10) + # get the stack_node and check its inputs and meta data + next_users = find_next_users(split_node) + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim) + counters["inductor"]["split_stack_to_cats_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": split_dim}, + ) + cat_node.meta["example_value"] = torch.cat( # type: ignore[arg-type] + new_cat_args_meta, dim=split_dim + ) + reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim) + counters["inductor"]["split_stack_to_cats_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=1) -> user=multiple +# \ ... / \ +# others getitem getitem getitem -> user=multiple +# \ \ / \ +# stack(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view others +# | / +# stack +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"), +) +def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_stack_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + unbind_dim = get_arg_value(unbind_node, 1, "dim") or 0 + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim) + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(stack_node, 1, "dim") + with graph.inserting_after(stack_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) + reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim) + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### +# input +# | +# split(dim=1) -> user=multiple +# \ \ +# others getitem getitem +# \ \ / +# reshape reshape reshape other_op +# \ \ / / +# stack(user=mul, dim=0) +# | + +# ################after transformation############# +# input +# | +# permute +# | +# reshape others +# | / +# cat (dim=0) +# | + + +def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]: + # cat_arg must be the split input + view_shape_list = [] + for user in cat_arg.users.keys(): + if user.target == torch.split: + for getitem in user.users.keys(): + if getitem.target == operator.getitem: + reshape_user = [ + user + for user in getitem.users.keys() + if user.target == torch.reshape + ] + if len(reshape_user) > 0: + view_shape_list = list( + reshape_user[0] + .meta["example_value"] + .unsqueeze(stack_dim) + .shape + ) + view_shape_list[stack_dim] = -1 + return view_shape_list + return view_shape_list + + +@register_graph_pattern( + CallFunction( + torch.stack, + reshape_getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"), +) +def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): + split_node = next(node for node in match.nodes if node.target == torch.split) + split_dim = _get_dim(split_node) + split_users = list(split_node.users.keys()) + stack_nodes = [node for node in match.nodes if node.target == torch.stack] + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "move_reshape_out_of_split_stack_pass" + ].get("threshold_to_cat", 10) + for stack_node in stack_nodes: + if not is_node_meta_valid(stack_node): + log.debug("example value absent for node: %s", stack_node) + continue + stack_dim = _get_dim(stack_node) + stack_inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + inputs = [] + for stack_input in stack_inputs: + if stack_input.target != torch.reshape: + inputs.append(stack_input) + else: + inputs.append(stack_input.args[0]) # type: ignore[union-attr] + new_cat_args, _new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_node = convert_reshape_cat_arg_to_stack( + graph, + new_cat_args[0], + stack_node, + stack_node.meta["example_value"].shape, + stack_dim, + split_dim, + ) + stack_node.replace_all_uses_with(reshape_node) + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # decompose the cat args into multiple stack nodes, i.e., we stack + # all the nodes exist in the stack inputs and reshape the rest followed by a cat + stack_node_input, stack_node_input_meta, cat_inputs = [], [], [] # type: ignore[var-annotated] + for cat_arg in new_cat_args: + if cat_arg not in stack_inputs: + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + # cat_arg must be the split input + view_shape_list = get_view_shape_list(cat_arg, stack_dim) + stack_node_shape = torch.reshape( + cat_arg.meta["example_value"], tuple(view_shape_list) + ).shape # type: ignore[union-attr] + cat_inputs.append( + convert_reshape_cat_arg_to_stack( + graph, + cat_arg, + stack_node, + stack_node_shape, + stack_dim, + split_dim, + ) + ) + stack_node_input, stack_node_input_meta = [], [] + else: + stack_node_input.append(cat_arg) + stack_node_input_meta.append(cat_arg.meta["example_value"]) + + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(cat_inputs,), + kwargs={"dim": stack_dim}, + ) + stack_node.replace_all_uses_with(cat_node) + cat_node.meta.update(stack_node.meta) + graph.erase_node(stack_node) + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 + + +view_getitem_split_aten = ListOf( + CallFunction( + [torch.ops.aten.reshape.default], + CallFunction( + operator.getitem, + CallFunctionVarArgs( + torch.ops.aten.split_with_sizes.default, users=MULTIPLE + ), + Ignored(), + _users=MULTIPLE, + ), + Arg(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + view_getitem_split_aten, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("move_view_after_cat_aten_pass"), +) +def move_view_after_cat(match: Match, *args, **kwargs): + split_node = next( + node + for node in match.nodes + if node.target == torch.ops.aten.split_with_sizes.default + ) + split_input, split_section, split_dim = _get_split_args_default(split_node) + split_users = list(split_node.users.keys()) + getitem_indices = [ + getitem.args[1] for getitem in split_users if getitem.target == operator.getitem + ] + if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type] + return + cat_nodes = [ + node for node in match.nodes if node.target == torch.ops.aten.cat.default + ] + graph = match.graph + for cat_node in cat_nodes: + if not is_node_meta_valid(cat_node): + log.debug("example value absent for node: %s", cat_node) + continue + cat_dim = _get_dim(cat_node) + cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + # we only consider the following special case + if len(cat_inputs) != len(split_section): + continue + # check if the cat inputs are all the view nodes + if not all( + view_node.target == torch.ops.aten.reshape.default + for view_node in cat_inputs + ): + continue + # check if the view nodes are all from getitem nodes + if not all( + view_node.args[0].target == operator.getitem for view_node in cat_inputs + ): + continue + view_indices = [view.args[0].args[1] for view in cat_inputs] + if not is_sorted_and_consecutive(view_indices): # type: ignore[arg-type] + continue + if cat_dim != split_dim: + # construct permute node + permute_list = list(range(len(cat_node.meta["val"].shape) + 1)) + permute_list[split_dim], permute_list[cat_dim] = ( + permute_list[cat_dim], + permute_list[split_dim], + ) + permute_node = graph.call_function( + torch.ops.aten.permute.default, + args=(split_input, permute_list), + ) + else: + permute_node = split_input + + with graph.inserting_before(cat_node): + view_node = graph.call_function( + torch.ops.aten.reshape.default, + args=(permute_node, list(cat_node.meta["val"].shape)), + ) + cat_node.replace_all_uses_with(view_node) + view_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["move_view_after_cat_aten_pass"] += 1 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32a033e1eecbf5a24f735c7f0538a564b1035eba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/__init__.py @@ -0,0 +1 @@ +from . import mm, mm_common, mm_plus_mm diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py new file mode 100644 index 0000000000000000000000000000000000000000..be782bac3a828a0fdf280afe2ed1e7b87fb36083 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/bmm.py @@ -0,0 +1,294 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch._dynamo.utils import counters +from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate + +from .. import ir, lowering as L +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, +) +from ..utils import ( + _use_cutlass_for_op, + use_aten_gemm_kernels, + use_ck_gemm_template, + use_cpp_bmm_template, + use_cutlass_template, + use_triton_template, +) +from ..virtualized import V +from .mm_common import ( + _is_static_problem, + addmm_epilogue, + is_batch_stride_largest, + mm_args, + mm_config_kwargs, + mm_options, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@SymbolicGridFn +def bmm_grid(b, m, n, meta, *, cdiv): + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) + + +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 128 or n > 128 or k > 128: + return True + return m * n > 2**12 + + +bmm_template = TritonTemplate( + name="bmm", + grid=bmm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", -2)}} + N = {{size("B", -1)}} + K = {{size("A", -1)}} + + stride_aq = {{stride("A", 0)}} + stride_am = {{stride("A", 1)}} + stride_ak = {{stride("A", 2)}} + + stride_bq = {{stride("B", 0)}} + stride_bk = {{stride("B", 1)}} + stride_bn = {{stride("B", 2)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + rk = tl.arange(0, BLOCK_K) + + idx_q = tl.program_id(1) # batch dimension for BMM + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_q = tl.program_id(1) # batch dimension for BMM + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}} +""", + cache_codegen_enabled_for_template=True, +) + +aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out") +aten_bmm_dtype = ExternKernelChoice( + torch.bmm, + "at::_bmm_out_dtype_cuda", + name="bmm_dtype", + op_overload=aten.bmm.dtype_out, +) +aten_baddbmm = ExternKernelChoice( + torch.baddbmm, "at::baddbmm_out", op_overload=aten.baddbmm.out +) + + +@L.register_lowering(aten.bmm) +def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None): + """ + Lowering for autotuning aten.bmm with different backends (Aten, Triton, CUTLASS, etc.) + """ + if all(x.get_device().type == "cpu" for x in [mat1, mat2]): + # decompose to small ops when memory bound + if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1: + mat1 = L.unsqueeze(mat1, -1) + mat2 = L.unsqueeze(mat2, 1) + return L.sum_(L.mul(mat1, mat2), axis=2) + + def is_valid_to_require_contiguous(t): + if not ir.is_storage_and_layout(t): + return True + _, layout = ir.as_storage_and_layout(t, freeze=False) + return isinstance(layout, ir.FlexibleLayout) + + def is_preferred_layout_as_bmm_input(sizes, strides): + # contiguous on one of the last two dims + return ( + strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1]) + ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2])) + + # Make the input of bmm contiguous + # if it is not contiguous on either of the last two dims, + # because bmm cpu implementation would do contiguous() if not. + # This is to avoid additional copies in bmm. + def may_require_contiguous(t, meta_t): + sizes = meta_t.meta["val"].size() + strides = meta_t.meta["val"].stride() + if not is_preferred_layout_as_bmm_input(sizes, strides): + t = ir.ExternKernel.require_contiguous(t) + return t + + if is_valid_to_require_contiguous(mat1): + meta_mat1 = V.graph.current_node.args[0] + mat1 = may_require_contiguous(mat1, meta_mat1) + if is_valid_to_require_contiguous(mat2): + meta_mat2 = V.graph.current_node.args[1] + mat2 = may_require_contiguous(mat2, meta_mat2) + + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=out_dtype + ) + + # below is for getting an overview logging info of inductor mms + batch_size = mat1.get_size()[0] # Extract batch dimension + counters["aten_mm_info"][f"aten.bmm_{batch_size}_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.bmm: batch=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + batch_size, + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + if out_dtype: + assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA" + aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype) + else: + aten_func = aten_bmm.bind((mat1, mat2), layout) + + # options to tune from + choices = [aten_func] if use_aten_gemm_kernels() else [] + + device_type = ir.get_device_type(mat1) + bmm_configs = V.choices.get_base_mm_configs(device_type) + + dtype = mat1.get_dtype() + if use_triton_template(layout): + # TODO: add out_dtype support for Triton Template + assert out_dtype is None, "out_dtype is not supported for Triton" + for config in bmm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + bmm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + _, is_nonzero = _is_static_problem(layout) + batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout) + if ( + batch_stride_largest + and is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("bmm") + ): + from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate + + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type] + + if use_cpp_bmm_template(layout, mat1, mat2): + from ..codegen.cpp_bmm_template import CppBmmTemplate + + CppBmmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + + if use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + + return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout) + + +@L.register_lowering(aten.baddbmm) +def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) + + # below is for getting an overview logging info of inductor mms + batch_size = mat1.get_size()[0] + counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.baddbmm: batch_size=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s", + batch_size, + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + inp.get_dtype(), + layout, + ) + + # options to tune from + choices = ( + [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)] + if use_aten_gemm_kernels() + else [] + ) + + device_type = ir.get_device_type(mat1) + bmm_configs = V.choices.get_base_mm_configs(device_type) + + if use_triton_template(layout): + for config in bmm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + bmm_template.maybe_append_choice( + choices, + input_nodes=(inp, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), + ) + + return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4b14989c372d118712af7620bde71f3e43233731 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/conv.py @@ -0,0 +1,696 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import Optional, TYPE_CHECKING, TypedDict + +import torch +from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate + +from .. import config, ir +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, +) +from ..utils import ( + is_ones, + is_zeros, + pad_listlike, + sympy_product, + use_ck_conv_template, + use_triton_template, +) +from ..virtualized import V +from .mm_common import mm_config_kwargs + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..ir import TensorBox + +log = logging.getLogger(__name__) + + +aten = torch.ops.aten + + +@SymbolicGridFn +def conv2d_grid(n, c, h, w, meta, *, cdiv): + return ( + cdiv(n * h * w, meta["BLOCK_M"]), + cdiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +@SymbolicGridFn +def conv3d_grid(n, c, d, h, w, meta, *, cdiv): + return ( + cdiv(n * d * h * w, meta["BLOCK_M"]), + cdiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 256 or n > 256 or k > 256: + return True + return m * n * k > 2**17 + + +LOOP_BODY_2D = """ + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +""" +This is a relatively simple conv implementation that can likely be +improved. Many alternate conv versions can be found here: +https://github.com/pytorch/torchdynamo/pull/971 +""" +conv2d_template = TritonTemplate( + name="convolution2d", + grid=conv2d_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_H = {{size("X", 2)}} + IN_W = {{size("X", 3)}} + OUT_C = {{size(None, 1)}} + OUT_H = {{size(None, 2)}} + OUT_W = {{size(None, 3)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xh = {{stride("X", 2)}} + stride_xw = {{stride("X", 3)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wh = {{stride("W", 2)}} + stride_ww = {{stride("W", 3)}} + + nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = nhw % OUT_W + nh = nhw // OUT_W + idx_y_h = nh % OUT_H + idx_n = nh // OUT_H + idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY_2D + + """ +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (ijk % BLOCK_K_COUNT) * BLOCK_K + ij = ijk // BLOCK_K_COUNT + i = ij // KERNEL_W + j = ij % KERNEL_W + """ + + LOOP_BODY_2D + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}} +""", +) + +LOOP_BODY_3D = """ + idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_d * stride_xd)[:, None] + + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_d >= 0)[:, None] + & (idx_x_d < IN_D)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + + (d * stride_wd) + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +conv3d_template = TritonTemplate( + name="convolution3d", + grid=conv3d_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_D = {{size("X", 2)}} + IN_H = {{size("X", 3)}} + IN_W = {{size("X", 4)}} + OUT_C = {{size(None, 1)}} + OUT_D = {{size(None, 2)}} + OUT_H = {{size(None, 3)}} + OUT_W = {{size(None, 4)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xd = {{stride("X", 2)}} + stride_xh = {{stride("X", 3)}} + stride_xw = {{stride("X", 4)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wd = {{stride("W", 2)}} + stride_wh = {{stride("W", 3)}} + stride_ww = {{stride("W", 4)}} + + ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = ndhw % OUT_W + ndh = ndhw // OUT_W + idx_y_h = ndh % OUT_H + nd = ndh // OUT_H + idx_y_d = nd % OUT_D + idx_n = nd // OUT_D + idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for d in range(KERNEL_D) %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + d = {{d}} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY_3D + + """ +{% endfor %} +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for d in range(KERNEL_D): + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (dijk % BLOCK_K_COUNT) * BLOCK_K + dij = dijk // BLOCK_K_COUNT + j = dij % KERNEL_W + di = dij // KERNEL_W + i = di % KERNEL_H + d = di // KERNEL_H + """ + + LOOP_BODY_3D + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_d < OUT_D)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_d = idx_y_d[:, None] + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}} +""", +) + +aten_convolution = ExternKernelChoice( + torch.convolution, + "at::convolution", + has_out_variant=False, + op_overload=aten.convolution.default, +) + + +def conv1x1_via_mm(x, w, *, out): + w = torch.squeeze(torch.squeeze(w, -1), -1) + return torch.matmul( + x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1) + ) + + +aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None) + + +class ConvLayoutParams(TypedDict): + stride: tuple[int, ...] + padding: tuple[int, ...] + dilation: tuple[int, ...] + transposed: bool + output_padding: tuple[int, ...] + groups: int + + +def conv_layout( + x: TensorBox, + weight: TensorBox, + bias: Optional[TensorBox], + stride: Sequence[int], + padding: tuple[int, ...], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, +) -> ir.Layout: + """Determine output layout for a convolution""" + with V.graph.fake_mode: + output = torch.ops.aten.convolution( + ir.ir_node_to_tensor(x, guard_shape=True), + ir.ir_node_to_tensor(weight, guard_shape=True), + ir.ir_node_to_tensor(bias, guard_shape=True), + V.graph.sizevars.size_hints(stride), # type: ignore[arg-type] + V.graph.sizevars.size_hints(padding), # type: ignore[arg-type] + V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type] + transposed, + V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type] + groups, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment] + + return ir.FixedLayout( + x.get_device_or_error(), + x.get_dtype(), + sizes, + stride, + ) + + +def channels_last_order(rank): + order = list(reversed(range(rank))) + order.insert(1, order.pop(-1)) + return order + + +def convert_1x1_conv_to_mm(x, weight, bias): + # special case for 1x1 convolution, which is actually just a matmul + rank = len(weight.get_size()) + for _ in range(rank - 2): + weight = L[aten.squeeze](weight, dim=-1) + weight = L[aten.permute](weight, [1, 0]) + + x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank)) + x_permute = list(range(rank)) + x_permute.append(x_permute.pop(1)) + x = L[aten.permute](x, x_permute) + *sizes, in_chan = x.get_size() + x = L[aten.reshape](x, [sympy_product(sizes), in_chan]) + if bias is None: + result = L[aten.mm](x, weight) + else: + result = L[aten.addmm](bias, x, weight) + result = L[aten.reshape](result, [*sizes, -1]) + result_permute = list(range(rank)) + result_permute.insert(1, result_permute.pop(-1)) + return L[aten.permute](result, result_permute) + + +@register_lowering(aten.convolution) +def convolution( + x: TensorBox, + weight: TensorBox, + bias: Optional[TensorBox], + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + transposed: bool, + output_padding: Sequence[int], + groups: int, +): + stride = tuple(stride) + padding = tuple(padding) + dilation = tuple(dilation) + output_padding = tuple(output_padding) + if not isinstance(groups, int): + groups = V.graph.sizevars.evaluate_static_shape(groups) + assert isinstance(groups, int) + + # Need use hint for triton template since the template does not + # work with a dynamic shape. + # + # No need to evaluate_static_shape for dilation and output_padding + # since the template is only used when dilation is 1 and output_padding + # is 0. + stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride)) + padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding)) + + kwargs: ConvLayoutParams = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "transposed": transposed, + "output_padding": output_padding, + "groups": groups, + } + + device_type = ir.get_device_type(x) + + if len(x.get_size()) == len(weight.get_size()) - 1: + # add batch dimension to simplify rest of function + return L[aten.squeeze]( + convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs), + dim=0, + ) + + out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes( + weight.get_size() + ) + + # Always convert conv1D to 2D for Intel GPU. + # Only conv2D can be converted to channel last layout, + # which have much better performance. + if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu": + kwargs.update( + { + "stride": (1,) + stride, + "padding": (0,) + padding, + "dilation": (1,) + dilation, + "output_padding": (0,) + output_padding, + } + ) + # (N, C, L) -> (N, C, 1, L) + x = L[aten.unsqueeze](x, dim=2) + weight = L[aten.unsqueeze](weight, dim=2) + + return L[aten.squeeze]( + convolution(x, weight, bias, **kwargs), + dim=2, + ) + + ndim = len(kernel_shape) + stride = pad_listlike(stride, ndim) + padding = pad_listlike(padding, ndim) + dilation = pad_listlike(dilation, ndim) + output_padding = pad_listlike(output_padding, ndim) + + def channels_last_conv(): + if V.graph.layout_opt and ndim == 2: + return True + + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + return req_stride_order == ir.NHWC_STRIDE_ORDER + + autotuning_gemm = config.max_autotune or config.max_autotune_gemm + + if ( + (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv())) + and is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + and groups == 1 + and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0) + ): + return convert_1x1_conv_to_mm(x, weight, bias) + + if bias is not None and device_type != "cpu": + # peel off the bias, cudnn is slower with it + result = convolution(x, weight, None, **kwargs) + return L[aten.add]( + result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1]) + ) + + x.realize() + weight.realize() + + # ndim can be 1 for convolution in models such as demucs + # TODO: check if it's beneficial to convert Conv1d to Conv2d and then + # apply channels last. + if V.graph.layout_opt and ndim == 2: + V.graph.num_channels_last_conv += 1 + x = ir.ExternKernel.require_channels_last(x) + # TODO maybe we can convert weights to channels last just once before + # running the model. + weight = ir.ExternKernel.require_channels_last(weight) + layout = conv_layout(x, weight, None, **kwargs) + else: + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + x = ir.ExternKernel.require_stride_order(x, req_stride_order) + weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) + + ordered_kwargs_for_cpp_kernel = [ + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + if bias is None: + args = [x, weight] + kwargs["bias"] = None # type: ignore[typeddict-unknown-key] + ordered_kwargs_for_cpp_kernel.insert(0, "bias") + else: + args = [x, weight, bias] + bias.realize() + bias.freeze_layout() + V.graph.sizevars.evaluate_static_shapes(bias.get_size()) + + choices = [] + if torch._inductor.utils._use_conv_autotune_backend("ATEN"): + choices = [ + aten_convolution.bind( + args, + layout, + ordered_kwargs_for_cpp_kernel, + **kwargs, + ) + ] + + if ( + torch._inductor.utils._use_conv_autotune_backend("TRITON") + and use_triton_template(layout) + # templates only support these: + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0) + and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type] + ): + if ( + is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and groups == 1 + ): + choices.append(aten_conv1x1_via_mm.bind(args, layout)) + + conv_configs = V.choices.get_conv_configs(device_type) + + for cfg in conv_configs( + sympy_product([x.get_size()[0], *x.get_size()[2:]]), + out_chan, + in_chan, + **mm_config_kwargs(device_type, _is_large_block_for_cpu), + ): + if ndim == 2: + conv2d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_H=kernel_shape[0], + KERNEL_W=kernel_shape[1], + STRIDE_H=stride[0], + STRIDE_W=stride[1], + PADDING_H=padding[0], + PADDING_W=padding[1], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/triton-lang/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + elif ndim == 3: + conv3d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_D=kernel_shape[0], + KERNEL_H=kernel_shape[1], + KERNEL_W=kernel_shape[2], + STRIDE_D=stride[0], + STRIDE_H=stride[1], + STRIDE_W=stride[2], + PADDING_D=padding[0], + PADDING_H=padding[1], + PADDING_W=padding[2], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/triton-lang/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + if use_ck_conv_template(layout): + CKGroupedConvFwdTemplate.add_ck_conv_choices( + choices, + layout, + input_nodes=(x, weight) + ((bias,) if bias is not None else tuple()), + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=ndim, + ) + return autotune_select_algorithm("convolution", choices, args, layout) + + +@register_lowering(aten._convolution) +def _convolution( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32, +): + return convolution( + x, weight, bias, stride, padding, dilation, transposed, output_padding, groups + ) + + +def constrain_conv_to_fx_strides(fx_node, *args, **kwargs): + assert fx_node.target == torch.ops.aten.convolution.default + if V.graph.layout_opt: + return args, kwargs + else: + return constrain_to_fx_strides(fx_node, *args, **kwargs) + + +add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..22751fe1004797e00d477c5690dc64ec7e41c458 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py @@ -0,0 +1,2763 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel""" + +import copy +import logging +import math +from collections.abc import Sequence +from dataclasses import dataclass +from enum import auto, Enum +from typing import Any, Optional, Union + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + +from ..ir import ( + Buffer, + ComputedBuffer, + ExternKernel, + FixedLayout, + FlexibleLayout, + get_fill_order, + InputBuffer, + IRNode, + MutationLayoutSHOULDREMOVE, + Scatter, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import ( + _full, + check_and_broadcast_indices, + empty, + empty_strided, + expand, + index_output_size_and_inner_fn, + lowerings, + register_lowering, + to_dtype, +) +from ..select_algorithm import ( + autotune_select_algorithm, + realize_inputs, + SymbolicGridFn, + TritonTemplate, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +Expr = sympy.Expr + + +def construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len(fill_order), ( + "Length of sizes must match the length of the fill order" + ) + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]): + """This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp + + Args: + size: The size of the output tensor + orig_strides: The strides of the input tensor + Returns: + List[int]: Dense non-overlapping strides that preserve the input tensor's layout permutation. + The returned strides follow the same stride propagation rules as TensorIterator. This matches + The behavior of empty_like() + """ + fill_order = get_fill_order(orig_strides, V.graph.sizevars.shape_env) + return construct_strides(size, fill_order) + + +@SymbolicGridFn +def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): + """How is this kernel parallelized? + We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) + Each block is responsible for iterating over blocks of keys and values calculating + the final attention output. + """ + return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1) + + +def create_placeholder( + name: str, + dtype: torch.dtype, + device: torch.device, + size: Optional[list[int]] = None, +) -> TensorBox: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer( + name=name, + layout=FixedLayout( + device, + dtype, + size if size else [], + FlexibleLayout.contiguous_strides(size) if size else [], + ), + ) + return TensorBox.create(input_buffer) + + +def maybe_realize(args: list[Optional[IRNode]]): + """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" + return tree_map( + lambda x: ( + realize_inputs(x) + if x is not None and not isinstance(x, sympy.Symbol) + else x + ), + args, + ) + + +def get_float32_precision(): + if ( + torch.get_float32_matmul_precision() == "highest" + or torch.version.hip + or torch.mtia.is_available() + ): + return "'ieee'" + else: + return "'tf32'" + + +def zeros_and_scatter_lowering(shape: list[int], indices, values): + # Always accumulate into fp32 then cast + grad = _full(0, values.get_device(), torch.float32, shape) + assert isinstance(grad, TensorBox) + grad.realize() + x_size = grad.get_size() + values = to_dtype(values, grad.get_dtype()) + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device()) + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=True, + ) + + values = expand(values, expected_vals_size) + device = grad.get_device() + assert device is not None + scatter = Scatter( + device=device, + dtype=grad.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add", + ) + + buffer = ComputedBuffer( + name=grad.data.data.name, # type: ignore[attr-defined] + layout=MutationLayoutSHOULDREMOVE(grad), + data=scatter, + ) + return buffer + + +SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]] + + +def build_subgraph_module_buffer( + args: list[TensorBox], graph_module: torch.fx.GraphModule +) -> SubgraphResults: + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph. Contains both fixed and lifted inputs. + subgraph: The Subgraph ir for which to produce the output node + """ + from ..subgraph_lowering import PointwiseSubgraphLowering + + pw_subgraph = PointwiseSubgraphLowering( + graph_module, + root_graph_lowering=V.graph, + allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]), + additional_lowerings={ + torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering + }, + ) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*args) + + # Since we are allowing mutations/buffer creation, we need to register any fresh buffers + # creating during the pointwise subgraph lowering + if len(pw_subgraph.buffers) > 0: + for buffer in pw_subgraph.buffers: + V.graph.register_buffer(buffer) + + def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: + if output_buffer is None: + return None + if isinstance(output_buffer, ComputedBuffer): + # These nodes are coming from the output of zeros_and_scatter + return output_buffer + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) + + +def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults: + return build_subgraph_module_buffer(args, subgraph.graph_module) + + +def get_fwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults +) -> list[Optional[ComputedBuffer]]: + subgraph_buffer = ( + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + return [*subgraph_buffer, *mask_graph_buffer] + + +# Inner Triton functions shared by flex_attention & split-k decoding kernels. +compute_next_offset_func = r""" +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset +""" + +get_bounded_indices_func = r""" +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices +""" + + +load_checked_block = r""" +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") +""" + +load_checked_2d = r""" +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_DIM: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +""" + +compute_flex_attention = r""" +{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0) + off_zq = tl.program_id(1) // HQ + off_hq = tl.program_id(1) % HQ + + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], + strides=[QK_HEAD_DIM, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + desc_k = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + K_block_ptr = None + V_block_ptr = None + Q_block_ptr = None + + if not USE_TMA: + Q_block_ptr = tl.make_block_ptr( + base=Q , + shape=(Q_LEN, QK_HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(q_start * BLOCK_M, 0), + block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} + q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1) // HQ + idx_hq = tl.program_id(1) % HQ + idx_m = offs_m[:, None] + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + + if OUTPUT_LOGSUMEXP: + off_hz = tl.program_id(1) + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + """ + + +compute_forward_inner = r""" +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + if not USE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + + + return acc, l_i, m_i + +""" + + +compute_forward_block_mn = r""" +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + # NB reversed order to since K is transposed + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( # load in row major + desc_k, + [start_n.to(tl.int32) , kv_start], + ) + {%- else %} + k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- endif %} + + if USE_TMA: + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_start.to(tl.int32) + start_n.to(tl.int32),0], + ) + {%- else %} + v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +""" + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, + source=compute_flex_attention + + compute_forward_inner + + compute_next_offset_func + + compute_forward_block_mn + + load_checked_block + + get_bounded_indices_func, +) + + +def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa): + """Decide which kernel to use, return true if use flex decoding kernel. + Note: + Since the number of splits is calculated based of the the number of batch and head dims + we need to ensure that the batch and head dims are statically known. Otherwise we just + use the main flex_attention kernel. + """ + force_flex = kernel_options.get("FORCE_USE_FLEX_ATTENTION", False) + short_query_length = V.graph.sizevars.evaluate_expr( + sympy.Lt(query.get_size()[-2], 128) + ) + non_zero_length = V.graph.sizevars.evaluate_expr(sympy.Gt(query.get_size()[-2], 0)) + static_batch = isinstance(query.get_size()[0], (int, sympy.Integer)) + static_num_heads = isinstance(query.get_size()[1], (int, sympy.Integer)) + if enable_gqa: + # in the current flex decoding triton kernel, grouped query heads for the + # same kv head are handled by the same block. So it's hard to support different + # kv num blocks for grouped query heads. We just fall back to main flex_attention + # kernel where each query head is handled by a separate block. + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Eq(kv_indices.get_size()[1], 1) + ) + else: + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Or( + sympy.Eq(kv_indices.get_size()[1], 1), + sympy.Eq(kv_indices.get_size()[1], query.get_size()[1]), + ) + ) + return ( + not force_flex + and short_query_length + and static_batch + and static_num_heads + and non_zero_length + and valid_block_mask_num_heads + ) + + +_h100_default_config = { + (torch.float32, 64): (128, 32, 4, 3), + (torch.float32, 128): (32, 64, 4, 3), + (torch.float32, 256): (32, 32, 4, 3), + (torch.bfloat16, 64): (128, 128, 4, 3), + (torch.bfloat16, 128): (128, 64, 8, 3), + (torch.bfloat16, 256): (64, 32, 4, 3), + (torch.float16, 64): (128, 128, 4, 3), + (torch.float16, 128): (128, 128, 8, 3), + (torch.float16, 256): (64, 32, 4, 3), +} + +_a100_default_config = { + (torch.float32, 64): (128, 32, 4, 3), + (torch.float32, 128): (128, 32, 4, 3), + (torch.float32, 256): (64, 16, 4, 3), + (torch.bfloat16, 64): (128, 64, 4, 3), + (torch.bfloat16, 128): (128, 64, 8, 3), + (torch.bfloat16, 256): (32, 64, 4, 3), + (torch.float16, 64): (128, 64, 4, 3), + (torch.float16, 128): (128, 64, 8, 3), + (torch.float16, 256): (32, 64, 4, 3), +} + +_rocm_default_config = { + (torch.float32, 64): (128, 32, 4, 1), + (torch.float32, 128): (128, 32, 4, 1), + (torch.float32, 256): (64, 16, 4, 1), + (torch.bfloat16, 64): (128, 64, 8, 1), + (torch.bfloat16, 128): (128, 64, 8, 1), + (torch.bfloat16, 256): (32, 64, 8, 1), + (torch.float16, 64): (128, 64, 8, 1), + (torch.float16, 128): (128, 64, 8, 1), + (torch.float16, 256): (32, 64, 4, 1), +} + + +class Mode(Enum): + fwd = auto() + bwd = auto() + + +def create_num_blocks_fake_generator(sparse_indices): + # The idea here is that we need to create a real tensor with real data + # that's representative for benchmarking. + # For example, returning all zeros for the `kv_num_blocks` input would mean + # that we are computing 0 blocks for each row, which would provide bogus + # autotuning results. + # + # In this case, we choose to use min(16, max_block) blocks, because I + # (Horace) think it'll probably result in pretty representative performance. + # If it's too short then prefetching won't help. If it's too long then + # autotuning will take longer for no good reason. + def create_num_blocks_fake(x) -> torch.Tensor: + num_blocks_for_autotuning = V.graph.sizevars.size_hint(sparse_indices.shape[-1]) + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + return torch.full( + size, + num_blocks_for_autotuning, + dtype=x.get_dtype(), + device=x.get_device(), + ) + + return create_num_blocks_fake + + +def create_indices_fake(x) -> torch.Tensor: + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + indices = torch.arange(0, size[-1], dtype=x.get_dtype(), device=x.get_device()) + indices = indices.expand(size).contiguous() + return indices + + +from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel + +from ..codegen.cpp_flex_attention_template import CppFlexAttentionTemplate + + +def check_cpu_supported(): + import os + import sys + + requires_avx2_on_cpu = ( + torch.cpu._is_avx2_supported() and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ) + supported = ( + requires_avx2_on_cpu + and not torch.xpu.is_available() + and not sys.platform == "darwin" + ) + return supported + + +def contiguous_last_dim(x): + """Ensure that realized IR node has a contiguous stride in the last dimension.""" + strides = x.maybe_get_stride() + if strides and strides[-1] != 1: + contiguous_stride_order = list(reversed(range(len(x.get_size())))) + return ExternKernel.require_stride_order(x, contiguous_stride_order) + return x + + +def lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + if kernel_options["OUTPUT_LOGSUMEXP"]: + raise NotImplementedError( + "torch.compile on CPU only supports inference and `return_lse` is not supported yet." + ) + if not check_cpu_supported(): + raise NotImplementedError( + "torch.compile on current platform is not supported for CPU." + ) + + fake_buffers: list[Buffer] = [] # noqa: F821 + + # [Note] Handle the case where the split sizes are not statically known. + # The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime. + # We use symbols to represent them during the compilation here. + # They'll be replaced by the string "cur_qSplitSize" and "cur_kvSplitSize" in + # the modification function of the CppFlexAttentionTemplate class. + cur_qSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + cur_kvSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + shape_env = V.graph.sizevars.shape_env + + # We don't know the concrete value of cur_qSplitSize and cur_kvSplitSize during the compilation. + # Mark symbols > 1 to ensure broadcasting is always applied. + # This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`. + shape_env.var_to_range[cur_qSplitSize] = ValueRanges(2, int_oo) + shape_env.var_to_range[cur_kvSplitSize] = ValueRanges(2, int_oo) + + score_dtype = torch.float + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + if subgraph_buffer is not None: + if isinstance(subgraph_buffer, list): + for _buf in subgraph_buffer: + if _buf is not None: + _buf.freeze_layout() + else: + subgraph_buffer.freeze_layout() + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + + # The original mask_graph works on a scalar and only includes + # the logic of calculating the mask value. + # We need to add the logic of applying the mark to the qk_data tensor + # into the graph for the later codegen of this part. + # Example: + # mask_graph: + # def mask_fn(b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # return mask + # The converted_mask_graph should be: + # def converted_mask_fn(qk_data, b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf"))) + # return qk_data + def convert_mask_graph_module(mask_graph): + gm = copy.deepcopy(mask_graph.graph_module) + graph = gm.graph + # Add qk_data as the first input + with graph.inserting_before(next(iter(graph.nodes))): + qk_data_node = graph.placeholder("qk_data") + + # Find the node that returns the mask + output_node = None + for node in graph.nodes: + if node.op == "output": + output_node = node + break + + # Get the mask node + assert output_node is not None + mask_node = output_node.args[0] + + size_node = [cur_qSplitSize, cur_kvSplitSize] + # Create a new node for torch.full + with graph.inserting_after(mask_node): + full_node = graph.call_function( + torch.full, + args=(size_node, -float("inf")), + kwargs={"dtype": score_dtype}, + ) + + # Create a new node for torch.where + with graph.inserting_after(full_node): + where_node = graph.call_function( + torch.ops.aten.where, args=(mask_node, qk_data_node, full_node) + ) + + # Update the output node to return the result of torch.where + output_node.args = (where_node,) + + graph.lint() + converted = torch.fx.GraphModule(gm, graph) + return converted + + converted_mask_graph_module = convert_mask_graph_module(mask_graph) + + mask_graph_buffer = build_subgraph_module_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), + converted_mask_graph_module, + ) + + # Clear the pending fresh unbacked symbols that are created for cur_qSplitSize and cur_kvSplitSize in the current kernel. + pending = V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols + V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols = [ + x for x in pending if x not in (cur_qSplitSize, cur_kvSplitSize) + ] + + buffer_list = ( + placeholder_inps + + list(score_mod_other_buffers) + + mask_graph_placeholder_inps + + list(mask_mod_other_buffers) + ) + for item in buffer_list: + if isinstance(item, TensorBox): + fake_buffers.append(item.data.data) # type: ignore[attr-defined] + + # CPU kernel requires last dim to be contiguous + query, key, value = map(contiguous_last_dim, [query, key, value]) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + if len(OrderedSet([query.get_name(), key.get_name(), value.get_name()])) != 3: + raise NotImplementedError( + "Unsupported for now if query, key, value are the same buffer." + ) + if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]: + raise NotImplementedError( + "`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. " + f"Found input tensors are `{query.get_dtype()}`." + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + B = Bq + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, query.get_stride()) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + _choices: list[Any] = [] + input_nodes = [query, key, value, kv_num_blocks, kv_indices] + if not full_kv_num_blocks: + no_full_kv_block = True + else: + no_full_kv_block = False + input_nodes += [full_kv_num_blocks] + input_nodes += [full_kv_indices] + has_other_buffer = False + kernel_input_name_to_buffer = {} + if score_mod_other_buffers or mask_mod_other_buffers: + has_other_buffer = True + + for prefix, buffers in [ + ("score_others", score_mod_other_buffers), + ("mask_others", mask_mod_other_buffers), + ]: + kernel_input_name_to_buffer.update( + {f"{prefix}_{i}": buf for i, buf in enumerate(buffers)} + ) + input_nodes += [ + value + for value in kernel_input_name_to_buffer.values() + if not isinstance(value, sympy.Symbol) + ] + + skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False) + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) + ), ( + "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + ) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) + ), ( + "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + ) + CppFlexAttentionTemplate.add_choices( + choices=_choices, + input_nodes=input_nodes, + layout=layout, + scale=scale, + score_mod=None if skip_mask_score else subgraph_buffer, + mask_mod=None if skip_mask_score else mask_graph_buffer, + kv_block_size=SPARSE_KV_BLOCK_SIZE, + q_block_size=SPARSE_Q_BLOCK_SIZE, + has_other_buffer=has_other_buffer, + no_full_kv_block=no_full_kv_block, + fake_buffers=fake_buffers, + len_score_other=len(score_mod_other_buffers), + len_mask_other=len(mask_mod_other_buffers), + kernel_input_name_to_buffer=kernel_input_name_to_buffer, + block_vars=(cur_qSplitSize, cur_kvSplitSize), + ) + inputs_for_autotuning = [ + query, + key, + value, + ] + res = autotune_select_algorithm( + "flex_attention", + _choices, + inputs_for_autotuning, + layout, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + res.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + res.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (res,) + + +def is_power_of_2(n): + return n != 0 and ((n & (n - 1)) == 0) + + +def next_power_of_two(n): + if n <= 0: + return 1 + return 2 ** math.ceil(math.log2(n)) + + +def set_head_dim_values( + kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars +): + """ + Mutates kernel options, adding head dimension calculations. + + Args: + kernel_options: Dictionary to populate with options + qk_head_dim: Query/Key head dimension + v_head_dim: Value head dimension + graph_sizevars: Graph size variables object with evaluate_static_shape method + + """ + # QK dimensions + qk_head_dim_static = graph_sizevars.evaluate_static_shape(qk_head_dim) + kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static) + kernel_options.setdefault( + "QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static) + ) + + # V dimensions + v_head_dim_static = graph_sizevars.evaluate_static_shape(v_head_dim) + kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static) + kernel_options.setdefault( + "V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static) + ) + + # Safety flag + kernel_options.setdefault( + "SAFE_HEAD_DIM", + is_power_of_2(qk_head_dim_static) and is_power_of_2(v_head_dim_static), + ) + + +@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) +def flex_attention( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + if query.get_device().type == "cpu": + return lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + # below is cuda path if device is not cpu + # tl.dot does not support embedding size less than 16 + small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16)) + small_dv = V.graph.sizevars.evaluate_expr(sympy.Lt(value.get_size()[-1], 16)) + if small_dqk or small_dv: + raise NotImplementedError( + f"NYI: embedding dimension of the query, key, and value must be " + f"at least 16 but got E={query.get_size()[-1]} and Ev={value.get_size()[-1]}" + ) + + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + enable_gqa = V.graph.sizevars.evaluate_expr( + sympy.Ne(query.get_size()[1], key.get_size()[1]), + ) + if _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa): + return create_flex_decoding_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), ( + "Query length must be greater than 0" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), ( + "Key length must be greater than 0" + ) + + B = Bq + + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # NB it is okay that the v_head_dim is different + # We are using these to match fill order of the output. + q_strides = query.get_stride() + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, q_strides) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = [B, Hq, seq_len_q] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA broadcast factor. + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is None if partial blocks are not computed + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + configs = V.choices.get_flex_attention_fwd_configs(head_dim, dtype) + + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + original_kernel_options = kernel_options.copy() + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + ): + if len(configs) == 1: + raise ValueError( + f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We " + f"got Q_BLOCK_SIZE={SPARSE_Q_BLOCK_SIZE} and KV_BLOCK_SIZE={SPARSE_KV_BLOCK_SIZE}." + ) + continue + + cur_kernel_options = original_kernel_options.copy() + # Performance tuning + # Triton parameters + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Disabling TMA by default, only explicit kernel_options supported for now + cur_kernel_options.setdefault("USE_TMA", False) + + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + input_gen_fns = { + 4: create_num_blocks_fake_generator(kv_indices), + 5: create_indices_fake, + 6: create_num_blocks_fake_generator(full_kv_indices), + 7: create_indices_fake, + } + + out = autotune_select_algorithm( + "flex_attention", + choices, + # Need to filter out symbols since there is an invariant + # that all input_nodes are of type IRNode + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + out.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + out.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (out, logsumexp) + + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +def flex_attention_backward_grid( + batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta +): + """How is this kernel parallelized? + Currently this is only parallelizing over batch* kv_heads, but we can, and want to + parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size). + To do this will either require atomic updates to some grad values or to have a two pass kernel design. + """ + import triton + + return ( + triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads) + + triton.cdiv(num_key_value, meta["BLOCK_N1"]), + 1, + batch_size * kv_heads, + ) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=r""" +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_hz = tl.program_id(2) + off_zq = off_hz // HKV # q batch idx + off_hkv = off_hz % HKV # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + if not IS_DIVISIBLE: + if hi >= 1: + for start_n in range(0, hi - 1): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1)}} + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds prior to the last loop + m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + if not IS_DIVISIBLE: + if hi >= 1: + for start_m in range(0, hi - 1): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1) }} + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds prior to the last loop + n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) + + dsT = grad_scores + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + """ + + compute_next_offset_func + + get_bounded_indices_func + + load_checked_2d, +) + + +def validate_joint_graph(joint_graph: torch.fx.Graph): + """We do some pre lowering graph checks in order to raise nicer error messages""" + for node in joint_graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.flex_lib.zeros_and_scatter.default + ): + for user in node.users: + if user.op != "output": + raise NotImplementedError( + "Using multiple indexing operations on the same tensor that requires gradients " + "in a score_mod function is not currently supported. " + "This typically happens when indexing the same tensor multiple times, like:\n\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias[kv_idx] # bias used twice!\n\n" + "A valid workaround is to clone() the tensors that will be indexed multiple times. For example:\n\n" + " bias1 = bias.clone()\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias1[kv_idx]\n\n" + "Note that this solution will use additional memory." + ) + return + + +@dataclass(frozen=True) +class JointOutputResult: + """Results from processing joint outputs.""" + + grad_input: ComputedBuffer + captured_grads_compute: list[ComputedBuffer] + captured_grads: list[Optional[TensorBox]] + mutated_grads: list[TensorBox] + + +def process_joint_outputs( + all_joint_outputs: SubgraphResults, num_placeholders: int +) -> JointOutputResult: + """Process joint outputs and extract various buffers needed for lowering + + Args: + all_joint_outputs: List of all the outputs from build_subgraphs + num_placeholders: The number of placeholder inputs, used to skip over unused backward compute buffers + + Returns: + JointOutputResult containing processed buffers and gradients + """ + assert isinstance(all_joint_outputs, list) + assert all_joint_outputs[0] is not None, ( + "joint_subgraph_buffer is None - this is a bug!" + ) + + joint_buffer = all_joint_outputs[0] + other_grads = all_joint_outputs[num_placeholders - 1 :] + + # outer_grads has the structure: Len(other_buffer_grads) if buffer doesn't require grad than it will be None + # We only grab the buffers that require grad for inlining into kernel + grads_compute = [buf for buf in other_grads if buf is not None] + + def get_out(buf): + if buf is None: + return None + assert isinstance(buf, ComputedBuffer) + assert buf.name is not None + return TensorBox.create(V.graph.get_buffer(buf.name)) + + grads_out = [get_out(x) for x in other_grads] + mutated_grads = [buf for buf in grads_out if buf is not None] + + return JointOutputResult( + grad_input=joint_buffer, + captured_grads_compute=grads_compute, + captured_grads=grads_out, + mutated_grads=mutated_grads, + ) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + ( + query, + key, + value, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + device = query.get_device() + dtype = query.get_dtype() + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) + for name, dtype in [ + ("score", dtype), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + fw_subgraph_buffer = build_subgraph_buffer( + fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph + ) + + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("grad_score_mod", dtype, device) + ] + # Sometimes we have weird unused nodes here + joint_graph.graph_module.graph.eliminate_dead_code() + + # It is hard to raise nice errors for some joint graphs during subgraph lowering + # This lets us do some checks before attempting to lower + validate_joint_graph(joint_graph.graph_module.graph) + + all_joint_outputs = build_subgraph_buffer( + joint_placeholder_inps + list(score_mod_other_buffers), + joint_graph, + ) + + joint_outputs = process_joint_outputs( + all_joint_outputs, len(joint_placeholder_inps) + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + + mask_graph_buffer = mask_graph_buffer + + # Construct layout with stride order matching K + key_size = [Bq, Hkv, seq_len_kv, qk_head_dim] + key_strides = infer_dense_strides(key_size, key.get_stride()) + + layout_broadcasted_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key_size, + stride=[sympy.sympify(s) for s in key_strides], + ) + + # Create delta which will is needed for the bwd's kernel + grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2)) + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + delta = lowerings[aten.sub](delta, grad_lse_exp2) + delta = ExternKernel.require_contiguous(delta) + + grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta]) + + # # see NOTE:[TritonTemplates with multiple outputs] + query_size = [Bq, Hq, seq_len_q, qk_head_dim] + grad_query_strides = infer_dense_strides(query_size, query.get_stride()) + grad_query = empty_strided( + query_size, + stride=[sympy.sympify(s) for s in grad_query_strides], + dtype=query.get_dtype(), + device=query.get_device(), + ) + + # Construct output layout with stride order matching value + value_size = [Bq, Hkv, seq_len_kv, v_head_dim] + value_strides = infer_dense_strides(value_size, value.get_stride()) + + broadcasted_grad_value = empty_strided( + value_size, + stride=[sympy.sympify(s) for s in value_strides], + dtype=value.get_dtype(), + device=value.get_device(), + ) + + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA factor + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed. + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = ( + empty(0, device=query.get_device()) for _ in range(4) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype) + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + original_kernel_options = kernel_options.copy() + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_n != 0 + ): + continue + + # Performance tuning + # Triton heuristics + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for backward kernels options and delete forward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("bwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("fwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N1", conf.block_n) + cur_kernel_options.setdefault("BLOCK_M2", conf.block_n) + cur_kernel_options.setdefault("BLOCK_N2", conf.block_m) + + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + **cur_kernel_options, + ) + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + + joint_outputs.mutated_grads + ) + input_gen_fns = { + 8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks + 9: create_indices_fake, + 10: create_num_blocks_fake_generator(q_indices), # q_num_blocks + 11: create_indices_fake, + 12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks + 13: create_indices_fake, + 14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks + 15: create_indices_fake, + } + + broadcasted_grad_key = autotune_select_algorithm( + "flex_attention_backward", + choices, + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout_broadcasted_k, + input_gen_fns=input_gen_fns, + ) # [Bq, Hkv, seq_len_kv, k_head_dim] + + # need subgraph inputs and outputs to analyze all symints used in flex attention + broadcasted_grad_key.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + broadcasted_grad_key.data.data.subgraph_outs = get_bwd_subgraph_outputs( + fw_subgraph_buffer, mask_graph_buffer, joint_outputs + ) + + if V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv)): + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value + else: + assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. " + f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} " + f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" + ) + grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) + grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) + + return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads)) + + +def get_bwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, + mask_graph_buffer: SubgraphResults, + joint_outputs: JointOutputResult, +) -> list[Optional[Union[ComputedBuffer, TensorBox]]]: + subgraph_buffer = ( + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + joint_output_buffers = [ + joint_outputs.grad_input, + *joint_outputs.captured_grads_compute, + *joint_outputs.captured_grads, + *joint_outputs.mutated_grads, + ] + + return [*subgraph_buffer, *mask_graph_buffer, *joint_output_buffers] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex_decoding.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..0c663421a036e48db9aa0d15cdf4280d3cd6e0c3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex_decoding.py @@ -0,0 +1,628 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" + +from typing import Any + +import sympy + +import torch +from torch._inductor.virtualized import V + +from .. import ir +from ..ir import FixedLayout, FlexibleLayout +from ..lowering import empty, empty_strided, lowerings +from ..runtime.runtime_utils import is_power_of_2, next_power_of_2 +from ..select_algorithm import autotune_select_algorithm, SymbolicGridFn, TritonTemplate +from .flex_attention import ( + compute_forward_block_mn, + compute_forward_inner, + compute_next_offset_func, + create_indices_fake, + create_num_blocks_fake_generator, + get_bounded_indices_func, + get_fwd_subgraph_outputs, + load_checked_2d, + load_checked_block, + maybe_realize, +) + + +aten = torch.ops.aten +prims = torch.ops.prims + + +@SymbolicGridFn +def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta): + """How is this kernel parallelized? + We create a grid of (batch_size * kv_heads, SPLIT_KV, 1) + Each block is responsible for iterating over blocks of keys and values calculating + the local output for their tile of keys and values over all full length of query. + groups of SPLIT_KV blocks then combine their output to produce the final result. + """ + + return (batch_size * kv_heads, meta["SPLIT_KV"], 1) + + +flex_decoding_template = TritonTemplate( + name="flex_decoding", + grid=flex_decoding_grid, + source=r""" + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0) % HKV + off_t = tl.program_id(1) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Apply both score_mod and mask_mod + + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + """ + + compute_forward_inner + + get_bounded_indices_func + + load_checked_block + + load_checked_2d + + compute_next_offset_func + + compute_forward_block_mn, +) + + +def get_split_k(B: int, H: int, Mk: int) -> int: + num_SM = torch.cuda.get_device_properties("cuda").multi_processor_count + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers" + split_k = num_SM // bh * 2 # Each SM should at least get one block. + # TODO: workload evening at runtime for splits fully masked out. + # Before we have runtime workload evening, assign 2 splits per SM. + split_k = max(split_k, 1) + + return split_k + + +def create_flex_decoding_kernel(*args, **kwargs): + from .flex_attention import set_head_dim_values + + ( + query, + key, + value, + block_mask, + scale, + kernel_options, + score_mod_subgraph, + mask_mod_subgraph, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, # full_kv_num_blocks, + full_kv_indices, # full_kv_indices, + _, # q_num_blocks + _, # q_indices + _, # full_q_num_blocks, + _, # full_q_indices, + _, # SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + _, + ) = block_mask + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + B = Bq + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + + # TODO: Fix flex decoding non-divisible case! + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # Calculate GQA head sharing + gqa_shared_heads = Hq // Hkv + if not is_power_of_2(gqa_shared_heads): + raise ValueError( + "Number of shared query heads sharing the same KV head must be power of 2. " + ) + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + # Create a plackeholder full block list in case it is empty + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + choices: list[Any] = [] + dtype = key.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(key.get_size()[-1]) + configs = V.choices.get_flex_decode_configs(head_dim, dtype) + + # TODO: fix autotuning. + + kernel_options.setdefault("SM_SCALE", scale) + kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv)) + MAX_SPLIT_KV = kernel_options["SPLIT_KV"] + + # create config dependent intermediate buffers + buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim] + buf_ML_shape = buf_ACC_shape[:-1] + buf_M = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + buf_L = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + + layout_acc = FixedLayout( + query.get_device(), + torch.float32, + buf_ACC_shape, + FlexibleLayout.contiguous_strides(buf_ACC_shape), + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + kernel_options.setdefault( + "BLOCK_M", + ( + # m + # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0)) + # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin + max( + next_power_of_2( + V.graph.sizevars.size_hint( + seq_len_q, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + * gqa_shared_heads + ), + 16, + ) + ), + ) + + query = ir.ExternKernel.realize_input(query) + stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride() + + # Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D] + gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim) + gqa_query_stride = ( + stride_b, + stride_hq * gqa_shared_heads, + stride_hq, + stride_seq_len_q, + stride_qk_head_dim, + ) + query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride) + + V.graph.sizevars.guard_leq( + seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"]) + ) + + kernel_options.setdefault( + "SAFE_M_BOUNDARY", + ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0, + ) + # TODO: This feels sketchy + kernel_options.setdefault("SAFE_N_BOUNDARY", True) + # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + + original_kernel_options = kernel_options.copy() + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if SPARSE_KV_BLOCK_SIZE % conf.block_n != 0: + continue + + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + # Performance tuning + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Set default to False + cur_kernel_options.setdefault("USE_TMA", False) + + # Add ROCm-specific parameters if they exist in the config + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_decoding_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout_acc, + subgraphs=[ + score_mod_subgraph, + mask_mod_subgraph, + ], + mutated_inputs=[buf_M, buf_L], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + + inputs_for_flex_decoding = ( + [ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + + input_gen_fns = { + 5: create_num_blocks_fake_generator(kv_indices), + 6: create_indices_fake, + 7: create_num_blocks_fake_generator(full_kv_indices), + 8: create_indices_fake, + } + + buf_ACC = autotune_select_algorithm( + "flex_decoding", + choices, + inputs_for_flex_decoding, + layout_acc, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + buf_ACC.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + buf_ACC.data.data.subgraph_outs = get_fwd_subgraph_outputs( + score_mod_subgraph, mask_mod_subgraph + ) + + # Reduction + + g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0] + # See [Note] Handle fully masked out rows: + # g_M Is the global max among split kv blocks. + masked_rows = lowerings[aten.eq](g_M, -float("inf")) + adj_M = lowerings[aten.sub](buf_M, g_M) + adj_M = lowerings[aten.where](masked_rows, 0, adj_M) + alpha = lowerings[aten.exp2](adj_M) + + buf_L = lowerings[aten.mul](buf_L, alpha) + g_L = lowerings[aten.sum](buf_L, axis=1) + masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1) + g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L) + logsumexp = lowerings[aten.log2](g_L) + logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1)) + + alpha_unseq = lowerings[aten.unsqueeze](alpha, 4) + buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq) + output = lowerings[aten.sum](buf_ACC, axis=1) + L_unseq = lowerings[aten.unsqueeze](g_L, 3) + output = lowerings[aten.div](output, L_unseq) + output = lowerings[prims.convert_element_type](output, query.get_dtype()) + + return ( + output, + logsumexp, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..b3135a7382133a2730259dd1cd8be282b5f552c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm.py @@ -0,0 +1,1373 @@ +# mypy: allow-untyped-defs +import functools +import logging +from typing import Any, Optional + +import sympy + +import torch +from torch._dynamo.utils import counters +from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + context_add_strides, + context_add_using_tf32, + mm_operations, +) +from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +from torch._inductor.virtualized import V +from torch.fx.experimental.proxy_tensor import make_fx +from torch.torch_version import TorchVersion + +from .. import config as inductor_config, ir +from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate +from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate +from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate +from ..codegen.subgraph import SubgraphTemplate +from ..ir import FlexibleLayout, is_triton +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + _use_cutlass_for_op, + get_k_splits, + get_tma_workspace_arg, + use_aten_gemm_kernels, + use_ck_gemm_template, + use_ck_tile_gemm_template, + use_cpp_gemm_template, + use_cutlass_template, + use_decompose_k_choice, + use_triton_template, + use_triton_tma_template, +) +from .mm_common import ( + _is_static_problem, + addmm_epilogue, + mm_args, + mm_config_kwargs, + mm_grid, + mm_options, + persistent_mm_grid, + persistent_mm_options, + scale_mm_epilogue, + scaled_mm_options, +) + + +try: + import triton + + triton_version = TorchVersion(triton.__version__) + has_triton = True +except ImportError: + triton_version = TorchVersion("0.0.0") + has_triton = False + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +mm_template = TritonTemplate( + name="mm", + grid=mm_grid, + source=( + r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and M >= BLOCK_M: + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and N >= BLOCK_N: + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", indent_width=8)}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""" + if (torch.version.hip is None) or triton_version >= "3.3.0" + # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 + # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. + # See more details in https://github.com/pytorch/pytorch/pull/146293 + else r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", indent_width=8)}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""" + ), + cache_codegen_enabled_for_template=True, + prologue_loads_all_inputs=True, +) + +persistent_tma_mm_template = TritonTemplate( + name="mm_persistent_tma", + grid=persistent_mm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + start_pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + width = GROUP_M * grid_n + rk_for_mask = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + {%- if TMA_EXPERIMENTAL_API %} + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + global_size=[M, K] if A_ROW_MAJOR else [K, M], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + global_size=[K, N] if B_ROW_MAJOR else [N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + a_desc = a_desc_ptr + b_desc = b_desc_ptr + {%- else %} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[K, 1] if A_ROW_MAJOR else [M, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[N, 1] if B_ROW_MAJOR else [K, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + {%- endif %} + + pid_m = 0 + pid_n = 0 + rm = 0 + rn = 0 + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + + rm = pid_m * BLOCK_M + rn = pid_n * BLOCK_N + + rk = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + A.dtype.element_ty, + ) + b = tl._experimental_descriptor_load( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + B.dtype.element_ty, + ) + {%- else %} + a = tl.load_tensor_descriptor( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + ) + b = tl.load_tensor_descriptor( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + ) + {%- endif %} + acc += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + if ki == k_tiles - 1: + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + +""", +) + +load_scales = r""" +@triton.jit +def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): + if SCALING_ROWWISE: + # For row-wise scaling, we'll return the pointers + return a_scale_ptr, b_scale_ptr + else: + # For per-tensor scaling, we'll load the scalar values + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + return a_scale, b_scale +""" + + +apply_scaling = r""" +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALING_ROWWISE: + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale +""" + + +device_tma = r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALING_ROWWISE: + stride_a_scale_m = 1 + stride_b_scale_n = 1 + else: + stride_a_scale_m = 0 + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + {%- if TMA_EXPERIMENTAL_API %} + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + a_desc = a_desc_ptr + b_desc = a_desc_ptr + {%- else %} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + {%- endif %} + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + {%- else %} + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + {%- endif %} + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +""" + + +scaled_mm_device_tma_template = TritonTemplate( + name="scaled_mm_device_tma", + grid=persistent_mm_grid, + source=device_tma + load_scales + apply_scaling, +) + + +# prevent duplication registration of extern functions +@functools.cache +def lazy_register_extern_choice(fn): + return ExternKernelChoice(fn) + + +aten_mm = ExternKernelChoice(torch.mm, "at::mm_out") + +aten_addmm = ExternKernelChoice( + torch.addmm, "at::addmm_out", op_overload=aten.addmm.default +) + +aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm_out") + +aten__sparse_semi_structured_mm = ExternKernelChoice( + torch._sparse_semi_structured_mm, + "at::_sparse_semi_structured_mm", + has_out_variant=False, +) + +aten__fp8_mm = ExternKernelChoice( + torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out +) + + +def _is_int8_mat(mat): + return mat.get_dtype() in (torch.int8, torch.uint8) + + +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + return m * n > 2**13 + + +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10 + + +def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): + """ + Giving torch.addmm a 1D tensor calls a different (faster) cublasLt + kernel under the hood. There are a few shapes where this is slower, + but they are rare. + """ + if inp.stride(0) == 0 or inp.size(0) == 1: + return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta) + return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) + + +def check_supported_striding(mat_a, mat_b) -> None: + def is_row_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[1], 1) + + def is_col_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[0], 1) + + def has_zero_dim(size) -> bool: + return bool( + V.graph.sizevars.statically_known_equals(size[0], 0) + or V.graph.sizevars.statically_known_equals(size[1], 0) + ) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + +aten_bias_addmm = ExternKernelChoice(bias_addmm, None) + + +def decomposeK(a, b, k_splits): + m = a.shape[0] + n = b.shape[1] + k = a.shape[1] + + k_parts = k // k_splits + B = k_splits + a_reshaped = torch.permute(a.reshape(m, B, k_parts), (1, 0, 2)) + b_reshaped = b.reshape(B, k_parts, n) + result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32) + reduced_buf = torch.sum(result, 0) + return reduced_buf.to(a.dtype) + + +@register_lowering(aten.mm, type_promotion_kind=None) +def tuned_mm(mat1, mat2, *, layout=None): + """ + Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.) + """ + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + device_type = ir.get_device_type(mat1) + name = "mm" + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + aten_layout = layout + if not (inductor_config.max_autotune or inductor_config.max_autotune_gemm): + aten_layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + + # options to tune from + choices = ( + [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else [] + ) + static_shape, is_nonzero = _is_static_problem(layout) + + mm_configs = V.choices.get_base_mm_configs(device_type) + persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + extra_mm_configs = V.choices.get_extra_mm_configs(device_type) + + dtype = mat1.get_dtype() + if is_nonzero and use_triton_template(layout): + for config in mm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + if use_triton_tma_template(mat1, mat2): + for config in persistent_mm_configs( + m, + n, + k, + **mm_config_kwargs( + device_type, _is_large_block_for_cpu, dtype.itemsize + ), + ): + persistent_tma_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat1.get_device(), + ), + **mm_options(config, m, n, k, layout), + **persistent_mm_options(mat1, mat2), + ) + + from torch._inductor.ir import get_free_symbols + + # Only do split-k optimization if K is much larger than m, n and m, n are small + # and if there aren't any unbacked symbols + unbacked_symbols = any( + len(get_free_symbols(itr, unbacked_only=True)) > 0 + for itr in ( + mat1.get_size(), + mat1.get_stride(), + mat2.get_size(), + mat2.get_stride(), + ) + ) + if use_decompose_k_choice(m, n, k) and not unbacked_symbols: + from torch._dispatch.python import enable_python_dispatcher + + from ..decomposition import select_decomp_table + + k_splits = get_k_splits(m, n, k) + for k_split in k_splits: + if not V.graph.sizevars.statically_known_true( + sympy.Eq(sympy.Mod(k, k_split), 0) + ): + continue + + with enable_python_dispatcher(): + decompositions = select_decomp_table() + + decompose_k_subgraph_template = SubgraphTemplate( + name=f"decompose_k_mm_{k_split}_split", + make_fx_graph=make_fx( + functools.partial(decomposeK, k_splits=k_split), + decompositions, + ), + ) + + decompose_k_subgraph_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + ) + + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("mm") + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k): + CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2]) + + if use_cpp_gemm_template(layout, mat1, mat2): + CppGemmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + + input_nodes = [mat1, mat2] + if ( + is_nonzero + and use_triton_template(layout) + and torch._inductor.config.run_autoheuristic(name) + and is_triton(mat1) + ): + always_included = [] + if use_aten_gemm_kernels(): + always_included.append("extern_mm") + num_choices_before_extra_configs = len(choices) + for config in extra_mm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + # using AutoHeuristic for ranking + ah_choices = mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + mm_operations(), + None, + top_k=10, + always_included=always_included, + ) + if not torch._inductor.config.collect_autoheuristic(name): + # if we are collecting data, we do not want to modify choices + if ah_choices is not None and len(ah_choices) > 0: + # the order in which autoheuristic returns choices is not the same as + # as the order of choices, which affects things like epilogue fusion. + # once epilogue fusion benchmarks choices in sorted order, I think we can + # just use the order returned by autoheuristic + choices = [choice for choice in choices if choice in ah_choices] + else: + choices = choices[:num_choices_before_extra_configs] + + for k in inductor_config.external_matmul: + choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout)) + + return autotune_select_algorithm(name, choices, [mat1, mat2], layout) + + +@register_lowering(aten._int_mm, type_promotion_kind=None) +def tuned_int_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=torch.int32 + ) + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._int_mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + device_type = ir.get_device_type(mat1) + + static_shape, is_nonzero = _is_static_problem(layout) + use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) + + choices = ( + [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + ) + + if use_cutlass and _use_cutlass_for_op("int_mm"): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) + + int8_mm_configs = V.choices.get_int8_mm_configs(device_type) + + if is_nonzero and use_triton_template(layout, enable_int32=True): + for config in int8_mm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout) + + +@register_lowering(aten.addmm, type_promotion_kind=None) +def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + device_type = ir.get_device_type(mat1) + m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) + static_shape, is_nonzero = _is_static_problem(layout) + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.addmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + if (not is_nonzero) or ( + not (inductor_config.max_autotune or inductor_config.max_autotune_gemm) + ): + # Use a FlexibleLayout if we are not autotuning. + # This allows padding strides for the output. + from torch._inductor.ir import FixedLayout, FlexibleLayout + + if isinstance(layout, FixedLayout): + layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + choices = ( + [ + aten_addmm.bind( + (inp, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout) + + choices = ( + [ + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + use_aten_gemm_kernels() + and inp_expanded.get_stride()[0] == 0 + and inp_expanded.get_device().type == "cuda" + and inductor_config.triton.autotune_cublasLt + ): + # unexpand inp to make sure fused addmm from cublasLt is used + choices.insert( + 0, + aten_bias_addmm.bind( + (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta + ), + ) + + mm_configs = V.choices.get_base_mm_configs(device_type) + persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + + dtype = mat1.get_dtype() + if is_nonzero and use_triton_template(layout): + for config in mm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), + ) + + if use_triton_tma_template(mat1, mat2): + for config in persistent_mm_configs( + m, + n, + k, + **mm_config_kwargs( + device_type, _is_large_block_for_cpu, dtype.itemsize + ), + ): + persistent_tma_mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat1.get_device(), + ), + **mm_options(config, m, n, k, layout), + **persistent_mm_options(mat1, mat2), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("addmm") + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + input_reorder=[2, 0, 1], + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + input_reorder=[2, 0, 1], + ) + + if use_cpp_gemm_template(layout, mat1, mat2): + CppGemmTemplate.add_choices( + choices, + layout, + [inp_expanded, mat1, mat2], + alpha=alpha, + beta=beta, + has_bias=True, + ) + + return autotune_select_algorithm( + "addmm", choices, [inp_expanded, mat1, mat2], layout + ) + + +@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None) +def tuned_sparse_semi_structured_mm( + mat1, mat1_meta, mat2, *, out_dtype=None, layout=None +): + from torch._inductor.select_algorithm import realize_inputs + + mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2) + m1, k1 = mat1.get_size() + m2, _ = mat1_meta.get_size() + k2, n = mat2.get_size() + m = V.graph.sizevars.guard_equals(m1, m2) + k = V.graph.sizevars.guard_equals(2 * k1, k2) + + if layout is None: + from torch._inductor.ir import FixedLayout + + layout = FixedLayout( + mat2.get_device(), + out_dtype if out_dtype else mat2.get_dtype(), + [m, n], + [n, 1], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + choices = ( + [ + aten__sparse_semi_structured_mm.bind( + (mat1, mat1_meta, mat2), layout, out_dtype=out_dtype + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + m * n != 0 + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("sparse_semi_structured_mm") + ): + CUTLASS2xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True + ) + + return autotune_select_algorithm( + "sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout + ) + + +add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) + + +@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] +def tuned_scaled_mm( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + layout=None, +): + """ + Performs an optimized matrix multiplication where scaling factors are applied + to the inputs and/or output. + + Args: + mat1 (Tensor): First input matrix + mat2 (Tensor): Second input matrix + scale1 (Tensor): Scale factor applied to mat1 (supports broadcasting) + scale2 (Tensor): Scale factor applied to mat2 (supports broadcasting) + bias (Tensor, optional): Optional bias tensor to add to the result + layout: Layout hint for optimization + + Returns: + Tensor: The result of the scaled matrix multiplication + """ + m, n, k, layout, mat_a, mat_b = mm_args( + mat_a, mat_b, layout=layout, out_dtype=out_dtype + ) + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + device_type = ir.get_device_type(mat_a) + check_supported_striding(mat_a, mat_b) + + scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) + + input_nodes: tuple[Any, ...] + + if not bias: + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real) + else: + bias_real = realize_inputs(bias) + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real) + + aten_choice = aten__fp8_mm.bind( + input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + choices = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + # We dont have triton lowerings for the MX variants yet + if scale_a.dtype != torch.float32: + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + _, is_nonzero = _is_static_problem(layout) + + scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) + scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( + device_type + ) + + if is_nonzero and use_triton_template(layout, enable_float8=True): + triton_input_nodes: tuple[Any, ...] + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + triton_bias = L[aten.unsqueeze](bias, 0) + else: + triton_bias = bias + + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + else: + triton_scale_a = scale_a + triton_scale_b = scale_b + + if bias: + triton_input_nodes = ( + mat_a, + mat_b, + triton_scale_a, + triton_scale_b, + triton_bias, + ) + suffix_args = 3 + else: + triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b) + suffix_args = 2 + + # TODO (paulzhan): There is no template that exists for bias and TMA + # Don't run tma template currently if bias exists + if use_triton_tma_template(mat_a, mat_b) and not bias: + for config in scaled_persistent_mm_configs(m, n, k): + kwargs = scaled_mm_options( + config, + m, + n, + k, + layout, + scale_a, + scale_b, + use_fast_accum, + device_tma=True, + ) + scaled_mm_device_tma_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat_a.get_device(), + ), + **kwargs, + ) + + for config in scaled_mm_configs(m, n, k): + if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)): + # Triton crashes however uncommon for real workloads + continue + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)): + continue + + kwargs = scaled_mm_options( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + # possibly appends a TritonTemplateCaller to choices + mm_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + **kwargs, + suffix_args=suffix_args, + epilogue_fn=scale_mm_epilogue(), + epilogue_fn_hash="scale_mm_epilogue", + ) + + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("scaled_mm") + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + input_nodes, # type: ignore[arg-type] + use_fast_accum=use_fast_accum, # type: ignore[arg-type] + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) + + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + +@functools.cache +def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: + props = torch.cuda.get_device_properties(index or 0) + return props.major <= 7 + + +def dims_are_int(dims): + return all(isinstance(dim, int) for dim in dims) + + +def mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + ops, + precondition, + top_k: Optional[int] = None, + always_included=None, +): + m, n, k = get_size_hints(mat1, mat2, m, n, k) + if not dims_are_int([m, n, k]): + return None + mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2) + + def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride): + context = AHContext() + context.add_feature("m", m) + context.add_feature("k", k) + context.add_feature("n", n) + context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True) + context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True) + context_add_strides(context, "mat1", mat1_stride) + context_add_strides(context, "mat2", mat2_stride) + context.add_feature( + "mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True + ) + context.add_feature( + "mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True + ) + if name == "mm": + context_add_using_tf32(context, mat1.layout.dtype) + return context + + def fallback(): + return None + + context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride) + autoheuristic = AutoHeuristicSelectAlgorithm( + fallback=fallback, + choices=choices, + input_nodes=input_nodes, + context=context, + name=name, + augment_context=ops, + precondition=precondition, + ) + + if top_k is not None: + # TODO: is there a cleaner way to ensure aten.mm is always included? + return autoheuristic.get_top_k_choices_caller( + top_k, always_included=always_included + ) + + return autoheuristic.get_choice_caller() + + +def get_size_hints(mat1, mat2, m, n, k): + if not isinstance(m, int) or not isinstance(k, int): + (m, k) = V.graph.sizevars.size_hints( + mat1.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + + if not isinstance(n, int) or not isinstance(k, int): + (k, n) = V.graph.sizevars.size_hints( + mat2.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + return m, n, k + + +def get_size_hints_strides(mat1, mat2): + mat1_stride = mat1.layout.stride + mat2_stride = mat2.layout.stride + strides = [mat1_stride, mat2_stride] + strides_hints = [] + for stride in strides: + if not isinstance(stride, int): + stride = V.graph.sizevars.size_hints( + stride, + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + strides_hints.append(stride) + return strides_hints[0], strides_hints[1] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..39aef23c346c0539d876117477b42660bbb6f59f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_common.py @@ -0,0 +1,302 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import Any + +import sympy + +import torch +from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn +from torch._inductor.utils import sympy_product +from torch._inductor.virtualized import V + +from .. import config as inductor_config +from ..codegen.wrapper import PythonWrapperCodegen +from ..ir import _IntLike, Layout, TensorBox +from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE + + +log = logging.getLogger(__name__) + + +@SymbolicGridFn +def mm_grid(m, n, meta, *, cdiv): + """ + The CUDA grid size for matmul triton templates. + """ + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) + + +@SymbolicGridFn +def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min): + """Defines the grid for persistent kernels.""" + return ( + min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])), + 1, + 1, + ) + + +@SymbolicGridFn +def persistent_grouped_mm_grid(*args): + meta = args[-1] + return (meta["NUM_SMS"], 1, 1) + + +def acc_type(dtype): + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +def mm_options(config, sym_m, sym_n, sym_k, layout): + """ + Common options to matmul triton templates. + """ + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] + ) + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) + ) + options_dict = dict( + EVEN_K=even_k_symbolic, + ALLOW_TF32=allow_tf32, + USE_FAST_ACCUM=False, # Option for _scaled_mm + ACC_TYPE=acc_type(layout.dtype), + num_stages=config.num_stages, + num_warps=config.num_warps, + **config.kwargs, + ) + + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in config.kwargs: + group_m = config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + + +def tma_options() -> dict[str, Any]: + from torch.utils._triton import has_triton_stable_tma_api + + return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()} + + +def persistent_mm_options(mat1, mat2): + res = dict( + A_ROW_MAJOR=not mat1.layout.is_transposed(), + B_ROW_MAJOR=not mat2.layout.is_transposed(), + NUM_SMS=get_num_sms(), + TMA_SIZE=TMA_DESCRIPTOR_SIZE, + ) + res.update(tma_options()) + return res + + +def scaled_mm_options( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a, + scale_b, + use_fast_accum: bool, + device_tma: bool = False, +) -> dict[str, Any]: + def are_compatible_scales(size_a, size_b) -> bool: + # Same sized scales are compatible + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout) + + mm_template_options["ACC_TYPE"] = "tl.float32" + mm_template_options["USE_FAST_ACCUM"] = use_fast_accum + mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2 + + if device_tma: + mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + mm_template_options["NUM_SMS"] = get_num_sms() + + mm_template_options.update(tma_options()) + + return mm_template_options + + +def mm_args( + mat1, + mat2, + *others, + layout=None, + out_dtype=None, + use_4x2_dim=False, + mat2_transposed=False, +): + """ + Common arg processing for mm,bmm,addmm,etc + """ + mat1, mat2 = realize_inputs(mat1, mat2) + *b1, m, k1 = mat1.get_size() + if mat2_transposed: + *b2, n, k2 = mat2.get_size() + else: + *b2, k2, n = mat2.get_size() + b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)] + if use_4x2_dim: + k2 = k2 * 2 + k = V.graph.sizevars.guard_equals(k1, k2) + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + + layout = FixedLayout( + mat1.get_device(), + out_dtype, + [*b, m, n], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + from ..lowering import expand + + others = [realize_inputs(expand(x, layout.size)) for x in others] + + return [m, n, k, layout, mat1, mat2, *others] + + +def mm_config_kwargs(device, exclude_condition, dtype_size=None): + if device == "cpu": + return { + "scale": 0.5, + "exclude": exclude_condition, + } + + if dtype_size and inductor_config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return { + "dtype_size": dtype_size, + } + return {} + + +def addmm_epilogue(dtype, alpha, beta): + def epilogue(acc, bias): + if alpha != 1: + acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) + if beta != 1: + bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) + return V.ops.add(acc, bias) + + return epilogue + + +def scale_mm_epilogue(): + """ + Create an epilogue function that applies scaling to matrix multiplication result + using the given scale factors. + + Args: + dtype: The data type of the output + scale_a: Scale factor for matrix A + scale_b: Scale factor for matrix B + + Returns: + Epilogue function that takes the accumulator and applies scaling + """ + + def epilogue(acc, inv_a_scale, inv_b_scale, bias=None): + # The epilogue function receives the accumulator (result of mat1 @ mat2) + # and applies the scaling factors + # In the original scaled_mm, we use inverse scales, so we multiply by them + mul_scales = V.ops.mul(inv_a_scale, inv_b_scale) + mul_acc = V.ops.mul(acc, mul_scales) + if bias is not None: + return V.ops.add(mul_acc, bias) + else: + return mul_acc + + return epilogue + + +def _is_static_problem(layout: Layout) -> tuple[bool, bool]: + """ + Check if input tensors and output layout have static shapes and non-zero sizes. + + Args: + layout: Output layout object with a 'size' attribute. + + Returns: + Tuple[bool, bool]: (is_static, is_nonzero) + is_static: True if all shapes are statically known + is_nonzero: True if all dimensions are non-zero + """ + static_shape = True + static_size = PythonWrapperCodegen.statically_known_list_of_ints_or_none( + layout.size + ) + if static_size is None: + nonzero = True + for s in layout.size: + sz = PythonWrapperCodegen.statically_known_int_or_none(s) + if sz is not None and sz == 0: + nonzero = False + break + return False, nonzero + numel = 1 + for dim in static_size: + numel *= dim + nonzero = numel > 0 + return static_shape, nonzero + + +def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None: + def is_row_major(stride: Sequence[_IntLike]) -> bool: + return stride[-1] == 1 + + def is_col_major(stride: Sequence[_IntLike]) -> bool: + return stride[-2] == 1 + + def has_zero_dim(size: Sequence[_IntLike]) -> bool: + return bool(size[0] == 0 or size[1] == 0) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + +def is_batch_stride_largest(mat1, mat2, layout) -> bool: + """ + Checking if the batch stride is the largest in the stride. + """ + sizes = [mat1.get_size(), mat2.get_size(), layout.size] + strides = [mat1.get_stride(), mat2.get_stride(), layout.stride] + for size, stride in zip(sizes, strides): + assert len(size) == len(stride) == 3, "Expect 3D tensors" + if stride[0] != sympy_product(size[1:]): + return False + + return True diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..64249e6fb57a9e01d52267d7a2a81f9c6a8e6469 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_plus_mm.py @@ -0,0 +1,166 @@ +# mypy: allow-untyped-defs + +import torch + +from .. import ir +from ..lowering import lowerings +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import use_aten_gemm_kernels, use_triton_template +from ..virtualized import V +from .mm_common import mm_args, mm_grid, mm_options + + +aten = torch.ops.aten + +aten_mm_plus_mm = ExternKernelChoice( + torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm" +) + +mm_plus_mm_template = TritonTemplate( + name="mm_plus_mm", + grid=mm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C", "D")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K1 = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + # K2 = {{size("C", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + stride_cm = {{stride("C", 0)}} + stride_ck = {{stride("C", 1)}} + stride_dk = {{stride("D", 0)}} + stride_dn = {{stride("D", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1)) + and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + + if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1)) + and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck) + D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k1 in range(K1, 0, -BLOCK_K): + # First matmul with A @ B + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k1, other=0.) + b = tl.load(B, mask=rk[:, None] < k1, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + for k2 in range(K1, 0, -BLOCK_K): + + # Second matmul with C @ D + if EVEN_K: + c = tl.load(C) + d = tl.load(D) + else: + c = tl.load(C, mask=rk[None, :] < k2, other=0.) + d = tl.load(D, mask=rk[:, None] < k2, other=0.) + acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) + C += BLOCK_K * stride_ck + D += BLOCK_K * stride_dk + + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", + cache_codegen_enabled_for_template=True, +) + + +def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): + """ + Computes mm(mat1, mat2) + mm(mat3, mat4) + """ + m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + device_type = ir.get_device_type(mat1) + + # Optimization is optional, because we can always just not do the fusion + if ( + m1 * n1 == 0 + or m2 * n2 == 0 + or not V.graph.sizevars.statically_known_list_equals( + mat1.get_size(), mat3.get_size() + ) + or not V.graph.sizevars.statically_known_list_equals( + mat2.get_size(), mat4.get_size() + ) + ): + # TODO(jansel): support different K values when this is fixed: + # https://github.com/triton-lang/triton/issues/967 + return lowerings[aten.add]( + lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4) + ) + + assert layout1 == layout2 + # options to tune from + choices = ( + [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)] + if use_aten_gemm_kernels() + else [] + ) + + mm_configs = V.choices.get_mm_plus_mm_configs(device_type) + if use_triton_template(layout1): + for config in mm_configs(): + # see https://github.com/triton-lang/triton/issues/1298 + # BLOCK_K = K causes llvm error + if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1): + mm_plus_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3, mat4), + layout=layout1, + **mm_options(config, m1, n1, k1, layout1), + ) + + return autotune_select_algorithm( + "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1 + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_scaled_grouped.py b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_scaled_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..ad34ea0210b51d7f3e0c270b2ff11d3191d6ba0e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/mm_scaled_grouped.py @@ -0,0 +1,741 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from torch._dynamo.utils import counters +from torch._inductor.runtime.triton_compat import tl +from torch._inductor.virtualized import V +from torch.utils._triton import has_triton + +from ..ir import ChoiceCaller, Layout, TensorBox +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + get_gpu_shared_memory, + get_num_sms, + has_free_symbols, + use_aten_gemm_kernels, +) +from .mm_common import ( + _is_static_problem, + check_supported_striding, + persistent_grouped_mm_grid, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass +class Config: + kwargs: dict[str, int] + num_stages: int + num_warps: int + + +_NV_CONFIGS = [ + Config( + { + "BLOCK_M": block_size_m, + "BLOCK_N": block_size_n, + "BLOCK_K": block_size_k, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m in [16, 32, 64, 128] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] +] + + +def grouped_mm_configs(): + return _NV_CONFIGS + + +def early_config_prune(g, m, configs, named_args): + dtsize = 1 + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, num_consumer_groups = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + config.num_stages, + config.num_warps, + getattr(config, "num_consumer_groups", 0), + ) + + # 1. Prune NV configs depending on g and m. + if not has_free_symbols((g, m)): + a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"] + m_avg = m // g if a_is_2d and not b_is_2d else m + if m_avg <= 16: + if BLOCK_M > 32: + continue + elif m_avg <= 32: + if BLOCK_M > 64: + continue + elif m_avg <= 64: + if BLOCK_M <= 16: + continue + else: + if BLOCK_M <= 32: + continue + + # 2. make sure we have enough smem + max_shared_memory = get_gpu_shared_memory() + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + use_warp_specialization = num_consumer_groups >= 1 + + # 3. make sure we can partition for ws + if use_warp_specialization: + if num_warps != 4: + continue + + # "tritongpu-warp-spec-data-partition" + m_slice = BLOCK_M // num_consumer_groups + n_slice = BLOCK_N // num_consumer_groups + if m_slice < 64 and n_slice < 256: + continue + + pruned_configs.append(config) + + return pruned_configs + + +triton_grouped_mm_source = r""" +{%- if SCALED %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}} +{%- endif %} +{%- else %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr")}} +{%- endif %} +{%- endif %} + tidx = tl.program_id(0) + +{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %} +{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %} +{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %} + +{%- if A_IS_2D %} +{%- if B_IS_2D %} + G = {{size("offsets_ptr", 0)}} +{%- else %} + G = {{size("b_ptr", 0)}} +{%- endif %} +{%- else %} +{%- if B_IS_2D %} + G = {{size("a_ptr", 0)}} +{%- else %} + G = {{size("a_ptr", 0)}} +{%- endif %} +{%- endif %} + + # the b_ptr tensor is given with its last two dims transposed, revert here + + M = {{size("a_ptr", -2)}} + N = {{size("b_ptr", -1)}} + K = {{size("a_ptr", -1)}} + + A_STRIDE_M = {{stride("a_ptr", -2)}} + A_STRIDE_K = {{stride("a_ptr", -1)}} +{%- if not A_IS_2D %} + A_STRIDE_G = {{stride("a_ptr", 0)}} +{%- if SCALED %} + SCALE_A_STRIDE_G = {{stride("scale_a_ptr", 0)}} +{%- endif %} +{%- endif %} + B_STRIDE_N = {{stride("b_ptr", -1)}} + B_STRIDE_K = {{stride("b_ptr", -2)}} +{%- if not B_IS_2D %} + B_STRIDE_G = {{stride("b_ptr", 0)}} +{%- if SCALED %} + SCALE_B_STRIDE_G = {{stride("scale_b_ptr", 0)}} +{%- endif %} +{%- endif %} + +{%- if USE_TMA_LOAD %} +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + a_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + a_desc = tl.make_tensor_descriptor( +{%- endif %} + a_ptr, +{%- if A_IS_2D %} + shape=[M, K], + # fixme: strides=[A_STRIDE_M, A_STRIDE_K], + strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}], + block_shape=[BLOCK_M, BLOCK_K], +{%- else %} + shape=[G, M, K], + # fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K], + strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}], + block_shape=[1, BLOCK_M, BLOCK_K], +{%- endif %} + ) + +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + b_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + b_desc = tl.make_tensor_descriptor( +{%- endif %} + b_ptr, +{%- if B_IS_2D %} + shape=[N, K], + # fixme: strides=[B_STRIDE_N, B_STRIDE_K], + strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}], + block_shape=[BLOCK_N, BLOCK_K], +{%- else %} + shape=[G, N, K], + # fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K], + strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}], + block_shape=[1, BLOCK_N, BLOCK_K], +{%- endif %} + ) +{%- endif %} + +{%- if M_IS_VARYING %} + m_end_offset = 0 +{%- endif %} +{%- if N_IS_VARYING %} + n_end_offset = 0 +{%- endif %} +{%- if K_IS_VARYING %} + k_end_offset = 0 +{%- endif %} + iterated_tiles = 0 + for g in tl.range(G): +{%- if M_IS_VARYING %} + # Move across groups + m_start_offset = m_end_offset + m_end_offset = tl.load(offsets_ptr + g) + m_size = m_end_offset - m_start_offset +{%- if SCALED %} + m_scale_start_offset = m_start_offset +{%- endif %} +{%- else %} + m_start_offset = 0 + m_size = M +{%- if SCALED %} + m_scale_start_offset = g * M +{%- endif %} +{%- endif %} + +{%- if N_IS_VARYING %} + # Move across groups + n_start_offset = n_end_offset + n_end_offset = tl.load(offsets_ptr + g) + n_size = n_end_offset - n_start_offset +{%- if SCALED %} + n_scale_start_offset = n_start_offset +{%- endif %} +{%- else %} + n_start_offset = 0 + n_size = N +{%- if SCALED %} + n_scale_start_offset = g * N +{%- endif %} +{%- endif %} + + if m_size > 0 and n_size > 0: +{%- if K_IS_VARYING %} + # Move across groups + k_start_offset = k_end_offset + k_end_offset = tl.load(offsets_ptr + g) + k_size = k_end_offset - k_start_offset +{%- else %} + k_start_offset = 0 + k_size = K +{%- endif %} + + num_m_tiles = tl.cdiv(m_size, BLOCK_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_N) + num_tiles = num_m_tiles * num_n_tiles + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{%- if USE_TMA_LOAD %} + m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32) + n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32) + + for k_offset in range(0, k_size, BLOCK_K): +{%- if A_IS_2D %} + a = a_desc.load([m_offset, k_start_offset + k_offset]) +{%- else %} + a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K) +{%- endif %} +{%- if B_IS_2D %} + b = b_desc.load([n_offset, k_start_offset + k_offset]) +{%- else %} + b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K) +{%- endif %} + +{%- if K_IS_VARYING %} + if k_offset + BLOCK_K > k_size: + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + a = tl.where(group_offs_k < k_size, a, 0) + b = tl.where(group_offs_k < k_size, b, 0) +{%- endif %} + +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} +{%- else %} + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = k_start_offset + tl.arange(0, BLOCK_K) + a_ptrs = ( + a_ptr +{%- if not A_IS_2D %} + + g * A_STRIDE_G +{%- endif %} + + (m_start_offset + offs_am[:, None]) * A_STRIDE_M + + offs_k[None, :] * A_STRIDE_K + ) + b_ptrs = ( + b_ptr +{%- if not B_IS_2D %} + + g * B_STRIDE_G +{%- endif %} + + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N + + offs_k[None, :] * B_STRIDE_K + ) + for k_offset in range(0, k_size, BLOCK_K): + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + if k_offset + BLOCK_K > k_size: + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + a = tl.where(group_offs_k < k_size, a, 0) + b = tl.where(group_offs_k < k_size, b, 0) +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} + a_ptrs += BLOCK_K + b_ptrs += BLOCK_K +{%- endif %} + + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) +{%- if SCALED %} + scale_a = tl.load( + scale_a_ptr +{%- if A_IS_2D %} + + m_scale_start_offset +{%- else %} + + g * SCALE_A_STRIDE_G +{%- endif %} + + offs_am[:, None], + mask=offs_am[:, None] < m_size, + ) + scale_b = tl.load( + scale_b_ptr +{%- if B_IS_2D %} + + n_scale_start_offset +{%- else %} + + g * SCALE_B_STRIDE_G +{%- endif %} + + offs_bn[None, :], + mask=offs_bn[None, :] < n_size, + ) + c = accumulator.to(tl.float32) * scale_a * scale_b +{%- else %} + c = accumulator.to(tl.float32) +{%- endif %} + +{%- if M_IS_VARYING %} + idx_m = (m_start_offset + offs_am[:, None]) +{%- else %} + idx_m = offs_am[:, None] +{%- endif %} +{%- if N_IS_VARYING %} + idx_n = (n_start_offset + offs_bn[None, :]) +{%- else %} + idx_n = offs_bn[None, :] +{%- endif %} + mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size +{%- if M_IS_VARYING or N_IS_VARYING %} + {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16)}} +{%- else %} + {{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16)}} +{%- endif %} + tidx += NUM_SMS + + iterated_tiles += num_tiles +""" + + +triton_grouped_mm_template = TritonTemplate( + name="grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + +triton_scaled_grouped_mm_template = TritonTemplate( + name="scaled_grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + + +def grouped_mm_args( + mat1: TensorBox, + mat2: TensorBox, + offs: Optional[TensorBox], + layout=None, + out_dtype=None, +): + mat1, mat2 = realize_inputs(mat1, mat2) + if offs is not None: + realize_inputs(offs) + mat1_size = mat1.get_size() + mat2_size = mat2.get_size() + + m1dim, m2dim = len(mat1_size), len(mat2_size) + + assert m1dim == 2 or m1dim == 3 + assert m2dim == 2 or m2dim == 3 + + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + + dims = [] + if m1dim == 2: + if m2dim == 2: + assert offs is not None + dims = [offs.get_size()[0], mat1_size[0], mat2_size[1]] + else: + dims = [mat1_size[0], mat2_size[-1]] + else: + if m2dim == 2: + dims = [mat1_size[1], mat2_size[1]] + else: + dims = [mat1_size[0], mat1_size[1], mat2_size[-1]] + layout = FixedLayout( + mat1.get_device(), + out_dtype, + dims, + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + return (mat1_size, mat2_size, layout, mat1, mat2, offs) + + +aten__grouped_mm = ExternKernelChoice( + torch._grouped_mm, + "at::_grouped_mm", + op_overload=aten._grouped_mm, + has_out_variant=False, +) + + +aten__scaled_grouped_mm = ExternKernelChoice( + torch._scaled_grouped_mm, + "at::_scaled_grouped_mm", + op_overload=aten._scaled_grouped_mm, + has_out_variant=False, +) + + +def can_use_triton_kernel( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox], + bias: Optional[TensorBox], + scale_result: Optional[TensorBox], +) -> bool: + if not ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + return False + if not has_triton(): + return False + + # The _grouped_mm()/_scaled_grouped_mm() operator do not support + # bias nor scale_result yet. + if bias is not None: + return False + if scale_result is not None: + return False + + if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2: + return offs is not None + else: + return offs is None + + +def create_offsets(x, m1_size, m2_size, offs_size): + m1_is_2d = len(m1_size) == 2 + m2_is_2d = len(m2_size) == 2 + if m1_is_2d: + if m2_is_2d: + k = V.graph.sizevars.size_hint(m1_size[1]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = k / noffs + return torch.linspace( + step, k, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + + else: + m = V.graph.sizevars.size_hint(m1_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = m / noffs + return torch.linspace( + step, m, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + if m2_is_2d: + n = V.graph.sizevars.size_hint(m2_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = n / noffs + return torch.linspace( + step, n, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + return None + + +def _tuned_grouped_mm_common( + operator_name: str, + algorithm_name: str, + extern_kernel_choice: ExternKernelChoice, + kernel_template: TritonTemplate, + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: Optional[TensorBox] = None, + scale_b: Optional[TensorBox] = None, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: Optional[bool] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + assert (scale_a is None) == (scale_b is None) + assert scale_result is None or scale_a is not None + + m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args( + mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype + ) + counters["aten_mm_info"][operator_name] += 1 + log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s" + log.info( + log_message, + m1_size, + m2_size, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + if scale_a is not None and scale_b is not None: + check_supported_striding(mat_a, mat_b) + + # workaround for Inductor not supporting optional tensor input arguments + input_nodes: list[Any] = [mat_a, mat_b] + if scale_a is not None: + input_nodes.append(realize_inputs(scale_a)) + if scale_b is not None: + input_nodes.append(realize_inputs(scale_b)) + if offs is not None: + input_nodes.append(realize_inputs(offs)) + + if use_fast_accum is None: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + ) + else: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + if use_fast_accum is None: + use_fast_accum = False + + choices: list[ChoiceCaller] = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + # Checking only for the equality of corresponding dims of + # multiplicands here, relying on meta function checks for + # everything else. + if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result): + scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + g = offs.get_size()[0] + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = False, False + + triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor") + triton_has_experimental_make_tensor_descriptor = hasattr( + tl, "_experimental_make_tensor_descriptor" + ) + use_tma_load = ( + triton_has_make_tensor_descriptor + or triton_has_experimental_make_tensor_descriptor + ) + # The make_tensor_descriptor imposes this additional limitation. + use_tma_load = use_tma_load and ( + mat_a.get_stride()[-1] == 1 and mat_b.get_stride()[-2] == 1 + ) + + kwargs = { + "SCALED": scaled, + "A_IS_2D": a_is_2d, + "B_IS_2D": b_is_2d, + "USE_FAST_ACCUM": use_fast_accum, + "NUM_SMS": get_num_sms(), + "USE_TMA_LOAD": use_tma_load, + "USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR": triton_has_experimental_make_tensor_descriptor, + } + + for config in early_config_prune(g, m, grouped_mm_configs(), kwargs): + kernel_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + num_stages=config.num_stages, + num_warps=config.num_warps, + **kwargs, + **config.kwargs, + ) + + input_gen_fns = { + 4: lambda x: create_offsets( + x, m1_size, m2_size, offs.get_size() if offs is not None else None + ), + } + return autotune_select_algorithm( + algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns + ) + + +@register_lowering(aten._grouped_mm.default, type_promotion_kind=None) +def tuned_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._grouped_mm.default", + "grouped_mm", + aten__grouped_mm, + triton_grouped_mm_template, + mat_a, + mat_b, + None, + None, + offs, + bias, + None, + out_dtype, + None, + layout, + ) + + +@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None) +def tuned_scaled_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: TensorBox, + scale_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _scaled_grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._scaled_grouped_mm.default", + "scaled_grouped_mm", + aten__scaled_grouped_mm, + triton_scaled_grouped_mm_template, + mat_a, + mat_b, + scale_a, + scale_b, + offs, + bias, + scale_result, + out_dtype, + use_fast_accum, + layout, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/package/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/package/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15587401b723581b57f94fdcddbcbc8255f73bfe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/package/__init__.py @@ -0,0 +1 @@ +from .package import AOTICompiledModel, load_package, package_aoti diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/package/build_package.py b/.venv/lib/python3.12/site-packages/torch/_inductor/package/build_package.py new file mode 100644 index 0000000000000000000000000000000000000000..9205b9ced254275018472108485173eba9479f11 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/package/build_package.py @@ -0,0 +1,15 @@ +build_package_contents = """ +import os +from pathlib import Path + +from torch._inductor.package.package import compile_so + +curr_dir = Path(__file__).parent +aoti_files = [ + os.path.join(root, file) + for root, dirs, files in os.walk(curr_dir) + for file in files +] + +output_so = compile_so(curr_dir, aoti_files, curr_dir) +""" diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/package/package.py b/.venv/lib/python3.12/site-packages/torch/_inductor/package/package.py new file mode 100644 index 0000000000000000000000000000000000000000..726b41d972403746da5260d3a38a819fa3ceff46 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/package/package.py @@ -0,0 +1,138 @@ +import io +import json +import logging +import os +import tempfile +from typing import IO + +import torch +from torch._inductor import config +from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder +from torch.export.pt2_archive._package import ( + AOTI_FILES, + AOTICompiledModel, + load_pt2, + package_pt2, +) +from torch.types import FileLike + + +log = logging.getLogger(__name__) + + +def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str: + def get_aoti_file_with_suffix(suffix: str) -> str: + for file in aoti_files: + if file.endswith(suffix): + return file + raise RuntimeError(f"Unable to find file with suffix {suffix}") + + # Compile all the files into a .so + cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp")) + consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o")) + + file_name = os.path.splitext(cpp_file)[0] + + # Parse compile flags and build the .o file + with open(file_name + "_compile_flags.json") as f: + compile_flags = json.load(f) + + compile_options = BuildOptionsBase( + **compile_flags, use_relative_path=config.is_fbcode() + ) + object_builder = CppBuilder( + name=file_name, + sources=cpp_file, + BuildOption=compile_options, + ) + output_o = object_builder.get_target_file_path() + object_builder.build() + + # Parse linker flags and build the .so file + with open(file_name + "_linker_flags.json") as f: + linker_flags = json.load(f) + + linker_options = BuildOptionsBase( + **linker_flags, use_relative_path=config.is_fbcode() + ) + so_builder = CppBuilder( + name=os.path.split(so_path)[-1], + sources=[output_o, consts_o], + BuildOption=linker_options, + output_dir=so_path, + ) + output_so = so_builder.get_target_file_path() + so_builder.build() + + # mmapped weights + serialized_weights_filename = file_name + "_serialized_weights.bin" + if serialized_weights_filename in aoti_files: + with open(serialized_weights_filename, "rb") as f_weights: + serialized_weights = f_weights.read() + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (16384 - so_size % 16384)) + f_so.write(serialized_weights) + + return output_so + + +def package_aoti( + archive_file: FileLike, + aoti_files: AOTI_FILES, +) -> FileLike: + """ + Saves the AOTInductor generated files to the PT2Archive format. + + Args: + archive_file: The file name to save the package to. + aoti_files: This can either be a singular path to a directory containing + the AOTInductor files, or a dictionary mapping the model name to the + path to its AOTInductor generated files. + """ + + return package_pt2( + archive_file, + aoti_files=aoti_files, + ) + + +def load_package( + path: FileLike, + model_name: str = "model", + run_single_threaded: bool = False, + num_runners: int = 1, + device_index: int = -1, +) -> AOTICompiledModel: # type: ignore[type-arg] + try: + pt2_contents = load_pt2( + path, + run_single_threaded=run_single_threaded, + num_runners=num_runners, + device_index=device_index, + ) + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in package") + return pt2_contents.aoti_runners[model_name] + except RuntimeError: + log.warning("Loading outdated pt2 file. Please regenerate your package.") + + if isinstance(path, (io.IOBase, IO)): + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + # TODO(angelayi): We shouldn't need to do this -- miniz should + # handle reading the buffer. This is just a temporary workaround + path.seek(0) + f.write(path.read()) + log.debug("Writing buffer to tmp file located at %s.", f.name) + loader = torch._C._aoti.AOTIModelPackageLoader( + f.name, model_name, run_single_threaded, num_runners, device_index + ) + return AOTICompiledModel(loader) + + path = os.fspath(path) # AOTIModelPackageLoader expects (str, str) + loader = torch._C._aoti.AOTIModelPackageLoader( + path, model_name, run_single_threaded, num_runners, device_index + ) + return AOTICompiledModel(loader) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__init__.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6ff95faf46e16adb7e391a806d8cd7f22291ead Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f0bc75753a34a172737c02a8db709ea49c372f4 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31da285c5d618f3b170f8a0f58a1615ee69f776a Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3adcd020cdfc6c00746826bd5abf4b7875b6fdb6 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..01d038aab8e7bc62e4219a0b0aa71dade49e0d0f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py @@ -0,0 +1,640 @@ +""" +PyTorch Inductor Autotuning Cache System + +This module implements a caching system for autotuning configurations in PyTorch's Inductor compiler. +It provides mechanisms to store and retrieve optimal kernel configurations both locally and remotely, +which significantly speeds up compilation by reusing previously discovered optimal parameters. + +The caching system includes: +- Local filesystem caching for individual machine reuse +- Remote caching for sharing optimizations across machines +- Bundled caching to efficiently store multiple related configurations +- Cache invalidation based on PyTorch versions and backend changes +- Serialization/deserialization support for worker processes + +Key components: +- AutotuneCache: Main class for managing cache access and storage +- AutotuneCacheBundler: Bundles multiple cache entries for efficient storage +- LocalAutotuneCache: Handles filesystem-based caching +- _LocalAutotuneCacheBackend: Low-level file operations for cache storage +- AutotuneCacheArtifact: Integration with PyTorch's artifact system + +This caching system is critical for performance as it eliminates the need to re-run +expensive autotuning operations when the same kernels are compiled multiple times. +""" + +from __future__ import annotations + +import dataclasses +import hashlib +import logging +import os +import os.path +import re +from typing import Any, Optional, TYPE_CHECKING +from typing_extensions import override + +import torch +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.utils._triton import has_triton + +from ..remote_cache import ( + create_cache, + JsonDataTy, + RemoteCache, + RemoteCacheBackend, + RemoteCacheJsonSerde, +) +from .triton_compat import Config, HAS_WARP_SPEC + + +if TYPE_CHECKING: + from ..remote_cache import Sample + +log = logging.getLogger(__name__) + + +_InductorMetaTy = dict[str, object] + + +def inductor_meta_from_config() -> _InductorMetaTy: + from torch._inductor import config + + backend_hash = None + if has_triton(): + try: + backend_hash = torch.utils._triton.triton_hash_with_backend() + except RuntimeError: + # This can get the error: + # RuntimeError: 0 active drivers ([]). There should only be one. + pass + + is_hip = None + if torch.version.hip is not None: + is_hip = True + + return { + "autotune_local_cache": config.autotune_local_cache, + "autotune_remote_cache": config.autotune_remote_cache, + "backend_hash": backend_hash, + "bundled_autotune_remote_cache": config.bundled_autotune_remote_cache, + "coordinate_descent_tuning": config.coordinate_descent_tuning, + "is_fbcode": config.is_fbcode(), + "is_hip": is_hip, + } + + +@CacheArtifactFactory.register +class AutotuneCacheArtifact(CacheArtifact): + @override + def populate_cache(self) -> None: + autotune_cache = _LocalAutotuneCacheBackend() + key = os.path.join(cache_dir(), self.key) + autotune_cache._put(key, self.content) + + @override + @staticmethod + def type() -> str: + return "autotune" + + @override + @staticmethod + def encode(content: JsonDataTy) -> bytes: + assert not isinstance(content, bytes) + serde = RemoteCacheJsonSerde() + content_bytes = serde.encode(content) + assert isinstance(content_bytes, bytes) + return content_bytes + + +@dataclasses.dataclass +class AutotuneCache: + configs_hash: str + local_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None + remote_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None + + # Create a AutotuneCache. Returns None if none of the caches can be used. + @staticmethod + def create( + inductor_meta: _InductorMetaTy, filename: str, configs_hash: str + ) -> Optional[AutotuneCache]: + cache = AutotuneCache(configs_hash) + key = AutotuneCache._prepare_key(filename) + cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) + cache._setup_remote_autotune_cache(inductor_meta, key) + if cache.local_cache or cache.remote_cache: + return cache + else: + return None + + @staticmethod + def _prepare_key(filename: str) -> str: + from torch.compiler import config as cconfig + + # base of filename is already sha256 hash the source contents + key = f"{os.path.basename(filename)}:{cconfig.cache_key_tag}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + # Read the best config options from the most local cache and return it. + def _read(self) -> Optional[dict[str, JsonDataTy]]: + if local_cache := self.local_cache: + cache, key = local_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + if remote_cache := self.remote_cache: + cache, key = remote_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + return None + + # Read the best config options from the most local cache and figure out + # which `configs` represents that option. + def read_best( + self, inductor_meta: _InductorMetaTy, configs: list[Config] + ) -> Optional[Config]: + if best := self._read(): + return _load_cached_autotuning( + best, self.configs_hash, configs, inductor_meta + ) + return None + + # Set up local filesystem caching information + def _setup_local_cache( + self, inductor_meta: _InductorMetaTy, dirname: str, cache_key: str + ) -> None: + if not inductor_meta.get("autotune_local_cache", True): + return + + from ..codecache import torch_key + + """ + [Note: torch_key in autotune cache key] + Include torch_key() in the cache key so that different versions + of torch result in cache invalidation. This is important in case + of changes to the best_config format or other code changes that + are not backward compatible w.r.t. the cache. + """ + hasher = hashlib.sha256() + hasher.update(cache_key.encode("utf-8")) + hasher.update(torch_key()) + updated_cache_key = hasher.hexdigest() + + cache_filename = f"{dirname}/{updated_cache_key}.best_config" + local_cache = LocalAutotuneCache() + self.local_cache = (local_cache, cache_filename) + + # Set up remote caching information + def _setup_remote_autotune_cache( + self, inductor_meta: _InductorMetaTy, cache_key: str + ) -> None: + if not _should_use_remote_autotune_cache(inductor_meta): + return + + if (backend_hash := inductor_meta.get("backend_hash", None)) is None: + log.debug( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + return + assert isinstance(backend_hash, str) + + from ..codecache import torch_key + + is_fbcode = bool(inductor_meta.get("is_fbcode", False)) + + salt = "autotune-best-config-v2" + # re: torch_key - see [Note: torch_key in autotune cache key] + key = torch_key().hex() + backend_hash + self.configs_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + remote_cache = create_cache( + key, + is_fbcode, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + ) + if not remote_cache: + return + + # Save the args passed to create_cache + # in case AutotuneCache needs to be pickled + self.remote_cache_full_key = key + self.is_fbcode = is_fbcode + self.remote_cache = (remote_cache, cache_key) + + # The AutotuneCache may be serialized/deserialized if we're using + # AsyncCompile worker processes to run triton compilation. + # This is because AutotuneCache instances are created on the worker + # process, but we need to run AutotuneCache.save on the parent process + # when actually doing autotuning. + def __getstate__(self) -> dict[str, Any]: + # The remote cache handles themselves may not be serializable + # So clear it and reconstruct it on setstate + remote_cache = getattr(self, "remote_cache", None) + return { + **self.__dict__, + # Save the cache_key portion + "remote_cache": remote_cache and remote_cache[1], + } + + def __setstate__(self, state: dict[str, Any]) -> None: + # Reconstruct the remote cache on the parent class + self.__dict__.update(state) + if self.remote_cache is not None: + assert isinstance(self.remote_cache, str) + assert hasattr(self, "remote_cache_full_key") + assert hasattr(self, "is_fbcode") + cache_key = self.remote_cache + remote_cache = create_cache( + self.remote_cache_full_key, + self.is_fbcode, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + ) + if remote_cache is not None: + self.remote_cache = (remote_cache, cache_key) + else: + log.warning("Warning, failed to recreate remote cache after pickling") + self.remote_cache = None + + # Save the config in the caches + def save( + self, + config: Config, + time_taken_ns: int, + found_by_coordesc: bool = False, + triton_cache_hash: Optional[str] = None, + ) -> None: + data = { + **config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + "configs_hash": self.configs_hash, + "found_by_coordesc": found_by_coordesc, + "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS + "triton_cache_hash": triton_cache_hash, + } + if HAS_WARP_SPEC: + data.update( + { + "num_consumer_groups": getattr(config, "num_consumer_groups", 0), + "num_buffers_warp_spec": getattr( + config, "num_buffers_warp_spec", 0 + ), + } + ) + + if local_cache := self.local_cache: + cache, key = local_cache + cache.put(key, data) + AutotuneCacheBundler.put(key, data) + autotune_artifact_key = os.path.join(*key.split(os.sep)[-2:]) + CacheArtifactManager.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, data + ) + + if log.isEnabledFor(logging.DEBUG): + type_str = "coordesc" if found_by_coordesc else "heuristic" + log.debug("Save %s tuning result to %s", type_str, key) + + if remote_cache := self.remote_cache: + cache, key = remote_cache + cache.put(key, data) + + +class _AutotuneCacheBundlerImpl: + """ + Caches a set of LocalAutotuneCacheBackend entries together in a single + cache. + """ + + _key: str + _cache: RemoteCache[JsonDataTy] + + # All known entries from LocalAutotuneCache.put() + _entries: dict[str, JsonDataTy] + + def end_compile(self) -> None: + # TODO: Do we need to compute time_taken_ms and encode that somehow? + if self._entries: + self._cache.put(self._key, self._entries) + + def put(self, basename: str, data: JsonDataTy) -> None: + # Do we need to worry about duplicates? We only have a single local fs + # entry - so probably not. + self._entries[basename] = data + + def __init__(self, key: str, cache: RemoteCache[JsonDataTy]) -> None: + self._key = key + self._cache = cache + self._entries = {} + + def sync(self) -> None: + # We don't currently use this - but we could async load starting at + # `begin_compile` and wait for the load to be finished here. + pass + + @classmethod + def _should_use_bundled_autotune_remote_cache( + cls, inductor_meta: _InductorMetaTy + ) -> bool: + # The bundled autotune cache is only available if you've also got local + # caching enabled (because we feed the bundled data to the local cache). + if not inductor_meta.get("autotune_local_cache", True): + return False + + # Check if the we're enabled via config + if ( + bundled_autotune_remote_cache := inductor_meta.get( + "bundled_autotune_remote_cache" + ) + ) is not None: + return bool(bundled_autotune_remote_cache) + + if not cls._get_is_fbcode(inductor_meta): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk = torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:bundled_autotune_remote_cache_version" + ) + return REMOTE_CACHE_VERSION >= jk + + def _load_cache(self) -> bool: + from torch._inductor import codecache + + # The single key is defined on construction of the cache. + entries = self._cache.get(self._key) + if entries is None or not isinstance(entries, dict): + # We couldn't load the cache - so mark _entries as non-None so we + # store local cache values. + return False + + # Go through the entries we got from the cache and save them locally. + time_saved_ns = 0 + for basename, data in entries.items(): + # Reconstruct the final filename (see put()) + root, ext = _splitext_nodot(basename) + _, _, filename = codecache.get_path(root, ext) + if isinstance(data, dict) and (tsns := data.get("time_saved_ns")): + time_saved_ns += int(tsns) # type: ignore[arg-type] + local_cache = LocalAutotuneCache() + local_cache.put(filename, data) + + codecache.add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + + return True + + @staticmethod + def _get_is_fbcode(inductor_meta: _InductorMetaTy) -> bool: + return bool(inductor_meta.get("is_fbcode", False)) + + @staticmethod + def _get_backend_hash(inductor_meta: _InductorMetaTy) -> str: + backend_hash = inductor_meta["backend_hash"] + assert isinstance(backend_hash, str) + return backend_hash + + +class AutotuneCacheBundler: + _bundler: Optional[_AutotuneCacheBundlerImpl] = None + + def __init__(self) -> None: + pass + + # Call this before we start any autotune computation for an inductor python + # file. On a cache hit it copies the individual results into the local + # autotune caches. + @classmethod + def begin_compile( + cls, + inductor_meta: _InductorMetaTy, + *, + code: Optional[str] = None, + code_hash: Optional[str] = None, + ) -> None: + assert cls._bundler is None + + if code is not None: + assert code_hash is None, "Cannot specify both code and code_hash" + code_hash = _comment_stripped_hash(code) + assert code_hash is not None + + if not _AutotuneCacheBundlerImpl._should_use_bundled_autotune_remote_cache( + inductor_meta + ): + return + + cache = create_cache( + "bundled-autotune-v1", + _AutotuneCacheBundlerImpl._get_is_fbcode(inductor_meta), + "FbRemoteBundledAutotuneCache", + "RemoteBundledAutotuneCache", + ) + if not cache: + return + + # We're starting a compilation phase. We have a cache key for the code + # we're compiling. We'll get the individual autotune bundles later (via + # self.put()). For now create the AutotuneCacheBundler and try to load + # from the cache. + + salt = "bundled-autotune-best-configs-v1" + backend_hash = _AutotuneCacheBundlerImpl._get_backend_hash(inductor_meta) + # TODO: The autotune cache includes configs_hash in the key. The problem + # is that the configs_hash includes info from the individual pointwise() + # calls (size_hints, for example) which we can't know yet. I *think* + # that info is basically present in the `code_hash` (since it's a + # parameter to the pointwise decorator) - but is there other info we + # need to include from inductor_meta? + key = code_hash + backend_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + bundler = _AutotuneCacheBundlerImpl(key, cache) + if not bundler._load_cache(): + # We couldn't load from the cache - so save the data so we can store + # the saved autotunes. + cls._bundler = bundler + + # If we get a cache hit don't bother saving any of the individual + # autotune results. + + # Call this after all individual autotune results are finished for a + # inductor python file. If we gathered any individual results then we bundle + # those and put it into the cache. + @classmethod + def end_compile(cls) -> None: + if bundler := cls._bundler: + cls._bundler = None + bundler.end_compile() + + @classmethod + def sync(cls) -> None: + if bundler := cls._bundler: + bundler.sync() + + @classmethod + def put(cls, filename: str, data: JsonDataTy) -> None: + if bundler := cls._bundler: + # The filename comes in as something like + # "/tmp/tmp{random}/{aa}/{basename}.py" (where aa is + # basename[1:3]). Strip it down and make sure that it looks like a path + # we could reconstruct (because it's possible for the caller to + # customize the path). + basename = os.path.basename(filename) + + # TODO: check cache_dir() vs filename, then strip dirname + bundler.put(basename, data) + + +# Remove the comments from the code (which include things like run ids and file +# paths) and then hash the result. +def _comment_stripped_hash(code: str) -> str: + code = re.sub(r"#.*$", "", code, count=0, flags=re.MULTILINE) + return torch._inductor.codecache.code_hash(code) + + +def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool: + if (config := inductor_meta.get("autotune_remote_cache")) is not None: + return bool(config) + if not inductor_meta.get("is_fbcode"): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:autotune_memcache_version" + ) + + +def _load_cached_autotuning( + best_config: dict[str, JsonDataTy], + configs_hash: str, + configs: list[Config], + inductor_meta: _InductorMetaTy, +) -> Optional[Config]: + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + + best_config.pop("triton_cache_hash", None) + + if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( + "found_by_coordesc", False + ): + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + + # Extract common arguments + config_args = { + "num_warps": num_warps, + "num_stages": num_stages, + } + + if HAS_WARP_SPEC: + config_args.update( + { + "num_consumer_groups": best_config.pop("num_consumer_groups", 0), + "num_buffers_warp_spec": best_config.pop( + "num_buffers_warp_spec", 0 + ), + } + ) + + # Create the triton_config with the appropriate arguments + triton_config = Config(best_config, **config_args) + triton_config.found_by_coordesc = True + return triton_config + + matching_configs = [ + cfg + for cfg in configs + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + and cfg.num_warps == best_config.get("num_warps") + and cfg.num_stages == best_config.get("num_stages") + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]): + @override + def _get(self, key: str) -> Optional[bytes]: + try: + with open(key, "rb") as fd: + return fd.read() + except FileNotFoundError: + return None + + @override + def _put(self, key: str, data: bytes) -> None: + os.makedirs(os.path.dirname(key), exist_ok=True) + from torch._inductor import codecache + + codecache.write_atomic(key, data) + + +class LocalAutotuneCache(RemoteCache[JsonDataTy]): + def __init__(self) -> None: + backend = _LocalAutotuneCacheBackend() + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + + @override + def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: + AutotuneCacheBundler.sync() + result = super()._get(key, sample) + if result is not None: + assert isinstance(result, dict) + # What? Why are we doing a put() here? Imagine we have a new model + # that reuses some existing kernels that have already been + # compiled. If we didn't do a `put` here (on cache hit) then the new + # model would only bundle *newly* compiled kernels, not existing + # kernels that were already compiled and cached. + AutotuneCacheBundler.put(key, result) + autotune_artifact_key = os.path.join(*key.split(os.sep)[-2:]) + CacheArtifactManager.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, result + ) + return result + + @override + def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None: + AutotuneCacheBundler.put(key, value) + super()._put(key, value, sample) + + +def _splitext_nodot(basename: str) -> tuple[str, str]: + root, ext = os.path.splitext(basename) + if ext: + ext = ext[1:] + return root, ext diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9cc60bef87a0f3fa908e45442d0d647f0d78d5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py @@ -0,0 +1,290 @@ +import inspect +import time +from functools import cached_property, wraps +from itertools import chain +from statistics import median +from typing import Any, Callable +from typing_extensions import Concatenate, ParamSpec, Self, TypeVar + +import torch +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.config import use_experimental_benchmarker + + +logger = torch._logging.getArtifactLogger(__name__, "benchmarking") +use_experimental_benchmarker = ( + use_experimental_benchmarker and torch.cuda.is_available() +) + + +MILLISECONDS_PER_SECOND = 1000 + +P = ParamSpec("P") +T = TypeVar("T") + + +def time_and_count( + fn: Callable[Concatenate[Any, P], T], +) -> Callable[Concatenate[Any, P], T]: + """Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo + counters. It is expected that `fn` is a method of `Benchmarker` or one of its + subclasses; typing limitations prevent us from declaring this directly. + """ + + @wraps(fn) + def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: + fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}" + counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1 + with dynamo_timed(fn_qual_name, log_pt2_compile_event=False): + return fn(self, *args, **kwargs) + + return wrapper + + +class Benchmarker: + def __init__(self: Self) -> None: + pass + + @time_and_count + def benchmark( + self: Self, + fn: Callable[..., Any], + fn_args: tuple[Any, ...], + fn_kwargs: dict[str, Any], + **kwargs: Any, + ) -> float: + """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the + actual runtime calculation is dictated by the benchmarking implementation, but may be + one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around + device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises + `ValueError(...)` if we can't safely infer the device type of `fn`; for example, + if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device + types are found. + + Arguments: + - fn: The function to benchmark. + - fn_args: The function's arguments. + - fn_kwargs: The function's kwargs. + + Keyword Arguments: + - **kwargs: The benchmarking implementation's kwargs. + + Returns: + - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. + """ + inferred_device = None + for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): + if not isinstance(arg_or_kwarg, torch.Tensor): + continue + if inferred_device is None: + inferred_device = arg_or_kwarg.device + elif arg_or_kwarg.device != inferred_device: + raise ValueError( + "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" + ) + if inferred_device is None: + raise ValueError( + "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950 + ) + _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 + if inferred_device == torch.device("cpu"): + return self.benchmark_cpu(_callable, **kwargs) + # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking + # implementation which was written specifically with CUDA devices in mind, we may want to + # explore alternate implementations for other device types. + return self.benchmark_gpu(_callable, **kwargs) + + @time_and_count + def benchmark_cpu( + self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100 + ) -> float: + """Benchmark the CPU callable, `_callable`, and return the median runtime, + in milliseconds. + + Arguments: + - _callable: The CPU callable to benchmark. + + Keyword Arguments: + - warmup: Optionally, the duration, in milliseconds, to run `_callable` + before benchmarking starts. + - rep: Optionally, the duration, in milliseconds, to run `_callable` + during benchmarking. + + Returns: + - The median runtime of `_callable`, in milliseconds. + """ + + def run_for(ms: int) -> list[float]: + timings = [] + run_start_t = time.perf_counter() + while True: + start_t = time.perf_counter() + _callable() + end_t = time.perf_counter() + timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND) + if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms: + break + return timings + + run_for(warmup) + return median(run_for(rep)) + + @time_and_count + def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float: + raise NotImplementedError + + +class TritonBenchmarker(Benchmarker): + @cached_property + def triton_do_bench(self: Self) -> Callable[..., Any]: + """Lazily import Triton's `do_bench`.""" + try: + from triton.testing import do_bench + except ImportError as e: + raise NotImplementedError("requires Triton") from e + return do_bench + + @time_and_count + def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float: + """Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds. + + Arguments: + - _callable: The GPU callable to benchmark. + + Keyword Arguments: + - quantiles: Optionally, a tuple of floats denoting the requested quantiles. + - return_mode: Optionally, the requested return mode. Currently, Triton's + `do_bench` supports min, max, mean, and median return modes. + - **kwargs: Additional kwargs passed to Triton's `do_bench`. + + Returns: + - The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified, + this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified, + this is the requested return mode. Otherwise, this is the median. + """ + do_bench_params = inspect.signature(self.triton_do_bench).parameters + for kwarg in list(kwargs.keys()): + if kwarg not in do_bench_params: + del kwargs[kwarg] + if "quantiles" in kwargs: + return self.triton_do_bench(_callable, **kwargs)[0] + elif "return_mode" in kwargs: + return self.triton_do_bench(_callable, **kwargs) + return self.triton_do_bench(_callable, **kwargs, return_mode="median") + + +class InductorBenchmarker(TritonBenchmarker): + @cached_property + def L2_cache_size(self: Self) -> int: + """Get the L2 cache size, in bytes, of the current device.""" + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + return props.L2_cache_size + + def get_event_pairs( + self: Self, iters: int + ) -> list[tuple[torch.cuda.Event, torch.cuda.Event]]: + """Get `iters` pairs of CUDA events.""" + return [ + ( + torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True), + ) + for _ in range(iters) + ] + + def get_event_pairs_min_timing( + self: Self, event_pairs: list[tuple[torch.cuda.Event, torch.cuda.Event]] + ) -> float: + """Get the minimum timing, in milliseconds, for a group of CUDA event pairs.""" + return min( + [ + start_event.elapsed_time(end_event) + for start_event, end_event in event_pairs + ] + ) + + @time_and_count + def benchmark_gpu( + self: Self, + _callable: Callable[[], Any], + estimation_iters: int = 5, + memory_warmup_iters: int = 100, + benchmark_iters: int = 100, + max_benchmark_duration: int = 25, + **kwargs: Any, + ) -> float: + """Benchmark a GPU callable using a custom benchmarking implementation. + + Arguments: + - _callable: The callable to benchmark. + + Keyword Arguments: + - estimation_iters: Optionally, the number of iterations to run `_callable` + during runtime estimation. + - memory_warmup_iters: Optionally, the number of iterations to flush the L2 + cache before starting benchmarking. + - benchmark_iters: Optionally, the number of iterations to run `_callable` + during the benchmarking. + - max_benchmark_duration: Optionally, the maximum duration of the benchmarking, + in milliseconds. An estimated duration is calculated based on the values + of `memory_warmup_iters` and `benchmark_iters`, along with the estimated + runtime of `_callable` and various other factors, and we then shrink + `benchmark_iters` to fit in the allotted maximum duration. + - **kwargs: Additional kwargs that may be passed to the fallback. + + Returns: + - The minimum runtime of `_callable`, in milliseconds. + """ + # we don't want any outside errors propagating into benchmarking + torch.cuda.synchronize() + + # warmup `_callable` (and catches any failures in the process) + _callable() + torch.cuda.synchronize() + + # see https://github.com/triton-lang/triton/pull/840 for why `dtype=torch.int` + buffer = torch.empty(self.L2_cache_size // 4, dtype=torch.int, device="cuda") + buffer.zero_() + + # estimate the runtime of `_callable` + event_pairs = self.get_event_pairs(estimation_iters) + for start_event, end_event in event_pairs: + buffer.zero_() + start_event.record() + _callable() + end_event.record() + torch.cuda.synchronize() + estimated_timing = self.get_event_pairs_min_timing(event_pairs) + + # adjust `benchmark_iters` to fit in the maximum benchmarking duration + benchmark_iters = max( + min(benchmark_iters, int(max_benchmark_duration // estimated_timing)), 1 + ) + + # do the memory warmup + for _ in range(memory_warmup_iters): + buffer.zero_() + + # benchmark `_callable` + event_pairs = self.get_event_pairs(benchmark_iters) + for start_event, end_event in event_pairs: + buffer.zero_() + start_event.record() + _callable() + end_event.record() + torch.cuda.synchronize() + benchmarked_timing = self.get_event_pairs_min_timing(event_pairs) + + # explicitly delete the buffer, sometimes helps memory + # footprint metrics in OSS Inductor performance benchmarks + del buffer + + # return the minimum of `estimated_timing` and `benchmarked_timing`, + # we just want the minimum timing overall so we might as well check both + return min(estimated_timing, benchmarked_timing) + + +benchmarker = ( + InductorBenchmarker() if use_experimental_benchmarker else TritonBenchmarker() +) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34b84a68f6300c1709593e303ff2a07e1f50bc46 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/cache_dir_utils.py @@ -0,0 +1,54 @@ +import getpass +import os +import re +import tempfile +from collections.abc import Generator +from contextlib import contextmanager + +from torch._environment import is_fbcode + + +# Factoring out to file without torch dependencies + + +def cache_dir() -> str: + cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + if cache_dir is None: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def default_cache_dir() -> str: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + return os.path.join( + tempfile.gettempdir() if not is_fbcode() else "/var/tmp", + "torchinductor_" + sanitized_username, + ) + + +def triton_cache_dir(device: int) -> str: + if (directory := os.getenv("TRITON_CACHE_DIR")) is not None: + return directory + return os.path.join( + cache_dir(), + "triton", + str(device), + ) + + +@contextmanager +def temporary_cache_dir(directory: str) -> Generator[None, None, None]: + from torch._inductor.utils import clear_caches + + original = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory + try: + clear_caches() + yield + finally: + clear_caches() + if original is None: + del os.environ["TORCHINDUCTOR_CACHE_DIR"] + else: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = original diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..67140369faac46f6f2794073599a1eba6f5699c9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/compile_tasks.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import functools +import linecache +import os +import sys +import time +import warnings +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, TYPE_CHECKING + + +if TYPE_CHECKING: + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + + +def _reload_python_module( + key: str, path: str, set_sys_modules: bool = True +) -> ModuleType: + with open(path) as f: + try: + code = compile(f.read(), path, "exec", dont_inherit=True) + except Exception as e: + raise RuntimeError( + f"Failed to import {path}\n{type(e).__name__}: {e}" + ) from None + mod = ModuleType(f"{__name__}.{key}") + mod.__file__ = path + mod.key = key # type: ignore[attr-defined] + exec(code, mod.__dict__, mod.__dict__) + if set_sys_modules: + sys.modules[mod.__name__] = mod + return mod + + +@functools.cache +def _set_triton_ptxas_path() -> None: + if os.environ.get("TRITON_PTXAS_PATH") is not None: + return + ptxas = Path(__file__).absolute().parents[1] / "bin" / "ptxas" + if not ptxas.exists(): + return + if ptxas.is_file() and os.access(ptxas, os.X_OK): + os.environ["TRITON_PTXAS_PATH"] = str(ptxas) + else: + warnings.warn(f"{ptxas} exists but is not an executable") + + +def _worker_compile_triton( + load_kernel: Callable[[], CachingAutotuner], + extra_env: dict[str, str], + extra_config: dict[str, Any], +) -> tuple[CachingAutotuner, int]: + _set_triton_ptxas_path() + os.environ.update(extra_env) + from torch._inductor import config + + with config.patch(extra_config): + start_ns = time.time_ns() + kernel = load_kernel() + kernel.precompile(warm_cache_only=True) + elapsed_ns = time.time_ns() - start_ns + kernel.prepare_for_pickle() + # We can release this memory in the compile subprocesses: + linecache.clearcache() + return kernel, elapsed_ns // 1000 diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..413dfaf09d061be5fc00421bec121682430d3896 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -0,0 +1,302 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import logging +from typing import Callable, Optional, TYPE_CHECKING + +from .hints import TRITON_MAX_BLOCK +from .runtime_utils import red_text, triton_config_to_hashable + + +if TYPE_CHECKING: + from .triton_compat import triton + + +log = logging.getLogger(__name__) + + +def get_field(config, name): + if name == "num_warps": + return config.num_warps + elif name == "num_stages": + return config.num_stages + elif name == "waves_per_eu": + return config.kwargs.get(name, int(8 // config.num_warps)) + else: + return config.kwargs.get(name, None) + + +def set_field(config, name, value): + if name == "num_warps": + config.num_warps = value + elif name == "num_stages": + config.num_stages = value + else: + config.kwargs[name] = value + + +class CoordescTuner: + """ + The coordinate descent tuner. Tune one field/coordinate at a time. + + TODO will it be necessary to tune multiple fields simultaneously. + + + TODO: what if both increasing and decreasing a field can improve perf. + i.e., there are multiple local optima.. + """ + + def __init__( + self, is_mm=False, name="unknown", size_hints=None, inductor_meta=None + ): + self.is_mm = is_mm # we will tune num_stages for mm + self.cached_benchmark_results = {} + self.name = name + self.size_hints = size_hints + self.inductor_meta = inductor_meta or {} + + def get_config_max(self, prefix: str) -> int: + max_block = TRITON_MAX_BLOCK[prefix.upper()] + size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None + return min(max_block, size_hint) if size_hint is not None else max_block + + def get_warpsmax(self): + # Currently, CUDA has a maximum of 1024 threads, so 32 is the max + # number of warps. + return 1024 // 32 + + def cache_benchmark_result(self, config, timing): + self.cached_benchmark_results[triton_config_to_hashable(config)] = timing + + def lookup_in_cache(self, config): + return self.cached_benchmark_results.get(triton_config_to_hashable(config)) + + def call_func(self, func, config): + found = self.lookup_in_cache(config) + if found is not None: + log.debug(" CACHED") + return found + timing = func(config) + self.cache_benchmark_result(config, timing) + return timing + + @property + def tunable_fields(self): + out = [ + "XBLOCK", + "YBLOCK", + "ZBLOCK", + # NOTE: we should not tune R0_BLOCK for persistent reduction. + # We rely on the fact that persistent reduction's triton.Config + # does not have the R0_BLOCK field to guarantee that. + "R0_BLOCK", + "R1_BLOCK", + # the following 3 are for mm + "BLOCK_M", + "BLOCK_N", + "BLOCK_K", + "num_warps", + ] + if self.is_mm: + out.append("num_stages") + if self.inductor_meta.get("is_hip") is True: + out.append("waves_per_eu") + + return out + + def value_too_large(self, name: str, val: int) -> bool: + block_suffix = "BLOCK" + if name.endswith(block_suffix): + prefix = name.strip(block_suffix).lower() + return val > self.get_config_max(prefix) + if name == "num_warps": + return val > self.get_warpsmax() + if name == "waves_per_eu": + return val > 8 + + return False + + def get_neighbour_values(self, name, orig_val, radius=1, include_self=False): + """ + Get neighbour values in 'radius' steps. The original value is not + returned as it's own neighbour. + """ + assert radius >= 1 + + def update(cur_val, inc=True): + if name == "num_stages": + if inc: + return cur_val + 1 + else: + return cur_val - 1 + else: + if inc: + return cur_val * 2 + else: + return cur_val // 2 + + out = [] + # increment loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, True) + if self.value_too_large(name, cur_val): + break + out.append(cur_val) + + # decrement loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, False) + if cur_val <= 0: + break + out.append(cur_val) + + if include_self: + out.append(orig_val) + return out + + @staticmethod + def has_improvement(baseline, test): + threshold = 0.001 # 0.1% + return test is not None and test < baseline * (1 - threshold) + + def check_all_tuning_directions( + self, + func: Callable[["triton.Config"], float], + best_config, + best_timing, + ): + """ + Check all directions. We only do this once the regular coordinate + descent tuning find no better choices any more. + We only have a few tunable fields, so this should be fine. + """ + candidate_values_list = [] + effective_fields = [] + for field in self.tunable_fields: + old_value = get_field(best_config, field) + if old_value is None: + continue + candidate_values = self.get_neighbour_values( + field, + old_value, + radius=self.inductor_meta.get("coordinate_descent_search_radius", 1), + include_self=True, + ) + candidate_values_list.append(candidate_values) + effective_fields.append(field) + + choices = itertools.product(*candidate_values_list) + improved = False + for choice in choices: + assert len(choice) == len(effective_fields) + candidate_config = copy.deepcopy(best_config) + for new_val, field in zip(choice, effective_fields): + set_field(candidate_config, field, new_val) + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config = candidate_config + best_timing = candidate_timing + + return improved, best_config, best_timing + + def compare_config(self, func, candidate_config, best_config, best_timing): + """ + Check if candidate_config is better than best_config. + + Return a tuple of (compare_result, candidate_timing). + compare_result is true iff candidate_config is better. + """ + log.debug("Try config %s", candidate_config) + try: + candidate_timing = self.call_func(func, candidate_config) + except Exception as e: + log.debug("Got exception %s", e) + return False, float("inf") + + if self.has_improvement(best_timing, candidate_timing): + log.debug( + "Tune from %s %f -> %s %f", + best_config, + best_timing, + candidate_config, + candidate_timing, + ) + + return True, candidate_timing + return False, candidate_timing + + def autotune( + self, + func: Callable[["triton.Config"], float], + baseline_config: "triton.Config", + baseline_timing: Optional[float] = None, + ) -> "triton.Config": + if baseline_timing is None: + baseline_timing = self.call_func(func, baseline_config) + + log.debug("= Do coordinate descent tuning for %s =", self.name) + log.debug( + "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing + ) + improved = True + best_config = baseline_config + best_timing = baseline_timing + tunable_fields = self.tunable_fields + + while improved: + improved = False + + for name in tunable_fields: + cur_val = get_field(best_config, name) + # some kernel don't have R0_BLOCK/YBLOCK/ZBLOCK. So cur_val may be None + if cur_val is None: + continue + + # It's possible that candidate_values is empty. + # E.g., if XBLOCK is 1 initially and size_hint for x is also 1. + # We would not try either larger or smaller XBLOCK in this case. + candidate_values = self.get_neighbour_values(name, cur_val) + + for next_val in candidate_values: + candidate_config = copy.deepcopy(best_config) + set_field(candidate_config, name, next_val) + + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config, best_timing = candidate_config, candidate_timing + + if not improved and self.inductor_meta.get( + "coordinate_descent_check_all_directions" + ): + old_best_timing = best_timing + improved, best_config, best_timing = self.check_all_tuning_directions( + func, best_config, best_timing + ) + + if improved: + msg = red_text( + "Coordinate descend tuning found improvement of %.3fx by looking in all directions." + ) + log.debug( + msg, + old_best_timing / best_timing, + ) + + log.debug( + "Improve from %s %f -> %s %f, %.3fx", + baseline_config, + baseline_timing, + best_config, + best_timing, + baseline_timing / best_timing, + ) + + return best_config diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bf70fe9d8db1cb66379df11e025ad84cc0069b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/halide_helpers.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +try: + import halide as hl # type: ignore[import-untyped, import-not-found] +except ImportError: + hl = None + +PHILOX_N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +if hl is not None: + PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9) + PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85) + PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53) + PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57) +else: + PHILOX_KEY_A_U32 = None + PHILOX_KEY_B_U32 = None + PHILOX_ROUND_A_U32 = None + PHILOX_ROUND_B_U32 = None + + +def _pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = hl.max(hl.f32(1.0e-7), u1) + th = hl.f32(6.283185307179586) * u2 + r = hl.sqrt(hl.f32(-2.0) * hl.log(u1)) + return r * hl.cos(th), r * hl.sin(th) + + +def _uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + + # TODO: + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + # https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132. + assert x.type() == hl.UInt(32) or x.type() == hl.Int(32) + x = hl.cast(hl.Int(32), x) + scale = hl.f64(4.6566127342e-10) + x = hl.select(x < 0, -x - 1, x) + return x * scale + + +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds): + def umulhi(a, b): + a = hl.cast(hl.UInt(64), a) + b = hl.cast(hl.UInt(64), b) + return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF)) + + for _ in range(n_rounds): + _c0, _c2 = c0, c2 + + c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0 + c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1 + c1 = PHILOX_ROUND_B_U32 * _c2 + c3 = PHILOX_ROUND_A_U32 * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A_U32 + k1 = k1 + PHILOX_KEY_B_U32 + + return c0, c1, c2, c3 + + +def halide_philox(seed, c0, c1, c2, c3, n_rounds): + seed = hl.cast(hl.UInt(64), seed) + + assert c0.type().bits() == 32 + + seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF)) + seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF)) + + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +def randint4x(seed, offset, n_rounds): + offset = hl.cast(hl.UInt(32), offset) + _0 = hl.u32(0) + return halide_philox(seed, offset, _0, _0, _0, n_rounds) + + +def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + i1, i2, i3, i4 = randint4x(seed, offset, n_rounds) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + u3 = _uint_to_uniform_float(i3) + u4 = _uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + source = randint(seed, offset, n_rounds) + return _uint_to_uniform_float(source) + + +def randn(seed, offset): + i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + n1, _ = _pair_uniform_to_normal(u1, u2) + return n1 + + +def randint64(seed, offset, low, high): + r0, r1, _r2, _r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + r0 = hl.cast(hl.UInt(64), r0) + r1 = hl.cast(hl.UInt(64), r1) + + result = r0 | (r1 << 32) + size = high - low + result = result % hl.cast(hl.UInt(64), size) + result = hl.cast(hl.Int(64), result) + low + return result diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py new file mode 100644 index 0000000000000000000000000000000000000000..e559eaa1a31d4d718bb0c95faac01f3e18cc2549 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/hints.py @@ -0,0 +1,221 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import functools +import typing +from enum import auto, Enum +from typing import Optional, Union + +from torch.utils._triton import has_triton_package + + +# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values +# NOTE: if these fail asserts submit a PR to increase them +TRITON_MAX_BLOCK = { + "X": 4096, + "Y": 1024, + "Z": 1024, + "R0_": 4096 * 16, # * 16 is multi-kernel only + "R1_": 2048 * 16, # * 16 is multi-kernel only +} +TRITON_MAX_RSPLIT = 64 + + +class ReductionHint(Enum): + INNER = 0 + OUTER = 1 + OUTER_TINY = 2 + DEFAULT = 3 + + +class TileHint(Enum): + SQUARE = 0 + DEFAULT = 1 + + +# Define `AttrsDescriptorWrapper` function with clear conditional handling +if has_triton_package(): + import triton + import triton.backends.compiler + import triton.compiler.compiler + + if hasattr(triton.backends.compiler, "AttrsDescriptor"): + # Triton 3.2.0 - the second implementation + from triton.backends.compiler import AttrsDescriptor + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "tt.divisibility": divisible_by_16, + "tt.equal_to": equal_to_1, + } + + # Instantiate AttrsDescriptor with the prepared arguments + res = AttrsDescriptor.from_dict( + {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__} + ) + assert res.property_values["tt.divisibility"] == 16 + assert res.property_values["tt.equal_to"] == 1 + return res + + elif hasattr(triton.compiler.compiler, "AttrsDescriptor"): + # Triton 3.0.0 - the original implementation + from triton.compiler.compiler import AttrsDescriptor + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "divisible_by_16": divisible_by_16, + "equal_to_1": equal_to_1, + } + + # Instantiate AttrsDescriptor with the prepared arguments + return AttrsDescriptor(**kwargs) + + else: + # Triton in 2025: + # note: there's also a range of triton commits not currently supported + # from ~Dec 9, 2024 to Jan 1 2025, in which AttrsDescriptors are still + # used, but the contents are different. + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16} + +else: + # Define a namedtuple as a fallback when AttrsDescriptor is not available + AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match] + "AttrsDescriptor", + ["divisible_by_16", "equal_to_1"], + defaults=[(), ()], + ) + + +_NUM_THREADS_PER_WARP = 32 + + +class HeuristicType(Enum): + PERSISTENT_REDUCTION = auto() + POINTWISE = auto() + REDUCTION = auto() + SPLIT_SCAN = auto() + TEMPLATE = auto() + USER_AUTOTUNE = auto() + FIXED = auto() + + +class AutotuneHint(Enum): + ONE_ELEMENT_PER_THREAD = 0 + + # Triton codegen tries to codegen set of AutotuneHints. + # Enum.__repr__ looks like """ + # which isn't valid python. + # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32". + __repr__ = Enum.__str__ + + +class DeviceProperties(typing.NamedTuple): + """Copy device properties into a data structure not requiring torch to be imported""" + + type: str # type: ignore[assignment] + index: int # type: ignore[assignment] + multi_processor_count: int + cc: int + major: Optional[int] = None + regs_per_multiprocessor: Optional[int] = None + max_threads_per_multi_processor: Optional[int] = None + warp_size: Optional[int] = None + + @classmethod + @functools.cache + def create(cls, device) -> DeviceProperties: + import torch + from torch._dynamo.device_interface import get_interface_for_device + + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + + device_interface = get_interface_for_device(device) + props = device_interface.get_device_properties(device) + try: + multi_processor_count = props.multi_processor_count + except AttributeError: + if device_type == "xpu": + multi_processor_count = props.gpu_subslice_count + elif device_type == "mps": + # TODO: Fetch the actual value from ioreg + multi_processor_count = 8 + else: + raise + return cls( + type=device_type, + index=device.index, + multi_processor_count=multi_processor_count, + cc=device_interface.get_compute_capability(device), + major=getattr(props, "major", None), + regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None), + max_threads_per_multi_processor=getattr( + props, "max_threads_per_multi_processor", None + ), + warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None), + ) + + +class HalideInputSpec(typing.NamedTuple): + ctype: str + name: str + shape: Optional[list[str]] = None + stride: Optional[list[str]] = None + offset: Optional[str] = None + alias_of: Optional[str] = None + + def bindings_type(self) -> str: + if self.ctype in ("at::Half*", "at::BFloat16*"): + return "uint16_t*" # half not defined + return self.ctype + + def halide_type(self) -> str: + if self.ctype == "at::Half*": + return "halide_type_t(halide_type_float, 16)" # half not defined + if self.ctype == "at::BFloat16*": + return "halide_type_t(halide_type_bfloat, 16)" # half not defined + return f"halide_type_of<{self.ctype.replace('*', '')}>()" + + def is_scalar(self) -> bool: + return self.shape is None + + def is_buffer(self) -> bool: + return self.shape is not None + + +class HalideMeta(typing.NamedTuple): + argtypes: list[HalideInputSpec] + target: str + scheduler: Optional[str] = None + scheduler_flags: Optional[dict[str, Union[int, str]]] = None + cuda_device: Optional[int] = None + + def args(self) -> list[str]: + """Command line args to pass to halide generator""" + args = [f"target={self.target}"] + if self.scheduler: + args.append(f"autoscheduler={self.scheduler}") + if self.scheduler_flags: + assert self.scheduler + for k, v in self.scheduler_flags.items(): + args.append(f"autoscheduler.{k}={v}") + return args + + def is_cuda(self) -> bool: + return self.cuda_device is not None diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..21cd5987f8f435134d707e0148b4c7b1e75a22c8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/runtime_utils.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import functools +import operator +from typing import Any, TYPE_CHECKING + +import torch + +# NOTE: other files rely on the imports below +from torch._dynamo import callback as compilation_callback # noqa: F401 +from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 + cache_dir, + default_cache_dir, + triton_cache_dir, +) + + +if TYPE_CHECKING: + from collections.abc import Hashable + + from .triton_compat import Config + + +def conditional_product(*args: int) -> int: + return functools.reduce(operator.mul, [x for x in args if x]) + + +def ceildiv(number: int, denom: int) -> int: + return -(number // -denom) + + +def is_power_of_2(n: int) -> bool: + """Returns whether n = 2 ** m for some integer m.""" + return n > 0 and n & n - 1 == 0 + + +def next_power_of_2(n: int) -> int: + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: + """ + Return the total number of bytes the arguments of tensor type takes. + + For in/out args, tensor sizes are counted twice: once for reading and + once for writing. + + The first num_in_out_args arguments are in out tensors. + """ + return sum( + arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args)) + for i, arg in enumerate(args) + if isinstance(arg, torch.Tensor) + ) + + +def triton_config_to_hashable(cfg: Config) -> Hashable: + """ + Convert triton config to a tuple that can uniquely identify it. We can use + the return value as a dictionary key. + """ + items = sorted(cfg.kwargs.items()) + items.append(("num_warps", cfg.num_warps)) + items.append(("num_stages", cfg.num_stages)) + return tuple(items) + + +def validate_triton_config(cfg: Config) -> None: + # [Note: Triton pre_hook in inductor] + # pre-hook is a lambda function, which we don't attempt to serialize. + # right now, if a pre-hook is attached to the config, it will not be saved; + # and then it won't be used when the config is loaded from cache. + # So we assert - if we do get a pre_hook, it might get ignored after caching. + assert getattr(cfg, "pre_hook", None) is None, ( + "triton configs with pre_hooks not supported" + ) + + +def create_bandwidth_info_str( + ms: float, + num_gb: float, + gb_per_s: float, + prefix: str = "", + suffix: str = "", + color: bool = True, +) -> str: + info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" + slow = ms > 0.012 and gb_per_s < 650 + return red_text(info_str) if color and slow else info_str + + +def get_max_y_grid() -> int: + return 65535 + + +try: + import colorama + + HAS_COLORAMA = True +except ModuleNotFoundError: + HAS_COLORAMA = False + colorama = None # type: ignore[assignment] + + +if HAS_COLORAMA: + + def _color_text(msg: str, color: str) -> str: + return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET + +else: + + def _color_text(msg: str, color: str) -> str: + return msg + + +def green_text(msg: str) -> str: + return _color_text(msg, "green") + + +def yellow_text(msg: str) -> str: + return _color_text(msg, "yellow") + + +def red_text(msg: str) -> str: + return _color_text(msg, "red") + + +def blue_text(msg: str) -> str: + return _color_text(msg, "blue") + + +def get_first_attr(obj: Any, *attrs: str) -> Any: + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") + + +dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type] + + +def triton_hash_to_path_key(key: str) -> str: + # In early versions of Triton, the hash is directly used in the path name. + # Later, the hash is converted to base64 before being used in the path name. + # Later, the base64 conversion was replaced to the base32 + # + # This code tries to import _base64 and falls back to _base32 if _base64 is unavailable. + # + # To handle this, try to import the to-base64-conversion function. + # If it exists, use it; otherwise, try using _base32; if both are unavailable, use the hash directly. + try: + from triton.runtime.cache import _base64 + + return _base64(key) + except Exception: + try: + from triton.runtime.cache import _base32 + + return _base32(key) + except Exception: + return key + + +def compile_mps_shader(source: str) -> Any: + """ + Compiles shader source but raise more actionable error message when needed + """ + try: + return torch.mps.compile_shader(source) + except SyntaxError as err: + raise SyntaxError(f"failed to compile {source} with {err.msg}") from err diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a52df4745f5907e6c1ed4c9b97e0e361f5fb5c27 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/static_cuda_launcher.py @@ -0,0 +1,237 @@ +import functools +import os +from typing import Any, Optional +from typing_extensions import Unpack + +from .triton_compat import ASTSource, CompiledKernel, knobs as triton_knobs + + +class StaticallyLaunchedCudaKernel: + """ + Parses the metadata of a CompiledKernel from Triton into a structure that can + launch the cuda kernel directly. Only works for triton kernels compiled to cubin. + + Doing this avoids C++ codegen and compilation during compile, since we can use a + statically compiled library to launch the kernel. To avoid mallocing for the arguments, + we have a launcher for different numbers of arguments up to a max. StaticCudaLauncher + only supports # of arguments up until 10 for now. + + Workflow: + Compile time: + 1. Compile a kernel with triton and get a CompiledKernel + 2. Instantiate kernel = StaticallyLaunchedCudaKernel(triton_kernel) + 3. Write to a cubin file: kernel.write_cubin_to_file(filepath) + 4. Call kernel.load_kernel() (CUDA should be initialized by this point) to load the cubin + Runtime: + 5. Call kernel.run(grid, stream, args) to launch the kernel + + Note that after step 3, StaticallyLaunchedCudaKernel is fully pickleable/serializable. + This allows it to be cached by FXGraphCache/TritonBundler, as well as sent from the worker + to the parent process in inductor. + + There are two main versions of triton that we wish to support: 3.3 and 3.2. Triton makes considerable changes + to how it handles constants in 3.3, so there's some special logic necessary to handle both versions. + """ + + def __init__(self, kernel: CompiledKernel) -> None: + self.name = kernel.src.fn.__name__ + self.cubin_raw = kernel.asm.get("cubin", None) + self.cubin_path = kernel._cubin_path + + # Used by torch.compile to filter constants in older triton versions + self.arg_names = kernel.src.fn.arg_names + + # Const exprs that are declared by the triton kernel directly + # Used to generate the kernel launcher's def args + self.declared_constexprs = kernel.src.fn.constexprs + + self.hash = kernel.hash + + if triton_knobs is None: + launch_enter = kernel.__class__.launch_enter_hook + launch_exit = kernel.__class__.launch_exit_hook + else: + launch_enter = triton_knobs.runtime.launch_enter_hook + launch_exit = triton_knobs.runtime.launch_exit_hook + + if launch_enter is not None or launch_exit is not None: + raise NotImplementedError( + "We don't support launch enter or launch exit hooks" + ) + self.num_warps = kernel.metadata.num_warps + self.shared = ( + kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared + ) + + # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel. + # Inductor never uses this field or enables it, but we still have to pass + # an extra None into the set of params if its enabled + if hasattr(kernel.metadata, "global_scratch_size"): + if kernel.metadata.global_scratch_size > 0: + raise NotImplementedError("Global scratch not yet supported") + else: + self.has_global_scratch = True + else: + self.has_global_scratch = False + + self.arg_tys = self.arg_ty_from_signature(kernel.src) + self.function: Optional[int] = ( + None # Loaded by load_kernel(on the parent process) + ) + num_ctas = 1 + if hasattr(kernel, "num_ctas"): + num_ctas = kernel.num_ctas + elif hasattr(kernel, "metadata"): + num_ctas = kernel.metadata.num_ctas + + if num_ctas != 1: + raise NotImplementedError( + "Static cuda launcher only supports num_ctas == 1" + ) + + def reload_cubin_from_raw(self, filepath: str) -> str: + """ + If the cubin file triton generated gets deleted under us, we can + reload it from the raw cubin file. + """ + if self.cubin_path is None: + assert self.cubin_raw is not None + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "wb") as f: + f.write(self.cubin_raw) + self.cubin_path = filepath + return self.cubin_path + + def load_kernel(self, device: int) -> None: + from torch._C import _StaticCudaLauncher + + if self.function is not None: + return + + assert hasattr(self, "cubin_path") + assert self.cubin_path is not None + (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel( + self.cubin_path, self.name, self.shared, device + ) + # Don't need the cubin path anymore now that we've loaded + self.cubin_path = None + self.cubin_raw = None + + @staticmethod + @functools.lru_cache + def type_mappings() -> dict[str, str]: + return { + "i1": "i", + "i8": "b", + "i16": "h", + "i32": "i", + "i64": "l", + "u1": "I", + "u8": "B", + "u16": "H", + "u32": "I", + "u64": "K", + "fp16": "f", + "bf16": "f", + "fp32": "f", + "f32": "f", + "fp64": "d", + # TODO handle nvTmaDesc/CUtensormap + } + + def extract_type(self, ty: str) -> str: + """ + Takes a triton type from CompiledKernel.signature and + converts it into a single char encoding. _StaticCudaLauncher + will switch on this char to figure out what type the underlying + value should be passed to the triton kernel as. + """ + if ty[0] == "*": + return "O" + elif ty == "nvTmaDesc": + raise NotImplementedError("nvTmaDesc kernels are not yet supported") + return StaticallyLaunchedCudaKernel.type_mappings()[ty] + + def arg_ty_from_signature(self, src: ASTSource) -> str: + def index_key(i: Any) -> int: + if isinstance(i, str): + return src.fn.arg_names.index(i) + elif isinstance(i, tuple): + # In triton 3.3, src.fn.constants has tuples as a key + return i[0] + else: + return i + + signature = {index_key(key): value for key, value in src.signature.items()} + # Triton uses these as the main way to filter out constants passed to their cubin + constants = [index_key(key) for key in getattr(src, "constants", dict())] + # This value is always a superset of kernel.fn.constexprs: kernel.fn.constexprs are + # constants declared by the triton kernel directly, whereas this list can have + # constants that are unused by the triton kernel that triton figured out during + # compilation. + self.full_constexprs = constants + # Despite requiring them to be passed in, the triton CUDA launcher + # completely ignores the constexprs passed into it when generating code. + # So we can ignore them here too + params = [] + + for i in sorted(signature.keys()): + ty = signature[i] + # In newer triton versions, constants are passed in to signature with type `constexpr` + # In older triton versions, there can be constants in src.constants that are not `constexpr` in signature + # so we check both here + if ty == "constexpr" or i in constants: + pass + else: + params.append(self.extract_type(ty)) + return "".join(params) + + def __getstate__(self) -> dict[str, Any]: + # Remove objects that are no longer valid for pickling + state = self.__dict__.copy() + state["function"] = None + # Cubin paths aren't consistent across processes, so we clear + # and reload them. + state["cubin_path"] = None + return state + + def run( + self, + grid_x: int, + grid_y: int, + grid_z: int, + stream: int, + *args: Unpack[tuple[object, ...]], + ) -> None: + """Actually run the kernel at runtime. This function is the hot codepath.""" + from torch._C import _StaticCudaLauncher + + # Assert load_kernel() has been called and args match + assert self.function is not None + + # TODO: actually, if the args *don't* match, we probably should + # throw an exception. But if inductor is the only one calling this + # thing, it should always match. + # Get rid of constants before passing to cubin launcher + + # Add a None if triton wants an extra parameter to the cubin + if self.has_global_scratch: + arg_tys = self.arg_tys + "O" + args = (*args, None) + else: + arg_tys = self.arg_tys + assert len(args) == len(arg_tys) + + # TODO: can handle grid functions here or in C++, so + # that we don't need the grid handler above. + _StaticCudaLauncher._launch_kernel( + self.function, + grid_x, + grid_y, + grid_z, + self.num_warps, + self.shared, + arg_tys, + args, + stream, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..d5aeb90d7d68455948740c190746b2d2d793a47b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import inspect +from typing import Any, Union + +import torch + + +try: + import triton +except ImportError: + triton = None + + +if triton is not None: + import triton.language as tl + from triton import Config + from triton.compiler import CompiledKernel + from triton.runtime.autotuner import OutOfResources + from triton.runtime.jit import KernelInterface + + try: + from triton.runtime.autotuner import PTXASError + except ImportError: + + class PTXASError(Exception): # type: ignore[no-redef] + pass + + try: + from triton.compiler.compiler import ASTSource + except ImportError: + ASTSource = None + + try: + from triton.backends.compiler import GPUTarget + except ImportError: + + def GPUTarget( + backend: str, + arch: Union[int, str], + warp_size: int, + ) -> Any: + if torch.version.hip: + return [backend, arch, warp_size] + return (backend, arch) + + # In the latest triton, math functions were shuffled around into different modules: + # https://github.com/triton-lang/triton/pull/3172 + try: + from triton.language.extra import libdevice + + libdevice = tl.extra.libdevice # noqa: F811 + math = tl.math + except ImportError: + if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math + elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): + libdevice = tl.extra.intel.libdevice + math = tl.math + else: + libdevice = tl.math + math = tl + + try: + from triton.language.standard import _log2 + except ImportError: + + def _log2(x: Any) -> Any: + raise NotImplementedError + + def _triton_config_has(param_name: str) -> bool: + if not hasattr(triton, "Config"): + return False + if not hasattr(triton.Config, "__init__"): + return False + return param_name in inspect.signature(triton.Config.__init__).parameters + + HAS_WARP_SPEC = ( + hasattr(tl, "async_task") + and _triton_config_has("num_consumer_groups") + and _triton_config_has("num_buffers_warp_spec") + ) + + try: + from triton import knobs + except ImportError: + knobs = None + + builtins_use_semantic_kwarg = ( + "_semantic" in inspect.signature(triton.language.core.view).parameters + ) +else: + + def _raise_error(*args: Any, **kwargs: Any) -> Any: + raise RuntimeError("triton package is not installed") + + class OutOfResources(Exception): # type: ignore[no-redef] + pass + + class PTXASError(Exception): # type: ignore[no-redef] + pass + + Config = object + CompiledKernel = object + KernelInterface = object + ASTSource = None + GPUTarget = None + _log2 = _raise_error + libdevice = None + math = None + knobs = None + builtins_use_semantic_kwarg = False + + class triton: # type: ignore[no-redef] + @staticmethod + def jit(*args: Any, **kwargs: Any) -> Any: + return _raise_error + + class tl: # type: ignore[no-redef] + @staticmethod + def constexpr(val: Any) -> Any: + return val + + tensor = Any + dtype = Any + + HAS_WARP_SPEC = False + + +def cc_warp_size(cc: Union[str, int]) -> int: + if torch.version.hip: + cc_str = str(cc) + if "gfx10" in cc_str or "gfx11" in cc_str: + return 32 + else: + return 64 + else: + return 32 + + +try: + autograd_profiler = torch.autograd.profiler +except AttributeError: # Compile workers only have a mock version of torch + + class autograd_profiler: # type: ignore[no-redef] + _is_profiler_enabled = False + + +__all__ = [ + "Config", + "CompiledKernel", + "OutOfResources", + "KernelInterface", + "PTXASError", + "ASTSource", + "GPUTarget", + "tl", + "_log2", + "libdevice", + "math", + "triton", + "cc_warp_size", + "knobs", +] diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd708bcf4bf81c0684af07f7f6e800e430552eb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_helpers.py @@ -0,0 +1,737 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math as pymath +import warnings +from functools import wraps +from typing import Any, Callable, TypeVar + +from .triton_compat import ( # noqa: F401 + _log2, + builtins_use_semantic_kwarg, + libdevice, + math, + tl, + triton, +) + + +_T = TypeVar("_T") +_LOG_2_E: tl.constexpr = tl.constexpr(pymath.log2(pymath.e)) + + +def set_driver_to_cpu(): + driver = triton.runtime.driver + if backend := triton.backends.backends.get("cpu", None): + if isinstance(driver.active, backend.driver): + # Don't re-initialize backend if it is already active + return + driver.set_active(backend.driver()) + return + # This can be a hard error once triton-cpu is merged into fbcode + warnings.warn( + "Could not find an active CPU backend. Generated kernels will not be executable!" + ) + + +def set_driver_to_gpu(): + driver = triton.runtime.driver + for name, backend in triton.backends.backends.items(): + if backend.driver.is_active() and name != "cpu": + # After https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4, + # `driver.active` can be of `LazyProxy` type and the sign of this - `_obj` attribute. + if ( + isinstance(driver.active, backend.driver) + or hasattr(driver.active, "_obj") + and isinstance(driver.active._obj, backend.driver) + ): + # Don't re-initialize backend if it is already active + return + driver.set_active(backend.driver()) + return + raise RuntimeError("Could not find an active GPU backend") + + +def get_backend_options(): + from triton.runtime import driver + + target = driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options(dict()) + return options.__dict__ + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def div_floor_integer(a, b): + # NOTE: a // b is C division, but we want floor division + # Based on c10::div_floor_integer + quot = a // b + remainder = a % b + fixed = tl.where(remainder != 0, quot - 1, quot) + return tl.where((a < 0) != (b < 0), fixed, quot) + + +@triton.jit +def remainder_integer(a, b): + # NOTE: a % b matches C division, not floor division + remainder = a % b + return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def _prod_accumulate(a, b): + return a * b + + +@triton.jit +def prod(input, axis): + return tl.reduce(input, axis, _prod_accumulate) + + +@triton.jit +def minimum(a, b): + mask = a < b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def maximum(a, b): + mask = a > b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def min2(a, dim): + return tl.reduce(a, dim, minimum) + + +@triton.jit +def max2(a, dim): + return tl.reduce(a, dim, maximum) + + +@triton.jit +def minimum_with_index(a_value, a_index, b_value, b_index): + mask = a_value < b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def maximum_with_index(a_value, a_index, b_value, b_index): + mask = a_value > b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def min_with_index(value, index, dim): + return tl.reduce((value, index), dim, minimum_with_index) + + +@triton.jit +def max_with_index(value, index, dim): + return tl.reduce((value, index), dim, maximum_with_index) + + +@triton.jit +def exp(x, use_fast_math: tl.constexpr): + if use_fast_math: + return libdevice.exp2(x * _LOG_2_E) + else: + return math.exp(x) + + +@triton.jit +def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr): + out_max = max2(lhs_max, dim) + out_max_keepdim = out_max[:, None] + delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim) + out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim) + return out_max, out_sum + + +@triton.jit +def online_softmax_combine(lhs_max, lhs_sum, rhs_max, use_fast_math: tl.constexpr): + """ + When we do combine, we assume lhs is the accumulator and rhs is the next + block of data. + Then rhs_sum is always 1. With that assumption, we can save some registers + and computation. + """ + out_max = maximum(lhs_max, rhs_max) + + lhs_scale = tl.where( + out_max == float("-inf"), 1.0, exp(lhs_max - out_max, use_fast_math) + ) + rhs_scale = tl.where( + out_max == float("-inf"), 1.0, exp(rhs_max - out_max, use_fast_math) + ) + + # Should be + # out_sum = lhs_sum * lhs_scale + rhs_sum * rhs_scale + # but since rhs_sum is all 1, we can simplify it. + out_sum = lhs_sum * lhs_scale + rhs_scale + return out_max, out_sum + + +@triton.jit +def welford_reduce(value, mean, m2, weight, first_iteration): + if first_iteration: + new_weight = tl.full(weight.shape, 1, weight.dtype) + new_mean = value + new_m2 = tl.zeros_like(m2) + else: + delta = value - mean + new_weight = weight + 1 + new_mean = mean + delta / new_weight + new_m2 = m2 + delta * (value - new_mean) + return new_mean, new_m2, new_weight + + +@triton.jit +def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight) + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +@triton.jit +def welford(mean, m2, weight, dim): + return tl.reduce((mean, m2, weight), dim, welford_combine) + + +@triton.jit +def device_assert_then(cond, msg, r): + tl.device_assert(cond, msg) + return r + + +@triton.jit +def randint64(seed, offset, low, high): + r0, r1, _r2, _r3 = tl.randint4x(seed, offset) + r0 = r0.to(tl.uint64) + r1 = r1.to(tl.uint64) + result = r0 | (r1 << 32) + size = high - low + result = result % size.to(tl.uint64) + result = result.to(tl.int64) + low + return result + + +@triton.jit +def _any_combine(a, b): + return a | b + + +@triton.jit +def any(a, dim): + return tl.reduce(a, dim, _any_combine) + + +@triton.jit +def bucketize_binary_search( + values: tl.tensor, + boundaries_ptr: tl.tensor, + BOUNDARIES_SIZE: int, + BOUNDARIES_UNDERLYING_NUMEL: int, + BOUNDARIES_STRIDE: int, + boundary_indices: tl.tensor, + indexing_dtype: tl.dtype, + right: "bool", # triton can't handle the unquoted bool annotation + sorter_ptr: tl.tensor, + SORTER_STRIDE: int, + sorter_indices: tl.tensor, +): + """ + See [Note: Inductor bucketize op] + + Inputs: + ------- + values: the values to bucketize. + boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D. + BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one + individual set of boundaries). + BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring + any striding. + BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor + boundary_indices: a tensor of the same size as "values"; each element is an index + into a 1-D, un-strided boundaries tensor, pointing to the first element in the set + of boundaries used for that value. + indexing_dtype: the dtype used for indexing into the boundaries tensor, and the + return dtype. + right: if true, use boundary intervals closed on the left; otherwise use intervals + closed on the right. + sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries, + but potentially different striding. If present, this allows us to treat boundaries + as sorted even if the elements of boundaries are unsorted. + SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last + dimension of the sorter tensor. + sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices". + BLOCK_SHAPE: the shape of the data block being processed. + """ + + low = tl.zeros(values.shape, dtype=indexing_dtype) + high = tl.full(values.shape, BOUNDARIES_SIZE, dtype=indexing_dtype) + + full_range = BOUNDARIES_SIZE + 1 + while full_range > 1: + mid = (high + low) // 2 + mask = ( + mid * BOUNDARIES_STRIDE + boundary_indices + ) < BOUNDARIES_UNDERLYING_NUMEL and mid < BOUNDARIES_SIZE + mid_indices = ( + mid + if sorter_ptr is None or SORTER_STRIDE is None + else tl.load( + sorter_ptr + sorter_indices + SORTER_STRIDE * mid, + mask=mask, + other=0, + ) + ) + + bucket_upper_bound = tl.load( + boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices, + mask=mask, + other=0, + ) + if right: + is_above = values >= bucket_upper_bound + else: + is_above = values > bucket_upper_bound + + low = tl.where(is_above & mask, mid + 1, low) + high = tl.where(is_above, high, mid) + + full_range = (full_range + 1) // 2 + + return low + + +@triton.jit +def pack_value_flag( + value, + flag, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK) + return flag.to(DTYPE_PACK) | (uv << bitwidth) + + +@triton.jit +def unpack_value( + pack, + DTYPE_VALUE, + DTYPE_VALUE_AS_UINT, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE = tl.core._unwrap_if_constexpr(DTYPE_VALUE) + DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT) + return value_uint.to(DTYPE_VALUE, bitcast=True) + + +@triton.jit +def unpack_flag(pack, DTYPE_FLAG): + return pack.to(DTYPE_FLAG) + + +@triton.jit +def exclusive_scan_decoupled_lookback( + scratch_base, + block_value, + index, + combine_fn, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value`` + DTYPE_PACK: Unsigned type twice the width of block_value + + NOTE: This function is limited to values which are 32-bits or less because + we need to pack (value, flag) into a single unsigned int. + """ + # Publish block sum so subsequent blocks don't get stuck waiting for us + DTYPE_VALUE = block_value.dtype + pack = pack_value_flag( + block_value, + tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + if index > 0: + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + + # Calculate exclusive prefix scan + exclusive_prefix = tl.zeros([], DTYPE_VALUE) + prefix_valid = False + test_target = index - 1 + while test_target >= 0: + # tl.atomic_load + flag = tl.full([], 0, DTYPE_VALUE_AS_UINT) + while flag == 0: + pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed") + flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT) + + value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT) + if prefix_valid: + exclusive_prefix = combine_fn(value, exclusive_prefix) + else: + exclusive_prefix = value + prefix_valid = True + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + if prefix_valid: + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + else: + inclusive_prefix = block_value + pack = pack_value_flag( + inclusive_prefix, + tl.full([], 2, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + return exclusive_prefix + + +@triton.jit +def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block, must be 64-bits wide + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + init: Scalar value equal to the identity of combine_fn + """ + # Publish block sum so subsequent blocks don't get stuck waiting for us + if index > 0: + block_value_u64 = block_value.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 1, block_value_u64) + tl.debug_barrier() + flag_one = tl.full([], 1, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release") + + # Calculate exclusive prefix scan + exclusive_prefix = tl.zeros([], block_value.dtype) + prefix_valid = False + test_target = index - 1 + while test_target >= 0: + flag = tl.full([], 0, tl.uint64) + while flag == 0: + flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire") + + value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32)) + value = value_u64.to(block_value.dtype, bitcast=True) + if prefix_valid: + exclusive_prefix = combine_fn(value, exclusive_prefix) + else: + exclusive_prefix = value + prefix_valid = True + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + if prefix_valid: + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + else: + inclusive_prefix = block_value + inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64) + tl.debug_barrier() + flag_two = tl.full([], 2, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release") + + return exclusive_prefix + + +@triton.jit +def frexp(x): + # TODO(isuruf): use inline_asm_elementwise here + y = libdevice.ilogb(x) + 1 + exponent = tl.where(x == 0, 0, y) + mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y)) + return mantissa, exponent + + +@triton.jit +def _compare_and_swap_with_index( + x, + idxs, + rnumel, + flip, + i: tl.constexpr, + n_dims: tl.constexpr, + stable: tl.constexpr, + descending: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] + + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + + y = tl.reshape(x, shape) + iy = y.to(idtype, bitcast=True) + # slice left/right with 'stride' 2**(n_dims - i - 1) + right_mask = tl.arange(0, 2)[None, :, None].to(idtype) + left_mask = (1 - right_mask).to(idtype) + ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape) + iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape) + ileft = tl.reshape(ileft, x.shape) + iright = tl.reshape(iright, x.shape) + left = ileft.to(x.dtype, bitcast=True) + right = iright.to(x.dtype, bitcast=True) + + # idx + y_idx = tl.reshape(idxs, shape) + left_idx = tl.broadcast_to( + tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape + ) + right_idx = tl.broadcast_to( + tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape + ) + left_idx = tl.reshape(left_idx, x.shape) + right_idx = tl.reshape(right_idx, x.shape) + + # valid + if rnumel is None: + left_valid_mask = tl.full(x.shape, True, tl.int1) + right_valid_mask = tl.full(x.shape, True, tl.int1) + else: + left_valid_mask = left_idx < rnumel + right_valid_mask = right_idx < rnumel + + # actual compare-and-swap + ix = x.to(idtype, bitcast=True) + + if descending: + cond = left < right + else: + cond = left > right + + if stable: + # When stable sorting, tie break by index + cond = cond | ((left == right) & (left_idx > right_idx)) + + cond = (right_valid_mask > left_valid_mask) | ( + (right_valid_mask == left_valid_mask) & cond + ) + cond = (cond ^ flip).to(tl.int1) + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs)) + + return ret.to(x.dtype, bitcast=True), new_idxs + + +@triton.jit +def _bitonic_merge_with_index( + x, + idxs, + rnumel, + stage: tl.constexpr, + alternating: tl.constexpr, + n_dims: tl.constexpr, + stable: tl.constexpr, + descending: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if alternating: + shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape( + tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape + ) + else: + flip = False + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, idxs = _compare_and_swap_with_index( + x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending + ) + return x, idxs + + +@triton.jit +def sort_with_index( + x, # value + idxs, # index + rnumel, # number of elements + dim: tl.constexpr = None, + stable: tl.constexpr = tl.constexpr(False), + descending: tl.constexpr = tl.constexpr(False), +): + x, idxs = tl.broadcast(x, idxs) + # handle default dimension or check that it is the most minor dim + _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim + tl.static_assert( + _dim == len(x.shape) - 1, "only minor dimension is currently supported" + ) + # iteratively run bitonic merge-sort steps + n_dims: tl.constexpr = _log2(x.shape[_dim]) + + for i in tl.static_range(1, n_dims + 1): + x, idxs = _bitonic_merge_with_index( + x, + idxs, + rnumel, + i, + alternating=i < n_dims, + n_dims=n_dims, + stable=stable, + descending=descending, + ) + return x, idxs + + +@triton.jit +def select_one(x, mask, dim, keep_dims=False): + idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False) + ix = x.to(idtype, bitcast=True) + iy = tl.sum(ix * mask, dim, keep_dims=keep_dims) + return iy.to(x.dtype, bitcast=True) + + +@triton.jit +def x_grid_barrier(sem): + """ + Wait for all other thread blocks in grid sharing same y/z program_id + to reach this barrier before returning. + + Args: + sem: an uint32 semaphores, zero or 0x80000000 initialized. Must be unique to each y/z program ID. + """ + # ensure stores before this are visible + tl.debug_barrier() + + one_i32 = 1 + one_u32 = one_i32.to(tl.uint32) # type: ignore[attr-defined] + expected = tl.num_programs(0).to(tl.uint32) + if tl.program_id(0) == 0: + nb = 0x80000000 - (expected - one_u32) + else: + nb = one_u32 + + old_arrive = tl.atomic_add(sem, nb, sem="release") + + bar_flipped = False + while not bar_flipped: + # want a `ld.acquire.gpu.u32 $0,[$1];` but Triton doesn't have it + current_arrive = tl.atomic_add(sem, 0, sem="acquire") + # current_arrive = tl.load(sem, volatile=True) + bar_flipped = ((old_arrive ^ current_arrive) & 0x80000000) != 0 + + # TODO(jansel): is this needed? + tl.debug_barrier() + + +def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]: + """ + Decorator to mark a function as a Triton built-in function. These functions + are evaluated at compile time. + + Args: + f (function): The function to be marked as a Triton built-in. + + Returns: + function: The same function, marked as a Triton built-in. + """ + if builtins_use_semantic_kwarg: + # support Triton before and after https://github.com/triton-lang/triton/pull/7054 + @wraps(f) + def wrapper(*args, **kwargs): + kwargs["_builder"] = kwargs["_semantic"] + del kwargs["_semantic"] + return f(*args, **kwargs) + else: + wrapper = f # type: ignore[assignment] + + wrapper.__triton_builtin__ = True # type: ignore[attr-defined] + return wrapper + + +@triton_builtin +def constexpr_next_power_of_2( + n: tl.constexpr, *, _builder: object = None +) -> tl.constexpr: + """ + A version triton.next_power_of_two that can be used within a kernel on constants. + """ + assert isinstance(n, tl.constexpr) + return tl.constexpr(triton.next_power_of_2(n.value)) + + +@triton_builtin +def if_mask(mask: Any, val, *, _builder: object = None) -> tl.constexpr: + """ + Work around triton compile error: `ValueError: `other` cannot be provided without `mask`` + A compile-time to check to return either `val` or `None` depending on the value of mask. + """ + if isinstance(mask, tl.constexpr) and mask.value is None: + return tl.constexpr(None) + return val diff --git a/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..4c50768b7188c3ecf8642982d78c5753e9e45c8d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py @@ -0,0 +1,3022 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import builtins +import copy +import dataclasses +import functools +import hashlib +import inspect +import itertools +import logging +import math +import operator +import os +import os.path +import re +import sys +import threading +import time +from collections import namedtuple +from typing import ( + Any, + Callable, + Generic, + Literal, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) + +import torch +from torch._dynamo.utils import set_feature_use +from torch._prims_common import compute_required_storage_length +from torch.utils._ordered_set import OrderedSet + +from ..triton_bundler import TritonBundler +from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict +from . import triton_helpers +from .autotune_cache import AutotuneCache +from .benchmarking import benchmarker +from .coordinate_descent_tuner import CoordescTuner +from .hints import ( + _NUM_THREADS_PER_WARP, + AutotuneHint, + DeviceProperties, + HeuristicType, + ReductionHint, + TileHint, + TRITON_MAX_BLOCK, + TRITON_MAX_RSPLIT, +) +from .runtime_utils import ( + ceildiv, + compilation_callback, + conditional_product, + create_bandwidth_info_str, + dynamo_timed, + get_first_attr, + get_max_y_grid, + get_num_bytes, + next_power_of_2, + triton_cache_dir, + triton_config_to_hashable, + triton_hash_to_path_key, + validate_triton_config, +) +from .static_cuda_launcher import StaticallyLaunchedCudaKernel +from .triton_compat import ( + ASTSource, + autograd_profiler, + cc_warp_size, + CompiledKernel, + Config, + GPUTarget, + HAS_WARP_SPEC, + KernelInterface, + knobs, + OutOfResources, + PTXASError, + triton, +) + + +class NoTritonConfigsError(RuntimeError): + pass + + +if TYPE_CHECKING: + from collections.abc import Container, Hashable + + from torch._guards import CompileId + + LauncherType = Any + +_KernelType = Union[CompiledKernel, StaticallyLaunchedCudaKernel] +_T = TypeVar("_T", bound=_KernelType) + +log = logging.getLogger(__name__) + + +def get_total_reduction_numel(numels: dict[str, int]) -> int: + return conditional_product( + *[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)] + ) + + +def autotune_hints_to_configs( + hints: OrderedSet[AutotuneHint], + size_hints, + block_size: int, + device_props: DeviceProperties, +) -> list[Config]: + """ + AutotuneHints can be attached to the metadata of triton kernels for providing + suggestions about what to try for autotuning. One reason to do this is if there are + some configs that are only useful in specific scenarios, in which case we can avoid + wasting compile time on autotuning unless we know we are in one of those scenarios. + + Based on those hints, this function will generate a list of additional autotuning + configs to try. + """ + xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...] + configs: list[Config] = [] + for hint in hints: + if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD: + if len(size_hints) == 1: + xyz_options = ((block_size // 4, None, None),) + elif len(size_hints) == 2: + xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None)) + elif len(size_hints) == 3: + xyz_options = ( + (block_size // 4, 1, 1), + (1, block_size // 4, 1), + (1, 1, block_size // 4), + ) + configs.extend( + triton_config( + size_hints, + *xyz, + num_elements_per_warp=( + device_props.warp_size if device_props.warp_size else 32 + ), + ) + for xyz in xyz_options + ) + + return configs + + +def disable_pointwise_autotuning(inductor_meta): + # Autotuning can give different benchmarking results from run to run, and + # therefore we disable autotuning when use_deterministic flag is on. + if inductor_meta.get("are_deterministic_algorithms_enabled"): + return True + return not inductor_meta.get("autotune_pointwise", True) + + +def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): + call_args = [] + call_kwargs = {} + for arg in args: + if isinstance(arg, (int, bool)): + call_args.append(str(arg)) + else: + call_args.append("T") + for k, v in kwargs.items(): + if isinstance(arg, (int, bool)): + call_kwargs[k] = v + else: + call_kwargs[k] = v + if not triton_version_uses_attrs_dict(): + call_kwargs.update(launcher.config.kwargs) + call_kwargs["num_warps"] = launcher.config.num_warps + call_kwargs["num_stages"] = launcher.config.num_stages + if HAS_WARP_SPEC: + call_kwargs["num_consumer_groups"] = getattr( + launcher.config, "num_consumer_groups", 0 + ) + call_kwargs["num_buffers_warp_spec"] = getattr( + launcher.config, "num_buffers_warp_spec", 0 + ) + args_str = [*call_args] + args_str.extend(f"{k}={v}" for k, v in call_kwargs.items()) + args_str = ", ".join(args_str) + abs_path = os.path.abspath(sys.argv[0]) + with open(f"{abs_path}.launch_params", "a") as f: + f.write(f"{kernel_name} | {args_str} | {grid!r}\n") + + +def check_autotune_cache( + configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any] +) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]: + """ + Given a list of configs, checks autotune cache and return metadata + """ + autotune_cache = None + autotune_cache_info = {} + disabled = inductor_meta.get("force_disable_caches", False) + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) + and not os.environ.get("TRITON_INTERPRET", "0") == "1" + ): + configs_hash = hash_configs(configs) + + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + if best_config := autotune_cache.read_best(inductor_meta, configs): + configs = [best_config] + autotune_cache_info["best_config"] = triton_config_to_hashable( + best_config + ) + autotune_cache_info["autotune_cache_state"] = "hit" + + else: + autotune_cache_info["autotune_cache_state"] = "miss" + autotune_cache_info["num_configs"] = len(configs) + if inductor_meta.get("coordinate_descent_tuning"): + autotune_cache_info["coordesc_tuning"] = True + if len(configs) == 1: + # This is the config that coordinate descent tuning started at, which + # is not the same as the final config chosen (i.e. only_config, best_config) + autotune_cache_info["coordesc_tuning_start_config"] = ( + triton_config_to_hashable(configs[0]) + ) + else: + if len(configs) == 1: + autotune_cache_info["autotune_cache_state"] = "only 1 config" + autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0]) + + if disabled: + autotune_cache_info["autotune_cache_state"] = "force_disabled" + log.debug("autotune caching is disabled by config.force_disable_caches") + + return configs, autotune_cache, autotune_cache_info + + +class CachingAutotuner(KernelInterface): + """ + Simplified version of Triton autotuner that has no invalidation + key and caches the best config to disk to improve cold start times. + Unlike the main triton Autotuner, this version can precompile all + configs, and does not rely on the Triton JIT. + """ + + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names: list[str], # see [Note: clone mutated buffers] + optimize_mem, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + filename: Optional[str] = None, + reset_to_zero_arg_names: Optional[list[str]] = None, + autotune_cache_info: Optional[dict[str, Any]] = None, + ): + super().__init__() + + assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" + # makes sure there are no pre-hooks on any of the triton configs + for cfg in configs: + validate_triton_config(cfg) + + self.fn = fn + self.device_props: DeviceProperties = triton_meta["device"] + self.triton_meta = { + **triton_meta, + "device": self.device_props.index, + "device_type": self.device_props.type, + } + self.inductor_meta = {} if inductor_meta is None else inductor_meta + self.save_cache_hook = save_cache_hook + self.mutated_arg_names = mutated_arg_names + self.reset_to_zero_arg_names = ( + [] if reset_to_zero_arg_names is None else reset_to_zero_arg_names + ) + self.optimize_mem = optimize_mem + self.configs = configs + self.heuristic_type = heuristic_type + self.custom_kernel = custom_kernel + self.cuda_kernel_saved = False + self.autotune_cache_info = autotune_cache_info + if log.isEnabledFor(logging.DEBUG): + log.debug( + "CachingAutotuner gets %d configs for %s", + len(self.configs), + self.fn.__name__, + ) + for c in self.configs: + log.debug(c) + + self.compile_results: list[CompileResult[_KernelType]] = [] + self.launchers: list[LauncherType] = [] + self.lock = threading.Lock() + if os.getenv("TRITON_CACHE_DIR") is None: + os.environ["TRITON_CACHE_DIR"] = triton_cache_dir( + self.triton_meta.get("device", 0) + ) + log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"]) + + self.size_hints = size_hints + self.coordesc_tuner = CoordescTuner( + is_mm=False, + name=self.fn.__name__, + size_hints=size_hints, + inductor_meta=self.inductor_meta, + ) + self.filename = filename + + # used for profiling + self.kernel_hash: str = "" + + # Kernels are stored in the codecache with the filename as a hash of the code. + # We rely on this to obtain the kernel hash + if self.filename is not None: + base_name = os.path.basename(self.filename) + if ".py" in base_name: + self.kernel_hash = os.path.splitext(base_name)[0] + + self.precompile_time_taken_ns = 0 + self.autotune_time_taken_ns = 0 + # Dumps the launch configs after autotuning. + self.dump_launch_params = ( + os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1" + ) + + self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" + + # Compile-time info included in runtime logginging + self.compile_id: Optional[CompileId] = None + self.is_backward = False + + def is_statically_launchable(self): + """ + Checks if every compiled kernel is statically launchable, which + allows us to efficiently cache it in FXGraphCache + """ + if not self.compile_results: + return False + return all( + isinstance(x, StaticTritonCompileResult) for x in self.compile_results + ) + + def recheck_autotune_cache( + self, reload_kernel_from_src: Callable[[], CachingAutotuner] + ) -> None: + """ + On cache load on static autotuner, we need to recheck the autotune cache, since + a best config could have been found from a previous run + """ + assert self.is_statically_launchable() + + configs = [result.config for result in self.compile_results] + + (cached_configs, _, autotune_cache_info) = check_autotune_cache( + configs, self.filename, self.inductor_meta + ) + self.autotune_cache_info = autotune_cache_info + # I.e. there was an autotune cache hit + if len(cached_configs) == 1 and len(configs) > 1: + best_config = cached_configs[0] + # Grab the best compiled config, if it's in the list of available ones + best_config_hash = triton_config_to_hashable(best_config) + + for compile_result in self.compile_results: + if triton_config_to_hashable(compile_result.config) == best_config_hash: + self.compile_results = [compile_result] + return + + # If the best config isn't in our list of compile results, + # it's likely because it was found by coordesc after the cache + # already saved + if best_config.found_by_coordesc: + with dynamo_timed("CachingAutotuner.slow_precompile_config"): + if self.fn.fn is None: + self.fn = reload_kernel_from_src().fn + self.compile_results = [self._precompile_config(best_config)] + + def set_compile_info( + self, compile_id: Optional[CompileId], is_backward: bool + ) -> None: + self.compile_id = compile_id + self.is_backward = is_backward + + def precompile( + self, + warm_cache_only=False, + reload_kernel: Optional[Callable[[], CachingAutotuner]] = None, + static_triton_bundle_key: Optional[str] = None, + ): + if warm_cache_only: + self._precompile_worker() + return + with self.lock: + # Helper function for reloading a kernel generated in a worker + # in the parent class. Normally we don't need to reload the kernel + # in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock), + # we need to actually run compilation on the parent process + if reload_kernel is not None: + self._reload_kernel = reload_kernel + self._precompile_worker() + if static_triton_bundle_key is not None and self.is_statically_launchable(): + TritonBundler.put_static_autotuner(static_triton_bundle_key, self) + self._make_launchers() + self._dynamic_scale_rblock() + + def _precompile_worker(self): + if self.compile_results: + for result in self.compile_results: + TritonBundler.put( + triton_hash_to_path_key(result.kernel.hash), # type: ignore[attr-defined] + self.triton_meta.get("device", 0), + ) + return + assert not self.launchers + if not self.configs: + raise NoTritonConfigsError("No triton configs are available") + + compile_results = [] + exc = None + for c in self.configs: + try: + compile_results.append(self._precompile_config(c)) + except (OutOfResources, PTXASError) as e: + exc = e + if len(compile_results) == 0: + raise NoTritonConfigsError( + f"No valid triton configs. {type(exc).__name__}: {exc}" + ) + self.compile_results = compile_results + self.configs = None + + def _dynamic_scale_rblock(self): + # TODO(jansel): we should find a way to move this extra compile into the worker process + # Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg. + device_prop = self.device_props + if ( + self.inductor_meta.get("dynamic_scale_rblock", True) + and not self.inductor_meta.get("persistent_reduction") + and self.heuristic_type == HeuristicType.REDUCTION + and self.size_hints is not None + # Disable for Intel as Triton is not ready to return n_regs for a compiled_binary. + and device_prop.type in ["cuda", "hip"] + and device_prop.major + and (device_prop.major >= 8 or torch.version.hip) + and device_prop.regs_per_multiprocessor is not None + ): + assert device_prop.regs_per_multiprocessor + assert device_prop.max_threads_per_multi_processor + assert device_prop.multi_processor_count + seen_config_hashes: Optional[OrderedSet[Hashable]] = None + warp_size = device_prop.warp_size or 32 + for result in self.compile_results: + triton_config = result.config + compiled_binary = result.kernel + assert len(self.size_hints) >= 2 + xblock = triton_config.kwargs.get("XBLOCK", 1) + reduction_kwargs = [ + kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R") + ] + rblocks = [triton_config.kwargs[kwarg] for kwarg in reduction_kwargs] + total_block = (self.size_hints["x"] + xblock - 1) // xblock + nreg = getattr(compiled_binary, "n_regs", None) + if nreg is None: + continue + + # make sure rblocks are not too small + if conditional_product(*rblocks) <= 64: + continue + + # each SM of A100 has 65536 32-bit registers. To maximize + # the theoretical occupancy, we need run 2048 threads on each + # SM. So each thread should use no more than 65536 / 2048 + # = 32 registers. In cases where occupancy matters, and each + # thread uses too many registers, reduce R0_BLOCK to reduce + # the register usage. + # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd + # from PLBartForCausalLM, latency improve from + # 7.795ms to 4.883ms. + # + if ( + nreg + <= device_prop.regs_per_multiprocessor + // device_prop.max_threads_per_multi_processor + ): + continue + + nreg_per_warp = nreg * warp_size + nreg_per_block = nreg_per_warp * triton_config.num_warps + + # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' + # The formula below is a tighter upper bound since we have the assumption that + # nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor + # due to the if condition above and: + # regs_per_multiprocessor / nreg_per_block + # = regs_per_multiprocessor / (nreg * 32 * num_warps) + # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps) + # = max_threads_per_multi_processor / (32 * num_warps) + # Using a tigher upper bound can reveal more optimization opportunities. + max_blocks_per_sm = max( + device_prop.regs_per_multiprocessor // nreg_per_block, 1 + ) + + if total_block <= max_blocks_per_sm * device_prop.multi_processor_count: + # no need to improve occupancy + continue + new_config = copy.deepcopy(triton_config) + + # Reduce the largest Rn_BLOCK by a factor of 2. + largest_rkwarg: str = max( + reduction_kwargs, key=triton_config.kwargs.__getitem__ + ) + new_config.kwargs[largest_rkwarg] //= 2 + + if seen_config_hashes is None: + seen_config_hashes = OrderedSet( + [ + triton_config_to_hashable(x.config) + for x in self.compile_results + ] + ) + new_config_hash = triton_config_to_hashable(new_config) + if new_config_hash in seen_config_hashes: + continue + seen_config_hashes.add(new_config_hash) + log.debug( + "Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)", + largest_rkwarg, + triton_config, + new_config, + ) + if self.fn.fn is None: + """ + We are in the parent process, while this program was compiled in a worker + and the fn was dropped in prepare_for_pickle(). We haven't loaded the module + containing the real fn yet. + """ + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn + self.compile_results.append(self._precompile_config(new_config)) + + self._make_launchers() + + def _make_launchers(self): + if len(self.launchers) == len(self.compile_results): + return + + from torch._dynamo.device_interface import DeviceGuard + + device_interface = self.get_device_interface() + + # load binary to the correct device + with DeviceGuard(device_interface, self.triton_meta["device"]): + # need to initialize context + device_interface.synchronize(device_interface.current_device()) + launchers = [] + exc = None + for result in self.compile_results: + try: + launchers.append(result.make_launcher()) + + except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e: + exc = e + if len(launchers) == 0: + raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") + self.launchers = launchers + + def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]: + """Drop stuff from triton.JITFunction that does not pickle. + This must be called after precompile so that these things are no longer needed. + Returns a tuple of old values + """ + old_values = ( + self.fn.fn, + self.fn.__globals__, + self.fn.used_global_vals, + self.fn.repr, + self.launchers, + ) + self.fn.fn = None + self.fn.__globals__ = None + self.fn.used_global_vals = None + self.fn.repr = _ConstRepr(self.fn.repr(self.fn)) + self.launchers = [] + return old_values + + def prepare_for_caching(self) -> None: + """ + Statically Launched CUDA Kernels have a raw cubin on them + that we don't need to store in the cache(since TritonBundler handles the collection for us) + """ + for result in self.compile_results: + if isinstance(result, StaticTritonCompileResult): + # Don't save this in the inductor cache, as it is very large + result.kernel.cubin_raw = None + + def __getstate__(self) -> dict[str, Any]: + assert not self.launchers, ( + "pickle should not be called with after make_launchers()" + ) + return { + **self.__dict__, + "lock": None, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__.update(state) + self.lock = threading.Lock() + + def get_device_interface(self): + # this code cannot run in compile workers, because it imports from torch + from torch._dynamo.device_interface import get_interface_for_device + + return get_interface_for_device(self.device_props.type.replace("hip", "cuda")) + + def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + cfg_kwargs = cfg.kwargs + if self.device_props.type == "hip": + cfg_kwargs = {**cfg_kwargs} + for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"): + if k in cfg_kwargs: + compile_meta[k] = cfg_kwargs.pop(k) + compile_meta["constants"].update(cfg_kwargs) + for i in self.fn.constexprs: + arg_name = self.fn.arg_names[i] + if arg_name not in compile_meta["constants"] and ( + arg_name == "num_warps" or arg_name == "num_stages" + ): + compile_meta["constants"][arg_name] = getattr(cfg, arg_name) + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + if HAS_WARP_SPEC: + compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0) + compile_meta["num_buffers_warp_spec"] = getattr( + cfg, "num_buffers_warp_spec", 0 + ) + compile_meta["debug"] = self.inductor_meta.get( + "assert_indirect_indexing", True + ) and not self.inductor_meta.get("is_hip", False) + + # device type will be "hip" rather than "cuda" here + compile_meta["device_type"] = self.device_props.type + compile_meta["cc"] = self.device_props.cc + + if self.device_props.type == "cpu": + triton_helpers.set_driver_to_cpu() + else: + triton_helpers.set_driver_to_gpu() + + if not ASTSource: + raise RuntimeError("Installed triton version too old, please upgrade") + + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + compile_meta["configs"][0], + ), + ) + + if self.device_props.type == "mtia": + from mtia.host_runtime.torch_mtia.acc_flags import ( # type: ignore[import-not-found] + build_codename, + ) + + arch = build_codename() + else: + arch = compile_meta["cc"] + + target = GPUTarget( + compile_meta["device_type"], + arch, + cc_warp_size(compile_meta["cc"]), + ) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + "sanitize_overflow": False, # turn off additional asserts added for overflow checks + } + if HAS_WARP_SPEC: + options.update( + { + "num_consumer_groups": compile_meta.get("num_consumer_groups", 0), + "num_buffers_warp_spec": compile_meta.get( + "num_buffers_warp_spec", 0 + ), + } + ) + if self.device_props.type == "hip": + if "waves_per_eu" in compile_meta: + options["waves_per_eu"] = compile_meta["waves_per_eu"] + if "matrix_instr_nonkdim" in compile_meta: + options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"] + compile_kwargs = { + "target": target, + "options": options, + } + + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception: + log.exception( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + raise + TritonBundler.put( + triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0) + ) + # If the binary has a cubin file to directly launch, save it on the binary + static_launcher = StaticTritonCompileResult.can_statically_launch( + binary, self.inductor_meta, self.triton_meta, self.heuristic_type + ) + + if static_launcher is not None: + result = StaticTritonCompileResult( + static_launcher, cfg, compile_meta, self.inductor_meta + ) + return result + + return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta) + + def _get_args_with_constexprs(self, args, launcher): + """ + `args` is passed in with only the non-constexpr args (because the constexpr arg values + depend on the config). However, in later triton versions, the constexpr args need to be + added into the args list. + """ + if triton_version_uses_attrs_dict(): + # first: aggregate the constexpr args in (index, val) pairs + # so we can sort them by index. + constexpr_args: list[tuple[int, Any]] = [] + for arg_name, arg_val in launcher.config.kwargs.items(): + if arg_name in self.fn.arg_names: + constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) + + constexpr_args.sort() + new_args = [*args] + for arg_idx, arg_val in constexpr_args: + new_args.insert(arg_idx, arg_val) + + return new_args + return args + + def bench(self, launcher, *args, with_profiler=False, **kwargs): + """Measure the performance of a given launcher""" + # we don't skip configs with spilled registers when auto-tuning custom + # (user-written) Triton kernels, as (i) we don't have any knowledge or + # control over the kernel code; (ii) there is empirical evidence that + # for some (complicated) custom Triton kernels, a register-spilling + # config may yield the best latency. + if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( + "spill_threshold", 16 + ): + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + device_interface = self.get_device_interface() + stream = device_interface.get_raw_stream(device_interface.current_device()) + + cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs) + + def kernel_call(): + cloned_args, cloned_kwargs = self.maybe_clone_args( + cpu_copies, *args, **kwargs + ) + # reset to zero before evaluating any config + self.reset_to_zero_args(*args, **kwargs) + args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher) + launcher( + *args_with_constexprs, + **cloned_kwargs, + stream=stream, + ) + self.restore_args_from_cpu(cpu_copies) + + if with_profiler: + from torch._inductor.utils import do_bench_using_profiling + + return do_bench_using_profiling(kernel_call, warmup=10, rep=40) + + if self.device_props.type == "cpu": + return benchmarker.benchmark_cpu(kernel_call) + + return benchmarker.benchmark_gpu(kernel_call, rep=40) + + def copy_args_to_cpu_if_needed(self, *args, **kwargs): + """ + To support benchmarking in the presence of mutated args, we need to avoid + autotuning contanminating them. We try to pass cloned args to the kernel. + If those clones would increase the peak memory usage, however, we instead + copy to cpu and restore them after each iteration. Figure out the args + to be copied and do the copying. + """ + if not self.optimize_mem: + return {} + + copies = {} + budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated() + + def maybe_copy(name, arg): + if name in self.mutated_arg_names and arg.is_cuda: + nonlocal budget + assert isinstance(arg, torch.Tensor) + required_storage_length = compute_required_storage_length( + arg.size(), + arg.stride(), + 0, + ) + size = required_storage_length * arg.element_size() + if size > budget: + cpu_arg = torch.empty_strided( + (required_storage_length,), + (1,), + dtype=arg.dtype, + device="cpu", + pin_memory=True, + ) + cpu_arg.copy_( + arg.as_strided((required_storage_length,), (1,)), + non_blocking=True, + ) + copies[name] = (arg, cpu_arg) + else: + budget -= size + + for name, arg in zip(self.fn.arg_names, args): + maybe_copy(name, arg) + + for name, arg in kwargs.items(): + maybe_copy(name, arg) + + return copies + + def restore_args_from_cpu(self, cpu_copies): + for pair in cpu_copies.values(): + arg, cpu_arg = pair + required_storage_length = compute_required_storage_length( + arg.size(), + arg.stride(), + 0, + ) + arg.as_strided((required_storage_length,), (1,)).copy_( + cpu_arg, non_blocking=True + ) + + def reset_to_zero_args(self, *args, **kwargs): + if not self.reset_to_zero_arg_names: + return + for i, arg in enumerate(args): + if self.fn.arg_names[i] in self.reset_to_zero_arg_names: + assert isinstance( + arg, + torch.Tensor, + ), ( + "self.reset_to_zero_arg_names should only contain valid argument names" + ) + arg.zero_() + + for name, arg in kwargs.items(): + if name in self.reset_to_zero_arg_names: + assert isinstance( + arg, + torch.Tensor, + ), ( + "self.reset_to_zero_arg_names should only contain valid argument names" + ) + arg.zero_() + + def maybe_clone_args( + self, exclude: Container[str], *args, **kwargs + ) -> tuple[list[Any], dict[str, Any]]: + """ + Prepare new args and kwargs by cloning any in-place buffers + (that are not in the provided exclusion list), to avoid autotune + contaminating them. Avoid cloning the other buffers because it + leads to increased memory usage. + """ + from ..compile_fx import clone_preserve_strides + + def prepare_arg(name, arg): + if name in self.mutated_arg_names and name not in exclude: + assert isinstance(arg, torch.Tensor) + return clone_preserve_strides(arg) + else: + return arg + + cloned_args = [ + prepare_arg(name, arg) + for name, arg in itertools.zip_longest(self.fn.arg_names[: len(args)], args) + ] + cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()} + return cloned_args, cloned_kwargs + + def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]: + return self.maybe_clone_args(OrderedSet(), *args, **kwargs) + + def benchmark_all_configs(self, *args, **kwargs): + with ( + dynamo_timed( + "CachingAutotuner.benchmark_all_configs", + log_pt2_compile_event=True, + metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, + dynamo_compile_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, + log_waitcounter=True, + waitcounter_name_override="triton_autotuner", + ), + compilation_callback.callback_handler.install_callbacks( + compilation_callback.CallbackTrigger.TRITON_AUTOTUNING, + str(self.compile_id), + ), + ): + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + + self.reset_to_zero_args(*args, **kwargs) + return timings + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + start_time = time.time_ns() + timings = self.benchmark_all_configs(*args, **kwargs) + benchmark_time_taken_ns = time.time_ns() - start_time + self.launchers = [builtins.min(timings, key=timings.get)] + self.autotune_time_taken_ns = ( + self.precompile_time_taken_ns + benchmark_time_taken_ns + ) + + # log the best config + launcher = self.launchers[0] + log.debug( + "Best config for %s: %s: %f, nreg %d, nspill %d, #shared-mem %s", + self.fn.__name__, + launcher.config, + timings[launcher], + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + + if self.save_cache_hook: + self.save_cache_hook( + launcher.config, + self.autotune_time_taken_ns, + triton_cache_hash=launcher.cache_hash, + ) + + def save_gpu_kernel(self, stream, launcher): + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + assert key is not None, "kernel_name can not be None" + params = { + "mangled_name": ( + launcher.bin.metadata.name + if hasattr(launcher.bin.metadata, "name") + else launcher.bin.metadata["name"] + ), + "num_warps": ( + launcher.bin.num_warps + if hasattr(launcher.bin, "num_warps") + else launcher.bin.metadata.num_warps + ), + "shared_mem": ( + launcher.bin.shared + if hasattr(launcher.bin, "shared") + else launcher.bin.metadata.shared + ), + "stream": stream, + # User defined triton kernels will have arbitrary kwarg names + "config": config_to_dict(launcher.config), + "inductor_meta": self.inductor_meta, + "triton_meta": self.triton_meta, + "def_args": launcher.def_args, + "call_args": launcher.call_args, + "global_scratch": launcher.global_scratch, + } + from torch._inductor.codecache import CudaKernelParamCache + + bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") + binary = launcher.bin.asm[bin_type] + # Also store asm code which can be used for debugging and generating cpp package + asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get( + self.device_props.type, None + ) + asm = launcher.bin.asm.get(asm_type, None) + + CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type) + self.cuda_kernel_saved = True + + def coordinate_descent_tuning(self, launcher, *args, **kwargs): + """ + Coordinate descent tuning can be run with or without max-autotune. + + The only difference between these two is the starting config for coordinate_descent tuning. + E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4 + and max-autotune figure out C3 is the best. + + Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1; + while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. + """ + if ( + self.heuristic_type == HeuristicType.TEMPLATE + or self.heuristic_type == HeuristicType.USER_AUTOTUNE + ): + # skip triton template + return launcher + + config2launcher = {launcher.config: launcher} + + # TODO: should we just load the kernels ahead of time if we know we're going to call this? + if self.fn.fn is None: + """ + We are in the parent process, while this program was compiled in a worker + and the fn was dropped in prepare_for_pickle(). We haven't loaded the module + containing the real fn yet. + """ + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn + + def benchmark_one_config(config): + with self.lock: + launcher = self._precompile_config(config).make_launcher() + config2launcher[config] = launcher + + out = self.bench(launcher, *args, **kwargs) + log.debug( + "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d", + launcher.config, + out, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + return out + + assert not ( + self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION + and "R0_BLOCK" in launcher.config.kwargs + ), ( + "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK" + ) + start_time = time.time_ns() + best_config = self.coordesc_tuner.autotune( + benchmark_one_config, launcher.config, None + ) + coordesc_time_taken_ns = time.time_ns() - start_time + best_config.found_by_coordesc = True + + if self.save_cache_hook: + self.save_cache_hook( + best_config, + self.autotune_time_taken_ns + coordesc_time_taken_ns, + found_by_coordesc=True, + ) + + if best_config not in config2launcher: + # On a Coordesc cache hit, we might not have loaded the launcher + # This can happen because PyCodeCache saves CachingAutotuners in memory, + # even for separate compile IDs (which can have different inputs without changing output code) + config2launcher[best_config] = self._precompile_config( + best_config + ).make_launcher() + return config2launcher[best_config] + + def run( + self, + *args, + stream, + benchmark_run=False, + **kwargs, + ): # type:ignore[override] + if hasattr(triton, "set_allocator"): + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty( + size, dtype=torch.int8, device=self.device_props.type + ) + + triton.set_allocator(alloc_fn) + + if self.triton_interpret: + args, grid = self._interpret_args_grid(args, self.configs[0]) + return self.fn[grid]( + *args, + **kwargs, + **self.configs[0].kwargs, + ) + + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, **kwargs) + + if not getattr( + self.launchers[0].config, "found_by_coordesc", False + ) and self.inductor_meta.get("coordinate_descent_tuning", False): + self.launchers = [ + self.coordinate_descent_tuning(self.launchers[0], *args, **kwargs) + ] + + (launcher,) = self.launchers + if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): + self.save_gpu_kernel(stream, launcher) + + args = self._get_args_with_constexprs(args, launcher) + + if self.dump_launch_params: + new_args, grid = self._interpret_args_grid(args, launcher.config) + _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) + + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + kernel_kwargs_str = ",".join( + f"{k}={v}" for (k, v) in launcher.config.kwargs.items() + ) + + profiler_kwargs = { + "kernel_file": (self.filename or ""), + "kernel_hash": self.kernel_hash, + "kernel_backend": "triton", + "stream": stream, + "num_warps": launcher.config.num_warps, + "num_stages": launcher.config.num_stages, + "kernel_kwargs": kernel_kwargs_str, + } + + with torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + args, + profiler_kwargs, + ): + return launcher( + *args, + **kwargs, + stream=stream, + ) + else: + return launcher( + *args, + **kwargs, + stream=stream, + ) + + def _interpret_args_grid( + self, args: tuple[Any, ...], cfg: Config + ) -> tuple[tuple[Any, ...], tuple[int, int, int]]: + grid = GridExpr.from_meta(self.inductor_meta, cfg).eval_slow( + dict( + zip( + [ + *self.triton_meta["signature"].keys(), + *self.inductor_meta.get("extra_launcher_args", ()), + ], + args, + ) + ) + ) + if self.inductor_meta.get("extra_launcher_args"): + args = args[: -len(self.inductor_meta["extra_launcher_args"])] + return args, grid + + +class _ConstRepr: + def __init__(self, value: str): + self.value = value + + def __call__(self, _=None) -> str: + return self.value + + +class CompileResult(Generic[_T]): + def __init__( + self, + kernel: _T, + config: Config, + compile_meta: dict[str, Any], + inductor_meta: dict[str, Any], + ): + self.kernel = kernel + self.config = config + self.compile_meta = compile_meta + self.inductor_meta = inductor_meta + + def make_launcher(self) -> LauncherType: ... + + def _gen_launcher_code(self, scope, def_args, runner_args) -> LauncherType: + grid = GridExpr.from_meta(self.inductor_meta, self.config) + # grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)` + lines = [ + f"def launcher({', '.join(def_args)}, stream):", + *[f" {line}" for line in grid.prefix], + f" grid_0 = {grid.x_grid}", + f" grid_1 = {grid.y_grid}", + f" grid_2 = {grid.z_grid}", + f" runner({', '.join(runner_args)})", + ] + launcher_code = "\n".join(lines) + exec(launcher_code, scope) + return scope["launcher"] + + def _get_arg_lists( + self, arg_names, constexprs + ) -> tuple[list[str], list[str], OrderedSet[str]]: + """ + Return a bunch of intermediate lists of args needed for generating + launcher code. + """ + compile_meta = self.compile_meta + cfg = self.config + known_constants = OrderedSet( + arg for i, arg in enumerate(arg_names) if i in constexprs + ) + + """ + https://github.com/pytorch/pytorch/issues/115344 + + self.fn.constexprs doesn't properly deal with None args, so when we filter out + an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well. + We also don't want to modify self.fn. + + We know that we removed something from the signature if: + 1. It's in compile_meta["constants"] + 2. It isn't a constant we already know about + Note: The value of interest has already been added to compile_meta['constants'], + so we use self.fn.constexprs instead. + 3. It isn't in the compile_meta signature + """ + none_args = OrderedSet( + k + for k, v in compile_meta["constants"].items() + if v is None and k not in known_constants + ) + none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys())) + + if triton_version_uses_attrs_dict(): + call_args = arg_names + def_args = arg_names + if ( + "num_warps" in compile_meta["constants"] + or "num_stages" in compile_meta["constants"] + ): + # num_warps/num_stages are special implicit args that are not in the signature + # see test_triton_kernel_special_params + def_args = [ + arg for arg in def_args if arg not in ("num_warps", "num_stages") + ] + repl = { + k: str(compile_meta["constants"].get(k)) + for k in ("num_warps", "num_stages") + } + call_args = [repl.get(arg, arg) for arg in call_args] + else: + call_args = [ + arg + for i, arg in enumerate(arg_names) + if i not in constexprs and arg not in none_args + ] + cfg_dict = config_to_dict(cfg) + def_args = [ + name + for name in arg_names + if name not in cfg_dict and name not in none_args + ] + + if "extra_launcher_args" in self.inductor_meta: + def_args = [*def_args, *self.inductor_meta["extra_launcher_args"]] + + return call_args, def_args, none_args + + +class CannotStaticallyLaunchKernel(Exception): + pass + + +class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]): + """ + TritonCompileResult that uses StaticCudaLauncher, + which vastly simplifies the setup and metadata needed to be kept. + """ + + @staticmethod + def can_statically_launch( + kernel: CompiledKernel, + inductor_meta: dict[str, Any], + triton_meta: dict[str, Any], + heuristic_type: HeuristicType, + ) -> Optional[StaticallyLaunchedCudaKernel]: + if not torch._inductor.config.use_static_cuda_launcher: + return None + + def check_can_launch() -> StaticallyLaunchedCudaKernel: + if triton_meta.get("device_type", None) != "cuda": + # Only cuda kernels + raise CannotStaticallyLaunchKernel("Non-cuda device") + + if torch._inductor.config.cpp_wrapper: + # If we're running with cpp wrapper, it doesn't + # make sense to statically compile since everything + # is codegenned anyway + raise CannotStaticallyLaunchKernel("Cpp wrapper enabled") + + if ( + heuristic_type == HeuristicType.USER_AUTOTUNE + and not torch._inductor.config.static_launch_user_defined_triton_kernels + ): + # Don't support user defined triton kernels yet + raise CannotStaticallyLaunchKernel("User defined triton kernel") + + if inductor_meta.get("store_cubin", None): + # Requires storing the entire binary + raise CannotStaticallyLaunchKernel("store_cubin is enabled") + + cubin_location = os.path.join( + triton_cache_dir(triton_meta.get("device", 0)), + triton_hash_to_path_key(kernel.hash), + f"{kernel.src.fn.__name__}.cubin", + ) + + if not os.path.exists(cubin_location): + raise CannotStaticallyLaunchKernel( + f"Cubin path not found: {cubin_location}" + ) + + else: + kernel._cubin_path = cubin_location + + try: + static_kernel = StaticallyLaunchedCudaKernel(kernel) + except NotImplementedError as e: + raise CannotStaticallyLaunchKernel(f"NotImplemented: {str(e)}") from e + + return static_kernel + + try: + result = check_can_launch() + return result + except CannotStaticallyLaunchKernel as e: + log.info("Bypassing StaticallyLaunchedCudaKernel due to %s", str(e)) + if torch._inductor.config.strict_static_cuda_launcher: + raise e + return None + + def reload_cubin_path(self): + """ + When loading from cache on disk, we want to reload cubin + files from their appropriate location on disc. + """ + cubin_location = os.path.join( + triton_cache_dir(self.compile_meta.get("device", 0)), + triton_hash_to_path_key(self.kernel.hash), + f"{self.kernel.name}.cubin", + ) + if not os.path.exists(cubin_location): + if self.kernel.cubin_raw is not None: + # We saved the raw cubin, so write it to he appropriate location + self.kernel.reload_cubin_from_raw(cubin_location) + else: + raise RuntimeError( + "Cubin file saved by TritonBundler not found at %s", cubin_location + ) + self.kernel.cubin_path = cubin_location + + def make_launcher(self) -> LauncherType: + # If at least one static make_launcher call occurs, + # we're sure static cuda launcher was used for this compile + set_feature_use("static_cuda_launcher", True) + # Load the binary on the parent + if not self.kernel.cubin_path: + self.reload_cubin_path() + device = self.compile_meta.get("device", 0) + if device is None: + device = 0 + self.kernel.load_kernel(device) + scope = { + "runner": self.kernel.run, + } + + # NOTE: Constexpr handling for triton and static cuda launcher + + # Triton kernels have two types of constexprs: *declared* ones, which are ones the user + # has explicitly declared as tl.constexpr, and *implied* ones, which are expressions triton + # deems constant while compiling/analyzing the code (i.e. unused parameters, for example) + + # Triton kernels handle constexprs slightly differently depending on which version of triton + # we care about (we support 3.2.0 and 3.3.0). + + # In 3.2.0, triton kernels do not require passing any declared constexprs into the kernel + # In 3.3.0, triton kernels require all declared constexprs be passed into the kernel, where + # they are subsequently ignored. + # When statically launching, since we're launching from the triton generated cubin, we actually want to + # always get rid of all const exprs, declared or implied, since the underlying cubin file has all + # of the constants stripped away anyway. + + # But CachingAutotuner.run will pass us a different number of arguments depending on + # whether or not we're in triton 3.2.0 or later, so we grab def_args with the same logic + # as the (non static) TritonCompileResult. We then generate call_args ourselves, since we + # want only a subset of the arguments passed to triton. + # Here, arg_names is exactly fn.src.arg_names and declared_constexprs is exactly fn.src.constexprs, + # which matches behavior with regular TritonCompileResult + _, def_args, none_args = self._get_arg_lists( + self.kernel.arg_names, self.kernel.declared_constexprs + ) + + call_args = [ + arg + for i, arg in enumerate(self.kernel.arg_names) + if i not in self.kernel.full_constexprs and arg not in none_args + ] + + # StaticallyLaunchedCudaKernel.run takes in order grid_0, grid_1, grid_2, stream, and call_args + runner_args = ["grid_0", "grid_1", "grid_2", "stream", *call_args] + launcher = self._gen_launcher_code(scope, def_args, runner_args) + launcher.config = self.config # type: ignore[attr-defined] + launcher.n_regs = self.kernel.n_regs # type: ignore[attr-defined] + launcher.n_spills = self.kernel.n_spills # type: ignore[attr-defined] + launcher.shared = self.kernel.shared # type: ignore[attr-defined] + launcher.cache_hash = triton_hash_to_path_key(self.kernel.hash) # type: ignore[attr-defined] + launcher.store_cubin = False # type: ignore[attr-defined] + launcher._is_static = True # type: ignore[attr-defined] + return launcher + + +class TritonCompileResult(CompileResult[CompiledKernel]): + """ + Upstream Triton CompileKernel can not be pickled. This is a wrapper + to support serialization and generate the launcher function. + """ + + @staticmethod + @functools.lru_cache(32) + def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any: + return namedtuple("KernelMetadata", sorted(fields)) + + @staticmethod + def _serialize_metadata(metadata): + """ + Triton uses a nested class called KernelMetadata to store metadata information. + Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear + in the toplevel namespace of the module. So these serialization/deser functions + are used to convert the namedtuples to a dict and back. + + As for packed_metadata, depending on the triton backend, KernelMetadata can be + a namedtuple, or a regular tuple! So the serialization function branches on whether + the metadata to be serialized is a namedtuple or regular, serializable one. + """ + + def is_namedtuple(obj) -> bool: + return ( + isinstance(obj, tuple) + and hasattr(obj, "_asdict") + and hasattr(obj, "_fields") + ) + + if is_namedtuple(metadata): + return metadata._asdict() + else: + return metadata + + @staticmethod + def _deserialize_metadata(metadata): + if isinstance(metadata, dict): + return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))( + **metadata + ) + else: + return metadata + + def __getstate__(self) -> dict[str, Any]: + kernel = self.kernel + # replace the fields that don't pickle nicely + kernel_state = { + **kernel.__dict__, + # See doc about serializing metadata above + "metadata": self._serialize_metadata(kernel.metadata), + "packed_metadata": self._serialize_metadata( + getattr(kernel, "packed_metadata", None) + ), + "module": None, # regenerated by kernel._init_handles() + "function": None, # regenerated by kernel._init_handles() + "run": None, # regenerated by kernel._init_handles() + } + return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item] + + def __setstate__(self, state: dict[str, Any]) -> None: + # src = ASTSource.__new__(ASTSource) + # src.__setstate__(state["kernel"]["src"]) + # TODO(jansel): need to fixup src.fn which is now None + kernel = CompiledKernel.__new__(CompiledKernel) + metadata = state["kernel"]["metadata"] + packed_metadata = state["kernel"]["packed_metadata"] + kernel.__dict__.update( + { + **state["kernel"], + # "src": src, + "metadata": self._deserialize_metadata(metadata), + "packed_metadata": self._deserialize_metadata(packed_metadata), + } + ) + self.__dict__.update(state) + self.kernel = kernel + + def make_launcher(self) -> LauncherType: + """ + Launching triton kernels is performance sensitive, we compile + a custom Python function get the grid() and reorder the args to + the underlying wrapper. + """ + cfg = self.config + compile_meta = self.compile_meta + binary = self.kernel + fn = binary.src.fn + binary._init_handles() + (call_args, def_args, none_args) = self._get_arg_lists( + fn.arg_names, fn.constexprs + ) + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + + if knobs is None: + launch_enter = binary.__class__.launch_enter_hook + launch_exit = binary.__class__.launch_exit_hook + else: + launch_enter = knobs.runtime.launch_enter_hook + launch_exit = knobs.runtime.launch_exit_hook + + import math as math_lib + + import torch as torch_lib + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": launch_enter, + "launch_exit_hook": launch_exit, + "metadata": ( + binary.packed_metadata + if hasattr(binary, "packed_metadata") + else binary.metadata + ), + "shared": binary_shared, + "num_warps": ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ), + "cta_args": ( + ( + binary.num_ctas, + *get_first_attr(binary, "cluster_dims", "clusterDims"), + ) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ), + "function": get_first_attr(binary, "function", "cu_function"), + "runner": get_first_attr(binary, "run", "c_wrapper"), + "math": math_lib, + "torch": torch_lib, + } + + if not hasattr(binary, "launch_metadata"): + # launch args before CompiledKernel.launch_metadata is added. + # TODO(jansel): delete this branch in mid-2025 + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "num_warps", + "*cta_args", + "shared", + "stream", + "function", + "launch_enter_hook", + "launch_exit_hook", + "metadata", + *call_args, + ] + else: # args after CompiledKernel.launch_metadata: https://github.com/triton-lang/triton/pull/3492 + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + # See https://github.com/pytorch/pytorch/issues/123597 + if launch_enter: + launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})" + else: + launch_metadata = "None" + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "stream", + "function", + "metadata", + launch_metadata, + "launch_enter_hook", + "launch_exit_hook", + *call_args, + ] + + launcher = self._gen_launcher_code(scope, def_args, runner_args) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.cache_hash = triton_hash_to_path_key(binary.hash) + launcher.store_cubin = self.inductor_meta.get("store_cubin", False) + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = fn + launcher.bin = binary + if triton_version_uses_attrs_dict(): + # arg filtering wasn't done above + cfg_dict = config_to_dict(cfg) + def_args = [x for x in def_args if x not in cfg_dict] + call_args = [ + x + for x in call_args + if compile_meta["signature"].get(x, "constexpr") != "constexpr" + and x not in none_args + ] + launcher.def_args = def_args + launcher.call_args = call_args + kernel_metadata = getattr(self.kernel, "metadata", None) + launcher.global_scratch = getattr( + kernel_metadata, "global_scratch_size", None + ) + return launcher + + +def _find_names(obj): + import gc + import inspect + + frame = inspect.currentframe() + while frame is not None: + frame.f_locals + frame = frame.f_back + obj_names = [] + for referrer in gc.get_referrers(obj): + if isinstance(referrer, dict): + for k, v in referrer.items(): + if v is obj: + obj_names.append(k) + return obj_names + + +collected_calls: list[Any] = [] + + +def start_graph(): + collected_calls.clear() + + +def end_graph(output_file): + if len(collected_calls) == 0: + return + overall_time = sum(call[0] for call in collected_calls) + overall_gb = sum(call[1] for call in collected_calls) + cur_file = inspect.stack()[1].filename + summary_str = ( + f"SUMMARY ({cur_file})\n" + f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s" + ) + log.info( + "%s", + summary_str, + ) + if output_file is not None: + # sort perf numbers in descending order, i.e. placing the + # most runtime-heavy kernels at the top of the list + sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True) + try: + with open(output_file, "a") as file: + log.info( + "Save profile bandwidth results to %s", + output_file, + ) + file.write("====================\n") + file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") + for ms, num_gb, gb_per_s, kernel_name in sorted_calls: + # also display the runtime percentage for each kernel + percentage = f"{ms / overall_time * 100:.2f}%" + suffix = f" \t {percentage} \t {kernel_name}" + bw_info_str = create_bandwidth_info_str( + ms, + num_gb, + gb_per_s, + suffix=suffix, + color=False, + ) + file.write(bw_info_str + "\n") + file.write(f"{summary_str}\n\n") + except Exception as e: + log.warning( + "failed to write profile bandwidth result into %s: %s", + output_file, + e, + ) + + +class DebugAutotuner(CachingAutotuner): + def __init__( + self, + *args, + regex_filter="", + with_profiler=False, + with_bandwidth_info=True, + **kwargs, + ): + self.regex_filter = regex_filter + self.with_profiler = with_profiler + self.with_bandwidth_info = with_bandwidth_info + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, stream, **kwargs): + if not self.with_bandwidth_info: + super().run(*args, stream=stream, **kwargs, benchmark_run=True) + return + else: + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, **kwargs) + (launcher,) = self.launchers + + if launcher.store_cubin: + self.save_gpu_kernel(stream, launcher) + + if self.cached is None: + ms = self.bench(launcher, *args, with_profiler=self.with_profiler) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = self.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = ms, num_gb, gb_per_s, kernel_name + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + log.info( + "%s", + create_bandwidth_info_str( + ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}" + ), + ) + else: + # in AOTI, we will call the kernel and its timing info has been cached already + collected_calls.append(self.cached) + + +def hash_configs(configs: list[Config]): + """ + Hash used to check for changes in configurations + """ + hasher = hashlib.sha256() + for cfg in configs: + hasher.update( + f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode() + ) + return hasher.hexdigest() + + +def cached_autotune( + size_hints: Optional[list[int]], + configs: list[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + assert len(configs) == 1 or filename + inductor_meta = {} if inductor_meta is None else inductor_meta + + configs, autotune_cache, autotune_cache_info = check_autotune_cache( + configs, filename, inductor_meta + ) + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + optimize_mem = inductor_meta.pop("optimize_mem", True) + + if "restore_value" in triton_meta: + mutated_arg_names += triton_meta.pop("restore_value") + + reset_to_zero_arg_names: list[str] = [] + if "reset_to_zero" in triton_meta: + reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero")) + + def decorator(fn): + # Remove XBLOCK from config if it's not a function argument. + # This way, coordinate descent tuning will not try to tune it. + # + # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. + import inspect + + if "XBLOCK" not in inspect.signature(fn.fn).parameters: + for tconfig in configs: + if "XBLOCK" in tconfig.kwargs: + assert tconfig.kwargs["XBLOCK"] == 1 + tconfig.kwargs.pop("XBLOCK") + + if inductor_meta.get("profile_bandwidth"): + return DebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=inductor_meta["profile_bandwidth_regex"], + with_profiler=inductor_meta[ + "profile_bandwidth_with_do_bench_using_profiling" + ], + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + with_bandwidth_info=True, + ) + return CachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + autotune_cache_info=autotune_cache_info, + ) + + return decorator + + +def unique_configs(configs: list[Config]): + """Remove duplicate configurations""" + seen: OrderedSet[Hashable] = OrderedSet() + pruned_configs = [] + + for cfg in configs: + key = triton_config_to_hashable(cfg) + if key not in seen: + seen.add(key) + pruned_configs.append(cfg) + return pruned_configs + + +def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None): + for numel, label in zip((xnumel, ynumel, znumel), "XYZ"): + if numel is None: + continue + block = cfg[f"{label}BLOCK"] + if numel == 1: + assert block == 1, ( + f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1" + f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})." + ) + max_block = TRITON_MAX_BLOCK[label] + max_block_str = f'config.triton.max_block["{label}"]' + assert max_block % block == 0, ( + f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}" + f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})." + ) + + +def check_max_block(cfg: dict[str, int]): + """ + Check that block sizes are within the maximum allowed. + """ + for var, val in cfg.items(): + block_suffix = "BLOCK" + if block_suffix in var: + prefix = var.removesuffix(block_suffix) + max_block = TRITON_MAX_BLOCK[prefix] + assert val <= max_block, ( + f"'{var}' too large. Maximum: {max_block}. Actual: {val}." + ) + + +def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False): + # On AMD GPU each warp has 64 lanes which is double the size on NV GPU, + # therefore using half the number of warps here correspondingly. + if torch.version.hip: + max_num_warps = (max_num_warps + 1) // 2 + min_num_warps = (min_num_warps + 1) // 2 + # persistent reduction is register intensive + if register_intensive: + max_num_warps = max_num_warps // 2 + return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps)) + + +def _check_max_grid_x(size_hints, x, num_warps): + # Check if maxGridSize is exceeded - if so then must scale XBLOCK further + max_grid_x = 2147483647 + warp_size = ( + 64 if torch.version.hip else 32 + ) # TODO: query warp size once #129663 is merged + num_blocks = (size_hints["x"] + x - 1) // x + + while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints["x"]: + x *= 2 # Scale up XBLOCK if grid exceeds limits + num_blocks = num_blocks // 2 + if (num_blocks * num_warps * warp_size) > max_grid_x: + raise AssertionError( + "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" + ) + return x, num_blocks + + +def triton_config( + size_hints, + x, + y=None, + z=None, + num_stages=1, + num_elements_per_warp=256, + min_elem_per_thread=0, +) -> Config: + """ + Construct a pointwise triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + + num_elements_per_warp is a suggestion for controlling how many warps + the triton config should contain. e.g.: if x=16, y=8, z=4 then + num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128, + we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's + just a suggestion, and sometimes other adjustment heuristics will + override the num_elements_per_warp. + + min_elem_per_thread controls the minimum number of elements + processed by each thread. It's always enforced. + """ + # Ideally we want to read this from some device config + + maxGridSize = [2147483647, 65535, 65535] + + target = conditional_product(x, y, z) + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + if y: + y = min(y, size_hints["y"]) + if z: + z = min(z, size_hints["z"]) + + # if we are below original block size, scale up where we can; + # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension + while x < min(size_hints["x"], TRITON_MAX_BLOCK["X"]) and ( + x * maxGridSize[0] < size_hints["x"] or conditional_product(x, y, z) < target + ): + x *= 2 + while ( + y + and y < min(size_hints["y"], TRITON_MAX_BLOCK["Y"]) + and ( + y * maxGridSize[1] < size_hints["y"] + or conditional_product(x, y, z) < target + ) + ): + y *= 2 + while ( + z + and z < min(size_hints["z"], TRITON_MAX_BLOCK["Z"]) + and ( + z * maxGridSize[2] < size_hints["z"] + or conditional_product(x, y, z) < target + ) + ): + z *= 2 + + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) + # we are going to arrive at 2 warps only if bs was too small due to + # numel being too small. However to workaround some ptx bugs we still + # want at least 4 warps if there's enough elements per thread + # given that this is a rare situation, don't expect this to affect perf + # in general + # see https://github.com/pytorch/pytorch/pull/97950 + if conditional_product(x, y, z) >= 128 and not torch.version.hip: + num_warps = max(num_warps, 4) + xnumel = size_hints["x"] + ynumel = size_hints.get("y") + znumel = size_hints.get("z") + + # Increase x to satisfy min_elem_per_thread requirements. + block_size = max( + conditional_product(x, y, z), + min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps, + ) + x *= math.ceil(block_size / conditional_product(x, y, z)) + + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + x = min(x, size_hints["x"]) + + cfg = {"XBLOCK": x} + if y: + cfg["YBLOCK"] = y + if z: + cfg["ZBLOCK"] = z + check_max_block(cfg) + check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: + """ + Converts a linear reduction numel to ND, in row major order. + This order is often desirable as it presents opportunities to coalesce memory + accesses. + For example, if r = 64 and size_hints = [32,32], this function returns [32, 2]. + This unraveling works because both r and size_hints are powers of 2. + """ + # Shrink r to size_hints. + r = min(r, get_total_reduction_numel(size_hints)) + num_reduction_dims = len( + [prefix for prefix in size_hints if prefix_is_reduction(prefix)] + ) + + remaining = r + rnumels = {} + for idx in range(num_reduction_dims - 1, -1, -1): + prefix = f"r{idx}_" + max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()]) + dim = min(max_size, remaining) + assert remaining % dim == 0, ( + f"Expected dimension '{dim}' to divide remaining size '{remaining}'" + ) + rnumels[prefix] = dim + remaining //= dim + + # Sanity check the results. + final_numel = conditional_product(*rnumels.values()) + assert r == final_numel, ( + f"Expected ND reduction size ({rnumels}) to have {r} elements." + ) + assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), ( + f"rnumels exceed size_hints. {rnumels} > {size_hints}" + ) + + return rnumels + + +def triton_config_reduction( + size_hints, + x: int, + r: int, + num_stages=1, + num_warps=None, + register_intensive=False, +) -> Config: + """ + Construct a reduction triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + """ + # Convert the linear reduction numel into a multi-dimensional block. + rnumels = _get_nd_reduction_numels(r, size_hints) + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + + def total_numel() -> int: + return conditional_product(x, *rnumels.values()) + + target = total_numel() + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # if we are below original block size, scale up where we can + while x < size_hints["x"] and total_numel() < target: + x *= 2 + for prefix in sorted(rnumels): + while rnumels[prefix] < size_hints[prefix] and total_numel() < target: + rnumels[prefix] *= 2 + + if num_warps is None: + num_warps = total_numel() // 128 + num_warps = _num_warps( + num_warps, max_num_warps=16, register_intensive=register_intensive + ) + + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + + for prefix in sorted(rnumels): + while total_numel() > target: + if rnumels[prefix] == 1: + break + rnumels[prefix] //= 2 + + cfg = _get_config({"x": x, **rnumels}) + check_max_block(cfg) + check_config(cfg, xnumel=size_hints["x"]) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def _get_config(numels: dict[str, int]) -> dict[str, int]: + """ + Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc. + """ + + return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()} + + +def triton_config_tiled_reduction( + size_hints, x, y, r, num_stages=1, register_intensive=False +): + """ + Construct a tile reduction triton config with some adjustment + heuristics based on size_hints. Size_hints is a tuple of numels in + each tile dimension and will be rounded up to the nearest power of 2. + """ + # Convert the linear reduction numel into a multi-dimensional block. + rnumels = _get_nd_reduction_numels(r, size_hints) + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + y = min(y, size_hints["y"]) + + def total_numel() -> int: + return conditional_product(x, y, *rnumels.values()) + + target = total_numel() + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # if we are below original block size, scale up where we can + while x < size_hints["x"] and total_numel() < target: + x *= 2 + for prefix in sorted(rnumels): + while rnumels[prefix] < size_hints[prefix] and total_numel() < target: + rnumels[prefix] *= 2 + while y < size_hints["y"] and total_numel() < target: + y *= 2 + + cfg = _get_config({"x": x, "y": y, **rnumels}) + num_warps = _num_warps(total_numel() // 256, min_num_warps=1) + num_warps = _num_warps( + num_warps, max_num_warps=16, register_intensive=register_intensive + ) + check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) + check_max_block(cfg) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def pointwise( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + """ + Construct @triton.heuristics() based on size_hints. + """ + inductor_meta = {} if inductor_meta is None else inductor_meta + assert not inductor_meta.get("no_x_dim") + + numel = functools.reduce(operator.mul, size_hints.values()) + bs = max(256, min(numel // 128, 1024)) + + hinted_configs = autotune_hints_to_configs( + inductor_meta.get("autotune_hints", OrderedSet()), + size_hints, + bs, + triton_meta["device"], + ) + + triton_config_with_settings = functools.partial( + triton_config, min_elem_per_thread=min_elem_per_thread + ) + + configs = None + if len(size_hints) == 1: + if disable_pointwise_autotuning(inductor_meta) and not ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ): + configs = [triton_config_with_settings(size_hints, bs)] + else: + configs = [ + triton_config_with_settings(size_hints, bs, num_elements_per_warp=256), + triton_config_with_settings( + size_hints, bs // 2, num_elements_per_warp=64 + ), + *hinted_configs, + ] + if len(size_hints) == 2: + if ( + disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + ) and not ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ): + configs = [triton_config_with_settings(size_hints, 32, 32)] + else: + configs = [ + triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 + triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings(size_hints, bs, 1), + triton_config_with_settings(size_hints, 1, bs), + *hinted_configs, + ] + if len(size_hints) == 3: + if disable_pointwise_autotuning(inductor_meta): + configs = [triton_config_with_settings(size_hints, 16, 16, 16)] + else: + configs = [ + triton_config_with_settings(size_hints, 16, 16, 16), + triton_config_with_settings(size_hints, 64, 8, 8), + triton_config_with_settings(size_hints, 8, 64, 8), + triton_config_with_settings(size_hints, 8, 8, 64), + triton_config_with_settings(size_hints, bs, 1, 1), + triton_config_with_settings(size_hints, 1, bs, 1), + triton_config_with_settings(size_hints, 1, 1, bs), + *hinted_configs, + ] + + if not configs: + raise NotImplementedError(f"size_hints: {size_hints}") + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + + +def _reduction_configs( + *, size_hints: dict[str, int], inductor_meta: dict[str, Any] +) -> list[Config]: + reduction_hint = inductor_meta.get("reduction_hint", None) + + # Convert reductions to 1D, to simplify heuristics. + rnumel = get_total_reduction_numel(size_hints) + + register_intensive = False + MAX_R0_BLOCK = 2048 + if ( + size_hints["x"] >= 1024 + and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0) + >= 10 + ): + # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers. + # Consider load and reduction since load need move data into registers and + # reduction needs an accumulator. + # + # The magic numbers are a bit arbitrary. + # + # We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes + # triton makes it to use less registers with worse perf. Check: + # https://github.com/pytorch/pytorch/issues/126463 + # + # The heuristic is a very simple one since registers can be reused. But + # hopefully it can be a good enough indicator. + MAX_R0_BLOCK = 1024 + register_intensive = True + + def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): + # For 3D case with tiling scores, create an adapted version + if "y" in size_hints: + assert "tiling_scores" in inductor_meta + return adapt_config_for_tiling( + size_hints, + inductor_meta["tiling_scores"], + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + ) + else: + # For other cases, use the original function + return triton_config_reduction( + size_hints, + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + ) + + contiguous_config = make_config( + 1, + min(rnumel, MAX_R0_BLOCK), + register_intensive=register_intensive, + ) + outer_config = make_config(64, 8, register_intensive=register_intensive) + tiny_config = make_config( + 2 * (256 // rnumel) if rnumel <= 256 else 1, + min(rnumel, MAX_R0_BLOCK), + register_intensive=register_intensive, + ) + # For 3d tiling, default to more autotuning initially + if "y" in size_hints: + pass + elif inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ): + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER: + return [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + return [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + return [tiny_config] + if disable_pointwise_autotuning(inductor_meta): + return [make_config(32, 128)] + return [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + +def match_target_block_product( + size_hints, tiling_scores, target_block_product, min_block_size=1 +): + """ + Distribute block sizes across dimensions according to tiling scores, + aiming to match a target product of block sizes. + """ + total_score = sum(tiling_scores.values()) + if total_score == 0: + # just assume even score with no minimum block size + min_block_size = 1 + tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product) + + # First, give each coalescing dimension at least min_block_size + block_sizes = {} + relative_scores = {} + curr_block_product = 1 + + for dim, score in tiling_scores.items(): + if score == 0: + block_sizes[dim] = 1 + continue + + block_sizes[dim] = min_block_size + curr_block_product *= min_block_size + relative_scores[dim] = score / total_score + + # Scale up dimensions by their relative scores until we reach the target + while curr_block_product < target_block_product and len(relative_scores): + dim, score = max(relative_scores.items(), key=lambda item: item[1]) + + # Check if we've hit the max for this dimension + if ( + block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()] + or block_sizes[dim] >= size_hints[dim] + ): + del relative_scores[dim] + continue + + block_sizes[dim] *= 2 + relative_scores[dim] /= 2 + curr_block_product *= 2 + + return block_sizes + + +def adapt_config_for_tiling( + size_hints, + tiling_scores, + original_x, + original_r, + num_warps=None, + num_stages=1, + register_intensive=False, + persistent_reduction=False, +) -> Config: + """ + Create an adapted configuration based on tiling scores, + redistributing the same total block size (x * r) according to tiling scores. + """ + assert all(s in tiling_scores for s in size_hints) + target_block_product = original_x * original_r + block_sizes = match_target_block_product( + size_hints, tiling_scores, target_block_product + ) + + return triton_config_tiled_reduction( + size_hints, + block_sizes["x"], + block_sizes["y"], + block_sizes["r0_"], + num_stages=num_stages, + register_intensive=register_intensive, + ) + + +def reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + assert triton_meta is not None + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def cooperative_reduction( + size_hints, + reduction_hint, + triton_meta, + filename, + inductor_meta, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + # Cooperative reductions currently only support a single reduction dimension. + assert len(size_hints) == 2, ( + "Cooperative reductions don't support tiling reduction dims" + ) + xnumel, rnumel = size_hints["x"], size_hints["r0_"] + + # TODO(jansel): we should base target on the SM count of the local GPU + target = 64 + split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT)) + assert rnumel >= split + assert split <= TRITON_MAX_RSPLIT + if inductor_meta["persistent_reduction"]: + configs = _persistent_reduction_configs( + {"x": xnumel, "r0_": rnumel // split}, reduction_hint, inductor_meta + ) + else: + configs = _reduction_configs( + size_hints={"x": xnumel, "r0_": rnumel // split}, + inductor_meta=inductor_meta, + ) + for config in configs: + config.kwargs["RSPLIT"] = split + # TODO(jansel): add more configs in max_autotune + + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def _persistent_reduction_configs( + size_hints, + reduction_hint=False, + inductor_meta=None, +): + xnumel = size_hints["x"] + rnumel = get_total_reduction_numel(size_hints) + + MAX_PERSISTENT_BLOCK_NUMEL = 4096 + + if "y" not in size_hints: + configs = [ + triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) + for xblock in (1, 8, 32, 128) + if xblock == 1 + or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) + ] + else: + configs = [] + assert "tiling_scores" in inductor_meta + x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} + for target_block_size in (1, 8, 32, 64, 128): + if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: + continue + + block_sizes = match_target_block_product( + size_hints, x_y_scores, target_block_size + ) + configs.append( + triton_config_tiled_reduction( + size_hints, block_sizes["x"], block_sizes["y"], rnumel + ) + ) + + # defer to more autotuning, initially + if "y" in size_hints: + pass + # TODO(jansel): we should be able to improve these heuristics + elif reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + for c in configs: + # we don't need Rn_BLOCK for persistent reduction + for prefix in size_hints: + if prefix_is_reduction(prefix): + c.kwargs.pop(f"{prefix.upper()}BLOCK") + + if disable_pointwise_autotuning(inductor_meta): + configs = configs[:1] + + return configs + + +def persistent_reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + configs = _persistent_reduction_configs(size_hints, reduction_hint, inductor_meta) + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def split_scan( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """Heuristic for TritonSplitScanKernel""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + assert triton_meta is not None + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + + # Fixup configs to enforce the minimum Rn_BLOCK size + min_rblock = inductor_meta.get("min_split_scan_rblock", 256) + for cfg in configs: + for var in list(cfg.kwargs.keys()): + if var.startswith("R") and cfg.kwargs[var] < min_rblock: + cfg.kwargs[var] = min_rblock + + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.SPLIT_SCAN, + filename=filename, + ) + + +def template( + num_stages, + num_warps, + triton_meta, + num_consumer_groups=0, + num_buffers_warp_spec=0, + filename=None, + inductor_meta=None, +): + """ + Compile a triton template + """ + # Prepare the base configuration + config_args = { + "num_stages": num_stages, + "num_warps": num_warps, + } + + # Conditionally add arguments based on HAS_WARP_SPEC + if HAS_WARP_SPEC: + config_args.update( + { + "num_consumer_groups": num_consumer_groups, + "num_buffers_warp_spec": num_buffers_warp_spec, + } + ) + return cached_autotune( + None, + [triton.Config({}, **config_args)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]: + """Extract triton.Config options that should become kwargs""" + popped = {} + for key in ( + "num_warps", + "num_stages", + "num_ctas", + "maxnreg", + "num_consumer_groups", + "num_buffers_warp_spec", + ): + val = config.pop(key, None) + if val is not None: + popped[key] = val + return popped + + +def config_to_dict(config: Config) -> dict[str, Any]: + config_dict = { + **config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + } + if HAS_WARP_SPEC: + config_dict.update( + { + "num_consumer_groups": getattr(config, "num_consumer_groups", 0), + "num_buffers_warp_spec": getattr(config, "num_buffers_warp_spec", 0), + } + ) + return config_dict + + +def config_from_dict(config: dict[str, Any]) -> Config: + config = {**config} + return Config(config, **_pop_config_kwargs(config)) + + +def fixed_config(config, filename, triton_meta, inductor_meta): + """ + Used when the configuration is already decided at compile time + """ + config = {**config} + return cached_autotune( + None, + [triton.Config(config, **_pop_config_kwargs(config))], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.FIXED, + filename=filename, + ) + + +def user_autotune( + configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False +): + """ + Compile a user defined triton kernel + """ + if len(configs) == 0: + configs = [triton.Config({})] + else: + configs = [*map(config_from_dict, configs)] + return cached_autotune( + None, + configs, + triton_meta=triton_meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + inductor_meta=inductor_meta, + custom_kernel=custom_kernel, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +@dataclasses.dataclass +class GridExpr: + """Generate code for grid size expressions in launcher""" + + inductor_meta: dict[str, Any] + mode: Literal["python", "cpp"] = "python" + prefix: list[str] = dataclasses.field(default_factory=list) + x_grid: Union[str, int] = 1 + y_grid: Union[str, int] = 1 + z_grid: Union[str, int] = 1 + + def __post_init__(self) -> None: + assert self.mode in ("python", "cpp") + + def generate(self, meta: dict[str, int]) -> None: + raise NotImplementedError + + def ceildiv( + self, numel: Union[str, int], block: Union[None, int, str] + ) -> Union[str, int]: + if block is None or block == 1: + return numel + if isinstance(numel, int) and isinstance(block, int): + return ceildiv(numel, block) # constant fold + if self.mode == "python": + return f"-(({numel}) // -({block}))" + # trick above doesn't work in C++ due to rounding differences + return f"(({numel} + ({block} - 1)) / ({block}))" + + def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]: + """Codegen for max function with constant folding, constants are represented as int""" + items = self._constant_fold(max, seq) + if len(items) <= 1: + return items[0] + if self.mode == "python": + return f"max({', '.join(map(str, items))})" + return functools.reduce(lambda x, y: f"std::max({x}, {y})", items) + + def summation(self, seq: list[Union[int, str]]) -> Union[int, str]: + """Codegen for sum function with constant folding, constants are represented as int""" + items = self._constant_fold(sum, seq) + if len(items) <= 1: + return items[0] + return " + ".join(map(str, items)) + + def _constant_fold( + self, fn: Callable[[list[int]], int], seq: list[Union[int, str]] + ) -> list[Union[int, str]]: + """Constant fold through a commutative fn where ints are constants""" + items: list[Union[int, str]] = [x for x in seq if not isinstance(x, int)] + const_items = [x for x in seq if isinstance(x, int)] + if const_items: + items.append(fn(const_items)) + return items + + def assign_tmp(self, name: str, expr: Union[str, int]) -> str: + # Grid functions are one per kernel, so name collisions are fine + if self.mode == "python": + return f"{name} = {expr}" + if self.mode == "cpp": + return f"uint32_t {name} = {expr};" + raise AssertionError(f"invalid mode {self.mode}") + + @staticmethod + def from_meta( + inductor_meta: dict[str, Any], + cfg: Union[Config, dict[str, int]], + mode: Literal["python", "cpp"] = "python", + ) -> GridExpr: + grid_cls = globals()[inductor_meta["grid_type"]] + assert issubclass(grid_cls, GridExpr) + grid = grid_cls(inductor_meta=inductor_meta, mode=mode) + if isinstance(cfg, Config): + cfg = config_to_dict(cfg) + grid.generate(cfg) + return grid + + def eval_slow(self, meta: dict[str, int]) -> tuple[int, int, int]: + scope = {**meta} + for line in self.prefix: + exec(line, scope) + exec(f"grid_0 = {self.x_grid}", scope) + exec(f"grid_1 = {self.y_grid}", scope) + exec(f"grid_2 = {self.z_grid}", scope) + return scope["grid_0"], scope["grid_1"], scope["grid_2"] + + +class Grid1D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + + +class Grid2D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK")) + + +class Grid3D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK")) + self.z_grid = self.ceildiv("znumel", meta.get("ZBLOCK")) + + +class Grid2DWithYZOverflow(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.prefix.extend( + [ + self.assign_tmp( + "y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK")) + ), + self.assign_tmp( + "y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid()) + ), + ] + ) + self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_") + self.z_grid = "y_grid_div_" + + +class CooperativeReductionGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = str(meta["RSPLIT"]) + self.y_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + + +class SplitScanGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + assert meta.get("XBLOCK", 1) == 1 + self.x_grid = self.ceildiv("r0_numel", meta.get("R0_BLOCK")) + self.y_grid = "xnumel" + + +class FixedGrid(GridExpr): + @staticmethod + def setup_grid_as_args() -> dict[str, Any]: + """Inductor meta so the launcher takes three extra grid arguments""" + return { + "grid_type": FixedGrid.__name__, + "fixed_grid": ["_grid_0", "_grid_1", "_grid_2"], + "extra_launcher_args": ["_grid_0", "_grid_1", "_grid_2"], + } + + def generate(self, meta: dict[str, int]) -> None: + self.x_grid, self.y_grid, self.z_grid = self.inductor_meta["fixed_grid"] + + +class PrecomputedGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + for candidate in self.inductor_meta["precomputed_grids"]: + if all(meta.get(k) == v for k, v in candidate["config"].items()): + self.x_grid, self.y_grid, self.z_grid = candidate[self.mode] + return + raise AssertionError( + f"Precomputed grid not found for {meta} in {self.inductor_meta['precomputed_grids']}" + ) + + +class ComboKernelGrid(GridExpr): + def generate(self, meta: dict[str, int]): + combo_meta = self.inductor_meta["combo_grid_meta"] + if combo_meta["default_config"]: + meta = {**combo_meta["default_config"], **meta} + no_x_dims = [] + xnumels = [] + ynumels = [] + znumels = [] + for num in range(combo_meta["num_kernels"]): + assert ( + combo_meta[f"xnumel_{num}"] is None or combo_meta[f"xnumel_{num}"] > 0 + ) + no_x_dims.append(combo_meta[f"no_x_dim_{num}"]) + xnumels.append(combo_meta[f"xnumel_{num}"] or f"xnumel_{num}") + if f"ynumel_{num}" in combo_meta: + ynumels.append(combo_meta[f"ynumel_{num}"] or f"ynumel_{num}") + if f"znumel_{num}" in combo_meta: + znumels.append(combo_meta[f"znumel_{num}"] or f"znumel_{num}") + + self.x_grid = self.combo_x_grid(xnumels, no_x_dims, meta) + if combo_meta["min_blocks"]: + self.x_grid = self.maximum([self.x_grid, combo_meta["min_blocks"]]) + if ynumels: + self.y_grid = self.ceildiv(self.maximum(ynumels), meta.get("YBLOCK")) + if znumels: + self.z_grid = self.ceildiv(self.maximum(znumels), meta.get("ZBLOCK")) + + def combo_x_grid( + self, + xnumels: list[Union[int, str]], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> Union[str, int]: + raise NotImplementedError + + +class SequentialComboKernelGrid(ComboKernelGrid): + def combo_x_grid( + self, + xnumels: list[Union[int, str]], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> Union[str, int]: + assert len(xnumels) == len(no_x_dims) + return self.summation( + [ + self.ceildiv(x, 1 if no_x_dim else meta.get("XBLOCK")) + for x, no_x_dim in zip(xnumels, no_x_dims) + ] + ) + + +class RoundRobinComboKernelGrid(ComboKernelGrid): + def combo_x_grid( + self, + xnumels: list[Union[int, str]], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> str: + assert len(xnumels) == len(no_x_dims) + num_kernels = self.inductor_meta["combo_grid_meta"]["num_kernels"] + exprs = [x for x, no_x_dim in zip(xnumels, no_x_dims) if no_x_dim] + xnumels_x_dim = [x for x, no_x_dim in zip(xnumels, no_x_dims) if not no_x_dim] + if xnumels_x_dim: + exprs.append(self.ceildiv(self.maximum(xnumels_x_dim), meta.get("XBLOCK"))) + return f"({self.maximum(exprs)}) * {num_kernels}" diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75c4ba0c5cfd631fa3559b1357c1771ee07002f3 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdea4346eeda1a0ed0777ffa1a961d56e56d8638 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b15b1710c37c1e2ed3978d6ecead07c03b239c75 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37dc231e2f2544743a62c2b3f6cd687ff04a7f89 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57bab24fe79b869b885bf4c35a1e654d17307306 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9056c0c927e743dd3cb30ffa27ad39973509727e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a51f4faa0f591628fca7341e377d918ff67292ca Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd3eba35c43011b2ffe9a0ec8a7a8289388acc35 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4066069343c59e377d981e6a7c0e8d7b12cf3f24 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaece196ac9699327a2f821ea94a16cfd633f971 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cef20d4cfdf51172ab84612fe4be9d3ce7702b4 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f27815b5da8e3dd7f757637010a60a50fd675c08 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..052908cff051b69fa1cd08d15c248e7c2b6a8196 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6e0044736da8335e249c3a4980d0aa7ca79250c Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/_util.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23488910e7fa62d7d326a70fd2348eb6b8d12e5e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/fft.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68500dc93b8ff021904f3a63ca1a00f93df08e18 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/linalg.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3d87fe4f3c5e0cf8197040f2bd474bd414e7549 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_numpy/__pycache__/random.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py b/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05e73b12e29f8e6608647a3f16fabab39fbfb582 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_numpy/testing/__init__.py @@ -0,0 +1,20 @@ +# mypy: ignore-errors + +from .utils import ( + _gen_alignment_data, + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_array_less, + assert_equal, + assert_raises_regex, + assert_warns, + HAS_REFCOUNT, + IS_WASM, + suppress_warnings, +) + + +# from .testing import assert_allclose # FIXME diff --git a/.venv/lib/python3.12/site-packages/torch/_numpy/testing/utils.py b/.venv/lib/python3.12/site-packages/torch/_numpy/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..29885b917049e3549eddd7063b9a4f839d740da6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_numpy/testing/utils.py @@ -0,0 +1,2386 @@ +# mypy: ignore-errors + +""" +Utility function to facilitate testing. + +""" +import contextlib +import gc +import operator +import os +import platform +import pprint +import re +import shutil +import sys +import warnings +from functools import wraps +from io import StringIO +from tempfile import mkdtemp, mkstemp +from warnings import WarningMessage + +import torch._numpy as np +from torch._numpy import arange, asarray as asanyarray, empty, float32, intp, ndarray + + +__all__ = [ + "assert_equal", + "assert_almost_equal", + "assert_approx_equal", + "assert_array_equal", + "assert_array_less", + "assert_string_equal", + "assert_", + "assert_array_almost_equal", + "build_err_msg", + "decorate_methods", + "print_assert_equal", + "verbose", + "assert_", + "assert_array_almost_equal_nulp", + "assert_raises_regex", + "assert_array_max_ulp", + "assert_warns", + "assert_no_warnings", + "assert_allclose", + "IgnoreException", + "clear_and_catch_warnings", + "temppath", + "tempdir", + "IS_PYPY", + "HAS_REFCOUNT", + "IS_WASM", + "suppress_warnings", + "assert_array_compare", + "assert_no_gc_cycles", + "break_cycles", + "IS_PYSTON", +] + + +verbose = 0 + +IS_WASM = platform.machine() in ["wasm32", "wasm64"] +IS_PYPY = sys.implementation.name == "pypy" +IS_PYSTON = hasattr(sys, "pyston_version_info") +HAS_REFCOUNT = getattr(sys, "getrefcount", None) is not None and not IS_PYSTON + + +def assert_(val, msg=""): + """ + Assert that works in release mode. + Accepts callable msg to allow deferring evaluation until failure. + + The Python built-in ``assert`` does not work when executing code in + optimized mode (the ``-O`` flag) - no byte-code is generated for it. + + For documentation on usage, refer to the Python documentation. + + """ + __tracebackhide__ = True # Hide traceback for py.test + if not val: + try: + smsg = msg() + except TypeError: + smsg = msg + raise AssertionError(smsg) + + +def gisnan(x): + return np.isnan(x) + + +def gisfinite(x): + return np.isfinite(x) + + +def gisinf(x): + return np.isinf(x) + + +def build_err_msg( + arrays, + err_msg, + header="Items are not equal:", + verbose=True, + names=("ACTUAL", "DESIRED"), + precision=8, +): + msg = ["\n" + header] + if err_msg: + if err_msg.find("\n") == -1 and len(err_msg) < 79 - len(header): + msg = [msg[0] + " " + err_msg] + else: + msg.append(err_msg) + if verbose: + for i, a in enumerate(arrays): + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + # r_func = partial(array_repr, precision=precision) + r_func = ndarray.__repr__ + else: + r_func = repr + + try: + r = r_func(a) + except Exception as exc: + r = f"[repr failed for <{type(a).__name__}>: {exc}]" + if r.count("\n") > 3: + r = "\n".join(r.splitlines()[:3]) + r += "..." + msg.append(f" {names[i]}: {r}") + return "\n".join(msg) + + +def assert_equal(actual, desired, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal. + + Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), + check that all elements of these objects are equal. An exception is raised + at the first conflicting values. + + When one of `actual` and `desired` is a scalar and the other is array_like, + the function checks that each element of the array_like object is equal to + the scalar. + + This function handles NaN comparisons as if NaN was a "normal" number. + That is, AssertionError is not raised if both objects have NaNs in the same + positions. This is in contrast to the IEEE standard on NaNs, which says + that NaN compared to anything must return False. + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal. + + Examples + -------- + >>> np.testing.assert_equal([4,5], [4,6]) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal: + item=1 + ACTUAL: 5 + DESIRED: 6 + + The following comparison does not raise an exception. There are NaNs + in the inputs, but they are in the same positions. + + >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + num_nones = sum([actual is None, desired is None]) + if num_nones == 1: + raise AssertionError(f"Not equal: {actual} != {desired}") + elif num_nones == 2: + return True + # else, carry on + + if isinstance(actual, np.DType) or isinstance(desired, np.DType): + result = actual == desired + if not result: + raise AssertionError(f"Not equal: {actual} != {desired}") + else: + return True + + if isinstance(desired, str) and isinstance(actual, str): + assert actual == desired + return + + if isinstance(desired, dict): + if not isinstance(actual, dict): + raise AssertionError(repr(type(actual))) + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in desired.keys(): + if k not in actual: + raise AssertionError(repr(k)) + assert_equal(actual[k], desired[k], f"key={k!r}\n{err_msg}", verbose) + return + if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in range(len(desired)): + assert_equal(actual[k], desired[k], f"item={k!r}\n{err_msg}", verbose) + return + + from torch._numpy import imag, iscomplexobj, isscalar, ndarray, real, signbit + + if isinstance(actual, ndarray) or isinstance(desired, ndarray): + return assert_array_equal(actual, desired, err_msg, verbose) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except (ValueError, TypeError): + usecomplex = False + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_equal(actualr, desiredr) + assert_equal(actuali, desiredi) + except AssertionError: + raise AssertionError(msg) # noqa: B904 + + # isscalar test to check cases such as [np.nan] != np.nan + if isscalar(desired) != isscalar(actual): + raise AssertionError(msg) + + # Inf/nan/negative zero handling + try: + isdesnan = gisnan(desired) + isactnan = gisnan(actual) + if isdesnan and isactnan: + return # both nan, so equal + + if desired == 0 and actual == 0: + if not signbit(desired) == signbit(actual): + raise AssertionError(msg) + + except (TypeError, ValueError, NotImplementedError): + pass + + try: + # Explicitly use __eq__ for comparison, gh-2552 + if not (desired == actual): + raise AssertionError(msg) + + except (DeprecationWarning, FutureWarning) as e: + # this handles the case when the two types are not even comparable + if "elementwise == comparison" in e.args[0]: + raise AssertionError(msg) # noqa: B904 + else: + raise + + +def print_assert_equal(test_string, actual, desired): + """ + Test if two objects are equal, and print an error message if test fails. + + The test is performed with ``actual == desired``. + + Parameters + ---------- + test_string : str + The message supplied to AssertionError. + actual : object + The object to test for equality against `desired`. + desired : object + The expected result. + + Examples + -------- + >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) # doctest: +SKIP + >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Test XYZ of func xyz failed + ACTUAL: + [0, 1] + DESIRED: + [0, 2] + + """ + __tracebackhide__ = True # Hide traceback for py.test + import pprint + + if not (actual == desired): + msg = StringIO() + msg.write(test_string) + msg.write(" failed\nACTUAL: \n") + pprint.pprint(actual, msg) + msg.write("DESIRED: \n") + pprint.pprint(desired, msg) + raise AssertionError(msg.getvalue()) + + +def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies that the elements of `actual` and `desired` satisfy. + + ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` + + That is a looser test than originally documented, but agrees with what the + actual implementation in `assert_array_almost_equal` did up to rounding + vagaries. An exception is raised at conflicting values. For ndarrays this + delegates to assert_array_almost_equal + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + decimal : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> from torch._numpy.testing import assert_almost_equal + >>> assert_almost_equal(2.3333333333333, 2.33333334) + >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 10 decimals + ACTUAL: 2.3333333333333 + DESIRED: 2.33333334 + + >>> assert_almost_equal(np.array([1.0,2.3333333333333]), + ... np.array([1.0,2.33333334]), decimal=9) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 9 decimals + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 6.666699636781459e-09 + Max relative difference: 2.8571569790287484e-09 + x: torch.ndarray([1.0000, 2.3333], dtype=float64) + y: torch.ndarray([1.0000, 2.3333], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import imag, iscomplexobj, ndarray, real + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except ValueError: + usecomplex = False + + def _build_err_msg(): + header = f"Arrays are not almost equal to {decimal:d} decimals" + return build_err_msg([actual, desired], err_msg, verbose=verbose, header=header) + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_almost_equal(actualr, desiredr, decimal=decimal) + assert_almost_equal(actuali, desiredi, decimal=decimal) + except AssertionError: + raise AssertionError(_build_err_msg()) # noqa: B904 + + if isinstance(actual, (ndarray, tuple, list)) or isinstance( + desired, (ndarray, tuple, list) + ): + return assert_array_almost_equal(actual, desired, decimal, err_msg) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(_build_err_msg()) + else: + if not desired == actual: + raise AssertionError(_build_err_msg()) + return + except (NotImplementedError, TypeError): + pass + if abs(desired - actual) >= np.float64(1.5 * 10.0 ** (-decimal)): + raise AssertionError(_build_err_msg()) + + +def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to significant + digits. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + Given two numbers, check that they are approximately equal. + Approximately equal is defined as the number of significant digits + that agree. + + Parameters + ---------- + actual : scalar + The object to check. + desired : scalar + The expected object. + significant : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) # doctest: +SKIP + >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, # doctest: +SKIP + ... significant=8) + >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, # doctest: +SKIP + ... significant=8) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal to 8 significant digits: + ACTUAL: 1.234567e-21 + DESIRED: 1.2345672e-21 + + the evaluated condition that raises the exception is + + >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) + True + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + (actual, desired) = map(float, (actual, desired)) + if desired == actual: + return + # Normalized the numbers to be in range (-10.0,10.0) + # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) + scale = 0.5 * (np.abs(desired) + np.abs(actual)) + scale = np.power(10, np.floor(np.log10(scale))) + try: + sc_desired = desired / scale + except ZeroDivisionError: + sc_desired = 0.0 + try: + sc_actual = actual / scale + except ZeroDivisionError: + sc_actual = 0.0 + msg = build_err_msg( + [actual, desired], + err_msg, + header=f"Items are not equal to {significant:d} significant digits:", + verbose=verbose, + ) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(msg) + else: + if not desired == actual: + raise AssertionError(msg) + return + except (TypeError, NotImplementedError): + pass + if np.abs(sc_desired - sc_actual) >= np.power(10.0, -(significant - 1)): + raise AssertionError(msg) + + +def assert_array_compare( + comparison, + x, + y, + err_msg="", + verbose=True, + header="", + precision=6, + equal_nan=True, + equal_inf=True, + *, + strict=False, +): + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import all, array, asarray, bool_, inf, isnan, max + + x = asarray(x) + y = asarray(y) + + def array2string(a): + return str(a) + + # original array for output formatting + ox, oy = x, y + + def func_assert_same_pos(x, y, func=isnan, hasval="nan"): + """Handling nan/inf. + + Combine results of running func on x and y, checking that they are True + at the same locations. + + """ + __tracebackhide__ = True # Hide traceback for py.test + x_id = func(x) + y_id = func(y) + # We include work-arounds here to handle three types of slightly + # pathological ndarray subclasses: + # (1) all() on `masked` array scalars can return masked arrays, so we + # use != True + # (2) __eq__ on some ndarray subclasses returns Python booleans + # instead of element-wise comparisons, so we cast to bool_() and + # use isinstance(..., bool) checks + # (3) subclasses with bare-bones __array_function__ implementations may + # not implement np.all(), so favor using the .all() method + # We are not committed to supporting such subclasses, but it's nice to + # support them if possible. + if (x_id == y_id).all().item() is not True: + msg = build_err_msg( + [x, y], + err_msg + f"\nx and y {hasval} location mismatch:", + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + # If there is a scalar, then here we know the array has the same + # flag as it everywhere, so we should return the scalar flag. + if isinstance(x_id, bool) or x_id.ndim == 0: + return bool_(x_id) + elif isinstance(y_id, bool) or y_id.ndim == 0: + return bool_(y_id) + else: + return y_id + + try: + if strict: + cond = x.shape == y.shape and x.dtype == y.dtype + else: + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape + if not cond: + if x.shape != y.shape: + reason = f"\n(shapes {x.shape}, {y.shape} mismatch)" + else: + reason = f"\n(dtypes {x.dtype}, {y.dtype} mismatch)" + msg = build_err_msg( + [x, y], + err_msg + reason, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + + flagged = bool_(False) + + if equal_nan: + flagged = func_assert_same_pos(x, y, func=isnan, hasval="nan") + + if equal_inf: + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == +inf, hasval="+inf" + ) + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == -inf, hasval="-inf" + ) + + if flagged.ndim > 0: + x, y = x[~flagged], y[~flagged] + # Only do the comparison if actual values are left + if x.size == 0: + return + elif flagged: + # no sense doing comparison if everything is flagged. + return + + val = comparison(x, y) + + if isinstance(val, bool): + cond = val + reduced = array([val]) + else: + reduced = val.ravel() + cond = reduced.all() + + # The below comparison is a hack to ensure that fully masked + # results, for which val.ravel().all() returns np.ma.masked, + # do not trigger a failure (np.ma.masked != True evaluates as + # np.ma.masked, which is falsy). + if not cond: + n_mismatch = reduced.size - int(reduced.sum(dtype=intp)) + n_elements = flagged.size if flagged.ndim != 0 else reduced.size + percent_mismatch = 100 * n_mismatch / n_elements + remarks = [ + f"Mismatched elements: {n_mismatch} / {n_elements} ({percent_mismatch:.3g}%)" + ] + + # with errstate(all='ignore'): + # ignore errors for non-numeric types + with contextlib.suppress(TypeError, RuntimeError): + error = abs(x - y) + if np.issubdtype(x.dtype, np.unsignedinteger): + error2 = abs(y - x) + np.minimum(error, error2, out=error) + max_abs_error = max(error) + remarks.append( + "Max absolute difference: " + array2string(max_abs_error.item()) + ) + + # note: this definition of relative error matches that one + # used by assert_allclose (found in np.isclose) + # Filter values where the divisor would be zero + nonzero = bool_(y != 0) + if all(~nonzero): + max_rel_error = array(inf) + else: + max_rel_error = max(error[nonzero] / abs(y[nonzero])) + remarks.append( + "Max relative difference: " + array2string(max_rel_error.item()) + ) + + err_msg += "\n" + "\n".join(remarks) + msg = build_err_msg( + [ox, oy], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + except ValueError: + import traceback + + efmt = traceback.format_exc() + header = f"error during assertion:\n\n{efmt}\n\n{header}" + + msg = build_err_msg( + [x, y], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise ValueError(msg) # noqa: B904 + + +def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): + """ + Raises an AssertionError if two array_like objects are not equal. + + Given two array_like objects, check that the shape is equal and all + elements of these objects are equal (but see the Notes for the special + handling of a scalar). An exception is raised at shape mismatch or + conflicting values. In contrast to the standard usage in numpy, NaNs + are compared like numbers, no assertion is raised if both objects have + NaNs in the same positions. + + The usual caution for verifying equality with floating point numbers is + advised. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an AssertionError when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Notes + ----- + When one of `x` and `y` is a scalar and the other is array_like, the + function checks that each element of the array_like object is equal to + the scalar. This behaviour can be disabled with the `strict` parameter. + + Examples + -------- + The first assert does not raise an exception: + + >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], + ... [np.exp(0),2.33333, np.nan]) + + Use `assert_allclose` or one of the nulp (number of floating point values) + functions for these cases instead: + + >>> np.testing.assert_allclose([1.0,np.pi,np.nan], + ... [1, np.sqrt(np.pi)**2, np.nan], + ... rtol=1e-10, atol=0) + + As mentioned in the Notes section, `assert_array_equal` has special + handling for scalars. Here the test checks that each value in `x` is 3: + + >>> x = np.full((2, 5), fill_value=3) + >>> np.testing.assert_array_equal(x, 3) + + Use `strict` to raise an AssertionError when comparing a scalar with an + array: + + >>> np.testing.assert_array_equal(x, 3, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (shapes (2, 5), () mismatch) + x: torch.ndarray([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: torch.ndarray(3) + + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2., 2., 2.], dtype=np.float32) + >>> np.testing.assert_array_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (dtypes dtype("int64"), dtype("float32") mismatch) + x: torch.ndarray([2, 2, 2]) + y: torch.ndarray([2., 2., 2.]) + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__eq__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not equal", + strict=strict, + ) + + +def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies identical shapes and that the elements of ``actual`` and + ``desired`` satisfy. + + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation did up to rounding vagaries. An exception is raised + at shape mismatch or conflicting values. In contrast to the standard usage + in numpy, NaNs are compared like numbers, no assertion is raised if both + objects have NaNs in the same positions. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + decimal : int, optional + Desired precision, default is 6. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + the first assert does not raise an exception + + >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], + ... [1.0,2.333,np.nan]) + + >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], + ... [1.0,2.33339,np.nan], decimal=5) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 5.999999999994898e-05 + Max relative difference: 2.5713661239633743e-05 + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64) + + >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], + ... [1.0,2.33333, 5], decimal=5) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + x and y nan location mismatch: + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3333, 5.0000], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import any as npany, float_, issubdtype, number, result_type + + def compare(x, y): + try: + if npany(gisinf(x)) or npany(gisinf(y)): + xinfid = gisinf(x) + yinfid = gisinf(y) + if not (xinfid == yinfid).all(): + return False + # if one item, x and y is +- inf + if x.size == y.size == 1: + return x == y + x = x[~xinfid] + y = y[~yinfid] + except (TypeError, NotImplementedError): + pass + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.0) + y = asanyarray(y, dtype) + z = abs(x - y) + + if not issubdtype(z.dtype, number): + z = z.astype(float_) # handle object arrays + + return z < 1.5 * 10.0 ** (-decimal) + + assert_array_compare( + compare, + x, + y, + err_msg=err_msg, + verbose=verbose, + header=f"Arrays are not almost equal to {decimal:d} decimals", + precision=decimal, + ) + + +def assert_array_less(x, y, err_msg="", verbose=True): + """ + Raises an AssertionError if two array_like objects are not ordered by less + than. + + Given two array_like objects, check that the shape is equal and all + elements of the first object are strictly smaller than those of the + second object. An exception is raised at shape mismatch or incorrectly + ordered values. Shape mismatch does not raise if an object has zero + dimension. In contrast to the standard usage in numpy, NaNs are + compared, no assertion is raised if both objects have NaNs in the same + positions. + + + + Parameters + ---------- + x : array_like + The smaller object to check. + y : array_like + The larger object to compare. + err_msg : string + The error message to be printed in case of failure. + verbose : bool + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_array_equal: tests objects for equality + assert_array_almost_equal: test objects for equality up to precision + + + + Examples + -------- + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 1.0 + Max relative difference: 0.5 + x: torch.ndarray([1., 1., nan], dtype=float64) + y: torch.ndarray([1., 2., nan], dtype=float64) + + >>> np.testing.assert_array_less([1.0, 4.0], 3) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 2.0 + Max relative difference: 0.6666666666666666 + x: torch.ndarray([1., 4.], dtype=float64) + y: torch.ndarray(3) + + >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + (shapes (3,), (1,) mismatch) + x: torch.ndarray([1., 2., 3.], dtype=float64) + y: torch.ndarray([4]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__lt__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not less-ordered", + equal_inf=False, + ) + + +def assert_string_equal(actual, desired): + """ + Test if two strings are equal. + + If the given strings are equal, `assert_string_equal` does nothing. + If they are not equal, an AssertionError is raised, and the diff + between the strings is shown. + + Parameters + ---------- + actual : str + The string to test for equality against the expected string. + desired : str + The expected string. + + Examples + -------- + >>> np.testing.assert_string_equal('abc', 'abc') # doctest: +SKIP + >>> np.testing.assert_string_equal('abc', 'abcd') # doctest: +SKIP + Traceback (most recent call last): + File "", line 1, in + ... + AssertionError: Differences in strings: + - abc+ abcd? + + + """ + # delay import of difflib to reduce startup time + __tracebackhide__ = True # Hide traceback for py.test + import difflib + + if not isinstance(actual, str): + raise AssertionError(repr(type(actual))) + if not isinstance(desired, str): + raise AssertionError(repr(type(desired))) + if desired == actual: + return + + diff = list( + difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)) + ) + diff_list = [] + while diff: + d1 = diff.pop(0) + if d1.startswith(" "): + continue + if d1.startswith("- "): + l = [d1] + d2 = diff.pop(0) + if d2.startswith("? "): + l.append(d2) + d2 = diff.pop(0) + if not d2.startswith("+ "): + raise AssertionError(repr(d2)) + l.append(d2) + if diff: + d3 = diff.pop(0) + if d3.startswith("? "): + l.append(d3) + else: + diff.insert(0, d3) + if d2[2:] == d1[2:]: + continue + diff_list.extend(l) + continue + raise AssertionError(repr(d1)) + if not diff_list: + return + msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" + if actual != desired: + raise AssertionError(msg) + + +import unittest + + +class _Dummy(unittest.TestCase): + def nop(self): + pass + + +_d = _Dummy("nop") + + +def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): + """ + assert_raises_regex(exception_class, expected_regexp, callable, *args, + **kwargs) + assert_raises_regex(exception_class, expected_regexp) + + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Alternatively, can be used as a context manager like `assert_raises`. + + Notes + ----- + .. versionadded:: 1.9.0 + + """ + __tracebackhide__ = True # Hide traceback for py.test + return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) + + +def decorate_methods(cls, decorator, testmatch=None): + """ + Apply a decorator to all methods in a class matching a regular expression. + + The given decorator is applied to all public methods of `cls` that are + matched by the regular expression `testmatch` + (``testmatch.search(methodname)``). Methods that are private, i.e. start + with an underscore, are ignored. + + Parameters + ---------- + cls : class + Class whose methods to decorate. + decorator : function + Decorator to apply to methods + testmatch : compiled regexp or str, optional + The regular expression. Default value is None, in which case the + nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) + is used. + If `testmatch` is a string, it is compiled to a regular expression + first. + + """ + if testmatch is None: + testmatch = re.compile(rf"(?:^|[\\b_\\.{os.sep}-])[Tt]est") + else: + testmatch = re.compile(testmatch) + cls_attr = cls.__dict__ + + # delayed import to reduce startup time + from inspect import isfunction + + methods = [_m for _m in cls_attr.values() if isfunction(_m)] + for function in methods: + try: + if hasattr(function, "compat_func_name"): + funcname = function.compat_func_name + else: + funcname = function.__name__ + except AttributeError: + # not a function + continue + if testmatch.search(funcname) and not funcname.startswith("_"): + setattr(cls, funcname, decorator(function)) + return + + +def _assert_valid_refcount(op): + """ + Check that ufuncs don't mishandle refcount of object `1`. + Used in a few regression tests. + """ + if not HAS_REFCOUNT: + return True + + import gc + + import numpy as np + + b = np.arange(100 * 100).reshape(100, 100) + c = b + i = 1 + + gc.disable() + try: + rc = sys.getrefcount(i) + for _ in range(15): + d = op(b, c) + assert_(sys.getrefcount(i) >= rc) + finally: + gc.enable() + del d # for pyflakes + + +def assert_allclose( + actual, + desired, + rtol=1e-7, + atol=0, + equal_nan=True, + err_msg="", + verbose=True, + check_dtype=False, +): + """ + Raises an AssertionError if two objects are not equal up to desired + tolerance. + + Given two array_like objects, check that their shapes and all elements + are equal (but see the Notes for the special handling of a scalar). An + exception is raised if the shapes mismatch or any values conflict. In + contrast to the standard usage in numpy, NaNs are compared like numbers, + no assertion is raised if both objects have NaNs in the same positions. + + The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note + that ``allclose`` has different default values). It compares the difference + between `actual` and `desired` to ``atol + rtol * abs(desired)``. + + .. versionadded:: 1.5.0 + + Parameters + ---------- + actual : array_like + Array obtained. + desired : array_like + Array desired. + rtol : float, optional + Relative tolerance. + atol : float, optional + Absolute tolerance. + equal_nan : bool, optional. + If True, NaNs will compare equal. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_array_almost_equal_nulp, assert_array_max_ulp + + Notes + ----- + When one of `actual` and `desired` is a scalar and the other is + array_like, the function checks that each element of the array_like + object is equal to the scalar. + + Examples + -------- + >>> x = [1e-5, 1e-3, 1e-1] + >>> y = np.arccos(np.cos(x)) + >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + def compare(x, y): + return np.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + actual, desired = asanyarray(actual), asanyarray(desired) + header = f"Not equal to tolerance rtol={rtol:g}, atol={atol:g}" + + if check_dtype: + assert actual.dtype == desired.dtype + + assert_array_compare( + compare, + actual, + desired, + err_msg=str(err_msg), + verbose=verbose, + header=header, + equal_nan=equal_nan, + ) + + +def assert_array_almost_equal_nulp(x, y, nulp=1): + """ + Compare two arrays relatively to their spacing. + + This is a relatively robust method to compare two arrays whose amplitude + is variable. + + Parameters + ---------- + x, y : array_like + Input arrays. + nulp : int, optional + The maximum number of unit in the last place for tolerance (see Notes). + Default is 1. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the spacing between `x` and `y` for one or more elements is larger + than `nulp`. + + See Also + -------- + assert_array_max_ulp : Check that all items of arrays differ in at most + N Units in the Last Place. + spacing : Return the distance between x and the nearest adjacent number. + + Notes + ----- + An assertion is raised if the following condition is not met:: + + abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) + + Examples + -------- + >>> x = np.array([1., 1e-10, 1e-20]) + >>> eps = np.finfo(x.dtype).eps + >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) # doctest: +SKIP + + >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: X and Y are not equal to 1 ULP (max is 2) + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ax = np.abs(x) + ay = np.abs(y) + ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) + if not np.all(np.abs(x - y) <= ref): + if np.iscomplexobj(x) or np.iscomplexobj(y): + msg = f"X and Y are not equal to {nulp:d} ULP" + else: + max_nulp = np.max(nulp_diff(x, y)) + msg = f"X and Y are not equal to {nulp:d} ULP (max is {max_nulp:g})" + raise AssertionError(msg) + + +def assert_array_max_ulp(a, b, maxulp=1, dtype=None): + """ + Check that all items of arrays differ in at most N Units in the Last Place. + + Parameters + ---------- + a, b : array_like + Input arrays to be compared. + maxulp : int, optional + The maximum number of units in the last place that elements of `a` and + `b` can differ. Default is 1. + dtype : dtype, optional + Data-type to convert `a` and `b` to if given. Default is None. + + Returns + ------- + ret : ndarray + Array containing number of representable floating point numbers between + items in `a` and `b`. + + Raises + ------ + AssertionError + If one or more elements differ by more than `maxulp`. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + See Also + -------- + assert_array_almost_equal_nulp : Compare two arrays relatively to their + spacing. + + Examples + -------- + >>> a = np.linspace(0., 1., 100) + >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ret = nulp_diff(a, b, dtype) + if not np.all(ret <= maxulp): + raise AssertionError( + f"Arrays are not almost equal up to {maxulp:g} " + f"ULP (max difference is {np.max(ret):g} ULP)" + ) + return ret + + +def nulp_diff(x, y, dtype=None): + """For each item in x and y, return the number of representable floating + points between them. + + Parameters + ---------- + x : array_like + first input array + y : array_like + second input array + dtype : dtype, optional + Data-type to convert `x` and `y` to if given. Default is None. + + Returns + ------- + nulp : array_like + number of representable floating point numbers between each item in x + and y. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + Examples + -------- + # By definition, epsilon is the smallest number such as 1 + eps != 1, so + # there should be exactly one ULP between 1 and 1 + eps + >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) # doctest: +SKIP + 1.0 + """ + import numpy as np + + if dtype: + x = np.asarray(x, dtype=dtype) + y = np.asarray(y, dtype=dtype) + else: + x = np.asarray(x) + y = np.asarray(y) + + t = np.common_type(x, y) + if np.iscomplexobj(x) or np.iscomplexobj(y): + raise NotImplementedError("_nulp not implemented for complex array") + + x = np.array([x], dtype=t) + y = np.array([y], dtype=t) + + x[np.isnan(x)] = np.nan + y[np.isnan(y)] = np.nan + + if not x.shape == y.shape: + raise ValueError(f"x and y do not have the same shape: {x.shape} - {y.shape}") + + def _diff(rx, ry, vdt): + diff = np.asarray(rx - ry, dtype=vdt) + return np.abs(diff) + + rx = integer_repr(x) + ry = integer_repr(y) + return _diff(rx, ry, t) + + +def _integer_repr(x, vdt, comp): + # Reinterpret binary representation of the float as sign-magnitude: + # take into account two-complement representation + # See also + # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ + rx = x.view(vdt) + if not (rx.size == 1): + rx[rx < 0] = comp - rx[rx < 0] + else: + if rx < 0: + rx = comp - rx + + return rx + + +def integer_repr(x): + """Return the signed-magnitude interpretation of the binary representation + of x.""" + import numpy as np + + if x.dtype == np.float16: + return _integer_repr(x, np.int16, np.int16(-(2**15))) + elif x.dtype == np.float32: + return _integer_repr(x, np.int32, np.int32(-(2**31))) + elif x.dtype == np.float64: + return _integer_repr(x, np.int64, np.int64(-(2**63))) + else: + raise ValueError(f"Unsupported dtype {x.dtype}") + + +@contextlib.contextmanager +def _assert_warns_context(warning_class, name=None): + __tracebackhide__ = True # Hide traceback for py.test + with suppress_warnings() as sup: + l = sup.record(warning_class) + yield + if not len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError("No warning raised" + name_str) + + +def assert_warns(warning_class, *args, **kwargs): + """ + Fail unless the given callable throws the specified warning. + + A warning of class warning_class should be thrown by the callable when + invoked with arguments args and keyword arguments kwargs. + If a different type of warning is thrown, it will not be caught. + + If called with all arguments other than the warning class omitted, may be + used as a context manager: + + with assert_warns(SomeWarning): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + warning_class : class + The class defining the warning that `func` is expected to throw. + func : callable, optional + Callable to test + *args : Arguments + Arguments for `func`. + **kwargs : Kwargs + Keyword arguments for `func`. + + Returns + ------- + The value returned by `func`. + + Examples + -------- + >>> import warnings + >>> def deprecated_func(num): + ... warnings.warn("Please upgrade", DeprecationWarning) + ... return num*num + >>> with np.testing.assert_warns(DeprecationWarning): + ... assert deprecated_func(4) == 16 + >>> # or passing a func + >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) + >>> assert ret == 16 + """ + if not args: + return _assert_warns_context(warning_class) + + func = args[0] + args = args[1:] + with _assert_warns_context(warning_class, name=func.__name__): + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _assert_no_warnings_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + with warnings.catch_warnings(record=True) as l: + warnings.simplefilter("always") + yield + if len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError(f"Got warnings{name_str}: {l}") + + +def assert_no_warnings(*args, **kwargs): + """ + Fail if the given callable produces any warnings. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_warnings(): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.7.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + The value returned by `func`. + + """ + if not args: + return _assert_no_warnings_context() + + func = args[0] + args = args[1:] + with _assert_no_warnings_context(name=func.__name__): + return func(*args, **kwargs) + + +def _gen_alignment_data(dtype=float32, type="binary", max_size=24): + """ + generator producing data with different alignment and offsets + to test simd vectorization + + Parameters + ---------- + dtype : dtype + data type to produce + type : string + 'unary': create data for unary operations, creates one input + and output array + 'binary': create data for unary operations, creates two input + and output array + max_size : integer + maximum size of data to produce + + Returns + ------- + if type is 'unary' yields one output, one input array and a message + containing information on the data + if type is 'binary' yields one output array, two input array and a message + containing information on the data + + """ + ufmt = "unary offset=(%d, %d), size=%d, dtype=%r, %s" + bfmt = "binary offset=(%d, %d, %d), size=%d, dtype=%r, %s" + for o in range(3): + for s in range(o + 2, max(o + 3, max_size)): + if type == "unary": + + def inp(): + return arange(s, dtype=dtype)[o:] + + out = empty((s,), dtype=dtype)[o:] + yield out, inp(), ufmt % (o, o, s, dtype, "out of place") + d = inp() + yield d, d, ufmt % (o, o, s, dtype, "in place") + yield out[1:], inp()[:-1], ufmt % ( + o + 1, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp()[1:], ufmt % ( + o, + o + 1, + s - 1, + dtype, + "out of place", + ) + yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased") + yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased") + if type == "binary": + + def inp1(): + return arange(s, dtype=dtype)[o:] + + inp2 = inp1 + out = empty((s,), dtype=dtype)[o:] + yield out, inp1(), inp2(), bfmt % (o, o, o, s, dtype, "out of place") + d = inp1() + yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1") + d = inp2() + yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2") + yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % ( + o + 1, + o, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % ( + o, + o + 1, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % ( + o, + o, + o + 1, + s - 1, + dtype, + "out of place", + ) + yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % ( + o + 1, + o, + o, + s - 1, + dtype, + "aliased", + ) + yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % ( + o, + o + 1, + o, + s - 1, + dtype, + "aliased", + ) + yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % ( + o, + o, + o + 1, + s - 1, + dtype, + "aliased", + ) + + +class IgnoreException(Exception): + "Ignoring this exception due to disabled feature" + + +@contextlib.contextmanager +def tempdir(*args, **kwargs): + """Context manager to provide a temporary test folder. + + All arguments are passed as this to the underlying tempfile.mkdtemp + function. + + """ + tmpdir = mkdtemp(*args, **kwargs) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + +@contextlib.contextmanager +def temppath(*args, **kwargs): + """Context manager for temporary files. + + Context manager that returns the path to a closed temporary file. Its + parameters are the same as for tempfile.mkstemp and are passed directly + to that function. The underlying file is removed when the context is + exited, so it should be closed at that time. + + Windows does not allow a temporary file to be opened if it is already + open, so the underlying file must be closed after opening before it + can be opened again. + + """ + fd, path = mkstemp(*args, **kwargs) + os.close(fd) + try: + yield path + finally: + os.remove(path) + + +class clear_and_catch_warnings(warnings.catch_warnings): + """Context manager that resets warning registry for catching warnings + + Warnings can be slippery, because, whenever a warning is triggered, Python + adds a ``__warningregistry__`` member to the *calling* module. This makes + it impossible to retrigger the warning in this module, whatever you put in + the warnings filters. This context manager accepts a sequence of `modules` + as a keyword argument to its constructor and: + + * stores and removes any ``__warningregistry__`` entries in given `modules` + on entry; + * resets ``__warningregistry__`` to its previous state on exit. + + This makes it possible to trigger any warning afresh inside the context + manager without disturbing the state of warnings outside. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + Parameters + ---------- + record : bool, optional + Specifies whether warnings should be captured by a custom + implementation of ``warnings.showwarning()`` and be appended to a list + returned by the context manager. Otherwise None is returned by the + context manager. The objects appended to the list are arguments whose + attributes mirror the arguments to ``showwarning()``. + modules : sequence, optional + Sequence of modules for which to reset warnings registry on entry and + restore on exit. To work correctly, all 'ignore' filters should + filter by one of these modules. + + Examples + -------- + >>> import warnings + >>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP + ... modules=[np.core.fromnumeric]): + ... warnings.simplefilter('always') + ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') + ... # do something that raises a warning but ignore those in + ... # np.core.fromnumeric + """ + + class_modules = () + + def __init__(self, record=False, modules=()): + self.modules = set(modules).union(self.class_modules) + self._warnreg_copies = {} + super().__init__(record=record) + + def __enter__(self): + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod_reg = mod.__warningregistry__ + self._warnreg_copies[mod] = mod_reg.copy() + mod_reg.clear() + return super().__enter__() + + def __exit__(self, *exc_info): + super().__exit__(*exc_info) + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod.__warningregistry__.clear() + if mod in self._warnreg_copies: + mod.__warningregistry__.update(self._warnreg_copies[mod]) + + +class suppress_warnings: + """ + Context manager and decorator doing much the same as + ``warnings.catch_warnings``. + + However, it also provides a filter mechanism to work around + https://bugs.python.org/issue4180. + + This bug causes Python before 3.4 to not reliably show warnings again + after they have been ignored once (even within catch_warnings). It + means that no "ignore" filter can be used easily, since following + tests might need to see the warning. Additionally it allows easier + specificity for testing warnings and can be nested. + + Parameters + ---------- + forwarding_rule : str, optional + One of "always", "once", "module", or "location". Analogous to + the usual warnings module filter mode, it is useful to reduce + noise mostly on the outmost level. Unsuppressed and unrecorded + warnings will be forwarded based on this rule. Defaults to "always". + "location" is equivalent to the warnings "default", match by exact + location the warning warning originated from. + + Notes + ----- + Filters added inside the context manager will be discarded again + when leaving it. Upon entering all filters defined outside a + context will be applied automatically. + + When a recording filter is added, matching warnings are stored in the + ``log`` attribute as well as in the list returned by ``record``. + + If filters are added and the ``module`` keyword is given, the + warning registry of this module will additionally be cleared when + applying it, entering the context, or exiting it. This could cause + warnings to appear a second time after leaving the context if they + were configured to be printed once (default) and were already + printed before the context was entered. + + Nesting this context manager will work as expected when the + forwarding rule is "always" (default). Unfiltered and unrecorded + warnings will be passed out and be matched by the outer level. + On the outmost level they will be printed (or caught by another + warnings context). The forwarding rule argument can modify this + behaviour. + + Like ``catch_warnings`` this context manager is not threadsafe. + + Examples + -------- + + With a context manager:: + + with np.testing.suppress_warnings() as sup: + sup.filter(DeprecationWarning, "Some text") + sup.filter(module=np.ma.core) + log = sup.record(FutureWarning, "Does this occur?") + command_giving_warnings() + # The FutureWarning was given once, the filtered warnings were + # ignored. All other warnings abide outside settings (may be + # printed/error) + assert_(len(log) == 1) + assert_(len(sup.log) == 1) # also stored in log attribute + + Or as a decorator:: + + sup = np.testing.suppress_warnings() + sup.filter(module=np.ma.core) # module must match exactly + @sup + def some_function(): + # do something which causes a warning in np.ma.core + pass + """ + + def __init__(self, forwarding_rule="always"): + self._entered = False + + # Suppressions are either instance or defined inside one with block: + self._suppressions = [] + + if forwarding_rule not in {"always", "module", "once", "location"}: + raise ValueError("unsupported forwarding rule.") + self._forwarding_rule = forwarding_rule + + def _clear_registries(self): + if hasattr(warnings, "_filters_mutated"): + # clearing the registry should not be necessary on new pythons, + # instead the filters should be mutated. + warnings._filters_mutated() + return + # Simply clear the registry, this should normally be harmless, + # note that on new pythons it would be invalidated anyway. + for module in self._tmp_modules: + if hasattr(module, "__warningregistry__"): + module.__warningregistry__.clear() + + def _filter(self, category=Warning, message="", module=None, record=False): + if record: + record = [] # The log where to store warnings + else: + record = None + if self._entered: + if module is None: + warnings.filterwarnings("always", category=category, message=message) + else: + module_regex = module.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=category, message=message, module=module_regex + ) + self._tmp_modules.add(module) + self._clear_registries() + + self._tmp_suppressions.append( + (category, message, re.compile(message, re.IGNORECASE), module, record) + ) + else: + self._suppressions.append( + (category, message, re.compile(message, re.IGNORECASE), module, record) + ) + + return record + + def filter(self, category=Warning, message="", module=None): + """ + Add a new suppressing filter or apply it if the state is entered. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + self._filter(category=category, message=message, module=module, record=False) + + def record(self, category=Warning, message="", module=None): + """ + Append a new recording filter or apply it if the state is entered. + + All warnings matching will be appended to the ``log`` attribute. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Returns + ------- + log : list + A list which will be filled with all matched warnings. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + return self._filter( + category=category, message=message, module=module, record=True + ) + + def __enter__(self): + if self._entered: + raise RuntimeError("cannot enter suppress_warnings twice.") + + self._orig_show = warnings.showwarning + self._filters = warnings.filters + warnings.filters = self._filters[:] + + self._entered = True + self._tmp_suppressions = [] + self._tmp_modules = set() + self._forwarded = set() + + self.log = [] # reset global log (no need to keep same list) + + for cat, mess, _, mod, log in self._suppressions: + if log is not None: + del log[:] # clear the log + if mod is None: + warnings.filterwarnings("always", category=cat, message=mess) + else: + module_regex = mod.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=cat, message=mess, module=module_regex + ) + self._tmp_modules.add(mod) + warnings.showwarning = self._showwarning + self._clear_registries() + + return self + + def __exit__(self, *exc_info): + warnings.showwarning = self._orig_show + warnings.filters = self._filters + self._clear_registries() + self._entered = False + del self._orig_show + del self._filters + + def _showwarning( + self, message, category, filename, lineno, *args, use_warnmsg=None, **kwargs + ): + for cat, _, pattern, mod, rec in (self._suppressions + self._tmp_suppressions)[ + ::-1 + ]: + if issubclass(category, cat) and pattern.match(message.args[0]) is not None: + if mod is None: + # Message and category match, either recorded or ignored + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + # Use startswith, because warnings strips the c or o from + # .pyc/.pyo files. + elif mod.__file__.startswith(filename): + # The message and module (filename) match + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + + # There is no filter in place, so pass to the outside handler + # unless we should only pass it once + if self._forwarding_rule == "always": + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + return + + if self._forwarding_rule == "once": + signature = (message.args, category) + elif self._forwarding_rule == "module": + signature = (message.args, category, filename) + elif self._forwarding_rule == "location": + signature = (message.args, category, filename, lineno) + + if signature in self._forwarded: + return + self._forwarded.add(signature) + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + + def __call__(self, func): + """ + Function decorator to apply certain suppressions to a whole + function. + """ + + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func + + +@contextlib.contextmanager +def _assert_no_gc_cycles_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + + # not meaningful to test if there is no refcounting + if not HAS_REFCOUNT: + yield + return + + assert_(gc.isenabled()) + gc.disable() + gc_debug = gc.get_debug() + try: + for _ in range(100): + if gc.collect() == 0: + break + else: + raise RuntimeError( + "Unable to fully collect garbage - perhaps a __del__ method " + "is creating more reference cycles?" + ) + + gc.set_debug(gc.DEBUG_SAVEALL) + yield + # gc.collect returns the number of unreachable objects in cycles that + # were found -- we are checking that no cycles were created in the context + n_objects_in_cycles = gc.collect() + objects_in_cycles = gc.garbage[:] + finally: + del gc.garbage[:] + gc.set_debug(gc_debug) + gc.enable() + + if n_objects_in_cycles: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError( + "Reference cycles were found{}: {} objects were collected, " + "of which {} are shown below:{}".format( + name_str, + n_objects_in_cycles, + len(objects_in_cycles), + "".join( + "\n {} object with id={}:\n {}".format( + type(o).__name__, + id(o), + pprint.pformat(o).replace("\n", "\n "), + ) + for o in objects_in_cycles + ), + ) + ) + + +def assert_no_gc_cycles(*args, **kwargs): + """ + Fail if the given callable produces any reference cycles. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_gc_cycles(): + do_something() + + .. versionadded:: 1.15.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + Nothing. The result is deliberately discarded to ensure that all cycles + are found. + + """ + if not args: + return _assert_no_gc_cycles_context() + + func = args[0] + args = args[1:] + with _assert_no_gc_cycles_context(name=func.__name__): + func(*args, **kwargs) + + +def break_cycles(): + """ + Break reference cycles by calling gc.collect + Objects can call other objects' methods (for instance, another object's + __del__) inside their own __del__. On PyPy, the interpreter only runs + between calls to gc.collect, so multiple calls are needed to completely + release all cycles. + """ + + gc.collect() + if IS_PYPY: + # a few more, just to make sure all the finalizers are called + gc.collect() + gc.collect() + gc.collect() + gc.collect() + + +def requires_memory(free_bytes): + """Decorator to skip a test if not enough memory is available""" + import pytest + + def decorator(func): + @wraps(func) + def wrapper(*a, **kw): + msg = check_free_memory(free_bytes) + if msg is not None: + pytest.skip(msg) + + try: + return func(*a, **kw) + except MemoryError: + # Probably ran out of memory regardless: don't regard as failure + pytest.xfail("MemoryError raised") + + return wrapper + + return decorator + + +def check_free_memory(free_bytes): + """ + Check whether `free_bytes` amount of memory is currently free. + Returns: None if enough memory available, otherwise error message + """ + env_var = "NPY_AVAILABLE_MEM" + env_value = os.environ.get(env_var) + if env_value is not None: + try: + mem_free = _parse_size(env_value) + except ValueError as exc: + raise ValueError( # noqa: B904 + f"Invalid environment variable {env_var}: {exc}" + ) + + msg = ( + f"{free_bytes / 1e9} GB memory required, but environment variable " + f"NPY_AVAILABLE_MEM={env_value} set" + ) + else: + mem_free = _get_mem_available() + + if mem_free is None: + msg = ( + "Could not determine available memory; set NPY_AVAILABLE_MEM " + "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " + "the test." + ) + mem_free = -1 + else: + msg = f"{free_bytes / 1e9} GB memory required, but {mem_free / 1e9} GB available" + + return msg if mem_free < free_bytes else None + + +def _parse_size(size_str): + """Convert memory size strings ('12 GB' etc.) to float""" + suffixes = { + "": 1, + "b": 1, + "k": 1000, + "m": 1000**2, + "g": 1000**3, + "t": 1000**4, + "kb": 1000, + "mb": 1000**2, + "gb": 1000**3, + "tb": 1000**4, + "kib": 1024, + "mib": 1024**2, + "gib": 1024**3, + "tib": 1024**4, + } + + size_re = re.compile( + r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), + re.IGNORECASE, + ) + + m = size_re.match(size_str.lower()) + if not m or m.group(2) not in suffixes: + raise ValueError(f"value {size_str!r} not a valid size") + return int(float(m.group(1)) * suffixes[m.group(2)]) + + +def _get_mem_available(): + """Return available memory in bytes, or None if unknown.""" + try: + import psutil + + return psutil.virtual_memory().available + except (ImportError, AttributeError): + pass + + if sys.platform.startswith("linux"): + info = {} + with open("/proc/meminfo") as f: + for line in f: + p = line.split() + info[p[0].strip(":").lower()] = int(p[1]) * 1024 + + if "memavailable" in info: + # Linux >= 3.14 + return info["memavailable"] + else: + return info["memfree"] + info["cached"] + + return None + + +def _no_tracing(func): + """ + Decorator to temporarily turn off tracing for the duration of a test. + Needed in tests that check refcounting, otherwise the tracing itself + influences the refcounts + """ + if not hasattr(sys, "gettrace"): + return func + else: + + @wraps(func) + def wrapper(*args, **kwargs): + original_trace = sys.gettrace() + try: + sys.settrace(None) + return func(*args, **kwargs) + finally: + sys.settrace(original_trace) + + return wrapper + + +def _get_glibc_version(): + try: + ver = os.confstr("CS_GNU_LIBC_VERSION").rsplit(" ")[1] + except Exception: + ver = "0.0" + + return ver + + +_glibcver = _get_glibc_version() + + +def _glibc_older_than(x): + return _glibcver != "0.0" and _glibcver < x diff --git a/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a976409daa0a2e9015a8f20975506e623ba42847 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb3f80afed7d37f080b289dcb4b277aa3367f152 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/context.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bca8683fcd11b8b1e930783155bc9a5418b5b90 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/debug_prims.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e75c9a12d097cce5d2f52711d01d55399a78e5c9 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/executor.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9417d9036cc153aa601ed0e3fa1a1a086aea4b0 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_prims/__pycache__/rng_prims.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..418e1efc5412fd9265bbab9a3ea404c88bb481b7 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/_conversions.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0afd962eac9b59397eebc9f3270c26cc2c1d4bce Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_refs/__pycache__/fft.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py b/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..418691fe24aaa9e5b4263fb884fbb371c478094c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__init__.py @@ -0,0 +1,343 @@ +# mypy: allow-untyped-defs +from functools import partial +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +import torch._refs.linalg as linalg +from torch import Tensor +from torch._prims_common import ( + check_fp_or_complex, + check_is_matrix, + Dim, + DimsType, + ELEMENTWISE_TYPE_PROMOTION_KIND, + IntLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + elementwise_type_promotion_wrapper, + out_wrapper, +) + + +__all__ = [ + "diagonal", + "matrix_norm", + "norm", + "svd", + "svdvals", + "vector_norm", + "vecdot", + "cross", +] + + +def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): + """ + Checks related to the dtype kwarg in `linalg.*norm` functions + """ + if dtype is not None: + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", + ) + torch._check( + utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), + lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( + fn_name=fn_name, + d="complex" if utils.is_complex_dtype(x_dtype) else "real", + dtype=dtype, + ), + ) + torch._check( + utils.get_higher_dtype(dtype, x_dtype) == dtype, + lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " + "without narrowing to the specified dtype ({dtype})", + ) + + +import operator + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition +from torch._decomp.decompositions import pw_cast_for_opmath + + +@register_decomposition(torch._ops.ops.aten.linalg_cross) +@out_wrapper() +@pw_cast_for_opmath +def cross(a: Tensor, b: Tensor, dim: int = -1): + torch._check( + a.ndim == b.ndim, + lambda: "linalg.cross: inputs must have the same number of dimensions.", + ) + torch._check( + a.size(dim) == 3 and b.size(dim) == 3, + lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}", + ) + a, b = torch.broadcast_tensors(a, b) + dim = utils.canonicalize_dim(a.ndim, dim) + idx = torch.arange(3, device=a.device) + return a.index_select(dim, (idx + 1) % 3) * b.index_select( + dim, (idx + 2) % 3 + ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + + +def diagonal( + input: TensorLikeType, + *, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) + + +def _check_vector_norm_args( + x: TensorLikeType, ord: Union[float, int] = 2, dim: Optional[DimsType] = None +): + from torch.fx.experimental.symbolic_shapes import sym_or + + if not (ord < 0.0 or ord == float("inf")): + return + + torch._check( + sym_or( + x.numel() != 0, + not isinstance(dim, IntLike) and dim is not None and len(dim) != 0, + ), + "linalg.vector_norm cannot compute the {ord} norm on an empty tensor " + "because the operation does not have an identity", + ) + + shape = x.shape + if dim is not None and not isinstance(dim, IntLike): + for d in dim: + torch._check( + sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0), + "linalg.vector_norm cannot compute the {ord} norm on the " + f"dimension {d} because this dimension is empty and the " + "operation does not have an identity", + ) + + +@register_decomposition(torch._ops.ops.aten.linalg_vector_norm) +@out_wrapper(exact_dtype=True) +def vector_norm( + x: TensorLikeType, + ord: Union[float, int] = 2, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + from torch.fx.experimental.symbolic_shapes import guard_or_false + + check_fp_or_complex(x.dtype, "linalg.vector_norm") + + if isinstance(dim, Dim): + dim = [dim] # type: ignore[assignment] + + _check_vector_norm_args(x, ord, dim) + + _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") + + computation_dtype, result_dtype = utils.reduction_dtypes( + x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype + ) + + to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype) + + # Implementation + if ord == 0.0: + return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) + elif ord == float("inf"): + return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + elif ord == float("-inf"): + return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + else: + # From here on the computation dtype is important as the reduction is non-trivial + x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment] + reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) + + is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 + if dim == []: + dim = None + + if (dim is None and x.numel() == 1) or ( + dim is not None + and (x.ndim > 0 and all(guard_or_false(x.shape[d] == 1) for d in dim)) + ): + if x.ndim > 64: + raise RuntimeError( + f"Received a tensor with {x.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + x = torch.abs(x) + if keepdim or x.ndim == 0: + return to_result_dtype(x).contiguous() + elif dim is None: + return x.flatten()[0] + else: + new_shape = [s for d, s in enumerate(x.shape) if d not in dim] + return to_result_dtype(x.view(new_shape)).contiguous() + + if not (is_ord_even and utils.is_float_dtype(x.dtype)): + x = torch.abs(x) + return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] + + +def _backshift_permutation(dim0, dim1, ndim): + # Auxiliary function for matrix_norm + # Computes the permutation that moves the two given dimensions to the back + ret = [i for i in range(ndim) if i != dim0 and i != dim1] + ret.extend((dim0, dim1)) + return ret + + +def _inverse_permutation(perm): + # Given a permutation, returns its inverse. It's equivalent to argsort on an array + return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))] + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def matrix_norm( + A: TensorLikeType, + ord: Union[float, str] = "fro", + dim: DimsType = (-2, -1), + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # shape + check_is_matrix(A, "linalg.matrix_norm") + # dim + dim = utils.canonicalize_dims(A.ndim, dim) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" + ) + torch._check( + dim[0] != dim[1], + lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", + ) + # dtype arg + _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") + + if isinstance(ord, str): + # ord + torch._check( + ord in ("fro", "nuc"), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc" + ) + + if ord == "fro": + return vector_norm(A, 2, dim, keepdim, dtype=dtype) + else: # ord == "nuc" + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: + # ord + abs_ord = abs(ord) + torch._check( + abs_ord in (2, 1, float("inf")), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2 + ) + + max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + + if abs_ord == 2.0: + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: # 1, -1, inf, -inf + dim0, dim1 = dim + if abs_ord == float("inf"): + dim0, dim1 = dim1, dim0 + if not keepdim and (dim0 < dim1): + dim1 -= 1 + return max_min( + vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 + ) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def norm( + A: TensorLikeType, + ord: Optional[Union[float, str]] = None, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + if dim is not None: + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) in (1, 2), + lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", + ) + elif ord is not None: + torch._check( + A.ndim in (1, 2), + lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", + ) + + if ord is not None and ( + (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2) + ): + if dim is None: + dim = (0, 1) + return matrix_norm(A, ord, dim, keepdim, dtype=dtype) + else: + if ord is None: + ord = 2.0 + return vector_norm(A, ord, dim, keepdim, dtype=dtype) # type: ignore[arg-type] + + +# CompositeImplicitAutograd +@out_wrapper("U", "S", "Vh", exact_dtype=True) +def svd(A: TensorLikeType, full_matrices: bool = True) -> tuple[Tensor, Tensor, Tensor]: + return prims.svd(A, full_matrices=full_matrices) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def svdvals(A: TensorLikeType) -> Tensor: + return svd(A, full_matrices=False)[1] + + +# CompositeImplicitAutograd +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("x", "y"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: + check_fp_or_complex(x.dtype, "linalg.vecdot") + return (x.conj() * y).sum(dim=dim) diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1edd15a0976cc4efe6f3e718ed7b4f01726fc75 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/nn/__init__.py b/.venv/lib/python3.12/site-packages/torch/_refs/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c2ef67bd9d44a21f9d3673ba631c0840740ced --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_refs/nn/__init__.py @@ -0,0 +1 @@ +__all__: list[str] = [] diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8efafacff928232880f8511db88bc439135eeb8b Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py b/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a54ca2c3debcb66eb42a4f7b2c92cd3d1e9eafa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__init__.py @@ -0,0 +1,1289 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +from functools import wraps +from typing import Callable, Optional, TypeVar, Union +from typing_extensions import Concatenate, ParamSpec + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) +from torch._refs import _make_inplace + + +__all__ = [ + "alpha_dropout", + "celu", + "celu_", + "channel_shuffle", + "dropout", + "elu", + "elu_", + "gelu", + "glu", + "group_norm", + "hardshrink", + "hardtanh", + "hinge_embedding_loss", + "huber_loss", + "l1_loss", + "layer_norm", + "leaky_relu", + "log_softmax", + "margin_ranking_loss", + "mish", + "mish_", + "mse_loss", + "nll_loss", + "pairwise_distance", + "pdist", + "poisson_nll_loss", + "prelu", + "relu", + "relu6", + "selu", + "selu_", + "smooth_l1_loss", + "softmax", + "softmin", + "softplus", + "softshrink", + "tanhshrink", + "threshold", + "threshold_", + "triplet_margin_loss", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +Tensor = torch.Tensor +aten = torch._ops.ops.aten +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + + +def _dropout_helper( + self: TensorLikeType, + val: float, +) -> TensorLikeType: + """ + Helper function for all dropout-type operators. During training, + some of the elements of the input tensor are randomly masked. + + Returns the masked tensor of the boolean values. + + """ + + return ( + refs._uniform_helper( + self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device + ) + < val + ) + + +@register_decomposition(aten.alpha_dropout) +def alpha_dropout( + self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return self + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(self) + + if p == 0: + return self + + dropout_mask = _dropout_helper(self, 1 - p) + + # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) + # alpha = - SELU.alpha * SELU.scale, here + # SELU.alpha = 1.6732632423543772848170429916717 and + # SELU.scale = 1.0507009873554804934193349852946 + alpha = -1.7580993408473766 + + a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) + b = torch.logical_not(dropout_mask) + b = b * (alpha * a) + alpha * a * p + dropout_mask = a * dropout_mask + + return self * dropout_mask + b + + +def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: + """ + Given a nn.functional non-linearity, implements its `inplace: bool` argument + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + a = args[0] + if "inplace" not in kwargs: + kwargs["inplace"] = False + if kwargs["inplace"]: + torch._check( + "out" not in kwargs, + lambda: "Cannot set inplace=True and pass out= at the same time", + ) + kwargs["inplace"] = False + kwargs["out"] = a + return fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + return _fn + + +# celu is implemented specially because it has an alpha argument +# celu is very similar to elu +@register_decomposition(aten.celu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def celu( + a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.celu + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if alpha is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] + else: + rhs = torch.expm1(a) + + return torch.where(a > 0, a, rhs) + + +@_inplace_wrapper +@out_wrapper() +def dropout( + a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return a + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(a) + + if p == 0: + return a + + scale = 1 / (1 - p) + dropout_mask = _dropout_helper(a, 1 - p) + + return a * dropout_mask * scale + + +@register_decomposition(aten.elu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def elu( + a: TensorLikeType, + alpha: NumberType = 1.0, + scale: NumberType = 1.0, + input_scale: NumberType = 1.0, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.elu + """ + if inplace: + raise NotImplementedError + + # nb. This should be factored out into a can_cast aux function + python_type = utils.dtype_to_type(a.dtype) + torch._check( + utils.is_weakly_lesser_type(type(input_scale), python_type), + lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(scale), python_type), + lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + + return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) + + +@register_decomposition(aten.relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu + """ + + if inplace: + raise NotImplementedError + + return torch.where(torch.le(a, 0), 0, a) + + +@register_decomposition(aten.channel_shuffle) +@out_wrapper() +def channel_shuffle(input: TensorLikeType, groups: int) -> TensorLikeType: + """ + Reference implementation of :func:`torch.nn.functional.channel_shuffle`. + """ + from torch._meta_registrations import device_hint + + torch._check( + input.dim() > 2, + lambda: f"channel_shuffle expects input with > 2 dims, but got input with sizes {list(input.size())}", + ) + c = input.shape[1] + torch._check( + groups > 0, + lambda: f"Number of groups to divide channels in must be positive. Value of groups:{groups}", + ) + torch._check( + (c % groups) == 0, + lambda: f"Number of channels must be divisible by groups. Got {c} channels and {groups} groups.", + ) + n = input.shape[0] + cg = c // groups + dhw = input.shape[2:] + + if input.numel() == 0 or ( + device_hint(input) == "cuda" and (groups == 1 or groups == c) + ): + return input.view(input.shape) + + return ( + input.reshape(n, groups, cg, *dhw) + .transpose(1, 2) + .reshape(input.shape) + .contiguous() + ) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + +def layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.layer_norm`. + """ + return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0] + + +@register_decomposition(aten.leaky_relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def leaky_relu( + a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.leaky_relu + """ + + if inplace: + raise NotImplementedError + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(negative_slope), python_type): + msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope)) + + +@register_decomposition(aten.mish) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.mish + """ + + if inplace: + raise NotImplementedError + return a * torch.tanh(torch.nn.functional.softplus(a)) + + +@register_decomposition(aten.selu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.selu + """ + if inplace: + raise NotImplementedError + + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + + rhs = alpha * torch.expm1(a) + + return scale * torch.where(a > 0, a, rhs) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# CompositeImplicitAutograd - don't register decomp +def softmin( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# softplus is implemented specially because it has beta and threshold arguments +@register_decomposition(aten.softplus) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def softplus( + a: TensorLikeType, + beta: Optional[NumberType] = None, + threshold: NumberType = 20, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.softplus + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if beta is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(beta), python_type): + msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + scaled_input = a * beta + rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] + + else: + scaled_input = a + rhs = torch.log1p(torch.exp(scaled_input)) + + return torch.where(scaled_input > threshold, a, rhs) + + +@aten.hardshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.hardshrink) +@out_wrapper() +def hardshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # hardshrink(x) = x if x > lambd + # = x if x < -lambd + # = 0 otherwise + return torch.where(torch.abs(a) <= lambd, 0, a) + + +@aten.softshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.softshrink) +@out_wrapper() +def softshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # softshrink(x) = x - lambd if x > lambd + # = x + lambd if x < -lambd + # = 0 otherwise + torch._check( + lambd >= 0, + lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", + ) + # We implement this in one torch.where to generate better code in the backward + # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211 + # We multiply by 0 for dealing with nans + return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, a * 0) + + +# Losses +def _reduction_int_to_str(reduction: int) -> str: + from torch._decomp.decompositions import Reduction + + if reduction == Reduction.NONE.value: + return "none" + elif reduction == Reduction.MEAN.value: + return "mean" + elif reduction == Reduction.SUM.value: + return "sum" + else: + raise ValueError(f"{reduction} is not a valid value for reduction") + + +def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType: + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def _check_reduction_value(reduction: str): + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + +# This helper function maps depreciated arguments, "size_average" and "reduce" +# to their corresponding "reduction" string argument +def _get_string_reduction_arg( + *, size_average: Optional[bool], reduce: Optional[bool] +) -> str: + if size_average is None: + size_average = True + if reduce is None: + reduce = True + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + return ret + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.abs(input - target) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def smooth_l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.smooth_l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + + if beta == 0.0: + return torch.nn.functional.l1_loss( + input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + else: + loss = torch.abs(input - target) + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return _apply_loss_reduction(loss, reduction) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@register_decomposition(aten.margin_ranking_loss) +def margin_ranking_loss( + input1: TensorLikeType, + input2: TensorLikeType, + target: TensorLikeType, + margin: float = 0.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = max(0, -target * (input1 - input2) + margin) + if input1.ndim != input2.ndim or input1.ndim != target.ndim: + raise RuntimeError( + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} " + ) + _check_reduction_value(reduction) + loss = torch.clamp_min(-target * (input1 - input2) + margin, 0) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def mse_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.pow(input - target, 2) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hinge_embedding_loss) +def hinge_embedding_loss( + input: TensorLikeType, + target: TensorLikeType, + margin: float = 1.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = input if y == 1 + # = max(0, margin - input) if y == -1 + _check_reduction_value(reduction) + margin_clamp = torch.clamp_min(margin - input, 0) + output_margin = torch.where(target != 1, margin_clamp, 0) + output_self = torch.where(target != -1, input, 0) + loss = output_margin + output_self + return _apply_loss_reduction(loss, reduction) + + +def _nll_loss_nd( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType], + reduction: str, + ignore_index: int, +) -> TensorLikeType: + torch._check( + input.ndim > 0 and input.ndim <= 3, + lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", + ) + + torch._check( + (input.ndim == 1) or (input.shape[0] == target.shape[0]), + lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", + ) + + _check_reduction_value(reduction) + + flat_target = torch.flatten(target) + ignore_classes_mask = torch.eq(flat_target, ignore_index) + + # TODO: Enable data-dependent checks with debug mode + # TODO: This check does not work with FakeTensor inputs; See Issue #85834 + # Explicit cast for class_check to bool; See Issue #78071 + """ + from torch._subclasses.fake_tensor import FakeTensor + num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] + valid_classes_mask = torch.logical_and( + (flat_target >= 0), (flat_target < num_classes) + ) + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) + torch._check( + isinstance(target, FakeTensor) or bool(class_check.item()), + lambda: "A target class is out-of-bounds and not the ignore index.", + ) + """ + + ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) + class_weight = ( + torch.scalar_tensor(1, dtype=input.dtype, device=input.device) + if weight is None + else weight[flat_target] + ) + current_weight = torch.where( + ignore_classes_mask, + ignore_class_weight, + class_weight, + ) + + if input.ndim == 1: + # implicit batch size = 1 + # input (1 batch size, C classes) + loss = -input[target] * current_weight + elif input.ndim == 2: + # input (N batch size, C classes) + batch_size = input.shape[0] + loss = -input[torch.arange(batch_size), target] * current_weight + else: + # 3D case (N batch size, C classe, K dimensions) + # input (N batch size, C classes, K) + batch_size = input.shape[0] + extent = input.shape[2] + numel = batch_size * extent + indices = torch.arange(numel) + bdx = indices // extent + kdx = indices % extent + loss = -input[bdx, flat_target, kdx] * current_weight + loss = torch.reshape(loss, target.shape) + + if reduction == "none": + return loss + elif reduction == "sum": + return torch.sum(loss) + else: + # calculate weighted mean of the loss function + return torch.sum(loss) / torch.sum(current_weight) + + +@register_decomposition(aten.nll_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def nll_loss( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.nll_loss + """ + torch._check( + input.ndim > 0, + lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", + ) + + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + # Convert these options for consistency with the eager mode + if size_average is not None or reduce is not None: + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # The expected behavior when the target and input have zero elements: + # reduction = 'none' --- tensor([]) + # reduction = 'sum' --- tensor(0.) + # reduction = 'mean' --- tensor(nan) + # Mean reduction on empty tensors produces NaN. See the discussion in + # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if input.numel() == 0 and target.numel() == 0: + if reduction == "none": + return torch.zeros_like(target) + elif reduction == "sum": + return torch.empty_like(target) + else: + return torch.full_like(target, float("nan")) + + # The _nll_loss_nd helper function handles the most common cases. + # ndim == 1 (Single Example) + # => Batch Size: 1, Input: (C), Target: () + # ndim == 2 (k = 1) + # => Batch Size: N, Input: (N, C), Target: (N) + # ndim == 3 (k > 1) + # => Batch Size: N, Input: (N, C, K), Target: (N, K) + if input.ndim <= 3: + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + + # For ndim > 3, we reshape the input and target to 3-D case. + # Input (N batch-size, C classes, k-dimensions) + # Target (N batch-size, k-dimensions) + torch._check( + input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], + lambda: ( + "Expected input and target to both have ndim > 0 and " + "target.shape[1:] == input.shape[2:], but got " + f"target.shape {target.shape} and input.shape {input.shape}" + ), + ) + + batch_size = input.shape[0] + num_classes = input.shape[1] + out_size = [batch_size] + list(target.shape[1:]) + + input = torch.reshape(input, [batch_size, num_classes, -1]) + target = torch.reshape(target, [batch_size, -1]) + if reduction != "none": + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + else: + result = _nll_loss_nd(input, target, weight, reduction, ignore_index) + # reshape flattened inner-dim to original k-dimensions + return torch.reshape(result, out_size) + + +# TODO: This ref supports int reduction and out kwarg to be compatible with ATen: +# https://github.com/pytorch/pytorch/issues/83931 +# TODO: Could be rewritten to support complex: +# https://github.com/pytorch/pytorch/pull/85041 +@register_decomposition(aten.huber_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def huber_loss( + input: TensorLikeType, + target: TensorLikeType, + reduction: Union[str, int] = "mean", + delta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.huber_loss + """ + if type(reduction) is int: + reduction = _reduction_int_to_str(reduction) + _check_reduction_value(reduction) # type: ignore[arg-type] + torch._check( + delta > 0, + lambda: "huber_loss does not support non-positive values for delta.", + ) + z = (input - target).abs() + loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) + return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type] + + +# tanhshrink does not use _make_elementwise_unary_reference because it does not support out +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def tanhshrink(a: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.tanhshrink + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + return a - torch.tanh(a) + + +@register_decomposition(aten.threshold) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def threshold( + a: TensorLikeType, + threshold: NumberType, + value: Union[bool, int, float], + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.threshold + """ + + if inplace: + raise NotImplementedError + + return torch.where(a <= threshold, value, a) + + +# CompositeImplicitAutograd - don't register decomp +# No elementwise type promotion - core op doesn't explicitly type promote +def triplet_margin_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + if margin <= 0: + raise ValueError(f"margin must be greater than 0, got {margin}") + + # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined + # since it's a pure Python implementation. Use this helper instead. + return _triplet_margin_with_distance_loss( + anchor=anchor, + positive=positive, + negative=negative, + distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps), + margin=margin, + swap=swap, + reduction=reduction, + ) + + +# Pure Python impl - don't register decomp and don't add a ref. Defined as a +# helper here since triplet_margin_loss can be nicely implemented with it. +def _triplet_margin_with_distance_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + *, + distance_function: Optional[ + Callable[[TensorLikeType, TensorLikeType], TensorLikeType] + ] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> TensorLikeType: + _check_reduction_value(reduction) + + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + torch._check( + a_dim == p_dim and p_dim == n_dim, + lambda: ( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ), + ) + + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hardtanh) +@_inplace_wrapper +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def hardtanh( + a: TensorLikeType, + min_val: NumberType = -1, + max_val: NumberType = 1, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.hardtanh + """ + if inplace: + raise NotImplementedError + if utils.is_boolean_dtype(a.dtype): + raise RuntimeError("Bool inputs not supported for hardtanh") + + # preserve legacy behavior of boundaries not causing type promotion + if utils.is_integer_dtype(a.dtype): + min_val = int(min_val) # type: ignore[arg-type] + max_val = int(max_val) # type: ignore[arg-type] + if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)): + raise RuntimeError( + "Cannot do hardtanh on an unsigned type with negative limits" + ) + + if min_val > max_val: # type: ignore[operator] + raise ValueError("min_val cannot be greater than max_val") + + return torch.clamp(a, min_val, max_val) # type: ignore[arg-type] + + +@register_decomposition(aten.gelu) +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.gelu + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + a_cube = a * a * a + inner = kBeta * (a + kKappa * a_cube) + return 0.5 * a * (1 + torch.tanh(inner)) + elif approximate == "none": + kAlpha = M_SQRT1_2 + return a * 0.5 * (1 + torch.erf(a * kAlpha)) + else: + raise RuntimeError("approximate argument must be either none or tanh.") + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def poisson_nll_loss( + input: TensorLikeType, + target: TensorLikeType, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.poisson_nll_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + if log_input: + loss = torch.exp(input) - target * input + else: + loss = input - target * torch.log(input + eps) + + if full: + stirling_term = ( + target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target) + ) + # avoid inplace add + loss = loss + stirling_term.masked_fill(target <= 1, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.prelu) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "weight"), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.prelu + """ + torch._check( + isinstance(a, TensorLike), + lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", + ) + torch._check( + isinstance(weight, TensorLike), + lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", + ) + + if weight.numel() != 1: + torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") + channel_size = a.shape[1] if a.ndim >= 2 else 1 + torch._check( + weight.numel() == channel_size, + lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" + f" {weight.numel()} and channel size = {channel_size}.", + ) + + torch._check( + weight.ndim == 0 or weight.ndim == 1, + lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " + f"ndim = {weight.ndim}", + ) + if a.ndim == 0: + weight = weight[0] if weight.ndim == 1 else weight + else: + weight = prims.broadcast_in_dim( + weight, a.shape, () if weight.ndim == 0 else (0 if a.ndim == 1 else 1,) + ) + + return torch.where(a > 0, a, a * weight) + + +@register_decomposition(aten.relu6) +@_inplace_wrapper +@out_wrapper() +def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu6 + """ + if inplace: + raise NotImplementedError + + # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 + # It may be better to use clamp here, but we use hardtanh to replicate + # the behavior of the existing implementation + return torch.nn.functional.hardtanh(a, 0, 6) + + +@register_decomposition(aten.glu) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: + dim = utils.canonicalize_dims(a.ndim, dim) + torch._check( + a.shape[dim] % 2 == 0, + lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", + ) + b, c = torch.tensor_split(a, 2, dim) + + return b * torch.sigmoid(c) + + +@register_decomposition(aten.pairwise_distance) +@out_wrapper() +def pairwise_distance( + x1: TensorLikeType, + x2: TensorLikeType, + p: NumberType = 2.0, + eps: NumberType = 1e-6, + keepdim=False, +) -> TensorLikeType: + return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim) + + +@register_decomposition(aten.pdist) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: + torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") + torch._check(p >= 0, lambda: "pdist only supports non-negative p values") + # For p == 2 we can use an efficient implementation, but other values of p + # require creating a much bigger tensor for an intermediate step + if p == 2: + aTa = torch.mm(a, a.T) + aTa_diag = torch.diag(aTa) + t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) + else: + t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) + i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) + return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) + + +@register_decomposition(aten.pixel_shuffle) +@out_wrapper() +def pixel_shuffle(self: Tensor, upscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] // upscale_factor**2 + HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5) + return ( + self.view( + *batch, + C_out, + upscale_factor, + upscale_factor, + self.shape[-2], + self.shape[-1], + ) + .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +@register_decomposition(aten.pixel_unshuffle) +@out_wrapper() +def pixel_unshuffle(self: Tensor, downscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] * downscale_factor**2 + HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5) + return ( + self.view( + *batch, + self.shape[-3], + HW_out[0], + downscale_factor, + HW_out[1], + downscale_factor, + ) + .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) +celu_ = _make_inplace(celu) +elu_ = _make_inplace(elu) +mish_ = _make_inplace(mish) +selu_ = _make_inplace(selu) +threshold_ = _make_inplace(threshold) diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa67dce08698cefe5dfe513e48bc683b58abdbd Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/special/__init__.py b/.venv/lib/python3.12/site-packages/torch/_refs/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de11bee923c9a4510f4e6cef45949e29236bc8d4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_refs/special/__init__.py @@ -0,0 +1,236 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch import Tensor +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper +from torch._refs import ( + _make_alias, + _make_elementwise_binary_reference, + _make_elementwise_unary_reference, +) + + +__all__ = [ + "bessel_j0", + "bessel_j1", + "entr", + "erfcx", + "expit", + "i0e", + "i1", + "i1e", + "log_ndtr", + "logit", + "log_softmax", + "multigammaln", + "ndtr", + "ndtri", + "softmax", + "spherical_bessel_j0", + "xlog1py", + "zeta", +] +aten = torch._ops.ops.aten + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j0(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j1(a) + + +@register_decomposition(aten.special_entr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def entr(a: TensorLikeType) -> TensorLikeType: + return torch.where( + torch.isnan(a), + a, + torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)), + ) + + +@register_decomposition(aten.special_erfcx) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def erfcx(a: TensorLikeType) -> TensorLikeType: + return prims.erfcx(a) + + +# alias for sigmoid +expit = _make_alias(torch.sigmoid, "expit") + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i0e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i0e(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1e(a) + + +@register_decomposition(aten.special_log_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def log_ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / sqrt(2) + M_SQRT1_2 = 0.707106781186547524400844362104849039 + t = a * M_SQRT1_2 + return torch.where( + a < 1.0, + torch.log(torch.special.erfcx(-t) / 2) - t * t, + torch.log1p(-torch.erfc(t) / 2), + ) + + +@register_decomposition(aten.logit) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: + if eps is None: + eps = -1.0 + lo = eps + hi = 1 - eps + self = torch.where(self < lo, lo, torch.where(self > hi, hi, self)) + return torch.log(torch.true_divide(self, torch.sub(1, self))) + + +@register_decomposition(aten.special_xlog1py) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(a, TensorLike) and isinstance(b, Number): + b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@register_decomposition(aten.mvlgamma) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType: + c = 0.25 * p * (p - 1) * math.log(math.pi) + b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device) + return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c + + +@register_decomposition(aten.special_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / sqrt(2) + M_SQRT1_2 = 0.707106781186547524400844362104849039 + a_sqrt_2 = a * M_SQRT1_2 + return (1 + torch.erf(a_sqrt_2)) * 0.5 + + +@register_decomposition(aten.special_ndtri) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtri(a: TensorLikeType) -> TensorLikeType: + return prims.ndtri(a) + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.spherical_bessel_j0(a) + + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.zeta(a, b) diff --git a/.venv/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..393bbbd521ce7eb37beaa526a7c75ed4e4809ce3 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_refs/special/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20aeea14dc51a8690373334ed9b0b6a83a01c2b Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8b36a892b25d05521dcedfc218648f9fab64dda Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843ec56f72e7af2e32b945838b4bc484eb427eb9 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0e3a9aa9b2309243bd0b58ae7269a18bbd16609 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6e556bd475d0c2c3b9fcd5c232ba7f12418f113 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b3b850a91592bf3ac731febbd0737ab8f2f0462 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f385517e494afadf541bf6282cad85df0a1913f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fea40870db211af67c00f3fbcf8f674861bd1d90 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_vendor/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22809cfd5dc25792d77070c269fc8d111a12eed0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__init__.py @@ -0,0 +1,15 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +__title__ = "packaging" +__summary__ = "Core utilities for Python packages" +__uri__ = "https://github.com/pypa/packaging" + +__version__ = "23.2" + +__author__ = "Donald Stufft and individual contributors" +__email__ = "donald@stufft.io" + +__license__ = "BSD-2-Clause or Apache-2.0" +__copyright__ = "2014 %s" % __author__ diff --git a/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2d857d1245bd07564e6036f32c1c22457ef7ad3 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad130f15b8da9252d70b85b130faf7c7bffbd406 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c1e7e6a73f3c42d7d03e657ff62c3d1cfb859a2 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..90a6465f9682c886363eea5327dac64bf623a6ff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/_structures.py @@ -0,0 +1,61 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + + +class InfinityType: + def __repr__(self) -> str: + return "Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return False + + def __le__(self, other: object) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return True + + def __ge__(self, other: object) -> bool: + return True + + def __neg__(self: object) -> "NegativeInfinityType": + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self) -> str: + return "-Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return True + + def __le__(self, other: object) -> bool: + return True + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return False + + def __ge__(self, other: object) -> bool: + return False + + def __neg__(self: object) -> InfinityType: + return Infinity + + +NegativeInfinity = NegativeInfinityType() diff --git a/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/version.py b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/version.py new file mode 100644 index 0000000000000000000000000000000000000000..5faab9bd0dcf28847960162b2b4f13a8a556ef20 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_vendor/packaging/version.py @@ -0,0 +1,563 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. +""" +.. testsetup:: + + from packaging.version import parse, Version +""" + +import itertools +import re +from typing import Any, Callable, NamedTuple, Optional, SupportsInt, Tuple, Union + +from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType + +__all__ = ["VERSION_PATTERN", "parse", "Version", "InvalidVersion"] + +LocalType = Tuple[Union[int, str], ...] + +CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]] +CmpLocalType = Union[ + NegativeInfinityType, + Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...], +] +CmpKey = Tuple[ + int, + Tuple[int, ...], + CmpPrePostDevType, + CmpPrePostDevType, + CmpPrePostDevType, + CmpLocalType, +] +VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool] + + +class _Version(NamedTuple): + epoch: int + release: Tuple[int, ...] + dev: Optional[Tuple[str, int]] + pre: Optional[Tuple[str, int]] + post: Optional[Tuple[str, int]] + local: Optional[LocalType] + + +def parse(version: str) -> "Version": + """Parse the given version string. + + >>> parse('1.0.dev1') + + + :param version: The version string to parse. + :raises InvalidVersion: When the version string is not a valid version. + """ + return Version(version) + + +class InvalidVersion(ValueError): + """Raised when a version string is not a valid version. + + >>> Version("invalid") + Traceback (most recent call last): + ... + packaging.version.InvalidVersion: Invalid version: 'invalid' + """ + + +class _BaseVersion: + _key: Tuple[Any, ...] + + def __hash__(self) -> int: + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +_VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?Palpha|a|beta|b|preview|pre|c|rc)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+VERSION_PATTERN = _VERSION_PATTERN
+"""
+A string containing the regular expression used to match a valid version.
+
+The pattern is not anchored at either end, and is intended for embedding in larger
+expressions (for example, matching a version number as part of a file name). The
+regular expression should be compiled with the ``re.VERBOSE`` and ``re.IGNORECASE``
+flags set.
+
+:meta hide-value:
+"""
+
+
+class Version(_BaseVersion):
+    """This class abstracts handling of a project's versions.
+
+    A :class:`Version` instance is comparison aware and can be compared and
+    sorted using the standard Python interfaces.
+
+    >>> v1 = Version("1.0a5")
+    >>> v2 = Version("1.0")
+    >>> v1
+    
+    >>> v2
+    
+    >>> v1 < v2
+    True
+    >>> v1 == v2
+    False
+    >>> v1 > v2
+    False
+    >>> v1 >= v2
+    False
+    >>> v1 <= v2
+    True
+    """
+
+    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+    _key: CmpKey
+
+    def __init__(self, version: str) -> None:
+        """Initialize a Version object.
+
+        :param version:
+            The string representation of a version which will be parsed and normalized
+            before use.
+        :raises InvalidVersion:
+            If the ``version`` does not conform to PEP 440 in any way then this
+            exception will be raised.
+        """
+
+        # Validate the version and parse it into pieces
+        match = self._regex.search(version)
+        if not match:
+            raise InvalidVersion(f"Invalid version: '{version}'")
+
+        # Store the parsed out pieces of the version
+        self._version = _Version(
+            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+            release=tuple(int(i) for i in match.group("release").split(".")),
+            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
+            post=_parse_letter_version(
+                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
+            ),
+            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
+            local=_parse_local_version(match.group("local")),
+        )
+
+        # Generate a key which will be used for sorting
+        self._key = _cmpkey(
+            self._version.epoch,
+            self._version.release,
+            self._version.pre,
+            self._version.post,
+            self._version.dev,
+            self._version.local,
+        )
+
+    def __repr__(self) -> str:
+        """A representation of the Version that shows all internal state.
+
+        >>> Version('1.0.0')
+        
+        """
+        return f""
+
+    def __str__(self) -> str:
+        """A string representation of the version that can be rounded-tripped.
+
+        >>> str(Version("1.0a5"))
+        '1.0a5'
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        # Pre-release
+        if self.pre is not None:
+            parts.append("".join(str(x) for x in self.pre))
+
+        # Post-release
+        if self.post is not None:
+            parts.append(f".post{self.post}")
+
+        # Development release
+        if self.dev is not None:
+            parts.append(f".dev{self.dev}")
+
+        # Local version segment
+        if self.local is not None:
+            parts.append(f"+{self.local}")
+
+        return "".join(parts)
+
+    @property
+    def epoch(self) -> int:
+        """The epoch of the version.
+
+        >>> Version("2.0.0").epoch
+        0
+        >>> Version("1!2.0.0").epoch
+        1
+        """
+        return self._version.epoch
+
+    @property
+    def release(self) -> Tuple[int, ...]:
+        """The components of the "release" segment of the version.
+
+        >>> Version("1.2.3").release
+        (1, 2, 3)
+        >>> Version("2.0.0").release
+        (2, 0, 0)
+        >>> Version("1!2.0.0.post0").release
+        (2, 0, 0)
+
+        Includes trailing zeroes but not the epoch or any pre-release / development /
+        post-release suffixes.
+        """
+        return self._version.release
+
+    @property
+    def pre(self) -> Optional[Tuple[str, int]]:
+        """The pre-release segment of the version.
+
+        >>> print(Version("1.2.3").pre)
+        None
+        >>> Version("1.2.3a1").pre
+        ('a', 1)
+        >>> Version("1.2.3b1").pre
+        ('b', 1)
+        >>> Version("1.2.3rc1").pre
+        ('rc', 1)
+        """
+        return self._version.pre
+
+    @property
+    def post(self) -> Optional[int]:
+        """The post-release number of the version.
+
+        >>> print(Version("1.2.3").post)
+        None
+        >>> Version("1.2.3.post1").post
+        1
+        """
+        return self._version.post[1] if self._version.post else None
+
+    @property
+    def dev(self) -> Optional[int]:
+        """The development number of the version.
+
+        >>> print(Version("1.2.3").dev)
+        None
+        >>> Version("1.2.3.dev1").dev
+        1
+        """
+        return self._version.dev[1] if self._version.dev else None
+
+    @property
+    def local(self) -> Optional[str]:
+        """The local version segment of the version.
+
+        >>> print(Version("1.2.3").local)
+        None
+        >>> Version("1.2.3+abc").local
+        'abc'
+        """
+        if self._version.local:
+            return ".".join(str(x) for x in self._version.local)
+        else:
+            return None
+
+    @property
+    def public(self) -> str:
+        """The public portion of the version.
+
+        >>> Version("1.2.3").public
+        '1.2.3'
+        >>> Version("1.2.3+abc").public
+        '1.2.3'
+        >>> Version("1.2.3+abc.dev1").public
+        '1.2.3'
+        """
+        return str(self).split("+", 1)[0]
+
+    @property
+    def base_version(self) -> str:
+        """The "base version" of the version.
+
+        >>> Version("1.2.3").base_version
+        '1.2.3'
+        >>> Version("1.2.3+abc").base_version
+        '1.2.3'
+        >>> Version("1!1.2.3+abc.dev1").base_version
+        '1!1.2.3'
+
+        The "base version" is the public version of the project without any pre or post
+        release markers.
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        return "".join(parts)
+
+    @property
+    def is_prerelease(self) -> bool:
+        """Whether this version is a pre-release.
+
+        >>> Version("1.2.3").is_prerelease
+        False
+        >>> Version("1.2.3a1").is_prerelease
+        True
+        >>> Version("1.2.3b1").is_prerelease
+        True
+        >>> Version("1.2.3rc1").is_prerelease
+        True
+        >>> Version("1.2.3dev1").is_prerelease
+        True
+        """
+        return self.dev is not None or self.pre is not None
+
+    @property
+    def is_postrelease(self) -> bool:
+        """Whether this version is a post-release.
+
+        >>> Version("1.2.3").is_postrelease
+        False
+        >>> Version("1.2.3.post1").is_postrelease
+        True
+        """
+        return self.post is not None
+
+    @property
+    def is_devrelease(self) -> bool:
+        """Whether this version is a development release.
+
+        >>> Version("1.2.3").is_devrelease
+        False
+        >>> Version("1.2.3.dev1").is_devrelease
+        True
+        """
+        return self.dev is not None
+
+    @property
+    def major(self) -> int:
+        """The first item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").major
+        1
+        """
+        return self.release[0] if len(self.release) >= 1 else 0
+
+    @property
+    def minor(self) -> int:
+        """The second item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").minor
+        2
+        >>> Version("1").minor
+        0
+        """
+        return self.release[1] if len(self.release) >= 2 else 0
+
+    @property
+    def micro(self) -> int:
+        """The third item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").micro
+        3
+        >>> Version("1").micro
+        0
+        """
+        return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(
+    letter: Optional[str], number: Union[str, bytes, SupportsInt, None]
+) -> Optional[Tuple[str, int]]:
+
+    if letter:
+        # We consider there to be an implicit 0 in a pre-release if there is
+        # not a numeral associated with it.
+        if number is None:
+            number = 0
+
+        # We normalize any letters to their lower case form
+        letter = letter.lower()
+
+        # We consider some words to be alternate spellings of other words and
+        # in those cases we want to normalize the spellings to our preferred
+        # spelling.
+        if letter == "alpha":
+            letter = "a"
+        elif letter == "beta":
+            letter = "b"
+        elif letter in ["c", "pre", "preview"]:
+            letter = "rc"
+        elif letter in ["rev", "r"]:
+            letter = "post"
+
+        return letter, int(number)
+    if not letter and number:
+        # We assume if we are given a number, but we are not given a letter
+        # then this is using the implicit post release syntax (e.g. 1.0-1)
+        letter = "post"
+
+        return letter, int(number)
+
+    return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local: Optional[str]) -> Optional[LocalType]:
+    """
+    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+    """
+    if local is not None:
+        return tuple(
+            part.lower() if not part.isdigit() else int(part)
+            for part in _local_version_separators.split(local)
+        )
+    return None
+
+
+def _cmpkey(
+    epoch: int,
+    release: Tuple[int, ...],
+    pre: Optional[Tuple[str, int]],
+    post: Optional[Tuple[str, int]],
+    dev: Optional[Tuple[str, int]],
+    local: Optional[LocalType],
+) -> CmpKey:
+
+    # When we compare a release version, we want to compare it with all of the
+    # trailing zeros removed. So we'll use a reverse the list, drop all the now
+    # leading zeros until we come to something non zero, then take the rest
+    # re-reverse it back into the correct order and make it a tuple and use
+    # that for our sorting key.
+    _release = tuple(
+        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
+    )
+
+    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+    # We'll do this by abusing the pre segment, but we _only_ want to do this
+    # if there is not a pre or a post segment. If we have one of those then
+    # the normal sorting rules will handle this case correctly.
+    if pre is None and post is None and dev is not None:
+        _pre: CmpPrePostDevType = NegativeInfinity
+    # Versions without a pre-release (except as noted above) should sort after
+    # those with one.
+    elif pre is None:
+        _pre = Infinity
+    else:
+        _pre = pre
+
+    # Versions without a post segment should sort before those with one.
+    if post is None:
+        _post: CmpPrePostDevType = NegativeInfinity
+
+    else:
+        _post = post
+
+    # Versions without a development segment should sort after those with one.
+    if dev is None:
+        _dev: CmpPrePostDevType = Infinity
+
+    else:
+        _dev = dev
+
+    if local is None:
+        # Versions without a local segment should sort before those with one.
+        _local: CmpLocalType = NegativeInfinity
+    else:
+        # Versions with a local segment need that segment parsed to implement
+        # the sorting rules in PEP440.
+        # - Alpha numeric segments sort before numeric segments
+        # - Alpha numeric segments sort lexicographically
+        # - Numeric segments sort numerically
+        # - Shorter versions sort before longer versions when the prefixes
+        #   match exactly
+        _local = tuple(
+            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
+        )
+
+    return epoch, _release, _pre, _post, _dev, _local
diff --git a/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1358e55211fe08e58aebdd28f97193ca10c2211f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/autocast_mode.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/autocast_mode.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..447f61d000b18813dc399db55ac67da47cadb81f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/autocast_mode.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/grad_scaler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/grad_scaler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7abc9d83b5e94746262e723097d240b8b3e91de2
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/amp/__pycache__/grad_scaler.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a249b42aea78b19b6cdeb96255f03b9fcda960b7
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/ns/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee840e1c935114e3d1fc42e0d0e866730136c960
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/ns/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..371989e7ac669f9ff856a7419730df6dd9b42098
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..904bda8b22d89e149551be167e0c62a2eb47e00a
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/ns_types.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0338443863b10dc7c2ba058de7b1534aead6275e
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/__pycache__/utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/graph_matcher.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/graph_matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f9c873971a3db16f6af6f576219baae3a865548
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/graph_matcher.py
@@ -0,0 +1,472 @@
+# mypy: allow-untyped-defs
+import collections
+import enum
+from typing import Any, Optional
+
+import torch
+from torch.ao.quantization import FakeQuantizeBase, ObserverBase
+from torch.ao.quantization.utils import getattr_from_fqn
+from torch.fx import GraphModule
+from torch.fx.graph import Graph, Node
+
+from .mappings import get_base_name_to_sets_of_related_ops, get_unmatchable_types_map
+from .ns_types import NSNodeTargetType, NSSubgraph
+from .pattern_utils import (
+    end_node_matches_reversed_fusion,
+    get_reversed_fusions,
+    get_type_a_related_to_b,
+)
+
+
+toq = torch.ops.quantized
+
+
+def _get_output_nodes(g: Graph) -> list[Node]:
+    return [n for n in g.nodes if n.op == "output"]
+
+
+class _NSGraphMatchableSubgraphsIterator:
+    """
+    Iterates through the graph of gm, starting with the output nodes
+    and continuing backwards.
+    1. Returns matchable subgraphs, in order. A subgraph is defined by
+       (start_node, end_node).
+    2. Skips over non-matchable subgraphs
+    """
+
+    def __init__(
+        self,
+        gm: GraphModule,
+        non_matchable_functions: set[NSNodeTargetType],
+        non_matchable_modules: set[NSNodeTargetType],
+        non_matchable_methods: set[NSNodeTargetType],
+    ):
+        self.gm: GraphModule = gm
+        self.non_matchable_functions: set[NSNodeTargetType] = non_matchable_functions
+        self.non_matchable_modules: set[NSNodeTargetType] = non_matchable_modules
+        self.non_matchable_methods: set[NSNodeTargetType] = non_matchable_methods
+        self.seen_nodes: set[Node] = set()
+        self.stack: list[Node] = []
+        for start_node in _get_output_nodes(self.gm.graph):
+            self.stack.append(start_node)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self) -> NSSubgraph:
+        """
+        Returns the next matchable subgraph.
+        """
+        while len(self.stack) > 0:
+            cur_end_node = self.stack.pop()
+            if cur_end_node in self.seen_nodes:
+                continue
+
+            # for subgraphs which are single nodes, start_node == end_node
+            # for subgraphs with more than one node, start node != end_node
+            cur_start_node = cur_end_node
+            # Subgraphs like linear-relu have the base node as the start node.
+            # Subgraphs like dequantize-linear-relu-to(torch.float16) have the
+            #   base node as the second node.
+            # The cur_base_op_node var will move to the actual node during
+            #   the fusion matching later in this code block.
+            cur_base_op_node = cur_end_node
+
+            # Check for potential fusions. For now, we are greedy
+            # and always skip all non-base nodes of a fusion.  For example,
+            # if we match linear-relu backwards, we will always skip the
+            # relu node and attempt to match the linear node.  This can
+            # be made configurable later if needed.
+            for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
+                is_match = end_node_matches_reversed_fusion(
+                    cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes
+                )
+                if is_match:
+                    # navigate to the base node
+                    for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
+                        self.seen_nodes.add(cur_start_node)
+                        # for now, assume that there are no other nodes
+                        # which need to be added to the stack
+                        cur_start_node = cur_start_node.args[0]  # type: ignore[assignment]
+                        # if the base op index matches the current node, set it
+                        rev_base_op_idx = len(_reverse_fusion_ops) - 2 - base_op_idx
+                        if rev_fusion_idx == rev_base_op_idx:
+                            cur_base_op_node = cur_start_node
+                    break
+
+            self.seen_nodes.add(cur_start_node)
+            # add args of previous nodes to stack
+            for arg in cur_start_node.all_input_nodes:
+                self._recursively_add_node_arg_to_stack(arg)
+
+            # skip unmatchable nodes
+            # note: this check is done on the start_node, i.e.
+            # if we are matching linear-relu in reverse, this would do the matchable
+            # check on the linear
+            if not self._is_matchable(cur_base_op_node):
+                continue
+
+            # If an observer or a fake_quant was not matched as a part of
+            # a pattern of multiple nodes, ignore it. One case where this is
+            # relevant is an observer on a graph input, which was added because
+            # it is necessary for the next node.
+            if cur_end_node.op == "call_module" and cur_start_node is cur_end_node:
+                maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target)  # type: ignore[arg-type]
+                if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
+                    continue
+
+            return NSSubgraph(
+                start_node=cur_start_node,
+                end_node=cur_end_node,
+                base_op_node=cur_base_op_node,
+            )
+
+        raise StopIteration
+
+    def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
+        """
+        Adds all of the nodes in this arg to the stack, properly navigating
+        through list, dicts and tuples.
+        """
+        if isinstance(arg, Node):
+            self.stack.append(arg)
+        elif (
+            isinstance(arg, torch.fx.immutable_collections.immutable_list)
+            or type(arg) is tuple
+        ):
+            for inner_arg in arg:
+                self._recursively_add_node_arg_to_stack(inner_arg)
+        elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
+            for value in arg.values():
+                self._recursively_add_node_arg_to_stack(value)
+
+    def _is_matchable(self, node: Node) -> bool:
+        if node.op == "call_function":
+            return node.target not in self.non_matchable_functions
+        elif node.op == "call_module":
+            assert isinstance(node.target, str)
+            target_mod = getattr_from_fqn(self.gm, node.target)
+            return not any(
+                isinstance(target_mod, t)  # type: ignore[arg-type]
+                for t in self.non_matchable_modules
+            )
+        elif node.op == "call_method":
+            return node.target not in self.non_matchable_methods
+        else:
+            return False
+
+
+class GraphMatchingException(Exception):
+    """
+    Exception raised when two graphs cannot be matched.
+    """
+
+
+class SubgraphTypeRelationship(enum.Enum):
+    # same type, known
+    # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
+    EQUAL = enum.auto()
+    # same type, but the type is not known to Numerical Suite
+    # (user defined type, etc).
+    EQUAL_BUT_UKNOWN = enum.auto()
+    # known, same subgraph_relationship set, but not the same type
+    # example: F.linear and toq.linear
+    RELATED_BUT_NOT_EQUAL = enum.auto()
+    # not related
+    NOT_RELATED = enum.auto()
+
+
+def _get_subgraph_relationship_type(
+    subgraph_a: NSSubgraph,
+    subgraph_b: NSSubgraph,
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    type_a_related_to_b: set[tuple[NSNodeTargetType, NSNodeTargetType]],
+) -> SubgraphTypeRelationship:
+    node_a = subgraph_a.base_op_node
+    node_b = subgraph_b.base_op_node
+
+    # TODO(next): make this code handle matching by what is before the base op
+    if node_a.op != node_b.op:
+        if not (
+            node_a.op in ("call_function", "call_method")
+            and node_b.op in ("call_function", "call_method")
+        ):
+            return SubgraphTypeRelationship.NOT_RELATED
+
+    if node_a.op in ("call_function", "call_method"):
+        key = (node_a.target, node_b.target)
+
+        if key not in type_a_related_to_b:
+            if node_a.target == node_b.target:
+                return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
+            else:
+                return SubgraphTypeRelationship.NOT_RELATED
+        # after this point, we are dealing with known types
+
+        if node_a.target == node_b.target:
+            node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
+            node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
+            if node_a_has_prev and (not node_b_has_prev):
+                return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+            elif (not node_a_has_prev) and node_b_has_prev:
+                return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+            elif (not node_a_has_prev) and (not node_b_has_prev):
+                return SubgraphTypeRelationship.EQUAL
+            else:
+                # TODO(future PR): check for matches start_op_node and base_op_node
+                return SubgraphTypeRelationship.EQUAL
+
+        if key in type_a_related_to_b:
+            return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+        else:
+            return SubgraphTypeRelationship.NOT_RELATED
+    elif node_a.op == "call_module":
+        assert (
+            subgraph_a.base_op_node == subgraph_a.start_node
+            and subgraph_b.base_op_node == subgraph_b.start_node
+        ), (
+            "Matching call_module patterns where base_op_node != start_node is not supported yet"
+        )
+        # for call_module, we need to look up the modules to do the type check
+        assert isinstance(node_a.target, str)
+        mod_a = getattr_from_fqn(gm_a, node_a.target)
+        assert isinstance(node_b.target, str)
+        mod_b = getattr_from_fqn(gm_b, node_b.target)
+
+        key = (type(mod_a), type(mod_b))
+
+        if key not in type_a_related_to_b:
+            if type(mod_a) == type(mod_b):
+                return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
+            else:
+                return SubgraphTypeRelationship.NOT_RELATED
+        elif type(mod_a) == type(mod_b):
+            return SubgraphTypeRelationship.EQUAL
+        else:
+            return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
+
+    return SubgraphTypeRelationship.NOT_RELATED
+
+
+def _get_name_for_subgraph(
+    subgraph_a: NSSubgraph,
+    gm_a: GraphModule,
+    base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]],
+    existing_names: set[str],
+) -> str:
+    """
+    Returns a unique name for a subgraph. This name is based on two things:
+    1. the name of the set containing the underlying type of the base op in the
+       subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
+    2. the number of previous subgraphs with related underlying type of the base op
+
+    For example, in the graph
+
+    linear0 -> relu0 -> linear1 -> relu1
+
+    The subgraphs are (linear0, relu0) and (linear1, relu1).  If we iterate
+    from the output node backwards, the name given to (linear1, relu1) will be
+    `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
+    will be `base_op_torch.nn.functional.linear_1`.
+
+    Why are we not just using the node name? Answer: because of two requirements:
+    A. fusions must be supported
+    B. some Numeric Suite APIs can be called without having all of the models in memory
+
+    For example, let's say we need to match nodes of
+
+    (1) ... -> linear0 -> relu0 -> ...
+
+    And
+
+    (2) ... -> linear_relu0 -> ...
+
+    Without being able to inspect them together. With the current naming scheme, if
+    we iterate through both of these graphs in the same order, and assuming the rest
+    of the graphs match, both of these subgraphs will get the same name without
+    (1) and (2) knowing anything about each other.
+    """
+    target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
+    target_base_type = None
+    for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
+        if target_type in sets_of_related_ops:
+            target_base_type = base_name
+    target_base_name = "base_op_" + str(target_base_type)
+    counter = 0
+    proposed_name = target_base_name + "_" + str(counter)
+    while proposed_name in existing_names:
+        counter += 1
+        proposed_name = target_base_name + "_" + str(counter)
+    existing_names.add(proposed_name)
+    return proposed_name
+
+
+def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
+    if node.op in ("call_function", "call_method"):
+        return node.target
+    elif node.op == "call_module":
+        assert isinstance(node.target, str)
+        mod = getattr_from_fqn(gm, node.target)
+        return type(mod)
+    return None
+
+
+def get_matching_subgraph_pairs(
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None,
+    unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None,
+) -> dict[str, tuple[NSSubgraph, NSSubgraph]]:
+    """
+    Matches matchable subgraphs of graph_a to graph_b.
+
+    For a node, "matchable" is defined as a node which is not an observer,
+    fake_quants, quant or dequant.
+
+    A subgraph can contain one or more nodes.  A subgraph is matchable if
+    at least one node inside of it is matchable.  Currently, all nodes in
+    a subgraph must be matchable (because we assume no observers will be
+    inserted in the middle of a fusion).
+
+    A subgraph is defined by (start_node, end_node).  We assume that only
+    start_node and end_node are linked with the surrounding graph, all other
+    nodes in a subgraph are self-contained.
+
+    A pair of nodes is "related" if both nodes represent the same mathematical
+    operation across different quantization flavors. For example,
+    `F.linear` and `torch.ops.quantized.linear` are related, and
+    `F.linear` and `torch.nn.Conv` are not related.
+
+    For each matchable pair of nodes node_a and node_b, they will match
+    if node_a and node_b are related.
+
+    For graphs A and B, they will match iff:
+    1. the number of matchable subgraphs in A and B is equivalent
+    2. when iterating through the matchable subgraphs of A and B in the same order, each
+       corresponding pair of base nodes is related.
+
+    This enables us to find the corresponding subgraphs between
+    graphs of related models.  For example, if we had two graphs such as:
+
+    graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
+             w  -/
+             b  -/
+
+    graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
+           packed_params_0 -/
+
+    This function will return the following result:
+    {
+        'conv_0': (  # the name of the node in graph_b
+          (conv_0, conv_0),  # (start_node_a, end_node_a)
+          (qconv_0, qconv_0),  # (start_node_b, end_node_b)
+        ),
+    }
+
+    Or, if we have a fusion pattern,
+
+    graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
+             w  -/
+             b  -/
+
+    graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
+           packed_params_0 -/
+
+    This function will return the following result:
+    {
+        'linear_relu_0': (  # the name of the node in graph_b
+          (linear_0, relu_0),  # (start_node_a, end_node_a)
+          (linear_relu_0, linear_relu_0),  # (start_node_b, end_node_b)
+        ),
+    }
+    """
+    if unmatchable_types_map is None:
+        unmatchable_types_map = get_unmatchable_types_map()
+    non_matchable_functions = unmatchable_types_map["funs_unmatchable"]
+    non_matchable_modules = unmatchable_types_map["mods_unmatchable"]
+    non_matchable_methods = unmatchable_types_map["meths_unmatchable"]
+
+    graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
+        gm_a, non_matchable_functions, non_matchable_modules, non_matchable_methods
+    )
+    graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
+        gm_b, non_matchable_functions, non_matchable_modules, non_matchable_methods
+    )
+    results = collections.OrderedDict()
+    if base_name_to_sets_of_related_ops is None:
+        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
+    type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops)
+
+    existing_names_a: set[str] = set()
+    existing_names_b: set[str] = set()
+
+    while True:
+        # fetch the next subgraphs from a and b
+        cur_subgraph_a, cur_subgraph_b = None, None
+        try:
+            cur_subgraph_a = next(graph_a_iterator)
+        except StopIteration:
+            pass
+        try:
+            cur_subgraph_b = next(graph_b_iterator)
+        except StopIteration:
+            pass
+
+        # look up types of a and b for useful error messages
+        type_start_a, type_start_b = None, None
+        if cur_subgraph_a is not None:
+            type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
+        if cur_subgraph_b is not None:
+            type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
+
+        # check for results and determine what to do next
+        if cur_subgraph_a is not None and cur_subgraph_b is not None:
+            # both nodes were fetched, check for subgraph_relationship
+            # note: subgraph_relationship is checked on the start node, i.e.
+            # if a linear-relu pattern is checked, we would check for subgraph_relationship
+            # of the linear
+            subgraph_relationship = _get_subgraph_relationship_type(
+                cur_subgraph_a, cur_subgraph_b, gm_a, gm_b, type_a_related_to_b
+            )
+            if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
+                msg = f"""
+The subgraphs
+({cur_subgraph_a}, {type_start_a}) and
+({cur_subgraph_b}, {type_start_b})
+are not related. Please ensure that the two models you pass in have the same number
+of subgraphs, and each pair of subgraphs is related to each other."""
+                raise GraphMatchingException(msg)
+            elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
+                # skip matching but unknown types
+                continue
+            key_name_a = _get_name_for_subgraph(
+                cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops, existing_names_a
+            )
+            key_name_b = _get_name_for_subgraph(
+                cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
+            )
+            assert key_name_a == key_name_b, (
+                f"Subgraph names {key_name_a} and {key_name_b} do not match"
+            )
+            results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
+            continue
+        elif cur_subgraph_a is None and cur_subgraph_b is None:
+            # we reached the end of both graphs
+            break
+        else:
+            # only one node was fetched, no match possible, throw error
+            msg = f"""
+Attempting to match
+({cur_subgraph_a}, {type_start_a}) and
+({cur_subgraph_b}, {type_start_b}),
+one of which is empty. Please ensure that the two models you pass in have the same number
+of subgraphs."""
+            raise GraphMatchingException(msg)
+
+    # The subgraph pairs are originally created by traversing the two graphs
+    # from the outputs to the inputs. Reverse the results to return the
+    # subgraphs in their order of execution.
+    results = collections.OrderedDict(reversed(results.items()))
+
+    return results
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/graph_passes.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/graph_passes.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7a4dfd3e5b2f63d4191679be85acf6f3314a32
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/graph_passes.py
@@ -0,0 +1,1133 @@
+# mypy: allow-untyped-defs
+from typing import Any, Callable, Optional, Union
+
+import torch
+from torch.ao.ns.fx.mappings import get_node_type_to_io_type_map
+from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
+from torch.ao.quantization.observer import _is_activation_post_process
+from torch.fx import GraphModule, map_arg
+from torch.fx.graph import Graph, Node
+
+from .ns_types import NSNodeTargetType, NSSingleResultValuesType, NSSubgraph
+from .utils import (
+    get_arg_indices_of_inputs_to_log,
+    get_node_first_input_and_output_type,
+    get_node_input_qparams,
+    get_normalized_nth_input,
+    get_number_of_non_param_args,
+    get_target_type_str,
+    getattr_from_fqn,
+    NodeInputOrOutputType,
+    op_type_supports_shadowing,
+    return_first_non_observer_node,
+)
+
+
+def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
+    fqn = None
+    if hasattr(gm, "_node_name_to_scope"):
+        # fqn on observers is not present, because they do not
+        # exist when the fqns are created during tracing. If this is
+        # an observer, get the fqn of the node being observed.
+        node_to_use_for_fqn = node
+        if node.op == "call_module":
+            assert isinstance(node.target, str)
+            module = getattr_from_fqn(gm, node.target)
+            if _is_activation_post_process(module):
+                node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
+        fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0]  # type: ignore[index]
+    return fqn  # type: ignore[return-value]
+
+
+def _insert_logger_after_node(
+    node: Node,
+    gm: GraphModule,
+    logger_cls: Callable,
+    logger_node_name_suffix: str,
+    ref_node_name: str,
+    model_name: str,
+    ref_name: str,
+    ref_node_target_type: str,
+    results_type: str,
+    index_within_arg: int,
+    index_of_arg: int,
+    fqn: Optional[str],
+) -> Node:
+    """
+    Given a starting graph of
+
+    prev_node -> node -> next_node
+
+    This function creates a new logger_cls obj and adds it
+    after node, resulting in
+
+    prev_node -> node -> logger_obj -> next_node
+    """
+    # create new name
+    logger_node_name = get_new_attr_name_with_prefix(
+        node.name + logger_node_name_suffix
+    )(gm)
+    target_type = get_target_type_str(node, gm)
+    # create the logger object
+    logger_obj = logger_cls(
+        ref_node_name,
+        node.name,
+        model_name,
+        ref_name,
+        target_type,
+        ref_node_target_type,
+        results_type,
+        index_within_arg,
+        index_of_arg,
+        fqn,
+    )
+    # attach the logger object to the parent module
+    setattr(gm, logger_node_name, logger_obj)
+    logger_node = node.graph.create_node("call_module", logger_node_name, (node,), {})
+    return logger_node
+
+
+def add_loggers_to_model(
+    gm: GraphModule,
+    node_to_instrument_inputs_to_ref_node_name: dict[Node, tuple[str, str]],
+    node_to_instrument_outputs_to_ref_node_name: dict[Node, tuple[str, str]],
+    logger_cls: Callable,
+    model_name: str,
+) -> GraphModule:
+    """
+    Takes the graph of gm, adds loggers to the output
+    of each node in nodes_to_instrument. Returns a GraphModule with the new
+    graph.
+    """
+
+    new_graph = Graph()
+    env: dict[str, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node.name])
+
+    for node in gm.graph.nodes:
+        if node.op == "output":
+            new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
+            continue
+
+        if (node in node_to_instrument_inputs_to_ref_node_name) or (
+            node in node_to_instrument_outputs_to_ref_node_name
+        ):
+            fqn = _maybe_get_fqn(node, gm)
+
+            if node in node_to_instrument_inputs_to_ref_node_name:
+                ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[
+                    node
+                ]
+                # Ops such add and mul are special because either
+                # one or two of the first two arguments can be tensors,
+                # and if one argument is a tensor it can be first or
+                # second (x + 1 versus 1 + x).
+                arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
+                for node_arg_idx in arg_indices_to_log:
+                    node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
+                    if type(node_arg) == Node:
+                        # create a single input logger
+                        prev_node = env[node_arg.name]
+                        env[node_arg.name] = _insert_logger_after_node(
+                            prev_node,
+                            gm,
+                            logger_cls,
+                            "_ns_logger_",
+                            node.name,
+                            model_name,
+                            ref_name,
+                            ref_node_type,
+                            NSSingleResultValuesType.NODE_INPUT.value,
+                            index_within_arg=0,
+                            index_of_arg=node_arg_idx,
+                            fqn=fqn,
+                        )
+                    elif (
+                        type(node_arg) == torch.fx.immutable_collections.immutable_list
+                    ):
+                        # create N input loggers, one for each node
+                        for arg_idx, arg in enumerate(node_arg):  # type: ignore[var-annotated, arg-type]
+                            prev_node = env[arg.name]
+                            env[prev_node.name] = _insert_logger_after_node(
+                                prev_node,
+                                gm,
+                                logger_cls,
+                                "_ns_logger_",
+                                node.name,
+                                model_name,
+                                ref_name,
+                                ref_node_type,
+                                NSSingleResultValuesType.NODE_INPUT.value,
+                                index_within_arg=arg_idx,
+                                index_of_arg=node_arg_idx,
+                                fqn=fqn,
+                            )
+                    else:
+                        pass
+
+            # ensure env is populated with base node
+            # Note: runs for both inputs and outputs
+            env[node.name] = new_graph.node_copy(node, load_arg)
+
+            if node in node_to_instrument_outputs_to_ref_node_name:
+                ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[
+                    node
+                ]
+                # add the logger after the base node
+                env[node.name] = _insert_logger_after_node(
+                    env[node.name],
+                    gm,
+                    logger_cls,
+                    "_ns_logger_",
+                    node.name,
+                    model_name,
+                    ref_name,
+                    ref_node_type,
+                    NSSingleResultValuesType.NODE_OUTPUT.value,
+                    index_within_arg=0,
+                    index_of_arg=0,
+                    fqn=fqn,
+                )
+
+        else:
+            env[node.name] = new_graph.node_copy(node, load_arg)
+
+    new_gm = GraphModule(gm, new_graph)
+    return new_gm
+
+
+def _insert_quantize_per_tensor_node(
+    prev_node_c: Node,
+    node_a: Node,
+    gm_b: GraphModule,
+    graph_c: Graph,
+    scale: Union[torch.Tensor, float],
+    zero_point: Union[torch.Tensor, int],
+    dtype_cast_name: str,
+) -> Node:
+    # copy scale
+    scale_node_name = get_new_attr_name_with_prefix(node_a.name + "_input_scale_")(gm_b)
+    setattr(gm_b, scale_node_name, scale)
+    scale_node = graph_c.create_node(
+        "get_attr", scale_node_name, (), {}, scale_node_name
+    )
+    # copy zero_point
+    zero_point_node_name = get_new_attr_name_with_prefix(
+        node_a.name + "_input_zero_point_"
+    )(gm_b)
+    setattr(gm_b, zero_point_node_name, zero_point)
+    zero_point_node = graph_c.create_node(
+        "get_attr", zero_point_node_name, (), {}, zero_point_node_name
+    )
+    # create the quantize_per_tensor call
+    return graph_c.create_node(
+        "call_function",
+        torch.quantize_per_tensor,
+        (prev_node_c, scale_node, zero_point_node, torch.quint8),
+        {},
+        dtype_cast_name,
+    )
+
+
+def _insert_dtype_cast_after_node(
+    node_a: Node,
+    node_c: Node,
+    prev_node_c: Union[Node, list[Node]],
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    graph_c: Graph,
+    node_name_prefix: str,
+    logger_cls: Callable,
+    node_type_to_io_type_map: dict[str, set[NSNodeTargetType]],
+) -> Union[Node, list[Node]]:
+    """
+    Given a starting graph C (derived from graph B) of
+
+    ... -> prev_node_c -> node_c -> ...
+
+    And a corresponding related node_a, inserts the correct dtype
+    cast node after prev_node_c to cast into the dtype expected
+    by node_a, resulting in:
+
+                          dtype_cast
+                        /
+    ... -> prev_node_c -> node_c -> ...
+
+    For example, if node_c is an int8 op and node_a is an fp32 op, this function
+    will insert a dequant.
+    """
+    dtype_cast_op = None
+    dtype_cast_mod_cls = None
+    dtype_cast_method = None
+    dtype_cast_method_dtype = None
+    dtype_cast_scale = None
+    dtype_cast_zero_point = None
+    node_input_type_a, _node_output_type_a = get_node_first_input_and_output_type(
+        node_a, gm_a, logger_cls, node_type_to_io_type_map
+    )
+    node_input_type_c, _node_output_type_c = get_node_first_input_and_output_type(
+        node_c, gm_b, logger_cls, node_type_to_io_type_map
+    )
+
+    if (
+        (
+            node_input_type_a == NodeInputOrOutputType.FP32
+            and node_input_type_c == NodeInputOrOutputType.INT8
+        )
+        or (
+            node_input_type_a == NodeInputOrOutputType.FP32
+            and node_input_type_c == NodeInputOrOutputType.FP16
+        )
+        or
+        # TODO(future PR): determine the actual dtype of node_c,
+        # the current code only works because dequantize works with
+        # multiple input dtypes.
+        (
+            node_input_type_a == NodeInputOrOutputType.FP32
+            and node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8
+        )
+    ):
+        dtype_cast_op = torch.dequantize
+    elif (
+        node_input_type_a == node_input_type_c
+        and node_input_type_a != NodeInputOrOutputType.UNKNOWN
+    ):
+        dtype_cast_mod_cls = torch.nn.Identity
+    elif (
+        node_input_type_a == NodeInputOrOutputType.INT8
+        and node_input_type_c == NodeInputOrOutputType.FP32
+    ):
+        # int8 shadows fp32, the dtype cast needs to quantize to int8
+        # with the right qparams.
+        node_a_input_qparams = get_node_input_qparams(
+            node_a, gm_a, node_type_to_io_type_map
+        )
+        if node_a_input_qparams is not None:
+            dtype_cast_op = torch.quantize_per_tensor  # type: ignore[assignment]
+            dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
+    elif (
+        node_input_type_a == NodeInputOrOutputType.FP16
+        and node_input_type_c == NodeInputOrOutputType.FP32
+    ):
+        dtype_cast_method = "to"
+        dtype_cast_method_dtype = torch.float16
+    else:
+        raise AssertionError(
+            f"dtype cast from {node_input_type_c} {node_c.format_node()} to "
+            + f"{node_input_type_a} {node_a.format_node()} needs to be implemented"
+        )
+
+    if isinstance(prev_node_c, Node):
+        new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
+        if dtype_cast_op:
+            if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
+                return _insert_quantize_per_tensor_node(
+                    prev_node_c,
+                    node_a,
+                    gm_b,
+                    graph_c,
+                    dtype_cast_scale,
+                    dtype_cast_zero_point,
+                    new_dtype_cast_name,
+                )
+            else:
+                return graph_c.create_node(
+                    "call_function",
+                    dtype_cast_op,
+                    (prev_node_c,),
+                    {},
+                    new_dtype_cast_name,
+                )
+        elif dtype_cast_method:
+            return graph_c.create_node(
+                "call_method",
+                dtype_cast_method,
+                (prev_node_c, dtype_cast_method_dtype),
+                {},
+                new_dtype_cast_name,
+            )
+        else:
+            assert dtype_cast_mod_cls
+            dtype_cast_mod = dtype_cast_mod_cls()
+            setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
+            return graph_c.create_node(
+                "call_module",
+                new_dtype_cast_name,
+                (prev_node_c,),
+                {},
+                new_dtype_cast_name,
+            )
+    elif isinstance(prev_node_c, list):
+        results = []
+        for prev_node_c_inner in prev_node_c:
+            new_dtype_cast_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
+            if dtype_cast_op:
+                # TODO(future PR): add handling for quantize_per_tensor
+                new_dtype_cast_node = graph_c.create_node(
+                    "call_function",
+                    dtype_cast_op,
+                    (prev_node_c_inner,),
+                    {},
+                    new_dtype_cast_name,
+                )
+                results.append(new_dtype_cast_node)
+            else:
+                assert dtype_cast_mod_cls
+                dtype_cast_mod = dtype_cast_mod_cls()
+                setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
+                new_dtype_cast_node = graph_c.create_node(
+                    "call_module",
+                    new_dtype_cast_name,
+                    (prev_node_c_inner,),
+                    {},
+                    new_dtype_cast_name,
+                )
+                results.append(new_dtype_cast_node)
+        return results
+    else:
+        raise AssertionError(f"type f{type(prev_node_c)} is not handled")
+
+
+# TODO(future PR): look into using copy_node API instead
+def _copy_node_from_a_to_c(
+    node_a: Node,
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    graph_c: Graph,
+) -> Node:
+    """
+    Simple copy of node_a to graph_c.
+    """
+    if node_a.op == "get_attr":
+        node_a_copy_name = get_new_attr_name_with_prefix(node_a.name + "_shadow_copy_")(
+            gm_b
+        )
+        node_a_obj = getattr_from_fqn(gm_a, node_a.target)  # type: ignore[arg-type]
+        if torch.is_tensor(node_a_obj):
+            node_a_obj = node_a_obj.detach()
+        setattr(gm_b, node_a_copy_name, node_a_obj)
+        node_a_copy = graph_c.create_node(
+            node_a.op, node_a_copy_name, (), {}, node_a_copy_name
+        )
+        return node_a_copy
+    elif node_a.op == "call_method":
+        assert node_a.target in (
+            "dequantize",
+            "to",
+        ), f"target {node_a.target} is not implemented"
+        if node_a.target == "dequantize":
+            arg_copy = _copy_node_from_a_to_c(
+                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
+            )  # type: ignore[arg-type]
+            node_a_copy_name = get_new_attr_name_with_prefix(
+                node_a.name + "_shadow_copy_"
+            )(gm_b)
+            node_a_copy = graph_c.create_node(
+                node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name
+            )
+            return node_a_copy
+        else:  # to
+            arg_copy = _copy_node_from_a_to_c(
+                get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c
+            )  # type: ignore[arg-type]
+            node_a_copy_name = get_new_attr_name_with_prefix(
+                node_a.name + "_shadow_copy_"
+            )(gm_b)
+            node_a_copy = graph_c.create_node(
+                node_a.op,
+                node_a.target,
+                (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
+                {},
+                node_a_copy_name,
+            )
+            return node_a_copy
+
+    else:
+        raise AssertionError(
+            f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented"
+        )
+
+
+def _can_insert_copy_of_subgraph_a(
+    subgraph_a: NSSubgraph,
+    gm_a: GraphModule,
+    num_non_param_args_node_a: int,
+) -> bool:
+    """
+    This function returns `False` if the input subgraph cannot be copied by
+    `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
+    that there is a corner case logic for which copy is not yet implemented.
+    """
+    # populate the list of nodes we need to check
+    nodes = []
+    cur_node = subgraph_a.end_node
+    while cur_node != subgraph_a.start_node:
+        nodes.append(cur_node)
+        cur_node = get_normalized_nth_input(cur_node, gm_a, 0)  # type: ignore[assignment]
+    nodes.append(cur_node)
+    nodes.reverse()
+
+    def _can_insert(node_a_arg, gm_a):
+        if isinstance(node_a_arg, Node):
+            arg_a = return_first_non_observer_node(node_a_arg, gm_a)
+            if arg_a.op == "call_method":
+                return arg_a.target in ("dequantize", "to")
+            elif arg_a.op == "get_attr":
+                return True
+            else:
+                return False
+        elif isinstance(node_a_arg, (list, tuple)):
+            for el in node_a_arg:
+                if not isinstance(el, Node):
+                    return False
+        return True
+
+    # For each node, check if we handle the copy behavior. This follows the
+    # logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
+    for node_a in nodes:
+        local_num_non_param_args_node_a = (
+            num_non_param_args_node_a if node_a is nodes[0] else 1
+        )
+
+        norm_args_kwargs = node_a.normalized_arguments(
+            gm_a, normalize_to_only_use_kwargs=True
+        )
+        if norm_args_kwargs is not None:
+            norm_args, norm_kwargs = norm_args_kwargs
+        else:
+            norm_args, norm_kwargs = node_a.args, node_a.kwargs
+
+        cur_idx = 0
+
+        while cur_idx < len(norm_args):
+            if cur_idx == 0:
+                pass
+            elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
+                pass
+            else:
+                if not _can_insert(norm_args[cur_idx], gm_a):
+                    return False
+            cur_idx += 1
+
+        for kwarg_val in norm_kwargs.values():
+            # stitch the inputs from base graph
+            if cur_idx == 0:
+                pass
+            elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
+                pass
+            else:
+                if not _can_insert(kwarg_val, gm_a):
+                    return False
+            cur_idx += 1
+
+    return True
+
+
+def _insert_copy_of_subgraph_a_after_input_node_c(
+    input_node_c: Union[Node, list[Node]],
+    input_node_c_2: Optional[Union[Node, list[Node]]],
+    subgraph_a: NSSubgraph,
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    node_name_prefix: str,
+) -> Node:
+    """
+    TODO(before land): real docblock
+    """
+    assert isinstance(input_node_c, (Node, list))
+
+    # create a sequential list of the subgraphs' nodes from start to end,
+    # because we need to add the nodes to graph C in non-reverse order
+    nodes_of_a = [subgraph_a.end_node]
+    cur_node = subgraph_a.end_node
+    while cur_node != subgraph_a.start_node:
+        cur_node = get_normalized_nth_input(cur_node, gm_a, 0)  # type: ignore[assignment]
+        nodes_of_a.insert(0, cur_node)
+
+    # go through nodes of a in order, and insert them into the graph of c
+    # sequentially
+    cur_node_a = nodes_of_a[0]
+    cur_node_c = _insert_copy_of_node_a_after_input_node_c(
+        input_node_c, input_node_c_2, cur_node_a, gm_a, gm_b, node_name_prefix
+    )
+    for cur_idx_a in range(1, len(nodes_of_a)):
+        cur_node_a = nodes_of_a[cur_idx_a]
+        prev_node_c = cur_node_c  # previous added node is the input to next node
+        cur_node_c = _insert_copy_of_node_a_after_input_node_c(
+            prev_node_c,
+            # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
+            None,
+            cur_node_a,
+            gm_a,
+            gm_b,
+            node_name_prefix,
+        )
+    # return the last inserted node
+    return cur_node_c
+
+
+def _insert_copy_of_node_a_after_input_node_c(
+    input_node_c: Union[Node, list[Node]],
+    input_node_c_2: Optional[Union[Node, list[Node]]],
+    node_a: Node,
+    gm_a: GraphModule,
+    gm_b: GraphModule,
+    node_name_prefix: str,
+) -> Node:
+    """
+    Assume that node_a from graph_a has
+      args (input, (input2)?, arg1, ...), and
+      kwargs {kw0: kwarg0, ...}
+
+    Note: input2 is optional. If it equals to None, we assume that the op
+    has a single non-param input.  If it is specified, we assume that the op
+    has two non-param inputs.
+
+    Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
+    and creates the corresponding nodes in graph_c. Note: observers are ignored,
+    so if an arg is an observer we navigate up until we find a non-observer parent.
+
+    If node_a is a call_module, points the module pointed to by node_a to gm_b.
+
+    Creates the copy of node_a in graph_c, with input as the first arg,
+    and all other args and kwargs pointing to the copies of the objects
+    in gm_b created above.
+
+    An example in pictures:
+
+    graph A:
+    ========
+
+    input -------------> node_a
+                         / / /
+    (input_2)?----------/ / /
+                         / /
+    weight -> weight_obs  /
+                         /
+    bias ----------------
+
+    graph C (derived from B):
+    =========================
+
+    input_node_c --> node_a_copy
+                     / / /
+    (input_node_c_2)? / /
+                     / /
+    weight_copy ----/ /
+                     /
+    bias_copy ------/
+    """
+    if isinstance(input_node_c, Node):
+        graph_c = input_node_c.graph
+    else:
+        assert isinstance(input_node_c, list)
+        graph_c = input_node_c[0].graph
+
+    norm_args_kwargs = node_a.normalized_arguments(
+        gm_a, normalize_to_only_use_kwargs=True
+    )
+    if norm_args_kwargs is not None:
+        norm_args, norm_kwargs = norm_args_kwargs
+    else:
+        norm_args, norm_kwargs = node_a.args, node_a.kwargs
+
+    new_args = []
+    new_kwargs = {}
+
+    def _copy_arg(arg):
+        # copy the other inputs from the other graph
+        if isinstance(arg, Node):
+            arg = return_first_non_observer_node(arg, gm_a)
+            arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
+            return arg
+        elif isinstance(arg, (int, float, torch.dtype)):
+            return arg
+        elif isinstance(kwarg_val, (list, tuple)):
+            for el in kwarg_val:
+                assert not isinstance(el, Node), (
+                    "handling of Node inside list is not implemented"
+                )
+            return arg
+        else:
+            raise AssertionError(
+                f"handling for kwarg of type {type(kwarg_val)} is not implemented"
+            )
+
+    cur_idx = 0
+
+    while cur_idx < len(norm_args):
+        if cur_idx == 0:
+            new_arg = input_node_c
+        elif cur_idx == 1 and input_node_c_2 is not None:
+            new_arg = input_node_c_2
+        else:
+            new_arg = _copy_arg(norm_args[cur_idx])
+        new_args.append(new_arg)
+        cur_idx += 1
+
+    for kwarg_name, kwarg_val in norm_kwargs.items():
+        # stitch the inputs from base graph
+        if cur_idx == 0:
+            new_kwargs[kwarg_name] = input_node_c
+        elif cur_idx == 1 and input_node_c_2 is not None:
+            new_kwargs[kwarg_name] = input_node_c_2
+        else:
+            new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
+        cur_idx += 1
+
+    new_args = tuple(new_args)  # type: ignore[assignment]
+
+    node_a_shadows_c_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
+
+    if node_a.op == "call_module":
+        # if target is a module, we point to the module from gm_b
+        new_mod_copy_name = get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
+        # fetch the corresponding module from gm_a
+        assert isinstance(node_a.target, str)
+        mod_a = getattr_from_fqn(gm_a, node_a.target)
+        setattr(gm_b, new_mod_copy_name, mod_a)
+        node_a_shadows_c = graph_c.create_node(
+            node_a.op,
+            new_mod_copy_name,
+            new_args,  # type: ignore[arg-type]
+            new_kwargs,  # type: ignore[arg-type]
+            node_a_shadows_c_name,
+        )
+        return node_a_shadows_c
+    else:
+        assert node_a.op in ("call_function", "call_method")
+        node_a_shadows_c = graph_c.create_node(
+            node_a.op,
+            node_a.target,
+            new_args,  # type: ignore[arg-type]
+            new_kwargs,  # type: ignore[arg-type]
+            node_a_shadows_c_name,
+        )
+        return node_a_shadows_c
+
+
+def create_a_shadows_b(
+    name_a: str,
+    gm_a: GraphModule,
+    name_b: str,
+    gm_b: GraphModule,
+    matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]],
+    logger_cls: Callable,
+    should_log_inputs: bool,
+    node_type_to_io_type_map: Optional[dict[str, set[NSNodeTargetType]]] = None,
+) -> GraphModule:
+    """
+    Creates a new GraphModule consisting of the graph of C, with the meaningful
+    nodes of A shadowing the corresponding nodes of B.  For example,
+
+    Graph A:
+    a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
+
+    Graph B:
+    b0 -> op0_int8 -> b1 -> op1_int8 -> b2
+
+    matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
+
+    Graph C (A shadows B):
+
+        / dequant0 -> op0_fp32 -> logger_a_0  / dequant_1 -> op1_fp32 -> logger_a_1
+       /                                     /
+    b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
+
+    In a nutshell, this function does the following for each node pair:
+    * copies the necessary attributes and modules from gm_a to gm_b,
+      keeping names unique
+    * adds a dtype cast op (dequant, quant, etc)
+    * adds a copy of node_a in gm_b's graph
+    * adds loggers to the outputs of node_a and node_b
+    """
+
+    if node_type_to_io_type_map is None:
+        node_type_to_io_type_map = get_node_type_to_io_type_map()
+
+    # graph_c is the graph created from copying the nodes of graph_b and inserting
+    # the shadows with the nodes copied from graph_a
+    graph_c = Graph()
+    env_c: dict[str, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env_c[node.name])
+
+    start_node_b_to_matched_subgraph_a_and_name = {}
+    end_node_b_to_matched_subgraph_a_and_name = {}
+    for match_name, match in matched_subgraph_pairs.items():
+        subgraph_a, subgraph_b = match
+        ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
+        ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
+        start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = (
+            subgraph_a,
+            match_name,
+            ref_node_type_a,
+            ref_node_type_b,
+        )
+        end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = (
+            subgraph_a,
+            match_name,
+            ref_node_type_a,
+            ref_node_type_b,
+        )
+
+    for node_b in gm_b.graph.nodes:
+        if node_b.op == "output":
+            graph_c.output(map_arg(node_b.args[0], load_arg))
+            continue
+
+        # calculate the flags to determine what to do with this node
+        node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
+        node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
+
+        if node_b_is_start_node or node_b_is_end_node:
+            if node_b_is_start_node:
+                (
+                    subgraph_a,
+                    ref_name,
+                    ref_node_type_a,
+                    ref_node_type_b,
+                ) = start_node_b_to_matched_subgraph_a_and_name[node_b]
+            else:
+                assert node_b_is_end_node
+                (
+                    subgraph_a,
+                    ref_name,
+                    ref_node_type_a,
+                    ref_node_type_b,
+                ) = end_node_b_to_matched_subgraph_a_and_name[node_b]
+
+            all_op_types_support_shadowing = op_type_supports_shadowing(
+                subgraph_a.start_node
+            ) and op_type_supports_shadowing(node_b)
+            if not all_op_types_support_shadowing:
+                print(
+                    f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
+                    + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
+                    + ", unsupported"
+                )
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                continue
+
+            # For both start_node and end_node verify that we know how to do
+            # the dtype cast. If we do not, skip.
+            (
+                node_input_type_a,
+                node_output_type_a,
+            ) = get_node_first_input_and_output_type(
+                subgraph_a.start_node, gm_a, logger_cls, node_type_to_io_type_map
+            )
+            (
+                node_input_type_b,
+                node_output_type_b,
+            ) = get_node_first_input_and_output_type(
+                node_b, gm_b, logger_cls, node_type_to_io_type_map
+            )
+            node_io_types_known_a_and_b = (
+                node_input_type_a != NodeInputOrOutputType.UNKNOWN
+                and node_output_type_a != NodeInputOrOutputType.UNKNOWN
+                and node_input_type_b != NodeInputOrOutputType.UNKNOWN
+                and node_output_type_b != NodeInputOrOutputType.UNKNOWN
+            )
+            if not node_io_types_known_a_and_b:
+                print(
+                    f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
+                    + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
+                    + ", unknown dtype cast"
+                )
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                continue
+
+            # If we are shadowing from fp32 to int8, we need to insert
+            # quantize_per_tensor call with qparams from the previous node.
+            # Only do this if we are able to infer these qparams from the graph.
+            if (
+                node_input_type_a == NodeInputOrOutputType.INT8
+                and node_input_type_b == NodeInputOrOutputType.FP32
+            ):
+                node_a_input_qparams = get_node_input_qparams(
+                    subgraph_a.start_node, gm_a, node_type_to_io_type_map
+                )
+                if not node_a_input_qparams:
+                    print(
+                        f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
+                        + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
+                        + ", unknown input qparams"
+                    )
+                    env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                    continue
+
+            num_non_param_args_node_a = get_number_of_non_param_args(
+                subgraph_a.start_node, gm_a
+            )
+            if not _can_insert_copy_of_subgraph_a(
+                subgraph_a, gm_a, num_non_param_args_node_a
+            ):
+                print(
+                    f"skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}"
+                    + f", start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}"
+                    + ", unhandled logic in subgraph copy"
+                )
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                continue
+
+            fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
+            fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)  # type: ignore[possibly-undefined]
+
+            if node_b_is_start_node:
+                # if necessary, log the input of node_c
+                if should_log_inputs:
+                    prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
+                    if isinstance(prev_node_b, Node):
+                        prev_node_c = env_c[prev_node_b.name]
+                        env_c[prev_node_c.name] = _insert_logger_after_node(
+                            prev_node_c,
+                            gm_b,
+                            logger_cls,
+                            "_ns_logger_b_inp_",
+                            node_b.name,
+                            name_b,
+                            ref_name,
+                            ref_node_type_b,
+                            NSSingleResultValuesType.NODE_INPUT.value,
+                            index_within_arg=0,
+                            index_of_arg=0,
+                            fqn=fqn_base_b,
+                        )
+                    elif isinstance(prev_node_b, list):
+                        # first, save the prev_node instances, because they
+                        # will be overwritten in the env after the first logger
+                        # is added
+                        prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
+
+                        for arg_idx, arg in enumerate(prev_node_b):
+                            prev_node_c = prev_node_c_list[arg_idx]
+                            env_c[prev_node_c.name] = _insert_logger_after_node(
+                                prev_node_c,
+                                gm_b,
+                                logger_cls,
+                                "_ns_logger_b_inp_",
+                                node_b.name,
+                                name_b,
+                                ref_name,
+                                ref_node_type_b,
+                                NSSingleResultValuesType.NODE_INPUT.value,
+                                index_within_arg=arg_idx,
+                                index_of_arg=0,
+                                fqn=fqn_base_b,
+                            )
+                    else:
+                        # logging of inputs which are not lists is not supported yet
+                        raise AssertionError(
+                            f"type {type(prev_node_b)} is not handled yet"
+                        )
+                # subgraph so far:
+                #
+                # (prev_node_c)+ -> (logger_c_input)?
+
+            # Note: this if statement is always True, spelling it out to clarify code
+            # intent.
+            if node_b_is_start_node or node_b_is_end_node:
+                # ensure env_c is populated with base node
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                node_c = env_c[node_b.name]
+
+                # after this point,
+                #
+                # node_a is the original node from graph_a, with parent module gm_a
+                # node_b is the original node from graph_b, with parent module gm_b
+                # node_c is the copy of node_b in graph_c
+                #
+                # subgraph so far:
+                #
+                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
+
+            if node_b_is_start_node:
+                # cast dtype from the dtype of node_c's input to the dtype of
+                # node_a's input (dequant, etc)
+                # prev_node_c = node_c.args[0]
+                prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)  # type: ignore[possibly-undefined]
+                if should_log_inputs:
+                    # skip the input logger when inserting a dtype cast
+                    if isinstance(prev_node_c, Node):
+                        prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
+                    elif isinstance(prev_node_c, list):
+                        prev_node_c = [
+                            get_normalized_nth_input(arg, gm_b, 0)
+                            for arg in prev_node_c
+                        ]
+                dtype_cast_node = _insert_dtype_cast_after_node(
+                    subgraph_a.start_node,
+                    node_c,
+                    prev_node_c,
+                    gm_a,
+                    gm_b,
+                    graph_c,
+                    node_b.name + "_dtype_cast_",
+                    logger_cls,
+                    node_type_to_io_type_map,
+                )
+                # note: not inserting to env_c because all nodes which use the dtype
+                #   casts are copied from graph_a
+                #
+                # subgraph so far:
+                #
+                #           (dtype_cast_node)+
+                #                  /
+                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
+
+                # if input logging is enabled, log the input to the subgraph
+                if should_log_inputs:
+                    # TODO: explain this
+                    ref_node_name = ""
+                    if isinstance(dtype_cast_node, Node):
+                        dtype_cast_node = _insert_logger_after_node(
+                            dtype_cast_node,
+                            gm_b,
+                            logger_cls,
+                            "_ns_logger_a_inp_",
+                            ref_node_name,
+                            name_a,
+                            ref_name,
+                            ref_node_type_a,
+                            NSSingleResultValuesType.NODE_INPUT.value,
+                            index_within_arg=0,
+                            index_of_arg=0,
+                            fqn=fqn_base_a,
+                        )
+                        input_logger: Union[Node, list[Node]] = dtype_cast_node
+                    else:
+                        assert isinstance(dtype_cast_node, list)
+                        new_loggers = []
+                        for dtype_cast_idx, dtype_cast_node_inner in enumerate(
+                            dtype_cast_node
+                        ):
+                            dtype_cast_logger = _insert_logger_after_node(
+                                dtype_cast_node_inner,
+                                gm_b,
+                                logger_cls,
+                                "_ns_logger_a_inp_",
+                                ref_node_name,
+                                name_a,
+                                ref_name,
+                                ref_node_type_a,
+                                NSSingleResultValuesType.NODE_INPUT.value,
+                                index_within_arg=dtype_cast_idx,
+                                index_of_arg=0,
+                                fqn=fqn_base_a,
+                            )
+                            new_loggers.append(dtype_cast_logger)
+                        dtype_cast_node = new_loggers
+                        input_logger = dtype_cast_node
+                    # subgraph so far:
+                    #
+                    #       (dtype_cast_node)+ -> (logger_a_input)?
+                    #                  /
+                    # prev_node_c -> (logger_c_input)? -> node_start_c
+
+                # hook up the new mod_a copy to be in the graph, receiving the
+                # same inputs as mod_b does, with dtype cast to match a
+                # Some ops, such as LSTMs, have two non-param inputs. If we have
+                # such an op, pass the second param as well. Note: dtype casting
+                # for the second param is not implemented yet, it can be added
+                # later if there is a use case.
+                node_c_second_non_param_arg = None
+                num_non_param_args_node_a = get_number_of_non_param_args(
+                    subgraph_a.start_node, gm_a
+                )
+                if num_non_param_args_node_a == 2:
+                    # node_c_second_non_param_arg = node_c.args[1]
+                    node_c_second_non_param_arg = get_normalized_nth_input(
+                        node_c, gm_b, 1
+                    )
+                node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
+                    dtype_cast_node,
+                    node_c_second_non_param_arg,
+                    subgraph_a,
+                    gm_a,
+                    gm_b,
+                    node_c.name + "_shadow_copy_",
+                )
+                env_c[node_a_shadows_c.name] = node_a_shadows_c
+                # subgraph so far:
+                #
+                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
+                #                  /
+                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
+
+                if should_log_inputs:
+                    # When we created the input logger, we left the ref_node_name
+                    # as an empty string, because the subgraph copy did not exist
+                    # yet. Now that the subgraph copy exists, we modify this name
+                    # to its true value.
+                    # Note: the alternative to this is to create the input logger
+                    # after creating the subgraph, which is slightly more
+                    # complicated. This is the lesser of two evils.
+                    # input_logger = env_c[dtype_cast_node.name]
+                    # Find the first node in the subgraph
+                    cur_node = node_a_shadows_c
+                    while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:  # type: ignore[possibly-undefined]
+                        cur_node = get_normalized_nth_input(cur_node, gm_b, 0)  # type: ignore[assignment]
+                    if isinstance(input_logger, Node):
+                        input_logger_mod = getattr(gm_b, input_logger.name)
+                        input_logger_mod.ref_node_name = cur_node.name
+                    else:
+                        assert isinstance(input_logger, list)
+                        for input_logger_inner in input_logger:
+                            input_logger_mod = getattr(gm_b, input_logger_inner.name)
+                            input_logger_mod.ref_node_name = cur_node.name
+
+                # hook up a logger to the mod_a copy
+                env_c[node_a_shadows_c.name] = _insert_logger_after_node(
+                    env_c[node_a_shadows_c.name],
+                    gm_b,
+                    logger_cls,
+                    "_ns_logger_a_",
+                    node_a_shadows_c.name,
+                    name_a,
+                    ref_name,
+                    ref_node_type_a,
+                    NSSingleResultValuesType.NODE_OUTPUT.value,
+                    index_within_arg=0,
+                    index_of_arg=0,
+                    fqn=fqn_base_a,
+                )
+                # subgraph so far:
+                #
+                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
+                #                  /
+                # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
+
+            if node_b_is_end_node:
+                # hook up a logger to the mod_b copy
+                env_c[node_b.name] = _insert_logger_after_node(
+                    env_c[node_b.name],
+                    gm_b,
+                    logger_cls,
+                    "_ns_logger_b_",
+                    node_b.name,
+                    name_b,
+                    ref_name,
+                    ref_node_type_b,
+                    NSSingleResultValuesType.NODE_OUTPUT.value,
+                    index_within_arg=0,
+                    index_of_arg=0,
+                    fqn=fqn_base_b,
+                )
+                # subgraph so far:
+                #
+                #       dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
+                #                  /
+                # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
+                #
+                # Note: node_start_c may be the same node as node_end_c, or they
+                # may have nodes inbetween.
+
+        else:
+            env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+
+    gm_c = GraphModule(gm_b, graph_c)
+    return gm_c
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/mappings.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/mappings.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8ca955d22fa69608e1e3e99f11f9bbfdf6f2280
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/mappings.py
@@ -0,0 +1,757 @@
+import operator
+from typing import Callable, Optional
+
+import torch
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.quantized as nniq
+import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.qat.dynamic as nnqatd
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.ao.quantization.fx._lower_to_native_backend as _lower_to_native_backend
+import torch.ao.quantization.quantization_mappings as quantization_mappings
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.quantization.backend_config import get_native_backend_config
+
+from .ns_types import NSNodeTargetType
+
+
+toq = torch.ops.quantized
+
+
+def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
+    # note: this set is modified below by items from backend_config
+    sets_of_related_ops: list[set[NSNodeTargetType]] = [
+        # conv modules
+        {
+            nn.Conv1d,
+        },
+        {
+            nn.Conv2d,
+        },
+        {
+            nn.Conv3d,
+        },
+        # conv functionals
+        {
+            F.conv1d,
+        },
+        {
+            F.conv2d,
+        },
+        {
+            F.conv3d,
+        },
+        # linear modules
+        {
+            nn.Linear,
+        },
+        # linear functionals
+        {
+            F.linear,
+        },
+        # average pool
+        {
+            nn.AvgPool1d,
+            torch.avg_pool1d,
+        },
+        {
+            nn.AvgPool2d,
+            torch._C._nn.avg_pool2d,
+        },
+        {
+            nn.AvgPool3d,
+            torch._C._nn.avg_pool3d,
+        },
+        # adaptive average pool
+        {
+            nn.AdaptiveAvgPool1d,
+            F.adaptive_avg_pool1d,
+        },
+        {
+            nn.AdaptiveAvgPool2d,
+            F.adaptive_avg_pool2d,
+        },
+        {
+            nn.AdaptiveAvgPool3d,
+            F.adaptive_avg_pool3d,
+        },
+        # LSTM
+        {
+            nn.LSTM,
+        },
+        # add
+        {
+            torch.add,
+            operator.add,  # x + y
+        },
+        # cat
+        {
+            torch.cat,
+        },
+        # mul
+        {
+            torch.mul,
+            operator.mul,
+        },
+        # relu
+        {
+            F.relu,
+            nn.ReLU,
+            "relu",
+            "relu_",
+            torch.relu,
+        },
+        # maxpool
+        {
+            nn.MaxPool1d,
+            F.max_pool1d,
+        },
+        {
+            nn.MaxPool2d,
+            F.max_pool2d,
+        },
+        {
+            nn.MaxPool3d,
+            F.max_pool3d,
+        },
+        # sigmoid
+        {
+            torch.sigmoid,
+            "sigmoid",
+            "sigmoid_",
+            nn.Sigmoid,
+            F.sigmoid,
+        },
+        # BatchNorm
+        {
+            nn.BatchNorm2d,
+        },
+        {
+            nn.BatchNorm3d,
+        },
+        # ConvTranspose
+        {
+            nn.ConvTranspose1d,
+        },
+        {
+            nn.ConvTranspose2d,
+        },
+        {
+            nn.ConvTranspose3d,
+        },
+        # functional transposed conv
+        {
+            F.conv_transpose1d,
+        },
+        {
+            F.conv_transpose2d,
+        },
+        {
+            F.conv_transpose3d,
+        },
+        # ELU
+        {
+            nn.ELU,
+        },
+        # Embedding
+        {
+            nn.Embedding,
+        },
+        # EmbeddingBag
+        {
+            nn.EmbeddingBag,
+        },
+        # GroupNorm
+        {
+            nn.GroupNorm,
+        },
+        # Hardswish
+        {
+            nn.Hardswish,
+        },
+        # InstanceNorm
+        {
+            nn.InstanceNorm1d,
+        },
+        {
+            nn.InstanceNorm2d,
+        },
+        {
+            nn.InstanceNorm3d,
+        },
+        # LayerNorm
+        {
+            nn.LayerNorm,
+        },
+        # LeakyReLU
+        {
+            nn.LeakyReLU,
+        },
+        # ReLU6
+        {
+            nn.ReLU6,
+            F.relu6,
+        },
+        # F.elu
+        {
+            F.elu,
+        },
+        # F.hardswish
+        {
+            F.hardswish,
+        },
+        # F.group_norm
+        {
+            F.group_norm,
+        },
+        # F.instance_norm
+        {
+            F.instance_norm,
+        },
+        # F.layer_norm
+        {
+            F.layer_norm,
+        },
+        # F.leaky_relu
+        {
+            F.leaky_relu,
+        },
+        # F.silu
+        {
+            nn.SiLU,
+            F.silu,
+        },
+        # F.mish
+        {
+            nn.Mish,
+            F.mish,
+        },
+        # F.tanh
+        {
+            nn.Tanh,
+            F.tanh,
+            torch.tanh,
+            "tanh_",
+            "tanh",
+        },
+        # F.hardsigmoid
+        {
+            "hardsigmoid_",
+            "hardsigmoid",
+            F.hardsigmoid,
+            nn.Hardsigmoid,
+        },
+        # F.hardtanh
+        {
+            nn.Hardtanh,
+            F.hardtanh,
+            F.hardtanh_,
+        },
+        # floordiv
+        {
+            operator.floordiv,
+        },
+        # unsqueeze
+        {
+            torch.unsqueeze,
+        },
+        # stack
+        {
+            torch.stack,
+        },
+        # squeeze
+        {
+            torch.squeeze,
+        },
+        # sort
+        {
+            torch.sort,
+        },
+        # repeat_interleave
+        {
+            torch.repeat_interleave,
+        },
+        # min
+        {
+            torch.min,
+        },
+        # mean
+        {
+            torch.mean,
+        },
+        # max
+        {
+            torch.max,
+        },
+        # transpose
+        {
+            torch.transpose,
+        },
+        # flatten
+        {
+            torch.flatten,
+        },
+        # clamp
+        {
+            torch.clamp,
+        },
+        # chunk
+        {
+            torch.chunk,
+        },
+        # interpolate
+        {
+            torch.nn.functional.interpolate,
+        },
+        # dropout
+        {
+            nn.Dropout,
+        },
+        # F.dropout
+        {
+            F.dropout,
+        },
+        # matmul
+        {
+            torch.matmul,
+        },
+        # Softmax
+        {
+            nn.Softmax,
+        },
+        # PReLU
+        {
+            nn.PReLU,
+            nnq.PReLU,
+        },
+        # F.prelu
+        {
+            F.prelu,
+            toq.prelu,
+        },
+        # pixel shuffle
+        {
+            nn.PixelShuffle,
+        },
+        {
+            F.pixel_shuffle,
+        },
+        # pixel unshuffle
+        {
+            nn.PixelUnshuffle,
+        },
+        {
+            F.pixel_unshuffle,
+        },
+        # narrow
+        {
+            torch.narrow,
+        },
+    ]
+
+    # for each floating point op, add versions of the op added by
+    # backend_config
+    backend_config = get_native_backend_config()
+
+    new_connections: list[tuple[Callable, Callable]] = [
+        # technical debt edge case
+        (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
+    ]
+
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        # pattern format: (c, (b, a))
+        first_element = pattern
+        # look from the end, because pattern is in reverse order
+        while isinstance(first_element, (list, tuple)):
+            first_element = first_element[-1]
+
+        if config.fused_module is not None:
+            # case 1: pattern fuses a pattern of ops into an op
+            # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
+            new_connections.append((first_element, config.fused_module))
+
+        if config.qat_module is not None:
+            # case 2: pattern swaps a module into a QAT module
+            # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
+            new_connections.append((first_element, config.qat_module))
+
+        if config.reference_quantized_module is not None:
+            # case 3: reference version of floating point module, such as
+            # nn.Conv2d and nnqr.Conv2d
+            new_connections.append((first_element, config.reference_quantized_module))
+
+    #
+    # Add reference module swaps from default lowering path
+    #
+
+    for source_to_target in (
+        _lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
+        _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
+        _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
+        _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
+    ):
+        for source, target in source_to_target.items():  # type: ignore[attr-defined]
+            new_connections.append((source, target))
+
+    for source_to_double_target in (
+        _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
+        _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
+        _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
+    ):
+        for source, (target1, target2) in source_to_double_target.items():  # type: ignore[attr-defined]
+            new_connections.append((source, target1))
+            new_connections.append((source, target2))
+
+    #
+    # Add function swaps from default lowering path
+    #
+
+    for source, (  # type:ignore[assignment]
+        target1,
+        target2,
+    ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
+        new_connections.append((source, target1))
+        new_connections.append((source, target2))
+
+    for source_to_target in (
+        _lower_to_native_backend.QBIN_OP_MAPPING,
+        _lower_to_native_backend.QBIN_RELU_OP_MAPPING,
+        quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
+    ):
+        for source, target in source_to_target.items():  # type:ignore[assignment]
+            new_connections.append((source, target))
+
+    #
+    # Add other swaps, ideally in the future this could be removed
+    # after the lowering code stops using these.
+    #
+    for source_to_target in (
+        quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
+    ):
+        for source, target in source_to_target.items():  # type:ignore[assignment]
+            new_connections.append((source, target))
+
+    # add the new connections from backend_config
+    for item1, item2 in new_connections:
+        for set_of_related_ops in sets_of_related_ops:
+            if item1 in set_of_related_ops or item2 in set_of_related_ops:
+                set_of_related_ops.add(item1)
+                set_of_related_ops.add(item2)
+                break
+
+    base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]] = {}
+
+    for counter, set_of_related_ops in enumerate(sets_of_related_ops):
+        base_name = str(counter)
+        base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
+
+    return base_name_to_sets_of_related_ops
+
+
+def get_base_name_for_op(
+    base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]],
+    op: NSNodeTargetType,
+) -> Optional[str]:
+    for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
+        if op in set_of_related_ops:
+            return base_name
+    return None
+
+
+def add_op_to_sets_of_related_ops(
+    base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]],
+    op: NSNodeTargetType,
+    related_op: Optional[NSNodeTargetType],
+) -> None:
+    if related_op is not None:
+        for set_of_related_ops in base_name_to_sets_of_related_ops.values():
+            if related_op in set_of_related_ops:
+                set_of_related_ops.add(op)
+                return
+        # if we got here, related_op was not found
+        raise AssertionError(f"{related_op} was not found")
+    else:
+        counter = 0
+        while str(counter) in base_name_to_sets_of_related_ops:
+            counter += 1
+        base_name_to_sets_of_related_ops[str(counter)] = {op}
+
+
+# TODO(future PR): clean this up
+def get_node_type_to_io_type_map() -> dict[str, set[NSNodeTargetType]]:
+    FUNS_IO_TYPE_FP32: set[NSNodeTargetType] = {
+        F.linear,
+        F.conv1d,
+        F.conv2d,
+        F.conv3d,
+        torch.cat,
+        F.elu,
+        F.hardswish,
+        F.instance_norm,
+        F.layer_norm,
+        F.leaky_relu,
+        F.dropout,
+        F.silu,
+        F.mish,
+        operator.add,
+        torch.add,
+        operator.mul,
+        torch.mul,
+        torch.sum,
+        F.prelu,
+    }
+
+    FUNS_IO_TYPE_FP16: set[NSNodeTargetType] = set()
+
+    FUNS_IO_TYPE_INT8: set[NSNodeTargetType] = {
+        toq.linear,
+        toq.linear_relu,
+        toq.conv1d,
+        toq.conv1d_relu,
+        toq.conv2d,
+        toq.conv2d_relu,
+        toq.conv3d,
+        toq.conv3d_relu,
+        toq.cat,
+        toq.elu,
+        toq.hardswish,
+        toq.instance_norm,
+        toq.layer_norm,
+        toq.leaky_relu,
+        toq.dropout,
+        toq.prelu,
+        # TODO(future PR): implement shadowing for binary ops and
+        # uncomment below
+        # toq.add,
+        # toq.mul,
+    }
+
+    FUNS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = {
+        F.relu,
+        F.tanh,
+        torch.tanh,
+        F.sigmoid,
+        torch.sigmoid,
+        F.hardsigmoid,
+        operator.floordiv,
+        torch.adaptive_avg_pool1d,
+        F.adaptive_avg_pool2d,
+        F.adaptive_avg_pool3d,
+        F.dropout,
+        F.hardtanh,
+        F.hardtanh_,
+        F.interpolate,
+        F.max_pool1d,
+        F.max_pool2d,
+        F.max_pool3d,
+        F.relu6,
+        F.pixel_shuffle,
+        F.pixel_unshuffle,
+        torch.avg_pool1d,
+        torch._C._nn.avg_pool2d,
+        torch._C._nn.avg_pool3d,
+        torch.cat,
+        torch.chunk,
+        torch.clamp,
+        torch.flatten,
+        torch.transpose,
+        torch.max,
+        torch.mean,
+        torch.min,
+        torch.narrow,
+        torch.repeat_interleave,
+        torch.sort,
+        torch.squeeze,
+        torch.stack,
+        torch.unsqueeze,
+        operator.add,
+    }
+
+    MODS_IO_TYPE_FP32: set[NSNodeTargetType] = {
+        nn.Linear,
+        nnqat.Linear,
+        nnqatd.Linear,
+        nnqd.Linear,
+        torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
+        nn.Conv1d,
+        nn.Conv2d,
+        nn.Conv3d,
+        nnqat.Conv1d,
+        nnqat.Conv2d,
+        nnqat.Conv3d,
+        nnqat.Embedding,
+        nnqat.EmbeddingBag,
+        nn.LSTM,
+        # note: nnqd.Linear is an instance of nnq.Linear, so this
+        # check has to happen before the int8 module check
+        nnqd.LSTM,
+        nn.BatchNorm2d,
+        nn.BatchNorm3d,
+        nn.Dropout,
+        nn.ConvTranspose1d,
+        nn.ConvTranspose2d,
+        nn.ConvTranspose3d,
+        nn.ELU,
+        nn.GroupNorm,
+        nn.InstanceNorm1d,
+        nn.InstanceNorm2d,
+        nn.InstanceNorm3d,
+        nn.LayerNorm,
+        nn.Hardswish,
+        nn.LeakyReLU,
+        nn.ReLU6,
+        nn.SiLU,
+        nn.Mish,
+        nn.Softmax,
+        nn.PReLU,
+        nni.BNReLU2d,
+        nni.BNReLU3d,
+        nni.ConvReLU1d,
+        nni.ConvReLU2d,
+        nni.ConvReLU3d,
+        nni.LinearReLU,
+        nni.LinearBn1d,
+        nni.ConvBn1d,
+        nni.ConvBn2d,
+        nni.ConvBn3d,
+        nniqat.ConvBn1d,
+        nniqat.ConvBn2d,
+        nniqat.ConvBn3d,
+        nniqat.ConvBnReLU1d,
+        nniqat.ConvBnReLU2d,
+        nniqat.ConvBnReLU3d,
+        nniqat.ConvReLU1d,
+        nniqat.ConvReLU2d,
+        nniqat.ConvReLU3d,
+        nniqat.LinearReLU,
+        nniqat.LinearBn1d,
+        nniqd.LinearReLU,
+        nni.LinearLeakyReLU,
+        nni.LinearTanh,
+        nni.ConvAdd2d,
+        nni.ConvAddReLU2d,
+    }
+
+    MODS_IO_TYPE_INT8: set[NSNodeTargetType] = {
+        nnq.Linear,
+        nnq.Conv1d,
+        nnq.Conv2d,
+        nnq.Conv3d,
+        nnq.BatchNorm2d,
+        nnq.BatchNorm3d,
+        nnq.Dropout,
+        nnq.ConvTranspose1d,
+        nnq.ConvTranspose2d,
+        nnq.ELU,
+        nnq.InstanceNorm1d,
+        nnq.InstanceNorm2d,
+        nnq.InstanceNorm3d,
+        nnq.LayerNorm,
+        nnq.Hardswish,
+        nnq.LeakyReLU,
+        nnq.Embedding,
+        nnq.EmbeddingBag,
+        nnq.Dropout,
+        nnq.Softmax,
+        nnq.PReLU,
+        nniq.BNReLU2d,
+        nniq.BNReLU3d,
+        nniq.ConvReLU1d,
+        nniq.ConvReLU2d,
+        nniq.ConvReLU3d,
+        nniq.LinearReLU,
+        nniq.LinearLeakyReLU,
+        nniq.LinearTanh,
+        nniq.ConvAdd2d,
+        nniq.ConvAddReLU2d,
+    }
+
+    MODS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = {
+        nn.ReLU,
+        nn.Tanh,
+        nn.Sigmoid,
+        nn.Hardsigmoid,
+        nn.AdaptiveAvgPool1d,
+        nn.AdaptiveAvgPool2d,
+        nn.AdaptiveAvgPool3d,
+        nn.AvgPool1d,
+        nn.AvgPool2d,
+        nn.AvgPool3d,
+        nn.Dropout,
+        nn.Hardtanh,
+        nn.Identity,
+        nn.MaxPool1d,
+        nn.MaxPool2d,
+        nn.MaxPool3d,
+        nn.PixelShuffle,
+        nn.PixelUnshuffle,
+        nn.ReLU6,
+    }
+
+    METHS_IO_TYPE_FP32_OR_INT8: set[NSNodeTargetType] = {
+        "sigmoid_",
+        "sigmoid",
+        "tanh_",
+        "tanh",
+        "hardsigmoid_",
+        "hardsigmoid",
+        "relu_",
+        "relu",
+    }
+
+    return {
+        "funs_io_type_fp32": FUNS_IO_TYPE_FP32,
+        "funs_io_type_fp16": FUNS_IO_TYPE_FP16,
+        "funs_io_type_int8": FUNS_IO_TYPE_INT8,
+        "funs_io_type_fp32_or_int8": FUNS_IO_TYPE_FP32_OR_INT8,
+        "mods_io_type_fp32": MODS_IO_TYPE_FP32,
+        "mods_io_type_int8": MODS_IO_TYPE_INT8,
+        "mods_io_type_fp32_or_int8": MODS_IO_TYPE_FP32_OR_INT8,
+        "meths_io_type_fp32_or_int8": METHS_IO_TYPE_FP32_OR_INT8,
+    }
+
+
+def get_unmatchable_types_map() -> dict[str, set[NSNodeTargetType]]:
+    FUNS_UNMATCHABLE: set[NSNodeTargetType] = {
+        torch.quantize_per_tensor,
+        operator.getitem,
+    }
+
+    MODS_UNMATCHABLE: set[NSNodeTargetType] = {
+        nn.Identity,
+    }
+
+    METHS_UNMATCHABLE: set[NSNodeTargetType] = {
+        "to",
+        "dequantize",
+        "reshape",
+        "view",
+        "unsqueeze_",
+        "unsqueeze",
+        "transpose",
+        "squeeze_",
+        "squeeze",
+        "size",
+        "shape",
+        "resize_",
+        "repeat_interleave",
+        "repeat",
+        "permute",
+        "numel",
+        "mean",
+        "detach_",
+        "detach",
+        "contiguous",
+        "clamp",
+        "chunk",
+    }
+
+    return {
+        "funs_unmatchable": FUNS_UNMATCHABLE,
+        "mods_unmatchable": MODS_UNMATCHABLE,
+        "meths_unmatchable": METHS_UNMATCHABLE,
+    }
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/n_shadows_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/n_shadows_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d8b569036ff24bef4f45886c6031eba34d29b87
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/n_shadows_utils.py
@@ -0,0 +1,1378 @@
+# mypy: allow-untyped-defs
+import collections
+import copy
+import operator
+from typing import Any, Callable, Optional
+
+import torch
+import torch.fx
+from torch.ao.ns.fx.graph_passes import _maybe_get_fqn
+from torch.ao.ns.fx.ns_types import NSResultsType, NSSingleResultValuesType
+from torch.ao.ns.fx.utils import (  # TODO(future PR): make this work correctly for methods
+    get_normalized_nth_input,
+    get_target_type_str,
+)
+from torch.ao.quantization import QConfigMapping
+from torch.ao.quantization.fx.match_utils import _MatchResult
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.ao.quantization.utils import getattr_from_fqn
+from torch.fx import Graph, GraphModule, Node
+from torch.utils._pytree import tree_map
+
+
+SHADOW_NODE_NAME_PREFIX = "shadow"
+SHADOW_WRAPPER_NODE_NAME_PREFIX = "shadow_wrapper"
+
+# TODO(future PR): reuse existing mapping instead of creating a new one
+BINARY_FUNCTIONS = {
+    torch.add,
+    torch.Tensor.add,
+    operator.add,
+    torch.mul,
+    torch.Tensor.mul,
+    operator.mul,
+}
+
+
+def _get_attr_name(subgraph_idx, subgraph_candidate_idx):
+    return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
+
+
+def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx):
+    return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
+
+
+class OutputProp:
+    """
+    Output propagation (modeled from shape propagation).
+
+    Given a GraphModule and an example input, saves the output flowing
+    through each node on `node.traced_result`.
+
+    Code based on the example from
+    https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern
+    """
+
+    def __init__(self, mod):
+        self.mod = mod
+        self.graph = mod.graph
+        self.modules = dict(self.mod.named_modules())
+
+    def propagate(self, *args):
+        args_iter = iter(args)
+        env: dict[str, Node] = {}
+
+        def load_arg(a):
+            return torch.fx.graph.map_arg(a, lambda n: env[n.name])
+
+        def fetch_attr(target: str):
+            target_atoms = target.split(".")
+            attr_itr = self.mod
+            for i, atom in enumerate(target_atoms):
+                if not hasattr(attr_itr, atom):
+                    raise RuntimeError(
+                        f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
+                    )
+                attr_itr = getattr(attr_itr, atom)
+            return attr_itr
+
+        for node in self.graph.nodes:
+            if node.op == "placeholder":
+                result = next(args_iter)
+            elif node.op == "get_attr":
+                result = fetch_attr(node.target)
+            elif node.op == "call_function":
+                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
+            elif node.op == "call_method":
+                self_obj, *args = load_arg(node.args)
+                kwargs = load_arg(node.kwargs)
+                result = getattr(self_obj, node.target)(*args, **kwargs)
+            elif node.op == "call_module":
+                result = self.modules[node.target](
+                    *load_arg(node.args), **load_arg(node.kwargs)
+                )
+
+            if isinstance(result, torch.Tensor):  # type: ignore[possibly-undefined]
+                node.traced_result = result
+
+            env[node.name] = result
+
+        return None
+
+
+def _get_dedup_subgraphs(matches: dict[str, _MatchResult]) -> dict[str, list[Node]]:
+    # the original matches variable is unique by node, make it unique by subgraph
+    # instead
+    seen_nodes = set()
+    subgraphs_dedup = {}
+
+    # Dict items are not reversible until Python 3.8, so we hack it
+    # to be compatible with previous Python versions
+    # TODO(future PR): try reversed(list(matches.items()))
+    matches_items_reversed: list[tuple[str, _MatchResult]] = list(
+        reversed(matches.items())
+    )
+
+    # Note: the order is important.  `matches` currently provides the matches
+    # in reverse order.  We would like to process the matches in non-reverse
+    # order, so that we can create an intuitive naming scheme, such as
+    # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)`
+    for name, cur_match in matches_items_reversed:  # type: ignore[call-overload]
+        was_seen = False
+        for node_or_tuple in cur_match[1]:
+            # Cur_match[1] has an unusual type. It says that it's a `List[Node]`,
+            # but it is really not. Furthermore, the contents of this field
+            # can change from match results of multiple nodes of the same pattern
+            #
+            # For example, for conv -> bn -> relu, we see
+            # match_results = {
+            #   'conv': (relu, [(bn, conv), relu], ...),
+            #   'bn': (relu, [(bn, conv), relu], ...),
+            #   'relu': (relu, [(bn, conv), relu], ...),
+            # }
+            #
+            # Ideally we should clean up the `find_matches` function to make
+            # this more intuitive. For the purposes of this prototype, we hack
+            # around it.
+
+            if isinstance(node_or_tuple, Node):
+                if node_or_tuple in seen_nodes:
+                    was_seen = True
+                seen_nodes.add(node_or_tuple)
+
+            else:
+                assert isinstance(node_or_tuple, tuple)
+                for node in node_or_tuple:
+                    assert isinstance(node, Node)
+                    if node in seen_nodes:
+                        was_seen = True
+                    seen_nodes.add(node)
+
+        if was_seen:
+            continue
+
+        # Start with the unusual type, convert it to [op_0, ..., op_n]
+        list_of_nodes = []
+
+        if len(cur_match[1]) == 1:
+            list_of_nodes = cur_match[1]
+        else:
+            assert len(cur_match[1]) == 2
+            # either (a, b), or ((a, b), c) or (c, (a, b))
+            # cannot make any assumptions on order, not clear what the
+            # _find_matches function is doing to populate this
+            # TODO(future PR): make this code less confusing,  see discussion
+            # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
+
+            def _order_nodes(node_a, node_b, node_c) -> list[Node]:
+                nodes = [node_a, node_b, node_c]
+                first_node = None
+                mid_node = None
+                last_node = None
+                for n in nodes:
+                    prev_n = n.args[0]
+                    next_n = next(iter(n.users))
+                    if prev_n not in nodes:
+                        first_node = n
+                    elif next_n not in nodes:
+                        last_node = n
+                    else:
+                        mid_node = n
+                assert (
+                    first_node is not None
+                    and mid_node is not None
+                    and last_node is not None
+                )
+                assert mid_node.args[0] is first_node
+                assert last_node.args[0] is mid_node
+                return [last_node, mid_node, first_node]
+
+            if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
+                # (a, b)
+                list_of_nodes = cur_match[1]
+            elif isinstance(cur_match[1][0], tuple):
+                # ((a, b), c)
+                node_a, node_b = cur_match[1][0]
+                node_c = cur_match[1][1]
+                list_of_nodes = _order_nodes(node_a, node_b, node_c)
+            elif isinstance(cur_match[1][1], tuple):
+                # (a, (b, c))
+                node_a, node_b = cur_match[1][1]
+                node_c = cur_match[1][0]
+                list_of_nodes = _order_nodes(node_a, node_b, node_c)
+
+        # [node_n, ..., node_0], note that the order is reversed
+        # to make it chronological for simple subgraphs
+        list_of_nodes.reverse()
+        subgraphs_dedup[name] = list_of_nodes
+
+    return subgraphs_dedup
+
+
+def _get_logger_for_subgraph(
+    model: GraphModule,
+    first_node: Node,
+    last_node: Node,
+    subgraph_idx: int,
+    subgraph_candidate_idx: int,
+    qconfig_str: str,
+    logger_cls: Callable,
+    fqn: Optional[str],
+) -> torch.nn.Module:
+    """
+    Given a model and a linear subgraph starting from `first_node` and
+    ending with `last_node`, creates a logger for the end of this
+    subgraph.
+    """
+    if fqn is None:
+        fqn = ""
+    logger_mod_orig = logger_cls(
+        first_node.name,  # ref_node_name
+        last_node.name,  # prev_node_name
+        f"subgraph_{subgraph_idx}_{subgraph_candidate_idx}",  # model_name
+        "model",  # ref_name
+        get_target_type_str(last_node, model),  # prev_node_target_type
+        get_target_type_str(first_node, model),  # ref_node_target_type
+        NSSingleResultValuesType.NODE_OUTPUT.value,  # results_type
+        0,  # index_within_arg
+        0,  # index_of_arg
+        fqn,  # fqn
+        qconfig_str,
+    )
+    # Usually we expect the user to add loggers, then calibrate, then convert,
+    # and then populate loggers.  This is why the loggers start disabled.
+    # TODO(future PR): reconsider the design to make this more intuitive.
+    logger_mod_orig.enabled = False
+    return logger_mod_orig
+
+
+def create_submodule_from_subgraph(
+    model: torch.nn.Module,
+    first_node: Node,
+    last_node: Node,
+) -> GraphModule:
+    """
+    Input: a model, and a linear subgraph within the model from first_node to
+      last_node.
+
+    Output: a new submodule containing a copy of the subgraph, with the inputs
+      to the first node becoming the inputs to the submodule, and all other
+      nodes in the subgraph being copied.
+
+    Example inputs:
+
+    `model`: a module with graph
+
+      x0 -> op1 -> x1 -> op2 -> x2
+             |
+            arg1
+
+    `first_node`: op1
+    `last_node`: op2
+
+    Example output: a new module with graph
+
+      input1 -> op1_copy -> x1 -> op2_copy -> output1
+                   |
+                  arg1
+    """
+
+    #
+    # create a blank GraphModule with an empty graph
+    #
+
+    class M(torch.nn.Module):
+        def forward(self, x):
+            pass
+
+    m = M()
+    gm = torch.fx.symbolic_trace(m)
+    g = gm.graph
+    for node in reversed(gm.graph.nodes):
+        g.erase_node(node)
+
+    #
+    # modify the graph to have a copy of our subgraph
+    #
+
+    cur_node_orig = first_node
+
+    cur_name_idx = 0
+
+    iteration_limit = 100
+    cur_iteration = 0
+
+    while True:
+        if cur_node_orig is first_node:
+            # we are at the first node, we need to set up graph inputs
+            # TODO(future): some graphs could have placeholders which are unrelated
+            # to the first node, need to handle this
+            cur_args_copy = []
+            cur_kwargs_copy = {}
+            seen_names: set[str] = set()
+            old_name_to_new_node: dict[str, Node] = {}
+
+            def _add_placeholder(
+                g: Graph, node: Node, seen_names, old_name_to_new_node
+            ):
+                # note: for graphs starting with patterns such as `y = x + x`, we
+                # need to ensure we do not add multiple placeholders with the
+                # same name
+                counter = 0
+                while node.name + "_" + str(counter) in seen_names:
+                    counter += 1
+                cur_name = node.name + "_" + str(counter)
+                seen_names.add(cur_name)
+                placeholder = g.placeholder(cur_name)
+                old_name_to_new_node[node.name] = placeholder
+                return placeholder
+
+            for arg in cur_node_orig.args:
+                if isinstance(arg, Node):
+                    p = _add_placeholder(g, arg, seen_names, old_name_to_new_node)
+                    cur_args_copy.append(p)
+                elif isinstance(arg, (list, tuple)):
+                    new_arg = []
+                    for inner_arg in arg:
+                        if isinstance(inner_arg, Node):
+                            new_arg.append(
+                                _add_placeholder(
+                                    g, inner_arg, seen_names, old_name_to_new_node
+                                )
+                            )
+                        else:
+                            new_arg.append(inner_arg)
+                    cur_args_copy.append(new_arg)
+                else:
+                    cur_args_copy.append(arg)
+
+            # TODO(future PR): handle non-normalized kwargs
+            for kwarg_name, kwarg in cur_node_orig.kwargs.items():
+                if isinstance(kwarg, Node):
+                    cur_kwargs_copy[kwarg_name] = _add_placeholder(
+                        g, kwarg, seen_names, old_name_to_new_node
+                    )
+                elif isinstance(kwarg, (list, tuple)):
+                    new_kwarg = []
+                    for inner_kwarg in kwarg:
+                        p = _add_placeholder(
+                            g,
+                            inner_kwarg,  # type: ignore[arg-type]
+                            seen_names,
+                            old_name_to_new_node,
+                        )
+                        new_kwarg.append(p)
+                    cur_kwargs_copy[kwarg_name] = new_kwarg
+                else:
+                    cur_kwargs_copy[kwarg_name] = kwarg
+
+            cur_args_copy = tuple(cur_args_copy)  # type: ignore[assignment]
+        else:
+            # we are not at first node, first arg is from the previous node,
+            # and all other args are copied
+
+            # the current implementation is simplistic and cannot handle
+            # ops with two or more arguments which need to be passed from
+            # the previous op, so we assert them out
+            assert cur_node_orig.target not in BINARY_FUNCTIONS
+
+            # at this point in the code, cur_node_copy is pointing to the copy
+            # of the previous node
+            # TODO(future PR): this is not handling complicated graphs correctly, need to
+            # look at actual relationships instead of assuming sequential graph
+            # TODO(future PR): this is ignoring kwargs, will need to support kwargs
+            # for any fusion pattern which has them for a node that is not the
+            # first node.
+            cur_args_copy = [cur_node_copy]  # type: ignore[has-type, possibly-undefined]  # noqa: F821
+
+            if len(cur_node_orig.args) > 1:
+                for arg in cur_node_orig.args[1:]:
+                    if isinstance(arg, torch.nn.Parameter):
+                        new_arg = arg.detach().clone()  # type: ignore[assignment]
+                        mod_name = f"mod_{cur_name_idx}"
+                        cur_name_idx += 1
+                        setattr(gm, mod_name, new_arg)
+                        new_arg_placeholder = gm.placeholder(mod_name)  # type: ignore[operator]
+                        cur_args_copy.append(new_arg_placeholder)
+                    elif isinstance(arg, (float, int, torch.dtype)):
+                        cur_args_copy.append(arg)
+                    else:
+                        raise AssertionError(f"arg of type {type(arg)} not handled yet")
+            cur_args_copy = tuple(cur_args_copy)  # type: ignore[assignment]
+
+        # copy the node
+        if cur_node_orig.op == "call_module":
+            orig_mod = getattr_from_fqn(model, cur_node_orig.target)  # type: ignore[arg-type]
+            orig_mod_copy = copy.deepcopy(orig_mod)
+            mod_name = f"mod_{cur_name_idx}"
+            setattr(gm, mod_name, orig_mod_copy)
+            cur_name_idx += 1
+            cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy)  # type: ignore[possibly-undefined,arg-type]
+
+        elif cur_node_orig.op == "call_function":
+            cur_node_copy = g.call_function(
+                cur_node_orig.target,  # type: ignore[arg-type]
+                cur_args_copy,  # type: ignore[arg-type]
+                cur_kwargs_copy,  # type: ignore[possibly-undefined]
+            )
+
+        elif cur_node_orig.op == "call_method":
+            cur_node_copy = g.call_method(
+                cur_node_orig.target,  # type: ignore[arg-type]
+                cur_args_copy,  # type: ignore[arg-type]
+                cur_kwargs_copy,  # type: ignore[possibly-undefined]
+            )
+
+        else:
+            raise AssertionError(f"{cur_node_orig.op} not supported yet")
+
+        if cur_node_orig is last_node:
+            break
+
+        # go to next node
+        assert len(cur_node_orig.users.keys()) == 1, (
+            f"{cur_node_orig} has more than 1 users, not supported yet"
+        )
+        cur_node_orig = next(iter(cur_node_orig.users.keys()))
+        cur_iteration += 1
+        if cur_iteration > iteration_limit:
+            raise AssertionError("iteration limit exceeded")
+
+    # set up outputs
+    g.output(cur_node_copy)
+
+    gm.recompile()
+    return gm
+
+
+def create_one_transformed_and_logged_copy_of_subgraph(
+    mt: GraphModule,
+    subgraph_idx: int,
+    subgraph_candidate_idx: int,
+    first_node: Node,
+    last_node: Node,
+    fqn: Optional[str],
+    list_of_node_name_to_qconfig: list[dict[str, QConfigAny]],
+    example_inputs: Any,
+    last_added_shadow_node_list: list[Optional[Node]],
+    custom_prepare_fn: Optional[Callable] = None,
+    custom_prepare_kwargs: Optional[dict[str, Any]] = None,
+) -> None:
+    """
+    Given a subgraph in `mt` and a subgraph candidate idx, inserts the
+    subgraph candidate copy and instruments it with loggers.
+
+    If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just
+    add a logger to the end.
+
+    If subgraph_candidate_idx is not 0, we create a copy of the subgraph and
+    prepare it with `prepare_fx`.
+    """
+
+    # TODO(future PR): move logger classes to utils to remove circular dependency
+    from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger
+
+    if subgraph_candidate_idx == 0:
+        # idx = 0 is the floating point (original) version of the subgraph
+        # We keep the subgraph as is, and add a logger at the end
+
+        qconfig_str = ""
+        logger_mod_orig = _get_logger_for_subgraph(
+            mt,
+            first_node,
+            last_node,
+            subgraph_idx,
+            subgraph_candidate_idx,
+            qconfig_str,
+            OutputLogger,
+            fqn,
+        )
+
+        attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
+        assert not hasattr(mt, attr_name)
+        setattr(mt, attr_name, logger_mod_orig)
+        with mt.graph.inserting_after(last_node):
+            new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
+            last_added_shadow_node_list[0] = new_node
+
+    else:
+        # idx > 0 means we have a candidate qconfig to try, so we need
+        # to make a copy of the subgraph, feed it with the right inputs,
+        # and add a logger at the end
+
+        # get the qconfig
+        # subtract one because the first candidate is the floating point
+        # version of the subgraph
+        node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
+        qconfig = node_name_to_qconfig[first_node.name]
+
+        # if no quantization is requested, skip
+        # TODO(future PR): deduplicate equivalent qconfigs that come from
+        #   different qconfig mapping objects
+        if qconfig is None:
+            return
+
+        qconfig_mapping = QConfigMapping().set_global(qconfig)
+
+        # create a copy of the submodule, wrapped in a separate module
+        orig_mod_copy_wrapped = create_submodule_from_subgraph(
+            mt, first_node, last_node
+        )
+
+        # add a call to prepare_fx on the wrapper module
+        if custom_prepare_fn is None:
+            orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx(
+                orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs
+            )
+        else:
+            if custom_prepare_kwargs is None:
+                custom_prepare_kwargs = {}
+            for kwarg_name in [
+                "example_inputs",
+                "prepare_custom_config",
+                "qconfig_mapping",
+            ]:
+                assert kwarg_name not in custom_prepare_kwargs, (
+                    f"cannot specify {kwarg_name} in custom_prepare_kwargs"
+                )
+            prepare_kwargs: dict[str, Any] = {
+                "example_inputs": example_inputs,
+                "qconfig_mapping": qconfig_mapping,
+            }
+            prepare_kwargs.update(custom_prepare_kwargs)
+            orig_mod_copy_wrapped = custom_prepare_fn(
+                orig_mod_copy_wrapped, **prepare_kwargs
+            )
+
+        # attach the wrapper to the model
+        attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
+        assert not hasattr(mt, attr_name)
+        setattr(mt, attr_name, orig_mod_copy_wrapped)
+
+        # add a call to the wrapper module from the parent graph
+        insert_after_node = last_added_shadow_node_list[0]
+        with mt.graph.inserting_after(insert_after_node):
+            # TODO(future PR): handle fusion patterns where non-first nodes
+            # need inputs
+
+            # pass in all node args and kwargs
+
+            new_args = []
+            for arg in first_node.args:
+                if isinstance(arg, Node):
+                    new_args.append(arg)
+                elif (
+                    isinstance(arg, (list, tuple))
+                    and len(arg)
+                    and isinstance(arg[0], Node)
+                ):
+                    new_args.extend(
+                        inner_arg for inner_arg in arg if isinstance(inner_arg, Node)
+                    )
+
+            new_kwargs = {}
+            for name, old_kwarg in first_node.kwargs.items():
+                if isinstance(old_kwarg, Node):
+                    new_kwargs[name] = old_kwarg
+                elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg):
+                    # TODO(future PR): clarify why we are adding kwargs to args
+                    new_args.extend(old_kwarg)  # type: ignore[arg-type]
+
+            new_args = tuple(new_args)  # type: ignore[assignment]
+
+            new_node = mt.graph.call_module(attr_name, args=new_args, kwargs=new_kwargs)  # type: ignore[arg-type]
+
+        # add a logger to parent graph to observe the shadow wrapper
+        logger_mod_orig = _get_logger_for_subgraph(
+            mt,
+            first_node,
+            last_node,
+            subgraph_idx,
+            subgraph_candidate_idx,
+            str(qconfig),
+            OutputComparisonLogger,
+            fqn,
+        )
+
+        attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
+        assert not hasattr(mt, attr_name)
+        setattr(mt, attr_name, logger_mod_orig)
+        with mt.graph.inserting_after(new_node):
+            logger = mt.graph.call_module(
+                attr_name, args=(new_node, last_node), kwargs={}
+            )
+            last_added_shadow_node_list[0] = logger
+
+    mt.recompile()
+
+
+def create_n_transformed_and_logged_copies_of_subgraph(
+    mt: GraphModule,
+    subgraph_idx: int,
+    match_name: str,
+    nodes_in_this_subgraph: list[Any],
+    qconfig_mappings: list[QConfigMapping],
+    list_of_node_name_to_qconfig: list[dict[str, QConfigAny]],
+    custom_prepare_fn: Optional[Callable] = None,
+    custom_prepare_kwargs: Optional[dict[str, Any]] = None,
+) -> None:
+    """
+    Given a model `mt` and a subgraph_idx, creates the needed copies
+    of the subgraph for all qconfigs, and instruments them with loggers.
+    """
+    # for now, assume that
+    # 1. the first node has one input
+    # 2. the last node has one output
+
+    # for now, ignore all subgraphs that contain non-nodes (tuples, etc)
+    # TODO(future PR): implement this
+    if any(not isinstance(node, Node) for node in nodes_in_this_subgraph):
+        return
+
+    first_node = nodes_in_this_subgraph[0]
+    last_node = nodes_in_this_subgraph[-1]
+    # We used output propagation to populate example values on each
+    # node. Use the example values from the previous node as the input
+    # to the current node.
+    prev_node = get_normalized_nth_input(first_node, mt, 0)
+    if isinstance(prev_node, list):
+        example_inputs = [x.traced_result for x in prev_node]
+    elif isinstance(prev_node, tuple):
+        example_inputs = (x.traced_result for x in prev_node)  # type: ignore[assignment]
+    else:
+        # currently some customer models do not have a traced_result in
+        # every node, so we have to guard for this case since we cannot
+        # quantize without an example input
+        # TODO(future PR): add a test case for this once we have an easy
+        # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489
+        # for additional context
+        if hasattr(prev_node, "traced_result"):
+            example_inputs = (prev_node.traced_result,)  # type: ignore[attr-defined, assignment]
+        else:
+            print(
+                "unable to get example input for node "
+                + f"{first_node.format_node()}, skipping"
+            )
+            return
+
+    # If there are no quantization configs for this subgraph, skip adding
+    # loggers. This reduces memory usage for models where not all layers are
+    # quantized.
+    # TODO(future): consider making this configurable
+    found_at_least_one_qconfig = False
+    for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
+        if subgraph_candidate_idx == 0:
+            # fp32 baseline does not need a qconfig
+            continue
+
+        # a. we have N shadows, so len(qconfig_mappings) is N
+        # b. we will have the fp32 layer + N shadows, so overall number of
+        #    (original_op) + (*shadows) will be N+1
+        # c. since `subgraph_candidate_idx` represents (b), we need
+        #    to subtract 1 to query from (a)
+        node_name_to_qconfig = list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
+        qconfig = node_name_to_qconfig[first_node.name]
+        if qconfig is not None:
+            found_at_least_one_qconfig = True
+            break
+    if not found_at_least_one_qconfig:
+        print(
+            "unable to find at least one qconfig for node "
+            + f"{first_node.format_node()}, skipping"
+        )
+        return
+
+    fqn = _maybe_get_fqn(first_node, mt)
+
+    # We want the results to contain the subgraphs in natural order,
+    # and the graph to also contain shadow wrappers and shadow loggers
+    # in natural order.
+    # If we just iterate in reverse, the graph will be in natural
+    # order but the eventual results will be in reverse order.
+    # So, we keep track of the last shadow logger we added and
+    # always insert after it.
+    last_added_shadow_node_list: list[Optional[Node]] = [None]
+    for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
+        create_one_transformed_and_logged_copy_of_subgraph(
+            mt,
+            subgraph_idx,
+            subgraph_candidate_idx,
+            first_node,
+            last_node,
+            fqn,
+            list_of_node_name_to_qconfig,
+            example_inputs,
+            last_added_shadow_node_list,
+            custom_prepare_fn,
+            custom_prepare_kwargs,
+        )
+
+
+def create_add_loggers_graph(
+    model: GraphModule,
+    subgraphs_dedup: dict[str, list[Node]],
+    qconfig_mapping: QConfigMapping,
+    node_name_to_qconfig: dict[str, QConfigAny],
+) -> None:
+    r"""
+    Given a model, a model graph partition (currently a set of matched
+    subgraphs) and instructions how to transform each subgraph
+    (currently quantizing it according to qconfig_mapping), modifies
+    the model graph to create an alternate path through the original graph,
+    with each of the subgraphs quantized.  This is useful to compare
+    propagation error of a transformation such as quantization.
+
+    For example, given layer op0 and op1, there are four cases when handling op1:
+    1. op0 and op1 quantized
+    2. op0 and op1 unquantized
+    3. op0 quantized, op1 unquantized
+    4. op0 unquantized, op1 quantized
+
+    Example input, case 1:
+
+    .. code::
+
+      x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
+       \                        \          \                 \       # noqa: W605
+         ---> op0_1 -> x1_1 ----> clog    op1_1 -> x2_1 ----> clog
+
+    Example output, case 1:
+
+    .. code::
+
+      x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
+       \                        \                           \        # noqa: W605
+         ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog
+
+    """
+    # TODO(future PR): move logger classes to utils to remove circular dependency
+    from torch.ao.ns._numeric_suite_fx import OutputComparisonLogger, OutputLogger
+
+    def _get_subgraph_containing_node(node, subgraphs_dedup):
+        for subgraph in subgraphs_dedup.values():
+            if node in subgraph:
+                return subgraph
+        return None
+
+    # First, we need to create shadow branches, going from
+    #
+    #   x0 -> op0 -> x1 -> ...
+    #
+    #
+    # to
+    #
+    #   x0 -> op0_0 -> x1_0 -> log -> ...
+    #    \                     \
+    #      -> op0_1 -> x1_1 -> clog
+    #
+    # Later, the outputs of each shadow will be rerouted to calculate
+    # propagation error.
+
+    # Note: we cannot iterate over matched subgraphs because some nodes
+    # may not be matched. So, we iterate over nodes in the graph, and
+    # associate them to matched subgraphs if possible.
+
+    nodes_to_skip = set()
+    # for each subgraph, save a mapping from first node of subgraph
+    # to first and last node of the shadow of this subgraph
+    orig_first_node_to_shadow_in_node = {}
+    orig_first_node_to_shadow_out_node = {}
+    # need to record original list because we will mutate the graph as we go
+    orig_nodes = list(model.graph.nodes)  # type: ignore[union-attr, arg-type]
+    cur_subgraph_idx = 0
+    for n in orig_nodes:
+        if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip:
+            continue
+
+        maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
+        insert_submodule_copy = False
+        if maybe_subgraph is not None:
+            first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
+            nodes_to_skip.update(maybe_subgraph)
+            qconfig = node_name_to_qconfig[first_node.name]
+            if qconfig is not None:
+                insert_submodule_copy = True
+        else:
+            first_node, last_node = n, n
+
+        if insert_submodule_copy:
+            match_name = first_node.name
+            create_n_transformed_and_logged_copies_of_subgraph(
+                model,
+                cur_subgraph_idx,
+                match_name,
+                maybe_subgraph,
+                [qconfig_mapping],
+                [node_name_to_qconfig],
+                None,
+                None,  # type: ignore[arg-type]
+            )
+            # find the created shadow module and record it so we
+            # can find it easily in step 2
+            expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1"
+            new_shadow_mod = None
+            for maybe_shadow_mod in model.graph.nodes:
+                if (
+                    maybe_shadow_mod.op == "call_module"
+                    and maybe_shadow_mod.target == expected_shadow_target
+                ):
+                    new_shadow_mod = maybe_shadow_mod
+                    break
+            assert new_shadow_mod is not None
+            orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
+            orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
+
+        else:
+            # create a copy of the subgraph by only copying FX nodes
+            # but not copying any parameters, to minimize memory usage
+            subgraph_to_use = (
+                maybe_subgraph if maybe_subgraph is not None else [first_node]
+            )
+
+            # add a regular logger after last_node
+            qconfig_str = ""
+            subgraph_candidate_idx = 0
+            fqn = _maybe_get_fqn(first_node, model)
+            logger_mod_orig = _get_logger_for_subgraph(
+                model,
+                first_node,
+                last_node,
+                cur_subgraph_idx,
+                subgraph_candidate_idx,
+                qconfig_str,
+                OutputLogger,
+                fqn,
+            )
+            attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
+            assert not hasattr(model, attr_name)
+            setattr(model, attr_name, logger_mod_orig)
+            insertion_point = last_node
+            with model.graph.inserting_after(insertion_point):
+                logger = model.graph.call_module(
+                    attr_name, args=(last_node,), kwargs={}
+                )
+                insertion_point = logger
+
+            # create a copy of the subgraph
+            cur_node_orig = first_node
+            cur_node_copy = None
+            first_node_copy = None
+            while cur_node_orig in subgraph_to_use:
+                # TODO(future PR): make this support all possible args/kwargs
+                if cur_node_orig is first_node:
+                    new_args = cur_node_orig.args
+                    new_kwargs = cur_node_orig.kwargs
+                else:
+                    first_arg_for_copy: Optional[Node] = cur_node_copy
+                    new_args = (first_arg_for_copy, *cur_node_orig.args[1:])
+                    new_kwargs = cur_node_orig.kwargs
+                # make a copy of cur_node_orig
+                with model.graph.inserting_after(insertion_point):
+                    cur_node_copy = model.graph.create_node(
+                        cur_node_orig.op,
+                        cur_node_orig.target,
+                        new_args,
+                        new_kwargs,
+                        # cur_node_orig.name,  # TODO(future PR): set name explicitly
+                    )
+                    if first_node_copy is None:
+                        first_node_copy = cur_node_copy
+                # since now only linear subgraphs are supported, all nodes
+                # except the last one must have only one user
+                if cur_node_orig != last_node:
+                    assert len(cur_node_orig.users.keys()) == 1
+                cur_node_orig = next(iter(cur_node_orig.users.keys()))
+                assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
+                insertion_point = cur_node_copy
+
+            # add a comparison logger after last_node's copy
+            subgraph_candidate_idx = 1
+            logger_mod_orig = _get_logger_for_subgraph(
+                model,
+                first_node,
+                last_node,
+                cur_subgraph_idx,
+                subgraph_candidate_idx,
+                qconfig_str,
+                OutputComparisonLogger,
+                fqn,
+            )
+            attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
+            assert not hasattr(model, attr_name)
+            setattr(model, attr_name, logger_mod_orig)
+            with model.graph.inserting_after(insertion_point):
+                logger = model.graph.call_module(
+                    attr_name, args=(cur_node_copy, last_node), kwargs={}
+                )
+
+            # save the final node so we can use it in step 2
+            orig_first_node_to_shadow_in_node[first_node] = first_node_copy
+            orig_first_node_to_shadow_out_node[first_node] = cur_node_copy
+
+        cur_subgraph_idx += 1
+
+    model.recompile()
+
+    # Now, we go from
+    #
+    #   x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ...
+    #    \                     \       \
+    #      -> op0_1 -> x1_1 -> clog      -> op1_1 -> ...
+    #
+    # to
+    #
+    #   x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ...
+    #    \                     \
+    #      -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ...
+    #
+    # sample values of key internal variables for the example above:
+    #
+    #   orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1}
+    #   orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1}
+    #
+    # note: for subgraphs with more than one node, in_node will be different
+    # compared to out_node
+
+    nodes_to_skip = set()
+    for n in orig_nodes:
+        if n.op in ("placeholder", "get_attr", "output") or n in nodes_to_skip:
+            continue
+
+        maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
+        if maybe_subgraph is not None:
+            first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
+            nodes_to_skip.update(maybe_subgraph)
+        else:
+            first_node, last_node = n, n
+
+        def maybe_remap_node_to_shadow(node):
+            """
+            If unshadowed `node` has a shadow version, return that. If not,
+            return `node`.
+            """
+            if not isinstance(node, Node):
+                # handle scalars
+                return node
+
+            if node.op in ("placeholder", "get_attr"):
+                return node
+
+            # Find the shadowed version of this arg from the previous
+            # subgraph. For this, we need to:
+            # 1. navigate to the first node of the previous subgraph
+            # 2. get the output of the shadow wrapper which has (1) as an input
+
+            # For now, assume the arg is in matched subgraphs. In the
+            # future we may have to handle the case where this is not true.
+            prev_subgraph = _get_subgraph_containing_node(node, subgraphs_dedup)
+            if prev_subgraph is None:
+                prev_subgraph = [node]
+            prev_first_node = prev_subgraph[0]
+            prev_shadow_output = orig_first_node_to_shadow_out_node[prev_first_node]
+            return prev_shadow_output
+
+        cur_shadow_input = orig_first_node_to_shadow_in_node[first_node]
+        assert cur_shadow_input is not None
+        cur_shadow_input.args = tree_map(
+            maybe_remap_node_to_shadow, cur_shadow_input.args
+        )
+        cur_shadow_input.kwargs = tree_map(
+            maybe_remap_node_to_shadow, cur_shadow_input.kwargs
+        )
+
+        model.recompile()
+
+
+def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
+    # input: shadow wrapper module
+    # output if shadow wrapper module has a weighted op:
+    #   (quantize_fn, (quantize_fn_args))
+    # output if shadow wrapper module doesn't have a weighted op:
+    #   None
+
+    # For now, assume that the weight is the second input
+    # to the shadow module. If that changes, we can fix it later.
+    placeholders_seen = 0
+    for shadow_n in shadow_wrapper.graph.nodes:  # type: ignore[union-attr]
+        if shadow_n.op != "placeholder":
+            continue
+
+        placeholders_seen += 1
+        if placeholders_seen != 2:
+            continue
+
+        # the subgraph looks like
+        #
+        #   _input_scale_1 = self._input_scale_1
+        #   _input_zero_point_1 = self._input_zero_point_1
+        #   quantize_per_channel = torch.quantize_per_channel(
+        #       w2_0, _input_scale_1, _input_zero_point_1,
+        #       0, torch.qint8)
+        #
+        #  we have `w2_0`, and are navigating this subgraph
+        #  to get `_input_scale_1` and `_input_zero_point_1`
+
+        assert len(shadow_n.users) == 1
+        quant_node = next(iter(shadow_n.users.keys()))
+        new_args: Any = None
+        if quant_node.target == torch.quantize_per_channel:
+            _weight, scale_node, zp_node, axis, dtype = quant_node.args
+            scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
+            zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
+            new_args = (scale_val, zp_val, axis, dtype)
+        else:
+            assert quant_node.target == torch.quantize_per_tensor
+            _weight, scale_node, zp_node, dtype = quant_node.args
+            scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target)
+            zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target)
+            new_args = (scale_val, zp_val, dtype)
+        return (quant_node.target, new_args)
+
+    return None
+
+
+def extract_weight_comparison(m: GraphModule) -> NSResultsType:
+    # example graph:
+    #
+    #   w1 = self.w1
+    #   b1 = self.b1
+    #   linear = torch._C._nn.linear(x, w1, b1)
+    #   shadow_0_0 = self.shadow_0_0(linear)
+    #   shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1)
+    #   shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear)
+    #
+    # algorithm:
+    # 1. for each call_function node matching our allowlist:
+    # 2.   if corresponding shadow wrapper exists, extract the weight pair
+    #
+    # Note: this is not super robust, but that's ok because this is
+    # just for legacy customers who depend on the previous two-model version
+    # of this API. TBD if we need to make this robust.
+    # Note: modules are not supported, since existing customers only
+    # use functions.
+
+    # TODO(future PR): move this to config
+    weighted_ops = {
+        torch.nn.functional.linear,
+    }
+
+    results: NSResultsType = {"model": {NSSingleResultValuesType.WEIGHT.value: {}}}
+
+    for n in m.graph.nodes:  # type: ignore[union-attr]
+        if not (n.op == "call_function" and n.target in weighted_ops):
+            continue
+
+        # Check if we have a corresponding shadow wrapper
+        # TODO(future PR, if needed): support kwargs
+        # TODO(future PR, if needed): support multiple shadow users
+        first_arg = n.args[0]
+        shadow_wrapper_node = None
+        for user in first_arg.users:
+            # TODO(before land): fix string match
+            if user.op == "call_module" and user.target.startswith("shadow_wrapper"):
+                shadow_wrapper_node = user
+                break
+
+        if shadow_wrapper_node is None:
+            continue
+
+        shadow_wrapper = getattr_from_fqn(m, shadow_wrapper_node.target)  # type: ignore[arg-type]
+        weight_info = _get_weight_info_from_shadow_wrapper(shadow_wrapper)
+        if weight_info is None:
+            continue
+
+        # get weight
+        w_node = n.args[1]
+        w_obj = getattr_from_fqn(m, w_node.target).detach()
+
+        # get a quantized version of weight
+        quant_fn, quant_fn_args_except_first = weight_info
+        new_args = (w_obj, *quant_fn_args_except_first)
+        w_obj_q = quant_fn(*new_args)
+
+        # add a comparison
+        ref_node_name = n.name
+        prev_node_name = n.name
+        ref_node_type = get_target_type_str(n, m)
+        prev_node_type = ref_node_type
+        fqn = None
+        if hasattr(m, "_node_name_to_scope"):
+            fqn = m._node_name_to_scope[n.name][0]  # type: ignore[index]
+        comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q)
+        result_fp32 = {
+            "res_type": NSSingleResultValuesType.WEIGHT.value,
+            "values": [w_obj],
+            "prev_node_name": prev_node_name,
+            "prev_node_target_type": prev_node_type,
+            "ref_node_name": ref_node_name,
+            "ref_node_target_type": ref_node_type,
+            "index_within_arg": 0,
+            "index_of_arg": 0,
+            "fqn": fqn,
+            "qconfig_str": "",
+            "comparisons": [comparison],
+            "comparison_fn_name": "sqnr",
+        }
+        result_q = {
+            "res_type": NSSingleResultValuesType.WEIGHT.value,
+            "values": [w_obj_q],
+            "prev_node_name": prev_node_name,
+            "prev_node_target_type": prev_node_type,
+            "ref_node_name": ref_node_name,
+            "ref_node_target_type": ref_node_type,
+            "index_within_arg": 0,
+            "index_of_arg": 0,
+            "fqn": fqn,
+            "qconfig_str": "",
+            "comparisons": [comparison],
+            "comparison_fn_name": "sqnr",
+        }
+
+        # go from subgraph_n_1 to subgraph_n_0
+        _1, _2, node_idx, _3 = shadow_wrapper_node.target.split("_")
+        name_fp32 = f"subgraph_{node_idx}_0"
+        name_q = f"subgraph_{node_idx}_1"
+
+        results["model"][NSSingleResultValuesType.WEIGHT.value][name_fp32] = [
+            result_fp32
+        ]
+        results["model"][NSSingleResultValuesType.WEIGHT.value][name_q] = [result_q]
+
+    return results
+
+
+# TODO(future PR): redesign this to make it easier to consume outputs
+def group_results_by_subgraph(results: NSResultsType) -> Any:
+    """
+    Creates a comparison of results
+
+    Input:
+
+    {
+      'model': {
+        'node_output': {
+          'subgraph_0_0': [
+            'values': [torch.tensor(...), ...], ...
+            'ref_node_name': ...,
+            'ref_node_target_type': ...,
+            'qconfig_str': ...,
+            'comparisons': [], ...
+            'comparison_fn_name': '',
+            'fqn': '...',
+          ],
+          'subgraph_0_1': [
+            'values': [torch.tensor(...), ...], ...
+            'ref_node_name': ...,
+            'ref_node_target_type': ...,
+            'qconfig_str': ...,
+            'comparisons': [torch.tensor(...), ...], ...
+            'comparison_fn_name': '...',
+            'fqn': '...',
+          ],
+          ...
+        },
+      },
+    }
+
+    Output:
+    {
+      'subgraph_0': {
+        '0': {
+          'ref_node_name': '...',
+          'ref_node_target_type': ...,
+          'values': [torch.tensor(...), ...],
+          'qconfig_str': None,
+          'comparisons': [torch.tensor(...), ...], ...
+          'comparison_fn_name': '...',
+          'fqn': '...',
+        },
+        '1': {
+          'ref_node_name': '...',
+          'ref_node_target_type': ...,
+          'values': [torch.tensor(...), ...],
+          'qconfig_str': '...',
+          'comparisons': [torch.tensor(...), ...], ...
+          'comparison_fn_name': '...',
+          'fqn': '...',
+        },
+      },
+    }
+
+    """
+    subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict)
+
+    # node_output or weight
+    key_to_use = next(iter(results["model"].keys()))
+
+    for subgraph_name_with_idx, subgraph_candidate_results in results["model"][
+        key_to_use
+    ].items():
+        # convert from `subgraph_m_n` to `subgraph_m` and `n`
+        (
+            subgraph_str,
+            subgraph_idx,
+            subgraph_candidate_idx,
+        ) = subgraph_name_with_idx.split("_")
+        subgraph_name = f"{subgraph_str}_{subgraph_idx}"
+
+        subgraph_results = {
+            "ref_node_name": subgraph_candidate_results[0]["ref_node_name"],
+            "ref_node_target_type": subgraph_candidate_results[0][
+                "ref_node_target_type"
+            ],
+            "fqn": subgraph_candidate_results[0]["fqn"],
+            "values": subgraph_candidate_results[0]["values"],
+            "qconfig_str": subgraph_candidate_results[0]["qconfig_str"],
+            "comparisons": subgraph_candidate_results[0]["comparisons"],
+            "comparison_fn_name": subgraph_candidate_results[0]["comparison_fn_name"],
+        }
+
+        subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = (
+            subgraph_results
+        )
+
+    return dict(subgraph_name_to_subgraph_results)
+
+
+# TODO(future PR): redesign this to make it easier to consume outputs
+def create_results_comparison(
+    results_grouped,
+) -> Any:
+    """
+    Input:
+
+    {
+      'subgraph_0': {
+        '0': {
+          'ref_node_name': '...',
+          'ref_node_target_type': ...,
+          'values': [torch.tensor(...), ...],
+          'qconfig_str': '',
+          'comparisons': [],
+          'comparison_fn_name': '',
+          'fqn': '...',
+        },
+        '1': {
+          'ref_node_name': '...',
+          'ref_node_target_type': ...,
+          'values': [torch.tensor(...), ...],
+          'qconfig_str': '...',
+          'comparisons': [torch.tensor(...), ...],
+          'comparison_fn_name': 'sqnr',
+          'fqn': '...',
+        },
+      },
+    }
+
+    Output:
+    {
+      'subgraph_0': {
+        'ref_node_name': '...',
+        'ref_node_target_type': '...',
+        'fqn': '...',
+        'candidates': {
+          '1': {
+            'qconfig_str': ...,
+            'comparison_fn_name': 'sqnr',
+            'cmp_raw': [..., ...],
+            'cmp_mean': ...,
+          },
+          ...,
+        },
+      },
+    }
+    """
+
+    results_comparison = {}
+
+    for subgraph_name, subgraph_results in results_grouped.items():
+        candidates = {}
+        for subgraph_inner_name, subgraph_inner_result in subgraph_results.items():
+            # skip comparing baseline to baseline
+            if subgraph_inner_name == "0":
+                continue
+
+            # we expect the comparisons to be precalculated from
+            # calibration, so we just fetch them here
+            cmp_raw = subgraph_inner_result["comparisons"]
+            cmp_raw_tensor = torch.stack(cmp_raw)
+
+            candidates[subgraph_inner_name] = {
+                "qconfig_str": subgraph_inner_result["qconfig_str"],
+                "comparison_fn_name": subgraph_inner_result["comparison_fn_name"],
+                "cmp_raw": cmp_raw_tensor,
+                "cmp_mean": torch.mean(cmp_raw_tensor),
+            }
+
+        results_comparison[subgraph_name] = {
+            "ref_node_name": subgraph_results["0"]["ref_node_name"],
+            "ref_node_target_type": subgraph_results["0"]["ref_node_target_type"],
+            "fqn": subgraph_results["0"]["fqn"],
+            "candidates": candidates,
+        }
+
+    return results_comparison
+
+
+# TODO(future PR): redesign this to make it easier to consume outputs
+def print_n_shadows_summary(
+    results_comparison,
+) -> None:
+    """
+    Input:
+
+    {
+      'subgraph_0': {
+        'ref_node_name': 'linear1',
+        'ref_node_target_type': '...',
+        'fqn': '...',
+        'candidates': {
+          '1': {
+            'qconfig_str': ...,
+            'comparison_fn_name': ...,
+            'cmp_raw': [45.0, 55.0],
+            'cmp_mean': 50.0,
+          },
+          ...,
+        },
+      },
+    }
+
+    Prints:
+
+    node_name | node_type | fqn | 0    | 1    | ...
+    linear1   | ...       | ... | 45.0 | 50.0 | ...
+    """
+
+    try:
+        from tabulate import tabulate
+    except ImportError:
+        print(
+            "`print_tabular` relies on the library `tabulate`, "
+            "which could not be found on this machine. Run `pip "
+            "install tabulate` to install the library."
+        )
+        return
+
+    results = []
+    for subgraph_data in results_comparison.values():
+        mean_all_candidates = [
+            candidate["cmp_mean"]
+            for candidate_name, candidate in subgraph_data["candidates"].items()
+        ]
+
+        data_row = [
+            subgraph_data["ref_node_name"],
+            subgraph_data["ref_node_target_type"],
+            subgraph_data["fqn"],
+            *mean_all_candidates,
+        ]
+        results.append(data_row)
+
+    max_candidate_idx_len = -1
+    for data_row in results:
+        max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1]))
+    candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)]
+
+    headers = ["node_name", "node_type", "fqn", *candidate_idx_headers]
+    print(tabulate(results, headers=headers))
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/ns_types.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/ns_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7fcd28e364831a5159c7db06785bc3a18c9c9ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/ns_types.py
@@ -0,0 +1,65 @@
+import enum
+from typing import Any, Callable, NamedTuple, Union
+
+from torch.fx.graph import Node
+
+
+class NSSingleResultValuesType(str, enum.Enum):
+    WEIGHT = "weight"
+    NODE_OUTPUT = "node_output"
+    NODE_INPUT = "node_input"
+
+
+class NSSubgraph(NamedTuple):
+    start_node: Node
+    end_node: Node
+    base_op_node: Node
+
+
+# TODO(future PR): see if we can use typing_extensions's TypedDict instead
+# to properly type the various keys
+# {
+#   # one of NSSingleResultValuesType
+#   'type': 'weight',
+#   # the values of type specified above
+#   'values': [torch.tensor(...), ...],
+#   # name of the node directly before the logger
+#   'prev_node_name': 'linear1',
+#   # type of the underlying function or module
+#   'prev_node_target_type': torch.nn.functional.linear  # or torch.nn.Linear, etc
+#   # name of the node responsible for adding this logger
+#   # Note: this may differ from prev_node_name if we are logging inputs
+#   'ref_node_name': 'linear1',
+#   # index of this node within the arg of the input/output node
+#   # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
+#   'index_within_arg': 0,
+#   # index of this node within the args of the input/output node
+#   # for example, in add(x1, x2), x2 would have index_of_arg == 1
+#   'index_of_arg': 0,
+#   # precomputed comparisons of logger values to reference values
+#   'comparisons': [torch.tensor(...), ...]
+#   # name of function used for precomputed comparisons
+#   'comparison_fn_name': 'sqnr',
+#   # string representation of qconfig responsible for creating this logger
+#   'qconfig_str': 'QConfig(...)',
+# }
+NSSingleResultType = dict[str, Any]
+
+# {
+#   'layer_name_1': {  # subgraph name
+#     'node_output': {  # results type (node_output, node_input, weight)
+#       'model_name_a':  # model name
+#          [NSSingleResultType, ...],  # results, ordered by index_within_arg
+#       'model_name_b':
+#          [NSSingleResultType, ...],
+#     },
+#   },
+# }
+#
+NSResultsType = dict[str, dict[str, dict[str, list[NSSingleResultType]]]]
+
+# Defines the underlying target type of a node, for example:
+# `F.conv1d` for a `call_function` conv node
+# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module
+# `'sigmoid'` for a `call_method` node calling `x.sigmoid()`
+NSNodeTargetType = Union[Callable, str]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/pattern_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/pattern_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ac267417f9730d6bbf6f7b9c150eeea619b350b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/pattern_utils.py
@@ -0,0 +1,209 @@
+from typing import Any, Callable, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.quantization import FakeQuantizeBase, ObserverBase
+from torch.ao.quantization.backend_config import get_native_backend_config
+from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
+from torch.ao.quantization.utils import getattr_from_fqn
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+
+from .ns_types import NSNodeTargetType
+
+
+toq = torch.ops.quantized
+
+
+def get_type_a_related_to_b(
+    base_name_to_sets_of_related_ops: dict[str, set[NSNodeTargetType]],
+) -> set[tuple[NSNodeTargetType, NSNodeTargetType]]:
+    # TODO(future PR): allow customizations
+    # TODO(future PR): reuse existing quantization mappings
+    # TODO(future PR): add the rest of modules and ops here
+    type_a_related_to_b: set[tuple[NSNodeTargetType, NSNodeTargetType]] = set()
+
+    for s in base_name_to_sets_of_related_ops.values():
+        s_list = list(s)
+        # add every bidirectional pair
+        for idx_0 in range(0, len(s_list)):
+            for idx_1 in range(idx_0, len(s_list)):
+                type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
+                type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
+
+    return type_a_related_to_b
+
+
+NSFusionElType = Union[
+    Callable,  # call_function or call_module type, example: F.linear or nn.Conv2d
+    str,  # call_method name, example: "dequantize"
+    tuple[
+        str, Any
+    ],  # call_method name and first argument, example: ("to", torch.float16)
+]
+NSFusionType = Union[
+    tuple[NSFusionElType, NSFusionElType],
+    tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
+]
+
+
+def get_reversed_fusions() -> list[tuple[NSFusionType, int]]:
+    """
+    Set of potential fusions, in reverse order.  The order is reversed
+    to match how fusion patterns are defined in quantization code.
+
+    Fusion format:
+    ((fusion_op_0, fusion_op_1), base_op_idx)
+
+    Where base_op_idx is the idx of the op we should use to match other related
+    ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
+    of 0 represents the first op in regular (non-reverse) order, 1 represents the
+    second op, etc.
+    """
+    results: list[tuple[NSFusionType, int]] = []
+
+    # Possible syntaxes:
+    # * single op: torch.nn.Conv2d
+    # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
+    # For fusions, we only care about patterns composed of multiple ops.
+    # TODO(future PR): allow customizations from default patterns.
+    all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
+
+    default_base_op_idx = 0
+    for quant_pattern in all_quant_patterns.keys():
+        # TODO: this is a temporary hack to flatten the patterns from quantization so
+        # that it works with the ns matcher function, maybe we should use `_is_match`
+        # in torch.ao.quantization.fx.match_utils to match the patterns
+        if (
+            isinstance(quant_pattern, tuple)
+            and len(quant_pattern) == 2
+            and isinstance(quant_pattern[1], tuple)
+            and len(quant_pattern[1]) == 2
+        ):
+            # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
+            quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1])
+
+        # Only patterns of multiple ops are fusions, ignore
+        # patterns which contain a single ops (they get matched
+        # without caring about fusions).
+        if isinstance(quant_pattern, tuple):
+            results.append((quant_pattern, default_base_op_idx))  # type: ignore[arg-type]
+
+        # For each pattern, add additional patterns with observers and
+        # fake quants at the end.
+        # TODO(future PR): if needed, implement matching for a node
+        #   having multiple output observers.
+        for cls in (ObserverBase, FakeQuantizeBase):
+            if isinstance(quant_pattern, tuple):
+                new_pattern = (cls, *quant_pattern)
+            else:
+                new_pattern = (cls, quant_pattern)
+            results.append((new_pattern, default_base_op_idx))  # type: ignore[arg-type]
+
+    # After this point, results contains values such as
+    # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
+
+    # Patterns for matching fp16 emulation are not specified in the quantization
+    # fusion mappings.  For now, define them here.
+    fp16_em_base_op_idx = 1
+    patterns_to_add = [
+        # linear-relu fp16 emulation:
+        # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
+        (
+            (("to", torch.float16), F.relu, F.linear, "dequantize"),
+            fp16_em_base_op_idx,
+        ),
+        # Conv-BN fusion (this happens outside of quantization patterns,
+        # which is why it is defined separately here).
+        ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
+        ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
+        ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
+        ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
+        ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
+        ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
+    ]
+    for p in patterns_to_add:
+        results.append(p)  # type: ignore[arg-type]
+        results.append(((ObserverBase, *p[0]), p[1]))  # type: ignore[arg-type]
+        results.append(((FakeQuantizeBase, *p[0]), p[1]))  # type: ignore[arg-type]
+
+    return results
+
+
+def end_node_matches_reversed_fusion(
+    end_node: Node,
+    reversed_fusion: NSFusionType,
+    gm: GraphModule,
+    seen_nodes: set[Node],
+) -> bool:
+    """
+    Returns true if a pattern ending with `end_node` matches
+    the fusion pattern.
+    """
+    cur_node = end_node
+    for fusion_idx in range(len(reversed_fusion)):
+        # each node can only belong to one matched pattern
+        if cur_node in seen_nodes:
+            return False
+
+        cur_fusion_el = reversed_fusion[fusion_idx]
+
+        if cur_node.op == "call_function":
+            fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and (
+                not isinstance(cur_fusion_el, type)
+            )
+            if fusion_el_is_fun:
+                if cur_node.target != cur_fusion_el:
+                    return False
+                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
+                    cur_node = cur_node.args[0]
+                else:
+                    return False
+            else:
+                return False
+
+        elif cur_node.op == "call_module":
+            fusion_el_is_mod = isinstance(cur_fusion_el, type)
+            if fusion_el_is_mod:
+                assert isinstance(cur_node.target, str)
+                target_mod = getattr_from_fqn(gm, cur_node.target)
+                if not isinstance(cur_fusion_el, type):
+                    return False
+                if not isinstance(target_mod, cur_fusion_el):
+                    return False
+                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
+                    cur_node = cur_node.args[0]
+                else:
+                    return False
+            else:
+                return False
+
+        elif cur_node.op == "call_method":
+            fusion_el_is_meth_with_second_arg = (
+                isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
+            )
+            fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
+            if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
+                if fusion_el_is_meth_without_args:
+                    if cur_node.target != cur_fusion_el:
+                        return False
+                else:
+                    assert isinstance(cur_fusion_el, tuple)
+                    if cur_node.target != cur_fusion_el[0]:
+                        return False
+                    elif len(cur_node.args) < 2:
+                        return False
+                    elif cur_node.args[1] != cur_fusion_el[1]:
+                        return False
+
+                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
+                    cur_node = cur_node.args[0]
+                else:
+                    return False
+            else:
+                return False
+        else:
+            return False
+
+    return True
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a7865f2f14b764bb64116f1606b68cc0c40369c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py
@@ -0,0 +1,249 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import copy
+from typing import Any, Callable, TYPE_CHECKING, Union
+
+import torch
+from torch.ao.quantization import QConfigMapping
+from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
+
+
+if TYPE_CHECKING:
+    from torch.ao.quantization.qconfig import QConfigAny
+
+__all__ = ["QConfigMultiMapping"]
+
+_QCONFIG_STYLE_TO_METHOD: dict[str, str] = {
+    "global_qconfig": "set_global",
+    "object_type_qconfigs": "set_object_type",
+    "module_name_regex_qconfigs": "set_module_name_regex",
+    "module_name_qconfigs": "set_module_name",
+    "module_name_object_type_order_qconfigs": "set_module_name_object_type_order",
+}
+
+
+def _remove_duplicates_and_none(qconfig_list: list[QConfigAny]) -> None:
+    to_remove = []
+    for index, cur_qconfig in enumerate(qconfig_list):
+        if cur_qconfig is None:
+            to_remove.append(index)
+            break
+        for checked_qconfig in qconfig_list[:index]:
+            if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig):
+                to_remove.append(index)
+                break
+    for index in to_remove[::-1]:
+        qconfig_list.pop(index)
+
+
+class QConfigMultiMapping:
+    """
+    This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s
+    so that multiple QConfigs can be specified for each QConfig matching style.
+
+    The user can specify QConfigs using the following methods (in increasing match priority):
+
+        ``set_global`` : sets the global (default) QConfigs
+
+        ``set_object_type`` : sets the QConfigs for a given module type, function, or method name
+
+        ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string
+
+        ``set_module_name`` : sets the QConfigs for modules matching the given module name
+
+        ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination
+        of the given module name, object type, and the index at which the module appears
+
+    Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a
+    single QConfig.
+
+    Example usage::
+
+        qconfig_mapping = QConfigMultiMapping()
+            .set_global([qconfig1, qconfig2])
+            .set_object_type(torch.nn.Linear, [qconfig2, qconfig3])
+            .set_object_type(torch.nn.ReLU, [qconfig1])
+            .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2])
+            .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3])
+            .set_module_name("module1", [None])
+            .set_module_name("module2", [qconfig2])
+            .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3])
+
+    """
+
+    def __init__(self) -> None:
+        # initialize this with 1 QConfigMapping to avoid corner cases
+        self.qconfig_mappings_list: list[QConfigMapping] = [QConfigMapping()]
+
+    def _handle_list_size_mismatch(
+        self, qconfig_list: list[QConfigAny], style: str
+    ) -> None:
+        # this method handles cases where the size of qconfig_list does not match
+        # the size of qconfig_mappings_list.
+        # Issue: Consider a user inserting global_qconfig A and B first, then inserting
+        # qconfig C as an object_type_qconfig for conv ops. If we internally store
+        # 1 QConfigMapping with A and C and another with just B, then the
+        # second QConfigMapping will match B to conv ops (which is not wanted), since B is global.
+
+        # we avoid this by maintaining the invariant that if any QConfigMapping
+        # has a qconfig style+key with a qconfig in it, all QConfigMappings must
+        # have either a qconfig or None for that same style+key. In the above
+        # example, a None qconfig would prevent the unwanted match in the
+        # second QConfigMapping
+
+        if len(qconfig_list) > len(self.qconfig_mappings_list):
+            # Case: we have more qconfigs (in qconfig_list) than QConfigMappings
+
+            # Add new QConfigMappings (initialized so we maintain the `invariant`)
+
+            new_qconfig_mapping = QConfigMapping()
+            # searches other QConfigMappings for qconfig style+keys
+            # that need to be inserted as `None` into the new QConfigMapping
+            for qconfig_mapping in self.qconfig_mappings_list:
+                # global_qconfig has None by default
+                for check_style in _QCONFIG_STYLE_ORDER[1:]:
+                    qconfigs_dict = getattr(qconfig_mapping, check_style)
+                    target_qconfigs_dict = getattr(new_qconfig_mapping, check_style)
+                    for key in qconfigs_dict:
+                        target_qconfigs_dict[key] = None
+                break
+
+            # insert copies of this new QConfigMapping until all entires
+            # in qconfig_list can fit among the QConfigMappings
+            while len(qconfig_list) > len(self.qconfig_mappings_list):
+                self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping))
+        else:
+            # Case: we have fewer qconfigs in qconfig_list than QConfigMappings
+
+            # pad qconfig_list with `None` until length is same
+            while len(qconfig_list) < len(self.qconfig_mappings_list):
+                qconfig_list.append(None)
+
+    # this function applies the insertion method across each QConfigMapping
+    def _insert_qconfig_list(
+        self,
+        style: str,
+        args: list[Union[str, int, Callable]],
+        qconfig_list: list[QConfigAny],
+    ) -> None:
+        # we remove duplicates and None to make the ordering of qconfigs
+        # deterministic upon insertion.
+        _remove_duplicates_and_none(qconfig_list)
+
+        self._handle_list_size_mismatch(qconfig_list, style)
+        method_name = _QCONFIG_STYLE_TO_METHOD[style]
+        for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list):
+            # uses QConfigMapping set method to insert qconfig
+            set_method = getattr(qconfig_mapping, method_name)
+            set_method(*args, qconfig)
+
+    def set_global(self, global_qconfig_list: list[QConfigAny]) -> QConfigMultiMapping:
+        """
+        Set global QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info
+        """
+        self._insert_qconfig_list("global_qconfig", [], global_qconfig_list)
+        return self
+
+    def set_object_type(
+        self, object_type: Union[Callable, str], qconfig_list: list[QConfigAny]
+    ) -> QConfigMultiMapping:
+        """
+        Set object type QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info
+        """
+        self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list)
+        return self
+
+    def set_module_name_regex(
+        self, module_name_regex: str, qconfig_list: list[QConfigAny]
+    ) -> QConfigMultiMapping:
+        """
+        Set module_name_regex QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info
+        """
+        self._insert_qconfig_list(
+            "module_name_regex_qconfigs", [module_name_regex], qconfig_list
+        )
+        return self
+
+    def set_module_name(
+        self, module_name: str, qconfig_list: list[QConfigAny]
+    ) -> QConfigMultiMapping:
+        """
+        Set module_name QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info
+        """
+        self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list)
+        return self
+
+    def set_module_name_object_type_order(
+        self,
+        module_name: str,
+        object_type: Callable,
+        index: int,
+        qconfig_list: list[QConfigAny],
+    ) -> QConfigMultiMapping:
+        """
+        Set module_name QConfigs
+        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info
+        """
+        self._insert_qconfig_list(
+            "module_name_object_type_order_qconfigs",
+            [module_name, object_type, index],
+            qconfig_list,
+        )
+        return self
+
+    def __repr__(self):
+        return (
+            self.__class__.__name__
+            + " ["
+            + "".join(
+                f"\n{qconfig_mapping.__repr__()},"
+                for qconfig_mapping in self.qconfig_mappings_list
+            )
+            + "\n]"
+        )
+
+    @classmethod
+    def from_list_qconfig_mapping(
+        cls, qconfig_mapping_list: list[QConfigMapping]
+    ) -> QConfigMultiMapping:
+        """
+        Creates a QConfigMultiMapping from a list of QConfigMappings
+        """
+        new_qconfig_multi_mapping = cls()
+
+        new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy(
+            qconfig_mapping_list
+        )
+
+        # we need to avoid the issue described in _handle_list_size_mismatch,
+        # so we reinsert all the qconfigs using the QConfigMultiMapping
+        # set methods
+
+        # go through all qconfig styles
+        # note: global can be ignored since it is None by default
+        for style in _QCONFIG_STYLE_ORDER[1:]:
+            # gather all key+qconfigs for current style
+            # into qconfig_dict_list
+            qconfig_dict_list: dict[Any, list[QConfigAny]] = {}
+            for qconfig_mapping in qconfig_mapping_list:
+                qconfig_dict = getattr(qconfig_mapping, style)
+                for key, qconfig in qconfig_dict.items():
+                    if key not in qconfig_dict_list:
+                        qconfig_dict_list[key] = []
+                    qconfig_dict_list[key].append(qconfig)
+
+            # reinsert all gathered key+qconfigs
+            set_method_name = _QCONFIG_STYLE_TO_METHOD[style]
+            set_method = getattr(new_qconfig_multi_mapping, set_method_name)
+            for key, qconfig_list in qconfig_dict_list.items():
+                if isinstance(key, tuple):
+                    set_method(*key, qconfig_list)
+                else:
+                    set_method(key, qconfig_list)
+
+        return new_qconfig_multi_mapping
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6357120dc14143853bc7edcec8a5a2b411c077a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/utils.py
@@ -0,0 +1,541 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+import enum
+import operator
+from typing import Callable, Optional, Union
+
+import torch
+import torch.ao.nn.intrinsic.quantized as nniq
+import torch.ao.nn.quantized as nnq
+import torch.nn as nn
+from torch.ao.quantization import FakeQuantizeBase, ObserverBase
+from torch.ao.quantization.observer import _is_activation_post_process
+from torch.ao.quantization.utils import getattr_from_fqn
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+
+from .ns_types import NSNodeTargetType, NSResultsType
+
+
+toq = torch.ops.quantized
+
+
+# TODO(future PR): consider deleting this enum and using the torch types
+# directly.  This might be tricky because it is not a one to one mapping.
+class NodeInputOrOutputType(enum.Enum):
+    FP32 = enum.auto()  # torch.float
+    INT8 = enum.auto()  # torch.qint8 or torch.quint8
+    FP16 = enum.auto()  # torch.float16
+    UNKNOWN = enum.auto()  # we cannot determine input/output dtype
+    # TODO(future PR): while these functions can support multiple dtypes,
+    #   for the purposes of numerical debugging we want to get the actual
+    #   dtype used in the model. We will likely need some kind of dtype
+    #   propagation to estimate this.
+    FP32_OR_INT8 = enum.auto()  # either torch.float or torch.quint8 or torch.qint8
+    # TODO(future PRs): dynamic quant, fake quant, etc
+
+
+def get_node_first_input_and_output_type(
+    node: Node,
+    gm: GraphModule,
+    logger_cls: Callable,
+    node_type_to_io_type_map: dict[str, set[NSNodeTargetType]],
+) -> tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
+    # TODO(future PR): clean this up
+    FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
+    FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
+    FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
+    FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
+    MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
+    MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
+    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
+    METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
+
+    if node.op == "call_function":
+        if node.target in FUNS_IO_TYPE_FP32:
+            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
+        if node.target in FUNS_IO_TYPE_FP16:
+            return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
+        elif node.target in FUNS_IO_TYPE_INT8:
+            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
+        elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
+            first_arg = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(first_arg, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                first_arg, gm, logger_cls, node_type_to_io_type_map
+            )
+            return (prev_node_output_type, prev_node_output_type)
+        else:
+            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
+
+    elif node.op == "call_module":
+        assert node.op == "call_module"
+        assert isinstance(node.target, str)
+        mod = getattr_from_fqn(gm, node.target)
+        is_known_fp32_or_int8_input_module = any(
+            isinstance(mod, target_type)  # type: ignore[arg-type]
+            for target_type in MODS_IO_TYPE_FP32_OR_INT8
+        )
+        if (
+            isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase))  # type: ignore[arg-type]
+            or is_known_fp32_or_int8_input_module
+        ):
+            # A logger or observer's input and output type is the output
+            # type of the preceding node.
+            first_arg = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(first_arg, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                first_arg, gm, logger_cls, node_type_to_io_type_map
+            )
+            return (prev_node_output_type, prev_node_output_type)
+        is_known_fp32_input_module = any(
+            isinstance(mod, target_type)  # type: ignore[arg-type]
+            for target_type in MODS_IO_TYPE_FP32
+        )
+        is_known_int8_input_module = any(
+            isinstance(mod, target_type)  # type: ignore[arg-type]
+            for target_type in MODS_IO_TYPE_INT8
+        )
+        if is_known_fp32_input_module:
+            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
+        elif is_known_int8_input_module:
+            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
+        else:
+            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
+
+    elif node.op == "call_method":
+        if node.target == "dequantize":
+            # Dequantize is a special node because it allows multiple input types.
+            # So, we look up the output type of the previous node and return that
+            # as the input type of this node instance.
+            prev_node = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(prev_node, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                prev_node, gm, logger_cls, node_type_to_io_type_map
+            )
+            return (prev_node_output_type, NodeInputOrOutputType.FP32)
+
+        elif node.target == "to":
+            # to is a special node because it allows multiple input types.
+            # So, we look up the output type of the previous node and return that
+            # as the input type of this node instance. We also look up the target
+            # of to and return the correct output type.
+            prev_node = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(prev_node, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                prev_node, gm, logger_cls, node_type_to_io_type_map
+            )
+
+            cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
+            assert cur_node_dtype_target is torch.float16, (
+                f"{cur_node_dtype_target} handling needs to be added"
+            )
+
+            return (prev_node_output_type, NodeInputOrOutputType.FP16)
+
+        elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
+            first_arg = get_normalized_nth_input(node, gm, 0)
+            assert isinstance(first_arg, Node)
+            (
+                _prev_node_input_type,
+                prev_node_output_type,
+            ) = get_node_first_input_and_output_type(
+                first_arg, gm, logger_cls, node_type_to_io_type_map
+            )
+            return (prev_node_output_type, prev_node_output_type)
+
+        return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
+    else:
+        return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
+
+
+def get_node_input_qparams(
+    node: Node,
+    gm: GraphModule,
+    node_type_to_io_type_map: dict[str, set[NSNodeTargetType]],
+) -> Optional[tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
+    """
+    Returns the qparams (scale, zero_point) of the first input to `node`,
+    if they can be inferred from the graph.
+    """
+    prev_node = get_normalized_nth_input(node, gm, 0)
+
+    if not isinstance(prev_node, Node):
+        return None
+
+    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
+
+    def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
+        scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
+        zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
+        assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
+        assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
+        scale_obj = getattr_from_fqn(gm, scale_node.target)
+        zp_obj = getattr_from_fqn(gm, zp_node.target)
+        return (scale_obj, zp_obj)
+
+    if prev_node.op == "call_function":
+        # quantize - read the args directly
+        if prev_node.target == torch.quantize_per_tensor:
+            return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
+        elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
+            return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
+
+        return None
+        # TODO(future PR): handle more functionals
+        # TODO(future PR): handle functional ops which inherit qparams from input
+
+    elif prev_node.op == "call_module":
+        # get type of the module
+        assert isinstance(prev_node.target, str)
+        module_obj = getattr_from_fqn(gm, prev_node.target)
+        if isinstance(
+            module_obj,
+            (
+                nnq.Linear,
+                nnq.Conv1d,
+                nnq.Conv2d,
+                nniq.ConvReLU2d,
+                nnq.Conv3d,
+                nnq.BatchNorm2d,
+                nnq.BatchNorm3d,
+                nnq.ConvTranspose1d,
+                nnq.ConvTranspose2d,
+                nnq.ELU,
+                nnq.GroupNorm,
+                nnq.InstanceNorm1d,
+                nnq.InstanceNorm2d,
+                nnq.InstanceNorm3d,
+                nnq.LayerNorm,
+                nnq.Hardswish,
+                nnq.LeakyReLU,
+                nnq.ReLU6,
+                nniq.BNReLU2d,
+                nniq.BNReLU3d,
+                nniq.ConvReLU1d,
+                nniq.ConvReLU2d,
+                nniq.ConvReLU3d,
+                nniq.LinearReLU,
+            ),
+        ):
+            return (module_obj.scale, module_obj.zero_point)  # type: ignore[return-value]
+
+        is_known_fp32_or_int8_input_module = any(
+            isinstance(module_obj, target_type)  # type: ignore[arg-type]
+            for target_type in MODS_IO_TYPE_FP32_OR_INT8
+        )
+        if is_known_fp32_or_int8_input_module:
+            return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
+
+    return None
+
+
+def return_first_non_observer_node(
+    node: Node,
+    gm: GraphModule,
+) -> Node:
+    """
+    If node is not an observer, returns it.  If node is an observer,
+    navigates up the graph and returns the first parent which is not an
+    observer.  For example,
+
+    graph: (node_non_obs), node = node_non_obs : returns node_non_obs
+    graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
+    graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
+    """
+    if node.op == "call_module":
+        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
+        if _is_activation_post_process(node_obj):
+            assert len(node.args) == 1
+            assert isinstance(node.args[0], Node)
+            node = node.args[0]
+            # code duplication intended, not worth refactoring
+            assert isinstance(node.target, str)
+            node_obj = getattr_from_fqn(gm, node.target)
+            if _is_activation_post_process(node_obj):
+                assert len(node.args) == 1
+                assert isinstance(node.args[0], Node)
+                node = node.args[0]
+    return node
+
+
+def get_number_of_non_param_args(
+    node: Node,
+    gm: GraphModule,
+) -> int:
+    """
+    Assumes that all non-param args occur first. Returns the number of
+    non-param args expected for a node.  For example, for
+
+      F.linear(x, weight, bias)
+
+    Returns 1, because x is a non-param arg and weight and bias are params.
+    For
+
+      lstm_mod(x, hid)
+
+    Returns 2, because both x and hid are non-param args.
+    """
+    if node.op == "call_module":
+        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
+        if isinstance(node_obj, nn.LSTM):
+            return 2
+
+    # default is 1
+    return 1
+
+
+def get_arg_indices_of_inputs_to_log(node: Node) -> list[int]:
+    """
+    Returns the indices of args of the node which we should attach
+    loggers to, if input logging is enabled.
+
+    For example,
+    * for (x + y), returns [0, 1]
+    * for (1 + y), returns [1]
+    * for (x + 1), returns [0]
+    * for (linear(x, w, b)) returns [0]
+    * by default, returns [0]
+    """
+    if len(node.args) == 0:
+        return []
+    if node.op == "call_function" and (
+        # TODO(future PR): use relationship map instead of hardcoding
+        node.target in (torch.add, torch.ops.quantized.add, operator.add)
+        or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
+    ):
+        result = [i for i in range(2) if type(node.args[i]) == Node]
+        return result
+    return [0]
+
+
+def get_target_type_str(node: Node, gm: GraphModule) -> str:
+    """
+    Returns a string representation of the type of the function or module
+    pointed to by this node, or '' for other node types.
+    """
+    target_type = ""
+    if node.op in ("call_function", "call_method"):
+        target_type = torch.typename(node.target)
+    elif node.op == "call_module":
+        assert isinstance(node.target, str)
+        target_mod = getattr_from_fqn(gm, node.target)
+        target_type = torch.typename(target_mod)
+    return target_type
+
+
+def rekey_logger_info_on_node_name_of_model(
+    results: NSResultsType,
+    model_name: str,
+) -> NSResultsType:
+    """
+    Rekeys the layer name of a results dictionary to use node names
+    from `model_name`.
+
+    For example, transforms
+
+        {'base_op_1_0': {'node_output': {'model_a':
+          [{'ref_node_name': 'linear1', ...}]}}}
+
+    into
+
+        {'linear1': {'node_output': {'model_a':
+          [{'ref_node_name': 'linear1', ...}]}}}
+
+    Note: we cannot use these node names directly because they are not
+    guaranteed to be consistent across models. This is why we extract
+    the results first and rekey afterwards.
+    """
+    new_results = {}
+    for old_layer_name, result_type_to_results in results.items():
+        new_layer_name = None
+        for model_name_to_results in result_type_to_results.values():
+            for cur_model_name, list_of_results in model_name_to_results.items():
+                if cur_model_name == model_name:
+                    assert len(list_of_results)
+                    new_layer_name = list_of_results[0]["ref_node_name"]
+                else:
+                    continue
+        if new_layer_name is not None:
+            new_results[new_layer_name] = result_type_to_results
+        else:
+            new_results[old_layer_name] = result_type_to_results
+    return new_results
+
+
+def maybe_add_missing_fqns(results: NSResultsType) -> None:
+    """
+    If `fqn` entries are filled in for one of the models in `results`, copies
+    them over to any models which do not have them filled out.
+
+    A common use case benefitting from this is comparing a model prepared by
+    quantization to a quantized model. In this case, the model prepared by
+    quantization would have `fqn` entries, and the quantized model would not.
+    """
+
+    # Check in the first result to find any model with fqn entries defined.
+    model_name_with_fqns = None
+    for result_type_to_results in results.values():
+        for model_name_to_results in result_type_to_results.values():
+            for model_name, model_results in model_name_to_results.items():
+                if len(model_results) > 0:
+                    if model_results[0]["fqn"] is not None:
+                        model_name_with_fqns = model_name
+                        break
+            break
+        break
+
+    if model_name_with_fqns:
+        for result_type_to_results in results.values():
+            for model_name_to_results in result_type_to_results.values():
+                ref_model_results = model_name_to_results[model_name_with_fqns]
+                for model_name, model_results in model_name_to_results.items():
+                    if model_name == model_name_with_fqns:
+                        continue
+                    for i in range(len(model_results)):
+                        fqn = ref_model_results[i]["fqn"]
+                        model_results[i]["fqn"] = fqn
+
+
+def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
+    def inner(*args, **kwargs):
+        a0, a1, *a_other = args
+
+        if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
+            isinstance(a0, list) and isinstance(a1, list)
+        ):
+            results = []
+            for el0, el1 in zip(a0, a1):
+                new_args = (el0, el1, *a_other)
+                results.append(inner(*new_args, **kwargs))
+            return results
+
+        elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
+            if a0.is_quantized:
+                a0 = a0.dequantize()
+            if a1.is_quantized:
+                a1 = a1.dequantize()
+
+        # for the purposes of this util, only handle floats
+        if a0.dtype != torch.float or a1.dtype != torch.float:
+            return None
+
+        new_args = (a0, a1, *a_other)
+        return f(*new_args, **kwargs)
+
+    return inner
+
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    """
+    Computes the SQNR between `x` and `y`.
+
+    Args:
+        x: Tensor or tuple of tensors
+        y: Tensor or tuple of tensors
+
+    Return:
+        float or tuple of floats
+    """
+    Ps = torch.norm(x)
+    Pn = torch.norm(x - y)
+    return 20 * torch.log10(Ps / Pn)
+
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    """
+    Computes the normalized L2 error between `x` and `y`.
+
+    Args:
+        x: Tensor or tuple of tensors
+        y: Tensor or tuple of tensors
+
+    Return:
+        float or tuple of floats
+    """
+    return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
+
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    """
+    Computes the cosine similarity between `x` and `y`.
+
+    Args:
+        x: Tensor or tuple of tensors
+        y: Tensor or tuple of tensors
+
+    Return:
+        float or tuple of floats
+    """
+    # For convolutions, the shape of the quantized weight has one additional
+    # dimension compared to the shape of the fp32 weight. Match the shapes
+    # to enable cosine similarity comparison.
+    x = x.reshape(1, -1)
+    y = y.reshape(1, -1)
+    return torch.nn.functional.cosine_similarity(x, y)
+
+
+def op_type_supports_shadowing(node: Node) -> bool:
+    if node.op == "call_function":
+        if node.target in (
+            torch.add,
+            torch.mul,
+            operator.add,
+            operator.mul,
+            torch.cat,
+            torch.stack,
+        ):
+            # shadowing for ops with multiple tensor inputs is not implemented yet
+            return False
+    return True
+
+
+def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
+    """
+    Given a node, gets the n'th input to that node, normalizing
+    args and kwargs to the best of its ability.
+    """
+    try:
+        norm_args_and_kwargs = node.normalized_arguments(
+            gm, normalize_to_only_use_kwargs=True
+        )
+        if norm_args_and_kwargs is not None:
+            norm_args, norm_kwargs = norm_args_and_kwargs
+            assert len(norm_args) + len(norm_kwargs) > idx
+            if idx < len(norm_args):
+                return norm_args[idx]
+            else:
+                # note: in Python 3.7+ dicts are ordered
+                return list(norm_kwargs.values())[idx]
+        else:
+            assert len(node.args) + len(node.kwargs) > idx
+            if idx < len(node.args):
+                return node.args[idx]  # type: ignore[return-value]
+            else:
+                kwargs_idx = idx + len(node.args)
+                return list(node.kwargs.values())[kwargs_idx]  # type: ignore[return-value]
+    except RuntimeError:
+        # this RuntimeError happens when node argument normalization
+        # requires typehints to proceed, such as for torch.add where
+        # either the first, second or both arguments could be tensors
+        assert len(node.args) + len(node.kwargs) > idx
+        if idx < len(node.args):
+            return node.args[idx]  # type: ignore[return-value]
+        else:
+            kwargs_idx = idx + len(node.args)
+            return list(node.kwargs.values())[kwargs_idx]  # type: ignore[return-value]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/weight_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/weight_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..52cb13c1286ea4c799b56d7364549f695d99a858
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/fx/weight_utils.py
@@ -0,0 +1,284 @@
+from typing import Callable, Optional
+
+import torch
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.qat as nniqat
+import torch.ao.nn.intrinsic.quantized as nniq
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+
+from .ns_types import NSSingleResultType, NSSingleResultValuesType
+from .utils import get_target_type_str, getattr_from_fqn, return_first_non_observer_node
+
+
+toq = torch.ops.quantized
+
+
+def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
+    return mod.weight.detach()  # type: ignore[operator]
+
+
+def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
+    return mod[0].weight.detach()  # type: ignore[index]
+
+
+def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
+    return mod._weight_bias()[0]  # type: ignore[operator]
+
+
+def get_lstm_weight(mod: nn.Module) -> list[torch.Tensor]:
+    res = []
+    for idx, param_name in enumerate(mod._flat_weights_names):  # type: ignore[arg-type]
+        if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
+            param_value = mod._flat_weights[idx].detach()  # type: ignore[index,union-attr]
+            res.append(param_value)
+    return res
+
+
+def get_qlstm_weight(mod: nn.Module) -> list[torch.Tensor]:
+    res = []
+    for weight_value in mod._all_weight_values:  # type: ignore[union-attr]
+        res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
+        res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
+    return res
+
+
+def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
+    if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+        return mod.weight.detach()
+    elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)):
+        return mod[0].weight.detach()  # type: ignore[operator]
+    else:
+        return mod._weight_bias()[0]  # type: ignore[operator]
+
+
+def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
+    if isinstance(mod, nn.Linear):
+        return mod.weight.detach()
+    elif isinstance(mod, nni.LinearReLU):
+        return mod[0].weight.detach()  # type: ignore[operator]
+    else:
+        return mod._weight_bias()[0]  # type: ignore[operator]
+
+
+def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
+    # TODO(future PR): make more generic, handle everything
+    if isinstance(mod, nn.LSTM):
+        res = []
+        for idx, param_name in enumerate(mod._flat_weights_names):
+            if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
+                param_value = mod._flat_weights[idx].detach()  # type: ignore[index,union-attr]
+                res.append(param_value)
+        return res
+    else:
+        assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
+        res = []
+        for weight_value in mod._all_weight_values:
+            res.append(
+                weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]  # type: ignore[index]
+            )
+            res.append(
+                weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]  # type: ignore[index]
+            )
+        return res
+
+
+def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
+    # traverse backwards from the weight arg, accounting for any observers
+    weight_arg_node = node.args[1]
+    assert isinstance(weight_arg_node, Node)
+    weight_node = return_first_non_observer_node(weight_arg_node, gm)
+    assert isinstance(weight_node, Node)
+    assert weight_node.op == "get_attr"
+    weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
+    return weight.detach()
+
+
+def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
+    # qconv state is arg 1
+    qconv_state_node = node.args[1]
+    assert isinstance(qconv_state_node, Node)
+    assert qconv_state_node.op == "get_attr"
+    qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target)  # type: ignore[arg-type]
+    return qconv_state_obj.weight()
+
+
+def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
+    # traverse backwards from the weight arg, accounting for any observers
+    # supported patterns:
+    # weight -> obs -> linear
+    # weight -> to(torch.float16) -> dequantize -> linear
+    linear_second_arg = node.args[1]
+    assert isinstance(linear_second_arg, Node)
+
+    if linear_second_arg.op == "call_module":
+        # weight -> obs -> linear
+        weight_arg_node = node.args[1]
+        assert isinstance(weight_arg_node, Node)
+        weight_node = weight_arg_node.args[0]
+        assert isinstance(weight_node, Node)
+        assert weight_node.op == "get_attr"
+        weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
+        return weight.detach()
+    elif linear_second_arg.op == "call_method":
+        # weight -> to(torch.float16) -> dequantize -> linear
+        assert linear_second_arg.op == "call_method"
+        dequant_node = node.args[1]
+        assert isinstance(dequant_node, Node)
+        to_fp16_node = dequant_node.args[0]
+        assert isinstance(to_fp16_node, Node)
+        # extract the dtype, so we can cast to it before returning
+        target_dtype = to_fp16_node.args[1]
+        weight_node = to_fp16_node.args[0]
+        assert isinstance(weight_node, Node)
+        assert weight_node.op == "get_attr"
+        weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
+        # return the weight with fp16 cast
+        return weight.detach().to(target_dtype)
+    else:
+        assert linear_second_arg.op == "get_attr"
+        weight = getattr_from_fqn(gm, linear_second_arg.target)  # type: ignore[arg-type]
+        return weight.detach()
+
+
+def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
+    # packed weight is arg 1
+    packed_weight_node = node.args[1]
+    assert isinstance(packed_weight_node, Node)
+    assert packed_weight_node.op == "get_attr"
+    packed_weight = getattr_from_fqn(gm, packed_weight_node.target)  # type: ignore[arg-type]
+    # TODO(future PR): why does packed_weight.unpack() not work?
+    (weight, _bias), _name = packed_weight.__getstate__()
+    return weight
+
+
+def get_op_to_type_to_weight_extraction_fn() -> dict[str, dict[Callable, Callable]]:
+    op_to_type_to_weight_extraction_fn: dict[str, dict[Callable, Callable]] = {
+        "call_module": {
+            # Conv1d
+            nn.Conv1d: mod_weight_detach,
+            nni.ConvReLU1d: mod_0_weight_detach,
+            nnq.Conv1d: mod_weight_bias_0,
+            nnqat.Conv1d: mod_weight_detach,
+            nniqat.ConvBn1d: mod_weight_detach,
+            nniqat.ConvBnReLU1d: mod_weight_detach,
+            nniqat.ConvReLU1d: mod_weight_detach,
+            nniq.ConvReLU1d: mod_weight_bias_0,
+            # Conv2d
+            nn.Conv2d: mod_weight_detach,
+            nni.ConvReLU2d: mod_0_weight_detach,
+            nnq.Conv2d: mod_weight_bias_0,
+            nnqat.Conv2d: mod_weight_detach,
+            nniqat.ConvBn2d: mod_weight_detach,
+            nniqat.ConvBnReLU2d: mod_weight_detach,
+            nniqat.ConvReLU2d: mod_weight_detach,
+            nniq.ConvReLU2d: mod_weight_bias_0,
+            # Conv3d
+            nn.Conv3d: mod_weight_detach,
+            nni.ConvReLU3d: mod_0_weight_detach,
+            nnq.Conv3d: mod_weight_bias_0,
+            nnqat.Conv3d: mod_weight_detach,
+            nniqat.ConvBn3d: mod_weight_detach,
+            nniqat.ConvBnReLU3d: mod_weight_detach,
+            nniqat.ConvReLU3d: mod_weight_detach,
+            nniq.ConvReLU3d: mod_weight_bias_0,
+            # Linear
+            nn.Linear: mod_weight_detach,
+            nnq.Linear: mod_weight_bias_0,
+            nni.LinearReLU: mod_0_weight_detach,
+            nniq.LinearReLU: mod_weight_bias_0,
+            nnqat.Linear: mod_weight_detach,
+            nnqd.Linear: mod_weight_bias_0,
+            nniqat.LinearReLU: mod_weight_detach,
+            nniqat.LinearBn1d: mod_weight_detach,
+            nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
+            # LSTM
+            nn.LSTM: get_lstm_weight,
+            nnqd.LSTM: get_qlstm_weight,
+        },
+        "call_function": {
+            # Conv
+            F.conv1d: get_conv_fun_weight,
+            F.conv2d: get_conv_fun_weight,
+            F.conv3d: get_conv_fun_weight,
+            toq.conv1d: get_qconv_fun_weight,
+            toq.conv2d: get_qconv_fun_weight,
+            toq.conv3d: get_qconv_fun_weight,
+            toq.conv1d_relu: get_qconv_fun_weight,
+            toq.conv2d_relu: get_qconv_fun_weight,
+            toq.conv3d_relu: get_qconv_fun_weight,
+            # Linear
+            F.linear: get_linear_fun_weight,
+            toq.linear: get_qlinear_fun_weight,
+            toq.linear_relu: get_qlinear_fun_weight,
+        },
+    }
+
+    return op_to_type_to_weight_extraction_fn
+
+
+def extract_weight_from_node(
+    node: Node,
+    gm: GraphModule,
+    op_to_type_to_weight_extraction_fn: Optional[
+        dict[str, dict[Callable, Callable]]
+    ] = None,
+) -> Optional[NSSingleResultType]:
+    res_type = NSSingleResultValuesType.WEIGHT.value
+
+    # Not all graphmodules have _node_name_to_scope, so only fill it
+    # out if it exists.
+    fqn = None
+    if hasattr(gm, "_node_name_to_scope"):
+        fqn = gm._node_name_to_scope[node.name][0]  # type: ignore[index]
+
+    if op_to_type_to_weight_extraction_fn is None:
+        op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
+
+    ref_node_type = get_target_type_str(node, gm)
+    # for extracting weights, these are always the same
+    prev_node_type = ref_node_type
+
+    if node.op == "call_function":
+        function_mapping = op_to_type_to_weight_extraction_fn["call_function"]
+        for target_fn_type, weight_extraction_fn in function_mapping.items():
+            if node.target == target_fn_type:
+                weight = weight_extraction_fn(node, gm)
+                return {
+                    "type": res_type,
+                    "values": [weight],
+                    "prev_node_name": node.name,
+                    "prev_node_target_type": prev_node_type,
+                    "ref_node_name": node.name,
+                    "ref_node_target_type": ref_node_type,
+                    "index_within_arg": 0,
+                    "index_of_arg": 0,
+                    "fqn": fqn,
+                }
+
+    elif node.op == "call_module":
+        # for call_module, we need to look up the modules to do the type check
+        assert isinstance(node.target, str)
+        mod = getattr_from_fqn(gm, node.target)
+        module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
+        for target_mod_type, weight_extraction_fn in module_mapping.items():
+            if type(mod) == target_mod_type:
+                weight = weight_extraction_fn(mod)
+                return {
+                    "type": res_type,
+                    "values": [weight],
+                    "prev_node_name": node.name,
+                    "prev_node_target_type": prev_node_type,
+                    "ref_node_name": node.name,
+                    "ref_node_target_type": ref_node_type,
+                    "index_within_arg": 0,
+                    "index_of_arg": 0,
+                    "fqn": fqn,
+                }
+
+    return None
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f602028d475ce7b60c64bb953e3794a927283c75
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/base_scheduler.py
@@ -0,0 +1,170 @@
+# mypy: allow-untyped-defs
+
+import warnings
+import weakref
+from functools import wraps
+
+from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier
+
+
+__all__ = ["BaseScheduler"]
+
+
+class BaseScheduler:
+    def __init__(self, sparsifier, last_epoch=-1, verbose=False):
+        # Attach sparsifier
+        if not isinstance(sparsifier, BaseSparsifier):
+            raise TypeError(
+                f"{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier"
+            )
+        self.sparsifier = sparsifier
+
+        # Initialize epoch and base sparsity levels
+
+        self.base_sl = [group["sparsity_level"] for group in sparsifier.groups]
+        self.last_epoch = last_epoch
+
+        # Following https://github.com/pytorch/pytorch/issues/20124
+        # We would like to ensure that `scheduler.step()` is called after
+        # `sparsifier.step()`
+        def with_counter(method):
+            if getattr(method, "_with_counter", False):
+                # `sparsifier.step()` has already been replaced, return.
+                return method
+
+            # Keep a weak reference to the sparsifier instance to prevent
+            # cyclic references.
+            instance_ref = weakref.ref(method.__self__)
+            # Get the unbound method for the same purpose.
+            func = method.__func__
+            cls = instance_ref().__class__
+            del method
+
+            @wraps(func)
+            def wrapper(*args, **kwargs):
+                instance = instance_ref()
+                instance._step_count += 1  # type: ignore[union-attr]
+                wrapped = func.__get__(instance, cls)
+                return wrapped(*args, **kwargs)
+
+            # Note that the returned function here is no longer a bound method,
+            # so attributes like `__func__` and `__self__` no longer exist.
+            wrapper._with_counter = True  # type: ignore[attr-defined]
+            return wrapper
+
+        self.sparsifier.step = with_counter(self.sparsifier.step)  # type: ignore[assignment]
+        self.sparsifier._step_count = 0  # type: ignore[attr-defined]
+        self._step_count: int = 0
+        self.verbose = verbose
+
+        # Housekeeping
+        self._get_sl_called_within_step: bool = False
+
+        self.step()
+
+    def state_dict(self):
+        """Returns the state of the scheduler as a :class:`dict`.
+
+        It contains an entry for every variable in self.__dict__ which
+        is not the sparsifier.
+        """
+        return {
+            key: value for key, value in self.__dict__.items() if key != "sparsifier"
+        }
+
+    def load_state_dict(self, state_dict):
+        """Loads the schedulers state.
+
+        Args:
+            state_dict (dict): scheduler state. Should be an object returned
+                from a call to :meth:`state_dict`.
+        """
+        self.__dict__.update(state_dict)
+
+    def get_last_sl(self):
+        """Return last computed sparsity level by current scheduler."""
+        return self._last_sl
+
+    def get_sl(self):
+        # Compute sparsity level using chainable form of the scheduler
+        # Note: This method is not intended to be called directly, and is only
+        #       used by the ".step" method. Use .get_last_sl() instead.
+        if not self._get_sl_called_within_step:
+            warnings.warn(
+                "To get the last sparsity level computed by the scheduler, "
+                "please use `get_last_sl()`."
+            )
+        raise NotImplementedError
+
+    def print_sl(self, is_verbose, group, sl, epoch=None):
+        """Display the current sparsity level."""
+        if is_verbose:
+            if epoch is None:
+                print(f"Adjusting sparsity level of group {group} to {sl:.4e}.")
+            else:
+                print(
+                    f"Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}."
+                )
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + " ("
+        format_string += "\n"
+        format_string += f"Sparsifier {self.sparsifier}\n"
+        format_string += f"    base_sl: {self.base_sl}\n"
+        format_string += ")"
+        return format_string
+
+    def step(self, epoch=None):
+        # Raise warning if trying to call scheduler step before the sparsifier.
+        # https://github.com/pytorch/pytorch/issues/20124
+        if self._step_count == 1:
+            if not hasattr(self.sparsifier.step, "_with_counter"):
+                warnings.warn(
+                    "Seems like `sparsifier.step()` has been overridden after sparsity scheduler "
+                    "initialization. Please, make sure to call `sparsifier.step()` before "
+                    "`scheduler.step()`.",
+                    UserWarning,
+                )
+
+            # Just check if there were two first scheduler.step() calls before sparsifier.step()
+            elif self.sparsifier._step_count < 1:  # type: ignore[attr-defined]
+                warnings.warn(
+                    "Detected call of `scheduler.step()` before `sparsifier.step()`. "
+                    "You have to make sure you run the sparsifier.step() BEFORE any "
+                    "calls to the scheduler.step().",
+                    UserWarning,
+                )
+        self._step_count += 1
+
+        class _enable_get_sl_call:
+            def __init__(self, o):
+                self.o = o
+
+            def __enter__(self):
+                self.o._get_sl_called_within_step = True
+                return self
+
+            def __exit__(self, type, value, traceback):
+                self.o._get_sl_called_within_step = False
+
+        with _enable_get_sl_call(self):
+            self.last_epoch += 1
+            values = self.get_sl()
+
+        for i, data in enumerate(zip(self.sparsifier.groups, values)):
+            param_group, sl = data
+            param_group["sparsity_level"] = sl
+            self.print_sl(self.verbose, i, sl, epoch)
+
+        self._last_sl = [group["sparsity_level"] for group in self.sparsifier.groups]
+        self.sparsifier.enable_mask_update = True
+
+    def _make_sure_a_list(self, var):
+        r"""Utility that extends it to the same length as the .groups, ensuring it is a list"""
+        n = len(self.sparsifier.groups)
+        if not isinstance(var, (list, tuple)):
+            return [var] * n
+        else:
+            if len(var) != n:
+                raise ValueError(f"Expected variable of length {n}, but got {len(var)}")
+            return list(var)  # We want the result to be in a list, not tuple
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..45985a8bbc524b3b1929b439ee100fef8ea23a1c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/cubic_scheduler.py
@@ -0,0 +1,113 @@
+# mypy: allow-untyped-defs
+import warnings
+
+from .base_scheduler import BaseScheduler
+
+
+__all__ = ["CubicSL"]
+
+
+def _clamp(x, lo, hi):
+    return max(lo, min(hi, x))
+
+
+class CubicSL(BaseScheduler):
+    r"""Sets the sparsity level of each parameter group to the final sl
+    plus a given exponential function.
+
+    .. math::
+
+        s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3
+
+    where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final
+    sparsity level, :math:`f(i)` is the function to be applied to the current epoch
+    :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`.
+    :math:`\Delta t` is used to control how often the update of the sparsity level
+    happens. By default,
+
+    Args:
+        sparsifier (BaseSparsifier): Wrapped sparsifier.
+        init_sl (int, list): Initial level of sparsity
+        init_t (int, list): Initial step, when pruning starts
+        delta_t (int, list): Pruning frequency
+        total_t (int, list): Total number of pruning steps
+        initially_zero (bool, list): If True, sets the level of sparsity to 0
+            before init_t (:math:`t_0`). Otherwise, the sparsity level before
+            init_t (:math:`t_0`) is set to init_sl(:math:`s_0`)
+        last_epoch (int): The index of last epoch. Default: -1.
+        verbose (bool): If ``True``, prints a message to stdout for
+            each update. Default: ``False``.
+    """
+
+    def __init__(
+        self,
+        sparsifier,
+        init_sl=0.0,
+        init_t=0,
+        delta_t=10,
+        total_t=100,
+        initially_zero=False,
+        last_epoch=-1,
+        verbose=False,
+    ):
+        self.sparsifier = sparsifier
+
+        self.init_sl = self._make_sure_a_list(init_sl)
+        self.init_t = self._make_sure_a_list(init_t)
+        self.delta_t = self._make_sure_a_list(delta_t)
+        self.total_t = self._make_sure_a_list(total_t)
+
+        self.initially_zero = self._make_sure_a_list(initially_zero)
+
+        super().__init__(sparsifier, last_epoch, verbose)
+
+    @staticmethod
+    def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False):
+        r""" "Computes the current level of sparsity.
+
+        Based on https://arxiv.org/pdf/1710.01878.pdf
+
+        Args:
+            s_0: Initial level of sparsity, :math:`s_i`
+            s_f: Target level of sparsity, :math:`s_f`
+            t: Current step, :math:`t`
+            t_0: Initial step, :math:`t_0`
+            dt: Pruning frequency, :math:`\Delta T`
+            n: Pruning steps, :math:`n`
+            initially_zero: Sets the level of sparsity to 0 before t_0.
+                If False, sets to s_0
+
+        Returns:
+            The sparsity level :math:`s_t` at the current step :math:`t`
+        """
+        if initially_zero and t < t_0:
+            return 0
+        s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3
+        s_t = _clamp(s_t, s_0, s_f)
+        return s_t
+
+    def get_sl(self):
+        if not self._get_sl_called_within_step:
+            warnings.warn(
+                "To get the last sparsity level computed by the scheduler, "
+                "please use `get_last_sl()`."
+            )
+        return [
+            self.sparsity_compute_fn(
+                s_0=initial_sparsity,
+                s_f=final_sparsity,
+                t=self.last_epoch,
+                t_0=initial_epoch,
+                dt=delta_epoch,
+                n=interval_epochs,
+                initially_zero=initially_zero,
+            )
+            for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in zip(
+                self.init_sl,
+                self.base_sl,
+                self.init_t,
+                self.delta_t,
+                self.total_t,
+                self.initially_zero,
+            )
+        ]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a25e9a1d67bc759b0eb88ae8a7b1f44455466a95
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py
@@ -0,0 +1,55 @@
+# mypy: allow-untyped-defs
+import warnings
+
+from .base_scheduler import BaseScheduler
+
+
+__all__ = ["LambdaSL"]
+
+
+class LambdaSL(BaseScheduler):
+    """Sets the sparsity level of each parameter group to the final sl
+    times a given function. When last_epoch=-1, sets initial sl as zero.
+    Args:
+        sparsifier (BaseSparsifier): Wrapped sparsifier.
+        sl_lambda (function or list): A function which computes a multiplicative
+            factor given an integer parameter epoch, or a list of such
+            functions, one for each group in sparsifier.param_groups.
+        last_epoch (int): The index of last epoch. Default: -1.
+        verbose (bool): If ``True``, prints a message to stdout for
+            each update. Default: ``False``.
+    Example:
+        >>> # Assuming sparsifier has two groups.
+        >>> lambda1 = lambda epoch: epoch // 30
+        >>> lambda2 = lambda epoch: 0.95**epoch
+        >>> # xdoctest: +SKIP
+        >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
+        >>> for epoch in range(100):
+        >>>     train(...)
+        >>>     validate(...)
+        >>>     scheduler.step()
+    """
+
+    def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
+        self.sparsifier = sparsifier
+
+        if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
+            self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
+        else:
+            if len(sl_lambda) != len(sparsifier.groups):
+                raise ValueError(
+                    f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}"
+                )
+            self.sl_lambdas = list(sl_lambda)
+        super().__init__(sparsifier, last_epoch, verbose)
+
+    def get_sl(self):
+        if not self._get_sl_called_within_step:
+            warnings.warn(
+                "To get the last sparsity level computed by the scheduler, "
+                "please use `get_last_sl()`."
+            )
+        return [
+            base_sl * lmbda(self.last_epoch)
+            for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)
+        ]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..73d4c283da63267f644dd59930e9d0fff21e472a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/base_sparsifier.py
@@ -0,0 +1,353 @@
+# mypy: allow-untyped-defs
+import abc
+import copy
+from collections import defaultdict
+from typing import Any, Optional
+
+import torch
+from torch import nn
+from torch.nn.utils import parametrize
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+from .utils import (
+    FakeSparsity,
+    get_arg_info_from_tensor_fqn,
+    module_contains_param,
+    module_to_fqn,
+    swap_module,
+)
+
+
+__all__ = ["BaseSparsifier"]
+
+SUPPORTED_MODULES = {nn.Linear}
+
+KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"]
+
+
+# TODO update desc with new config args
+class BaseSparsifier(abc.ABC):
+    r"""Base class for all sparsifiers.
+
+    Abstract methods that need to be implemented:
+
+    - update_mask: Function to compute a new mask for all keys in the
+        `groups`.
+
+    Args:
+        - model [nn.Module]: model to configure. The model itself is not saved
+            but used for the state_dict saving / loading.
+        - config [list]: configuration elements should be a dict map that includes
+            `tensor_fqn` of tensors to sparsify
+        - defaults [dict]: default configurations will be attached to the
+            configuration. Only the keys that don't exist in the `config` will
+            be updated.
+
+    Example::
+
+        >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask")
+        >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}]
+        >>> defaults = {'sparsity_level': 0.7}
+        >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default)
+        >>> sparsifier = BaseSparsifier(config, defaults)
+    """
+
+    def __init__(self, defaults: Optional[dict[str, Any]] = None):
+        super().__init__()
+        self.defaults: dict[str, Any] = defaults or {}
+
+        self.state: dict[str, dict] = defaultdict(dict)
+        self.groups: list[dict[str, Any]] = []
+        self.enable_mask_update = True
+
+    def __getstate__(self) -> dict[str, Any]:
+        return {
+            "defaults": self.defaults,
+            "state": self.state,
+            "groups": self.groups,
+        }
+
+    def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
+        self.__dict__.update(state)
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + " ("
+        for i, sparse_args in enumerate(self.groups):
+            module = sparse_args["module"]
+            format_string += "\n"
+            format_string += f"\tGroup {i}\n"
+            format_string += f"\t    module: {module}\n"
+            for key in sorted(sparse_args.keys()):
+                if key == "module":
+                    continue
+                format_string += f"\t    {key}: {sparse_args[key]}\n"
+        format_string += ")"
+        return format_string
+
+    def state_dict(self) -> dict[str, Any]:
+        r"""Returns the state of the optimizer as a :class:`dict`.
+
+        It contains:
+        * state - current state of the sparsification.
+        * groups - a list containing all sparsity configuration groups
+            with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model
+
+        TODO: Need a clean way of loading the state of the "prepared" module
+        """
+
+        groups: list[dict[str, Any]] = [
+            dict(
+                filter(
+                    lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT,
+                    mg.items(),
+                )
+            )
+            for mg in self.groups
+        ]
+
+        return {
+            "state": self.state,
+            "groups": groups,
+        }
+
+    def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True):
+        groups = copy.deepcopy(state_dict["groups"])
+        states = state_dict["state"]
+        for tensor_fqn, s in states.items():
+            arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn)
+            module = arg_info["module"]
+            tensor_name = arg_info["tensor_name"]
+            if strict and module is None:
+                raise RuntimeError(f"Error loading {tensor_fqn} into the model")
+
+            found = False
+            for p in module.parametrizations[tensor_name]:
+                if isinstance(p, FakeSparsity):
+                    found = True
+                    break
+            if not found:
+                p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape))
+                parametrize.register_parametrization(module, tensor_name, p)
+            if s.get("mask", None) is not None:
+                mask = s.pop("mask")
+                p.mask = mask
+
+            for mg in groups:
+                if mg["tensor_fqn"] == tensor_fqn:
+                    mg.update(arg_info)
+        self.__setstate__({"state": states, "groups": groups})
+
+    def make_config_from_model(
+        self,
+        model: nn.Module,
+        SUPPORTED_MODULES: set[type[nn.Linear]] = SUPPORTED_MODULES,
+    ) -> None:
+        self.config = []
+        stack = [model]
+        while stack:
+            module = stack.pop()
+            for _name, child in module.named_children():
+                if type(child) in SUPPORTED_MODULES:
+                    module_fqn = module_to_fqn(model, child)
+                    assert isinstance(module_fqn, str)  # for mypy
+                    self.config.append({"tensor_fqn": module_fqn + ".weight"})
+                else:
+                    stack.append(child)
+
+    def prepare(self, model, config):
+        r"""Prepares a model, by adding the parametrizations.
+
+        Note::
+
+            The model is modified inplace. If you need to preserve the original
+            model, use copy.deepcopy.
+        """
+        self.model = model  # TODO: Need to figure out how to load without this.
+        self.config = config
+
+        # If no config -- try getting all the supported layers
+        if self.config is None:
+            self.make_config_from_model(model)
+
+        # TODO: Remove the configuration by reference ('module')
+        for module_config in self.config:
+            assert isinstance(module_config, dict), (
+                "config elements should be dicts not modules i.e.:"
+                "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
+            )
+
+            assert isinstance(self.defaults, dict)  # for mypy
+            local_args = copy.deepcopy(self.defaults)
+            local_args.update(module_config)
+
+            tensor_fqn = local_args.get("tensor_fqn", None)
+            assert tensor_fqn is not None, (
+                "tensor_fqn is a required argument in the sparsity config which"
+                "replaces previous `module` and [module]`fqn` arguments"
+            )
+
+            # populate all information from tensor_fqn
+            info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
+
+            # check that whatever was put into local_args agrees with what was obtained
+            # from tensor_fqn
+            for key in info_from_tensor_fqn.keys():
+                if key in local_args:
+                    assert (
+                        info_from_tensor_fqn[key] == local_args[key]
+                        or (
+                            key == "tensor_fqn"
+                            and "." + info_from_tensor_fqn[key] == local_args[key]
+                        )
+                        # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
+                    ), (
+                        f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
+                    )
+            local_args.update(info_from_tensor_fqn)
+            self.groups.append(local_args)
+        self._prepare()
+
+    def _prepare(self, *args, **kwargs):
+        r"""Adds mask parametrization to the layer weight"""
+        for config in self.groups:
+            module = config["module"]
+            tensor_name = config["tensor_name"]
+            parametrization = config.get("parametrization", FakeSparsity)
+            mask = config.get("mask", torch.ones_like(getattr(module, tensor_name)))
+            self.state[config["tensor_fqn"]]["mask"] = mask
+            parametrize.register_parametrization(
+                module, tensor_name, parametrization(mask)
+            )
+
+    def squash_mask(
+        self,
+        params_to_keep: Optional[tuple[str, ...]] = None,
+        params_to_keep_per_layer: Optional[dict[str, tuple[str, ...]]] = None,
+        *args,
+        **kwargs,
+    ):
+        r"""Squashes the sparse masks into the appropriate tensors.
+
+        If either the `params_to_keep` or `params_to_keep_per_layer` is set,
+        the module will have a `sparse_params` dict attached to it.
+
+        Args:
+            params_to_keep: List of keys to save in the module or a dict
+                            representing the modules and keys that will have
+                            sparsity parameters saved
+            params_to_keep_per_layer: Dict to specify the params that should be
+                            saved for specific layers. The keys in the dict
+                            should be the module fqn, while the values should
+                            be a list of strings with the names of the variables
+                            to save in the `sparse_params`
+
+        Examples:
+            >>> # xdoctest: +SKIP("locals are undefined")
+            >>> # Don't save any sparse params
+            >>> sparsifier.squash_mask()
+            >>> hasattr(model.submodule1, "sparse_params")
+            False
+
+            >>> # Keep sparse params per layer
+            >>> sparsifier.squash_mask(
+            ...     params_to_keep_per_layer={
+            ...         "submodule1.linear1": ("foo", "bar"),
+            ...         "submodule2.linear42": ("baz",),
+            ...     }
+            ... )
+            >>> print(model.submodule1.linear1.sparse_params)
+            {'foo': 42, 'bar': 24}
+            >>> print(model.submodule2.linear42.sparse_params)
+            {'baz': 0.1}
+
+            >>> # Keep sparse params for all layers
+            >>> sparsifier.squash_mask(params_to_keep=("foo", "bar"))
+            >>> print(model.submodule1.linear1.sparse_params)
+            {'foo': 42, 'bar': 24}
+            >>> print(model.submodule2.linear42.sparse_params)
+            {'foo': 42, 'bar': 24}
+
+            >>> # Keep some sparse params for all layers, and specific ones for
+            >>> # some other layers
+            >>> sparsifier.squash_mask(
+            ...     params_to_keep=("foo", "bar"),
+            ...     params_to_keep_per_layer={"submodule2.linear42": ("baz",)},
+            ... )
+            >>> print(model.submodule1.linear1.sparse_params)
+            {'foo': 42, 'bar': 24}
+            >>> print(model.submodule2.linear42.sparse_params)
+            {'foo': 42, 'bar': 24, 'baz': 0.1}
+        """
+        for config in self.groups:
+            module = config["module"]
+            tensor_name = config["tensor_name"]
+            parametrize.remove_parametrizations(
+                module, tensor_name, leave_parametrized=True
+            )
+            sparse_params = {}
+            if params_to_keep is not None:
+                global_params = {k: config[k] for k in params_to_keep}
+                sparse_params.update(global_params)
+            if params_to_keep_per_layer is not None:
+                params = params_to_keep_per_layer.get(config["module_fqn"], None)
+                if params is not None:
+                    per_layer_params = {k: config[k] for k in params}
+                    sparse_params.update(per_layer_params)
+            if sparse_params:
+                # TODO handle multiple tensor being quantized on a single module, where to store sparse_params?
+                module.sparse_params = sparse_params
+
+    def convert(
+        self,
+        module: nn.Module,
+        mapping: Optional[dict[type[nn.Module], type[nn.Module]]] = None,
+        inplace: bool = False,
+        parameterization: type[nn.Module] = FakeSparsity,
+    ):
+        r"""Converts submodules in input module to a different module according to `mapping`
+        by calling `from_dense` method on the target module class
+        Args:
+            module: input module
+            mapping: a dictionary that maps from source module type to target
+                module type, can be overwritten to allow swapping user defined
+                Modules
+            inplace: carry out model transformations in-place, the original module
+                is mutated
+        """
+        if mapping is None:
+            raise NotImplementedError("Need to auto generate mapping ")
+        if not inplace:
+            module = copy.deepcopy(module)
+
+        reassign = {}
+        for name, mod in module.named_children():
+            # leaf node
+            if (
+                module_contains_param(mod, parameterization)
+                and type_before_parametrizations(mod) in mapping
+            ):
+                reassign[name] = swap_module(mod, mapping)
+            else:
+                # recurse
+                reassign[name] = self.convert(
+                    mod,
+                    mapping=mapping,
+                    inplace=True,
+                    parameterization=parameterization,
+                )
+
+        for key, value in reassign.items():
+            module._modules[key] = value
+
+        return module
+
+    def step(self, use_path: bool = True) -> None:
+        if not self.enable_mask_update:
+            return
+        with torch.no_grad():
+            for config in self.groups:
+                self.update_mask(**config)
+
+    @abc.abstractmethod
+    def update_mask(self, module: nn.Module, tensor_name: str, **kwargs):
+        pass
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..302f7e0b0b7c1e820ee6428189f71e82c32c7a91
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/utils.py
@@ -0,0 +1,138 @@
+# mypy: allow-untyped-defs
+from itertools import chain
+from typing import Any, Optional
+
+from torch import nn
+from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations
+
+
+__all__ = [
+    "module_contains_param",
+    "swap_module",
+    "module_to_fqn",
+    "fqn_to_module",
+    "get_arg_info_from_tensor_fqn",
+    "FakeSparsity",
+]
+
+
+def module_contains_param(module: nn.Module, parametrization: type[nn.Module]) -> bool:
+    if is_parametrized(module):
+        # see if any of the module tensors have a parametriztion attached that matches the one passed in
+        return any(
+            any(isinstance(param, parametrization) for param in param_list)
+            for key, param_list in module.parametrizations.items()  # type: ignore[union-attr,operator]
+        )
+    return False
+
+
+def swap_module(
+    mod: nn.Module, mapping: dict[type[nn.Module], type[nn.Module]]
+) -> nn.Module:
+    r"""Swaps the module using from_dense according to the mapping passed in.
+    Args:
+        mod: input module
+        mapping: a dictionary that maps from nn module to sparse nn module
+    Return:
+        The corresponding sparse module of `mod` according to mapping, created using from_dense
+    """
+    if type_before_parametrizations(mod) in mapping:
+        sparse_mod = mapping[type_before_parametrizations(mod)]
+
+        # TODO Fix this typing, as Type[Module] has no attribute "from_dense"
+        new_mod = sparse_mod.from_dense(mod)  # type: ignore[attr-defined]
+
+        # Preserve module's pre forward hooks. They'll be called on quantized input
+        for pre_hook_fn in mod._forward_pre_hooks.values():
+            new_mod.register_forward_pre_hook(pre_hook_fn)
+        # Preserve module's post forward hooks except _observer_forward_hook
+        # After convert they'll work with quantized output
+        for hook_fn in mod._forward_hooks.values():
+            new_mod.register_forward_hook(hook_fn)
+
+        # respect device affinity when swapping modules
+        devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
+        assert len(devices) <= 1, (
+            f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
+        )
+        device = next(iter(devices)) if len(devices) > 0 else None
+        if device:
+            new_mod.to(device)
+
+        return new_mod
+
+    else:
+        return mod
+
+
+def module_to_fqn(
+    model: nn.Module, module: nn.Module, prefix: str = ""
+) -> Optional[str]:
+    """
+    Returns the fqn for a module or None if module not a descendent of model.
+    """
+    if module is model:
+        return ""
+    for name, child in model.named_children():
+        fqn = module_to_fqn(child, module, ".")
+        if isinstance(fqn, str):
+            return prefix + name + fqn
+    return None
+
+
+def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]:
+    """
+    Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path`
+    doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors.
+    """
+    if path != "":
+        for name in path.split("."):
+            model = getattr(model, name, None)
+    return model
+
+
+def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str, Any]:
+    """
+    Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name
+    """
+    # string manip to split tensor_fqn into module_fqn and tensor_name
+    # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
+    # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
+    tensor_name = tensor_fqn.split(".")[-1]
+    module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
+
+    module = fqn_to_module(model, module_fqn)
+
+    return {
+        "module_fqn": module_fqn,
+        "module": module,
+        "tensor_name": tensor_name,
+        "tensor_fqn": tensor_fqn,
+    }
+
+
+# Parametrizations
+class FakeSparsity(nn.Module):
+    r"""Parametrization for the weights. Should be attached to the 'weight' or
+    any other parameter that requires a mask applied to it.
+
+    Note::
+
+        Once the mask is passed, the variable should not change the id. The
+        contents of the mask can change, but the mask reference itself should
+        not.
+    """
+
+    def __init__(self, mask):
+        super().__init__()
+        self.register_buffer("mask", mask)
+
+    def forward(self, x):
+        assert self.mask.shape == x.shape
+        return self.mask * x
+
+    def state_dict(self, *args, **kwargs):
+        # We don't want to let the parametrizations to save the mask.
+        # That way we make sure that the linear module doesn't store the masks
+        # alongside their parametrizations.
+        return {}
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..89c707ad33e6351496121af94f8f4d280c8bab82
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py
@@ -0,0 +1,248 @@
+# mypy: allow-untyped-defs
+import operator
+from functools import reduce
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn.functional as F
+
+from .base_sparsifier import BaseSparsifier
+
+
+__all__ = ["WeightNormSparsifier"]
+
+
+def _flat_idx_to_2d(idx, shape):
+    rows = idx // shape[1]
+    cols = idx % shape[1]
+    return rows, cols
+
+
+class WeightNormSparsifier(BaseSparsifier):
+    r"""Weight-Norm Sparsifier
+
+    This sparsifier computes the norm of every sparse block and "zeroes-out" the
+    ones with the lowest norm. The level of sparsity defines how many of the
+    blocks is removed.
+
+    This sparsifier is controlled by three variables:
+    1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out
+    2. `sparse_block_shape` defines the shape of the sparse blocks. Note that
+        the sparse blocks originate at the zero-index of the tensor.
+    3. `zeros_per_block` is the number of zeros that we are expecting in each
+        sparse block. By default we assume that all elements within a block are
+        zeroed-out. However, setting this variable sets the target number of
+        zeros per block. The zeros within each block are chosen as the *smallest
+        absolute values*.
+
+    Args:
+
+        sparsity_level: The target level of sparsity
+        sparse_block_shape: The shape of a sparse block (see note below)
+        zeros_per_block: Number of zeros in a sparse block
+        norm: Norm to use. Could be either `int` or a callable.
+            If `int`, only L1 and L2 are implemented.
+
+    Note::
+        The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS),
+        irrespective of what the rows / cols mean in the data tensor. That means,
+        if you were to sparsify a weight tensor in the nn.Linear, which has a
+        weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output
+        channels, while the `block_COLS` would refer to the input channels.
+
+    Note::
+        All arguments to the WeightNormSparsifier constructor are "default"
+        arguments and could be overridden by the configuration provided in the
+        `prepare` step.
+    """
+
+    def __init__(
+        self,
+        sparsity_level: float = 0.5,
+        sparse_block_shape: tuple[int, int] = (1, 4),
+        zeros_per_block: Optional[int] = None,
+        norm: Optional[Union[Callable, int]] = None,
+    ):
+        if zeros_per_block is None:
+            zeros_per_block = reduce(operator.mul, sparse_block_shape)
+        defaults = {
+            "sparsity_level": sparsity_level,
+            "sparse_block_shape": sparse_block_shape,
+            "zeros_per_block": zeros_per_block,
+        }
+        if norm is None:
+            norm = 2
+        if callable(norm):
+            self.norm_fn = norm
+        elif norm == 1:
+            self.norm_fn = lambda T: T.abs()
+        elif norm == 2:
+            self.norm_fn = lambda T: T * T
+        else:
+            raise NotImplementedError(f"L-{norm} is not yet implemented.")
+        super().__init__(defaults=defaults)
+
+    def _scatter_fold_block_mask(
+        self,
+        output_shape,
+        dim,
+        indices,
+        block_shape,
+        mask=None,
+        input_shape=None,
+        device=None,
+    ):
+        r"""Creates patches of size `block_shape` after scattering the indices."""
+        if mask is None:
+            assert input_shape is not None
+            mask = torch.ones(input_shape, device=device)
+        mask.scatter_(dim=dim, index=indices, value=0)
+        mask.data = F.fold(
+            mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape
+        )
+        return mask
+
+    def _make_tensor_mask(
+        self, data, input_shape, sparsity_level, sparse_block_shape, mask=None
+    ):
+        r"""Creates a tensor-level mask.
+
+        Tensor-level mask is described as a mask, where the granularity of sparsification of the
+        smallest patch is the sparse_block_shape. That means, that for a given mask and a
+        sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape.
+
+        In this context, `sparsity_level` describes the fraction of sparse patches.
+        """
+        h, w = data.shape[-2:]
+        block_h, block_w = sparse_block_shape
+        dh = (block_h - h % block_h) % block_h
+        dw = (block_w - w % block_w) % block_w
+
+        if mask is None:
+            mask = torch.ones(h + dh, w + dw, device=data.device)
+
+        if sparsity_level >= 1.0:
+            mask.data = torch.zeros_like(mask)
+            return mask
+        elif sparsity_level <= 0.0:
+            mask.data = torch.ones_like(mask)
+            return mask
+
+        values_per_block = reduce(operator.mul, sparse_block_shape)
+        if values_per_block > 1:
+            # Reduce the data
+            data = F.avg_pool2d(
+                data[None, None, :],
+                kernel_size=sparse_block_shape,
+                stride=sparse_block_shape,
+                ceil_mode=True,
+            )
+        data = data.flatten()
+        num_blocks = len(data)
+
+        data = data.repeat(1, values_per_block, 1)
+
+        threshold_idx = int(round(sparsity_level * num_blocks))
+        threshold_idx = max(0, min(num_blocks - 1, threshold_idx))  # Sanity check
+        _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False)
+
+        # Temp reshape for mask
+        mask_reshape = mask.reshape(data.shape)  # data might be reshaped
+        self._scatter_fold_block_mask(
+            dim=2,
+            output_shape=(h + dh, w + dw),
+            indices=sorted_idx,
+            block_shape=sparse_block_shape,
+            mask=mask_reshape,
+        )
+        mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous()
+        return mask
+
+    def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None):
+        r"""Creates a block-level mask.
+
+        Block-level mask is described as a mask, where the granularity of sparsification of the
+        largest patch is the sparse_block_shape. That means that for a given mask and a
+        sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape.
+
+        In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch.
+        """
+        h, w = data.shape[-2:]
+        block_h, block_w = sparse_block_shape
+        dh = (block_h - h % block_h) % block_h
+        dw = (block_w - w % block_w) % block_w
+        values_per_block = reduce(operator.mul, sparse_block_shape)
+
+        if mask is None:
+            mask = torch.ones((h + dh, w + dw), device=data.device)
+
+        if values_per_block == zeros_per_block:
+            # Everything should be sparsified
+            mask.data = torch.zeros_like(mask)
+            return mask
+
+        # create a new padded tensor like data (to match the block_shape)
+        padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device)
+        padded_data.fill_(torch.nan)
+        padded_data[:h, :w] = data
+        unfolded_data = F.unfold(
+            padded_data[None, None, :],
+            kernel_size=sparse_block_shape,
+            stride=sparse_block_shape,
+        )
+
+        # Temp reshape for mask
+        mask_reshape = mask.reshape(unfolded_data.shape)
+        _, sorted_idx = torch.topk(
+            unfolded_data, k=zeros_per_block, dim=1, largest=False
+        )
+
+        self._scatter_fold_block_mask(
+            dim=1,
+            indices=sorted_idx,
+            output_shape=padded_data.shape,
+            block_shape=sparse_block_shape,
+            mask=mask_reshape,
+        )
+
+        mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous()
+        return mask
+
+    def update_mask(  # type: ignore[call-override, override]
+        self,
+        module,
+        tensor_name,
+        sparsity_level,
+        sparse_block_shape,
+        zeros_per_block,
+        **kwargs,
+    ):
+        values_per_block = reduce(operator.mul, sparse_block_shape)
+        if zeros_per_block > values_per_block:
+            raise ValueError(
+                "Number of zeros per block cannot be more than the total number of elements in that block."
+            )
+        if zeros_per_block < 0:
+            raise ValueError("Number of zeros per block should be positive.")
+
+        mask = getattr(module.parametrizations, tensor_name)[0].mask
+        if sparsity_level <= 0 or zeros_per_block == 0:
+            mask.data = torch.ones_like(mask)
+        elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block):
+            mask.data = torch.zeros_like(mask)
+        else:
+            ww = self.norm_fn(getattr(module, tensor_name))
+            tensor_mask = self._make_tensor_mask(
+                data=ww,
+                input_shape=ww.shape,
+                sparsity_level=sparsity_level,
+                sparse_block_shape=sparse_block_shape,
+            )
+            if values_per_block != zeros_per_block:
+                block_mask = self._make_block_mask(
+                    data=ww,
+                    sparse_block_shape=sparse_block_shape,
+                    zeros_per_block=zeros_per_block,
+                )
+                tensor_mask = torch.logical_or(tensor_mask, block_mask)
+            mask.data = tensor_mask
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d3f79d45ace6ce3d4addc05acb9e57e487bb16f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c4b3b68a4746754128ab4ac9199c295b1262b92
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2616e98dafbf54403e18da2ab3219e0eb0b2a5c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3399173f4830cb5b94da9e6f5e8bed8c7614c258
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/observer.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/observer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1259d5115d555e1e59787609a5202686d2589aa
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/observer.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7e25a6242188324edf389625be57395f858a054
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa7ce93a9545eb195aeca64b1975fa14feb89d4b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8695d1e7a103e7e51a18e439a2012f84ad498b3e
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79beea816940e4a3ed4692bdf1daefaddf0e7094
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6ef335247aec929889d851609a2f0f2c3db75f3
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..328f8ee56a232d7f344c521ef4bccfe371b252e6
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fac8600a04560b695adf2a685c633f67e544da1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cba1f8b277a5fd61047876ccb9ff8b9a5a038461
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__pycache__/utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d452359d41c36dc719e6df932b6fb018ca6a36b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/__init__.py
@@ -0,0 +1,30 @@
+from .backend_config import (
+    BackendConfig,
+    BackendPatternConfig,
+    DTypeConfig,
+    DTypeWithConstraints,
+    ObservationType,
+)
+from .executorch import get_executorch_backend_config
+from .fbgemm import get_fbgemm_backend_config
+from .native import get_native_backend_config, get_native_backend_config_dict
+from .onednn import get_onednn_backend_config
+from .qnnpack import get_qnnpack_backend_config
+from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict
+
+
+__all__ = [
+    "get_fbgemm_backend_config",
+    "get_native_backend_config",
+    "get_native_backend_config_dict",
+    "get_qnnpack_backend_config",
+    "get_tensorrt_backend_config",
+    "get_tensorrt_backend_config_dict",
+    "get_executorch_backend_config",
+    "BackendConfig",
+    "BackendPatternConfig",
+    "DTypeConfig",
+    "DTypeWithConstraints",
+    "ObservationType",
+    "get_onednn_backend_config",
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..781bfdc8b39290076f2a1d9238152813475787ff
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py
@@ -0,0 +1,782 @@
+# mypy: allow-untyped-defs
+import copy
+import operator
+from collections import namedtuple
+from typing import Callable, Union
+
+import torch
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.qat as nniqat
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.quantized.reference as nnqr
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.quantization.fuser_method_mappings import (
+    _sequential_wrapper2,
+    fuse_conv_bn,
+    fuse_conv_bn_relu,
+    fuse_convtranspose_bn,
+    fuse_linear_bn,
+)
+
+from .backend_config import (
+    BackendPatternConfig,
+    DTypeConfig,
+    DTypeWithConstraints,
+    ObservationType,
+)
+
+
+__all__: list[str] = []
+
+# TODO: rename to be more explicit, e.g. qat_conv_relu
+_ConvMetadata = namedtuple(
+    "_ConvMetadata",
+    [
+        "root",
+        "transpose",
+        "bn",
+        "reference",
+        "transpose_reference",
+        "fused_conv_relu",
+        "fused_conv_bn",
+        "fused_conv_bn_relu",
+        "qat",
+        "relu_qat",
+        "bn_qat",
+        "bn_relu_qat",
+        "func",
+        "func_transpose",
+    ],
+)
+_Conv1dMetadata = _ConvMetadata(
+    nn.Conv1d,
+    nn.ConvTranspose1d,
+    nn.BatchNorm1d,
+    nnqr.Conv1d,
+    nnqr.ConvTranspose1d,
+    nni.ConvReLU1d,
+    nni.ConvBn1d,
+    nni.ConvBnReLU1d,
+    nnqat.Conv1d,
+    nniqat.ConvReLU1d,
+    nniqat.ConvBn1d,
+    nniqat.ConvBnReLU1d,
+    F.conv1d,
+    F.conv_transpose1d,
+)
+_Conv2dMetadata = _ConvMetadata(
+    nn.Conv2d,
+    nn.ConvTranspose2d,
+    nn.BatchNorm2d,
+    nnqr.Conv2d,
+    nnqr.ConvTranspose2d,
+    nni.ConvReLU2d,
+    nni.ConvBn2d,
+    nni.ConvBnReLU2d,
+    nnqat.Conv2d,
+    nniqat.ConvReLU2d,
+    nniqat.ConvBn2d,
+    nniqat.ConvBnReLU2d,
+    F.conv2d,
+    F.conv_transpose2d,
+)
+_Conv3dMetadata = _ConvMetadata(
+    nn.Conv3d,
+    nn.ConvTranspose3d,
+    nn.BatchNorm3d,
+    nnqr.Conv3d,
+    nnqr.ConvTranspose3d,
+    nni.ConvReLU3d,
+    nni.ConvBn3d,
+    nni.ConvBnReLU3d,
+    nnqat.Conv3d,
+    nniqat.ConvReLU3d,
+    nniqat.ConvBn3d,
+    nniqat.ConvBnReLU3d,
+    F.conv3d,
+    F.conv_transpose3d,
+)
+
+# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values
+# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh
+_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints(
+    dtype=torch.quint8,
+    quant_min_lower_bound=0,
+    quant_max_upper_bound=255,
+    scale_exact_match=1.0 / 256.0,
+    zero_point_exact_match=0,
+)
+_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints(
+    dtype=torch.quint8,
+    quant_min_lower_bound=0,
+    quant_max_upper_bound=255,
+    scale_exact_match=2.0 / 256.0,
+    zero_point_exact_match=128,
+)
+_FIXED_QPARAMS_OP_TO_CONSTRAINTS: dict[Union[Callable, str], DTypeWithConstraints] = {
+    torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
+    torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
+    torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
+    "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
+    "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
+}
+
+
+def _get_binary_op_configs(
+    dtype_configs: list[DTypeConfig],
+) -> list[BackendPatternConfig]:
+    binary_op_configs: list[BackendPatternConfig] = []
+    num_tensor_args_to_observation_type_mapping = {
+        # TODO: this is not used right now since we have extra check in prepare
+        # will need to change this to NO_OBSERVER later after we implemented
+        # Tensor dtype inference properly
+        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
+        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+    }
+    for op_with_quantized_bop_scalar_variant in [
+        operator.add,
+        torch.add,
+        operator.mul,
+        torch.mul,
+    ]:
+        bop_patterns = [
+            (op_with_quantized_bop_scalar_variant, nn.ReLU),
+            (op_with_quantized_bop_scalar_variant, F.relu),
+            (op_with_quantized_bop_scalar_variant, torch.relu),
+            op_with_quantized_bop_scalar_variant,
+        ]
+        binary_op_configs.extend(
+            BackendPatternConfig(bop_pattern)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            ._set_num_tensor_args_to_observation_type(
+                num_tensor_args_to_observation_type_mapping
+            )
+            for bop_pattern in bop_patterns
+        )
+    # matmul
+    binary_op_configs.append(
+        BackendPatternConfig(torch.matmul).set_dtype_configs(dtype_configs)  # noqa: E131
+    )
+    return binary_op_configs
+
+
+def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]:
+    """
+    Return all configs related to linear modules and ops.
+    """
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    linear_configs: list[BackendPatternConfig] = []
+
+    # (1) Single linear modules/functions
+    # -------------------------------------
+    # linear module
+    linear_configs.append(
+        BackendPatternConfig(torch.nn.Linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+        .set_qat_module(nnqat.Linear)
+    )
+    # linear qat module
+    linear_configs.append(
+        BackendPatternConfig(nnqat.Linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+    )
+    # functional linear
+    linear_configs.append(
+        BackendPatternConfig(torch.nn.functional.linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+
+    # (2) Linear + relu
+    # -------------------
+    # 2.1 linear module + relu fusion config
+    # linear relu, linear module + relu module
+    linear_configs.append(
+        BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
+        .set_dtype_configs(dtype_configs)  # noqa: E131
+        .set_fuser_method(_sequential_wrapper2(nni.LinearReLU))
+        .set_fused_module(nni.LinearReLU)
+    )
+    # linear relu, linear module + functional relu
+    linear_configs.append(
+        BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu))
+        .set_dtype_configs(dtype_configs)  # noqa: E131
+        .set_fuser_method(_sequential_wrapper2(nni.LinearReLU))
+        .set_fused_module(nni.LinearReLU)
+    )
+
+    # 2.2 linear module + relu, fused module configs
+    # linear relu, fused module
+    linear_configs.append(
+        BackendPatternConfig(nni.LinearReLU)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+        .set_qat_module(nniqat.LinearReLU)
+    )
+    # linear relu, qat fused module
+    linear_configs.append(
+        BackendPatternConfig(nniqat.LinearReLU)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+    )
+    # 2.3 functional linear + relu configs
+    # linear relu, functional linear + relu module
+    linear_configs.append(
+        BackendPatternConfig((F.linear, torch.nn.ReLU))
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+    # linear relu, functional linear + functional relu
+    linear_configs.append(
+        BackendPatternConfig((F.linear, F.relu))
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+
+    # (3) Linear + batchnorm
+    # ------------------------
+    # 3.1 linear bn fusion
+    linear_configs.append(
+        BackendPatternConfig((nn.Linear, nn.BatchNorm1d))
+        .set_dtype_configs(dtype_configs)  # noqa: E131
+        .set_fuser_method(fuse_linear_bn)
+        .set_fused_module(nni.LinearBn1d)
+    )
+
+    # 3.2 linear bn fused
+    # linear bn, fused module
+    linear_configs.append(
+        BackendPatternConfig(nni.LinearBn1d)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+        .set_qat_module(nniqat.LinearBn1d)
+    )
+    # linear bn, qat fused module
+    linear_configs.append(
+        BackendPatternConfig(nniqat.LinearBn1d)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+    )
+    return linear_configs
+
+
+def _get_conv_configs(dtype_configs):
+    """
+    Return all configs related to conv modules and ops.
+    """
+    conv_configs = []
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]:
+        # (1) Single conv modules/functions
+        # -----------------------------------
+        # conv module
+        conv_configs.append(
+            BackendPatternConfig(convs.root)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+            .set_qat_module(convs.qat)
+        )
+        # conv qat module
+        conv_configs.append(
+            BackendPatternConfig(convs.qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # functional conv
+        conv_configs.append(
+            BackendPatternConfig(convs.func)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            ._set_input_type_to_index({"weight": 1, "bias": 2})
+        )
+
+        # (2) Conv + relu
+        # -----------------
+        # 2.1 conv module + relu fusion configs
+        # conv relu fusion, conv module + relu module
+        conv_configs.append(
+            BackendPatternConfig((convs.root, torch.nn.ReLU))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
+            .set_fused_module(convs.fused_conv_relu)
+        )
+        # conv relu fusion, conv module + functional relu
+        conv_configs.append(
+            BackendPatternConfig((convs.root, F.relu))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
+            .set_fused_module(convs.fused_conv_relu)
+        )
+        # 2.2 conv module + relu fused module configs
+        # conv relu, fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_relu)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+            .set_qat_module(convs.relu_qat)
+        )
+        # conv relu, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.relu_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # 2.3 functional conv + relu configs
+        # conv relu, functional conv + relu module
+        conv_configs.append(
+            BackendPatternConfig((convs.func, torch.nn.ReLU))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+        # conv relu, functional conv + functional relu
+        conv_configs.append(
+            BackendPatternConfig((convs.func, F.relu))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+
+        # fused conv relu
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_relu)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.relu_qat)
+        )
+
+        conv_configs.append(
+            BackendPatternConfig(convs.relu_qat)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+
+        # (3) Conv + batchnorm (+ relu)
+        # -------------------------------
+        # 3.1 conv bn fusion configs
+        # conv + bn fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuse_conv_bn)
+            .set_fused_module(convs.fused_conv_bn)
+        )
+        # conv + bn + relu module fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuse_conv_bn_relu)
+            .set_fused_module(convs.fused_conv_bn_relu)
+        )
+        # conv + bn + relu functional fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn, F.relu))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_root_module(convs.root)
+            .set_fuser_method(fuse_conv_bn_relu)
+            .set_fused_module(convs.fused_conv_bn_relu)
+        )
+        # TODO: we can add fusion for torch.relu as well
+
+        # 3.2 conv + bn (+ relu) fused module configs
+        # fused conv bn
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_bn)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.bn_qat)
+        )
+
+        # fused conv bn relu
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_bn_relu)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.bn_relu_qat)
+        )
+
+        # conv bn, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.bn_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # conv bn relu, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.bn_relu_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+
+        # (4) conv transpose and its fusion
+        # 4.1 conv transpose config
+        conv_configs.append(
+            BackendPatternConfig(convs.transpose)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_root_module(convs.transpose)
+            .set_reference_quantized_module(convs.transpose_reference)
+        )
+
+        # 4.2 conv transpose + bn fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.transpose, convs.bn))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuse_convtranspose_bn)
+            .set_root_module(convs.transpose)
+            .set_reference_quantized_module(convs.transpose_reference)
+        )
+
+        # 4.3 functional conv transpose
+        conv_configs.append(
+            BackendPatternConfig(convs.func_transpose)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            ._set_input_type_to_index({"weight": 1, "bias": 2})
+        )
+
+    return conv_configs
+
+
+def _get_cat_config(dtype_configs: list[DTypeConfig]) -> BackendPatternConfig:
+    return (
+        BackendPatternConfig(torch.cat)
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+        .set_dtype_configs(dtype_configs)
+    )
+
+
+def _get_ln_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]:
+    ln_configs = []
+    ln_configs.append(
+        BackendPatternConfig(torch.nn.LayerNorm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+    ln_configs.append(
+        BackendPatternConfig(torch.nn.functional.layer_norm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 2, "bias": 3})
+    )
+    return ln_configs
+
+
+def _get_default_op_configs(
+    dtype_configs: list[DTypeConfig],
+) -> list[BackendPatternConfig]:
+    default_ops = [
+        torch.nn.ELU,
+        torch.nn.LeakyReLU,
+        torch.nn.Hardswish,
+        torch.nn.InstanceNorm1d,
+        torch.nn.InstanceNorm2d,
+        torch.nn.InstanceNorm3d,
+        torch.nn.Dropout,
+        torch.nn.PReLU,
+        torch.nn.functional.elu,
+        torch.nn.functional.hardswish,
+        torch.nn.functional.leaky_relu,
+        torch.nn.functional.dropout,
+    ]
+    configs = [
+        BackendPatternConfig(op)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        for op in default_ops
+    ]
+
+    configs.append(
+        BackendPatternConfig(torch.nn.functional.group_norm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 2, "bias": 3})
+    )
+
+    configs.append(
+        BackendPatternConfig(torch.nn.functional.instance_norm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 3, "bias": 4})
+    )
+    return configs
+
+
+def _add_fixed_qparams_to_dtype_configs(
+    dtype_configs: list[DTypeConfig],
+    constraints: DTypeWithConstraints,
+) -> list[DTypeConfig]:
+    """
+    Return a copy of the list of DTypeConfigs where activations are subject to the specified
+    constraints required for fixed qparams ops.
+
+    If the data type doesn't match the one in the constraints, simply leave the corresponding
+    DTypeConfig unchanged.
+
+    If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations,
+    throw an exception since these settings are incompatible with fixed qparams ops.
+    """
+    new_dtype_configs = []
+    for dtype_config in dtype_configs:
+        dc = copy.deepcopy(dtype_config)
+        for orig_constraints in [
+            dc.input_dtype_with_constraints,
+            dc.output_dtype_with_constraints,
+        ]:
+            if orig_constraints.dtype != constraints.dtype:
+                continue
+            if orig_constraints.scale_min_lower_bound is not None:
+                raise ValueError(
+                    f"scale_min_lower_bound is invalid for fixed qparams ops: {dtype_config}"
+                )
+            if orig_constraints.scale_max_upper_bound is not None:
+                raise ValueError(
+                    f"scale_max_upper_bound is invalid for fixed qparams ops: {dtype_config}"
+                )
+            orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound
+            orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound
+            orig_constraints.scale_exact_match = constraints.scale_exact_match
+            orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match
+        new_dtype_configs.append(dc)
+    return new_dtype_configs
+
+
+def _get_fixed_qparams_op_configs(
+    dtype_configs: list[DTypeConfig],
+) -> list[BackendPatternConfig]:
+    fixed_qparams_op_configs = []
+    for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items():
+        new_dtype_configs = _add_fixed_qparams_to_dtype_configs(
+            dtype_configs, constraints
+        )
+        fixed_qparams_op_configs.append(
+            BackendPatternConfig(fixed_qparam_op)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(new_dtype_configs)
+        )
+    return fixed_qparams_op_configs
+
+
+def _get_share_qparams_op_configs(dtype_configs):
+    """Get the operator config for the operators that works for both float and quantized input
+    if input is quantized, the output Tensor shares the same quantization parameter
+    with input.
+    Example operator: avgpool2d, reshape, transpose, maxpool2d
+    Example observed operator:
+    observer_0 - avgpool2d - observer_0 (same observer instance as input)
+    """
+
+    def _get_share_qprams_op_backend_config(op):
+        return (
+            BackendPatternConfig(op)
+            .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+            .set_dtype_configs(dtype_configs)
+        )
+
+    share_qparams_ops = [
+        torch.nn.AdaptiveAvgPool1d,
+        torch.nn.AdaptiveAvgPool2d,
+        torch.nn.AdaptiveAvgPool3d,
+        torch.nn.AvgPool1d,
+        torch.nn.AvgPool2d,
+        torch.nn.AvgPool3d,
+        torch.nn.Hardtanh,
+        torch.nn.Identity,
+        torch.nn.MaxPool1d,
+        torch.nn.MaxPool2d,
+        torch.nn.MaxPool3d,
+        torch.nn.PixelShuffle,
+        torch.nn.PixelUnshuffle,
+        torch.nn.ReLU,
+        torch.nn.ReLU6,
+        torch.adaptive_avg_pool1d,
+        torch.nn.functional.adaptive_avg_pool2d,
+        torch.nn.functional.adaptive_avg_pool3d,
+        torch.nn.functional.hardtanh,
+        torch.nn.functional.hardtanh_,
+        torch.nn.functional.interpolate,
+        torch.nn.functional.max_pool1d,
+        torch.nn.functional.max_pool2d,
+        torch.nn.functional.max_pool3d,
+        torch.nn.functional.pixel_shuffle,
+        torch.nn.functional.pixel_unshuffle,
+        torch.nn.functional.relu,
+        torch.nn.functional.relu6,
+        torch.avg_pool1d,
+        torch._C._nn.avg_pool2d,
+        torch._C._nn.avg_pool3d,
+        torch.clamp,
+        torch.flatten,
+        torch.mean,
+        torch.narrow,
+        torch.repeat_interleave,
+        torch.transpose,
+        torch.squeeze,
+        torch.stack,
+        torch.unsqueeze,
+        operator.floordiv,
+        "contiguous",
+        "clamp",
+        "detach",
+        "detach_",
+        "mean",
+        "permute",
+        "repeat",
+        "repeat_interleave",
+        "reshape",
+        "resize_",
+        "relu",
+        "relu_",
+        "squeeze",
+        "squeeze_",
+        "transpose",
+        "unsqueeze",
+        "unsqueeze_",
+        "view",
+    ]
+    return [_get_share_qprams_op_backend_config(op) for op in share_qparams_ops]
+
+
+def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]:
+    """Get configs related to batchnorm."""
+    bn_configs = []
+    bn_to_fused_bn = {
+        torch.nn.BatchNorm2d: nni.BNReLU2d,
+        torch.nn.BatchNorm3d: nni.BNReLU3d,
+    }
+    for bn in bn_to_fused_bn.keys():
+        fused_bn = bn_to_fused_bn[bn]
+        # bn module + relu module fusion config
+        bn_configs.append(
+            BackendPatternConfig((bn, nn.ReLU))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(fused_bn))
+            .set_fused_module(fused_bn)
+        )
+        # bn module + F.relu fusion config
+        bn_configs.append(
+            BackendPatternConfig((bn, F.relu))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(fused_bn))
+            .set_fused_module(fused_bn)
+        )
+        bn_configs.append(
+            BackendPatternConfig(bn)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+
+    # fused bn configs
+    for fused_bn in bn_to_fused_bn.values():
+        bn_configs.append(
+            BackendPatternConfig(fused_bn)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+    return bn_configs
+
+
+def _get_rnn_op_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConfig]:
+    rnn_op_configs = []
+    for rnn_op, ref_rnn_op in [
+        (nn.GRUCell, nnqr.GRUCell),
+        (nn.LSTMCell, nnqr.LSTMCell),
+        (nn.RNNCell, nnqr.RNNCell),
+        (nn.LSTM, nnqr.LSTM),
+        (nn.GRU, nnqr.GRU),
+    ]:
+        rnn_op_configs.append(
+            BackendPatternConfig(rnn_op)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(rnn_op)
+            .set_reference_quantized_module(ref_rnn_op)
+        )
+    return rnn_op_configs
+
+
+def _get_embedding_op_configs(
+    dtype_configs: list[DTypeConfig],
+) -> list[BackendPatternConfig]:
+    embedding_op_configs = []
+    for embedding_op, qat_embedding_op, ref_embedding_op in [
+        (nn.Embedding, nnqat.Embedding, nnqr.Embedding),
+        (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag),
+    ]:
+        embedding_op_configs.append(
+            BackendPatternConfig(embedding_op)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_qat_module(qat_embedding_op)
+            .set_root_module(embedding_op)
+            .set_reference_quantized_module(ref_embedding_op)
+        )
+
+        # config for qat op
+        embedding_op_configs.append(
+            BackendPatternConfig(qat_embedding_op)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(embedding_op)
+            .set_reference_quantized_module(ref_embedding_op)
+        )
+    return embedding_op_configs
+
+
+def _get_tensor_info_op_configs(dtype_configs):
+    """
+    These ops work on tensors of different dtypes but return non-tensors
+    containing information about the input tensor.
+    """
+
+    def _get_config(op):
+        return (
+            BackendPatternConfig(op)
+            .set_observation_type(ObservationType.INPUT_OUTPUT_NOT_OBSERVED)
+            .set_dtype_configs(dtype_configs)
+        )
+
+    return [_get_config(op) for op in ("shape", "size")]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4e67b79c370207c228fd66d33fadad03a58ed2a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py
@@ -0,0 +1,181 @@
+# mypy: allow-untyped-defs
+import operator
+
+import torch
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    BackendPatternConfig,
+    DTypeConfig,
+    ObservationType,
+)
+
+
+weighted_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+
+def get_linear_configs():
+    linear_configs = []
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [weighted_op_quint8_dtype_config]
+
+    # TODO: need to fix the way we insert observers for this pattern
+    # should be solved in the new fusion API
+    # reason that this doesn't work: the pattern is a bit complicated and we don't
+    # have a way to specify which input of the pattern we would like to observe
+    # pattern:
+    # bias input weight
+    # \     |    /
+    #  \    |   t
+    #   \   |  /
+    #    addmm
+    # we want to observe "weight" as weight, but there is not way to convey this
+    # information with current pattern language
+    #
+    # right now:
+    # original:
+    #         weight - t \
+    #         input  - addmm
+    # observed (no hack):
+    #      weight - t - observer \
+    #       input - observer - addmm
+    # target:
+    #      weight - observer - t \
+    #        input - observer - addmm
+
+    # def root_node_getter(node_pattern):
+    #     addmm, bias, act, weight = node_pattern
+    #     return addmm
+
+    # linear_configs.append(
+    #     BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default))
+    #     .set_observation_type(observation_type)  # noqa: E131
+    #     .set_dtype_configs(dtype_configs)
+    #     ._set_root_node_getter(root_node_getter))
+
+    linear_configs.append(
+        BackendPatternConfig(torch.ops.aten.addmm.default)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 2, "bias": 0})
+    )
+    # linear is decomposed to `t - mm` if bias is not present
+    linear_configs.append(
+        BackendPatternConfig(torch.ops.aten.mm.default)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1})
+    )
+    return linear_configs
+
+
+def get_conv_configs():
+    conv_configs = []
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [weighted_op_quint8_dtype_config]
+    conv_configs.append(
+        BackendPatternConfig(torch.ops.aten.convolution.default)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+    conv_configs.append(
+        BackendPatternConfig(
+            (torch.ops.aten.convolution.default, torch.ops.aten.relu.default)
+        )
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+    # TODO: remove when functionalization is supported in PT2 mode
+    conv_configs.append(
+        BackendPatternConfig(
+            (torch.ops.aten.convolution.default, torch.ops.aten.relu_.default)
+        )
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+    return conv_configs
+
+
+def get_pooling_configs():
+    backend_pattern_configs = []
+    observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
+    dtype_configs = [weighted_op_quint8_dtype_config]
+
+    def root_node_getter(node_pattern):
+        _getitem, maxpool, _index = node_pattern
+        return maxpool
+
+    backend_pattern_configs.append(
+        BackendPatternConfig()
+        ._set_pattern_complex_format(
+            (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0)
+        )
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_root_node_getter(root_node_getter)
+    )
+
+    return backend_pattern_configs
+
+
+def get_relu_configs():
+    backend_pattern_configs = []
+    observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
+    dtype_configs = [weighted_op_quint8_dtype_config]
+    backend_pattern_configs.append(
+        BackendPatternConfig(torch.ops.aten.relu.default)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+    return backend_pattern_configs
+
+
+def get_binary_op_configs():
+    binary_op_configs: list[BackendPatternConfig] = []
+    dtype_configs = [weighted_op_quint8_dtype_config]
+    num_tensor_args_to_observation_type_mapping = {
+        # TODO: this is not used right now since we have extra check in prepare
+        # will need to change this to NO_OBSERVER later after we implemented
+        # Tensor dtype inference properly
+        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
+        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+    }
+    for op_with_quantized_bop_scalar_variant in [
+        torch.ops.aten.add.Tensor,
+        torch.ops.aten.add_.Tensor,
+    ]:
+        bop_patterns = [
+            (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default),
+            op_with_quantized_bop_scalar_variant,
+            # TODO: remove when functionalization is supported in pt2_mode
+            (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
+        ]
+        binary_op_configs.extend(
+            BackendPatternConfig(bop_pattern)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            ._set_num_tensor_args_to_observation_type(
+                num_tensor_args_to_observation_type_mapping
+            )
+            for bop_pattern in bop_patterns
+        )
+
+    return binary_op_configs
+
+
+def get_qnnpack_pt2e_backend_config():
+    return (
+        BackendConfig("qnnpack_pytorch_2.0_export")
+        .set_backend_pattern_configs(get_linear_configs())
+        .set_backend_pattern_configs(get_binary_op_configs())
+        .set_backend_pattern_configs(get_conv_configs())
+        .set_backend_pattern_configs(get_pooling_configs())
+        .set_backend_pattern_configs(get_relu_configs())
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3919b84da280883577ff30af23c58c49d4209cd0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/backend_config.py
@@ -0,0 +1,749 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Callable, Optional, TYPE_CHECKING, Union
+
+import torch
+
+
+if TYPE_CHECKING:
+    from torch.ao.quantization.utils import Pattern
+
+
+__all__ = [
+    "BackendConfig",
+    "BackendPatternConfig",
+    "DTypeConfig",
+    "DTypeWithConstraints",
+    "ObservationType",
+]
+
+
+# DTypeConfig dict keys
+INPUT_DTYPE_DICT_KEY = "input_dtype"
+OUTPUT_DTYPE_DICT_KEY = "output_dtype"
+WEIGHT_DTYPE_DICT_KEY = "weight_dtype"
+BIAS_DTYPE_DICT_KEY = "bias_dtype"
+IS_DYNAMIC_DICT_KEY = "is_dynamic"
+
+# BackendConfig dict keys
+NAME_DICT_KEY = "name"
+CONFIGS_DICT_KEY = "configs"
+
+# BackendPatternConfig dict keys
+PATTERN_DICT_KEY = "pattern"
+PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format"
+OBSERVATION_TYPE_DICT_KEY = "observation_type"
+DTYPE_CONFIGS_DICT_KEY = "dtype_configs"
+ROOT_MODULE_DICT_KEY = "root_module"
+QAT_MODULE_DICT_KEY = "qat_module"
+REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root"
+FUSED_MODULE_DICT_KEY = "fused_module"
+FUSER_METHOD_DICT_KEY = "fuser_method"
+ROOT_NODE_GETTER_DICT_KEY = "root_node_getter"
+EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter"
+NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type"
+INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index"
+
+
+# TODO: maybe rename this to something that's not related to observer
+# e.g. QParamsType
+class ObservationType(Enum):
+    """An enum that represents different ways of how an operator/operator pattern
+    should be observed
+    """
+
+    OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0
+    """this means input and output are observed with different observers, based
+    on qconfig.activation
+    example: conv, linear, softmax
+    """
+
+    OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1
+    """this means the output will use the same observer instance as input, based
+    on qconfig.activation
+    example: torch.cat, maxpool
+    """
+
+    INPUT_OUTPUT_NOT_OBSERVED = 2
+    """this means the input and output are never observed
+    example: x.shape, x.size
+    """
+
+
+@dataclass
+class DTypeWithConstraints:
+    """
+    Config for specifying additional constraints for a given dtype, such as quantization
+    value ranges, scale value ranges, and fixed quantization params, to be used in
+    :class:`~torch.ao.quantization.backend_config.DTypeConfig`.
+
+    The constraints currently supported are:
+
+    * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper
+      bounds for the minimum and maximum quantized values respectively. If
+      the QConfig's `quant_min` and `quant_max` fall outside this range,
+      then the QConfig will be ignored.
+
+    * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper
+      bounds for the minimum and maximum scale values respectively. If the
+      QConfig's minimum scale value (currently exposed as `eps`) falls below
+      the lower bound, then the QConfig will be ignored. Note that the upper
+      bound is currently not enforced.
+
+    * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements
+      for scale and zero point, to be used for operators with fixed quantization
+      parameters such as sigmoid and tanh. If the observer specified in the QConfig
+      is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if
+      the quantization parameters don't match, then the QConfig will be ignored.
+    """
+
+    dtype: Optional[torch.dtype] = None
+    quant_min_lower_bound: Union[int, float, None] = None
+    quant_max_upper_bound: Union[int, float, None] = None
+    scale_min_lower_bound: Union[int, float, None] = None
+    scale_max_upper_bound: Union[int, float, None] = None
+    scale_exact_match: Optional[float] = None
+    zero_point_exact_match: Optional[int] = None
+
+
+@dataclass
+class DTypeConfig:
+    """
+    Config object that specifies the supported data types passed as arguments to
+    quantize ops in the reference model spec, for input and output activations,
+    weights, and biases.
+
+    For example, consider the following reference model:
+
+      quant1 - [dequant1 - fp32_linear - quant2] - dequant2
+
+    The pattern in the square brackets refers to the reference pattern of
+    statically quantized linear. Setting the input dtype as `torch.quint8`
+    in the DTypeConfig means we pass in `torch.quint8` as the dtype argument
+    to the first quantize op (quant1). Similarly, setting the output dtype as
+    `torch.quint8` means we pass in `torch.quint8` as the dtype argument to
+    the second quantize op (quant2).
+
+    Note that the dtype here does not refer to the interface dtypes of the
+    op. For example, the "input dtype" here is not the dtype of the input
+    tensor passed to the quantized linear op. Though it can still be the
+    same as the interface dtype, this is not always the case, e.g. the
+    interface dtype is fp32 in dynamic quantization but the "input dtype"
+    specified in the DTypeConfig would still be quint8. The semantics of
+    dtypes here are the same as the semantics of the dtypes specified in
+    the observers.
+
+    These dtypes are matched against the ones specified in the user's
+    QConfig. If there is a match, and the QConfig satisfies the constraints
+    specified in the DTypeConfig (if any), then we will quantize the given
+    pattern using this DTypeConfig. Otherwise, the QConfig is ignored and
+    the pattern will not be quantized.
+
+    Example usage::
+
+        >>> # xdoctest: +SKIP(failing)
+        >>> dtype_config1 = DTypeConfig(
+        ...     input_dtype=torch.quint8,
+        ...     output_dtype=torch.quint8,
+        ...     weight_dtype=torch.qint8,
+        ...     bias_dtype=torch.float)
+
+        >>> dtype_config2 = DTypeConfig(
+        ...     input_dtype=DTypeWithConstraints(
+        ...         dtype=torch.quint8,
+        ...         quant_min_lower_bound=0,
+        ...         quant_max_upper_bound=255,
+        ...     ),
+        ...     output_dtype=DTypeWithConstraints(
+        ...         dtype=torch.quint8,
+        ...         quant_min_lower_bound=0,
+        ...         quant_max_upper_bound=255,
+        ...     ),
+        ...     weight_dtype=DTypeWithConstraints(
+        ...         dtype=torch.qint8,
+        ...         quant_min_lower_bound=-128,
+        ...         quant_max_upper_bound=127,
+        ...     ),
+        ...     bias_dtype=torch.float)
+
+        >>> dtype_config1.input_dtype
+        torch.quint8
+
+        >>> dtype_config2.input_dtype
+        torch.quint8
+
+        >>> dtype_config2.input_dtype_with_constraints
+        DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \
+scale_min_lower_bound=None, scale_max_upper_bound=None)
+    """
+
+    input_dtype_with_constraints: DTypeWithConstraints
+    output_dtype_with_constraints: DTypeWithConstraints
+    weight_dtype_with_constraints: DTypeWithConstraints
+    bias_dtype: Optional[torch.dtype]
+    is_dynamic: Optional[bool]
+
+    def __init__(
+        self,
+        input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
+        output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
+        weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None,
+        bias_dtype: Optional[torch.dtype] = None,
+        is_dynamic: Optional[bool] = None,
+    ):
+        if isinstance(input_dtype, DTypeWithConstraints):
+            self.input_dtype_with_constraints = input_dtype
+        else:
+            self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype)
+
+        if isinstance(output_dtype, DTypeWithConstraints):
+            self.output_dtype_with_constraints = output_dtype
+        else:
+            self.output_dtype_with_constraints = DTypeWithConstraints(
+                dtype=output_dtype
+            )
+
+        if isinstance(weight_dtype, DTypeWithConstraints):
+            self.weight_dtype_with_constraints = weight_dtype
+        else:
+            self.weight_dtype_with_constraints = DTypeWithConstraints(
+                dtype=weight_dtype
+            )
+
+        self.bias_dtype = bias_dtype
+        self.is_dynamic = is_dynamic
+
+    @property
+    def input_dtype(self) -> Optional[torch.dtype]:
+        return self.input_dtype_with_constraints.dtype
+
+    @property
+    def output_dtype(self) -> Optional[torch.dtype]:
+        return self.output_dtype_with_constraints.dtype
+
+    @property
+    def weight_dtype(self) -> Optional[torch.dtype]:
+        return self.weight_dtype_with_constraints.dtype
+
+    @classmethod
+    def from_dict(cls, dtype_config_dict: dict[str, Any]) -> DTypeConfig:
+        """
+        Create a ``DTypeConfig`` from a dictionary with the following items (all optional):
+            "input_dtype": torch.dtype or ``DTypeWithConstraints``
+            "output_dtype": torch.dtype or ``DTypeWithConstraints``
+            "weight_dtype": torch.dtype or ``DTypeWithConstraints``
+            "bias_type": torch.dtype
+            "is_dynamic": bool
+        """
+        input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
+        if input_dtype is not None and not isinstance(
+            input_dtype, (torch.dtype, DTypeWithConstraints)
+        ):
+            raise ValueError(
+                "Expected input_dtype to be a torch.dtype or DTypeWithConstraints"
+            )
+        output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
+        if output_dtype is not None and not isinstance(
+            output_dtype, (torch.dtype, DTypeWithConstraints)
+        ):
+            raise ValueError(
+                "Expected output_dtype to be a torch.dtype or DTypeWithConstraints"
+            )
+        weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
+        if weight_dtype is not None and not isinstance(
+            weight_dtype, (torch.dtype, DTypeWithConstraints)
+        ):
+            raise ValueError(
+                "Expected weight_dtype to be a torch.dtype or DTypeWithConstraints"
+            )
+        bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
+        is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
+        return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
+
+    def to_dict(self) -> dict[str, Any]:
+        """
+        Convert this ``DTypeConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`.
+        """
+        dtype_config_dict: dict[str, Any] = {}
+        if self.input_dtype is not None:
+            dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints
+        if self.output_dtype is not None:
+            dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = (
+                self.output_dtype_with_constraints
+            )
+        if self.weight_dtype is not None:
+            dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = (
+                self.weight_dtype_with_constraints
+            )
+        if self.bias_dtype is not None:
+            dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype
+        if self.is_dynamic is not None:
+            dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic
+        return dtype_config_dict
+
+
+class BackendConfig:
+    # TODO: refer to NativeBackendConfig once that is implemented
+    """Config that defines the set of patterns that can be quantized on a given backend, and how reference
+    quantized models can be produced from these patterns.
+
+    A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph
+    of the above. Each pattern supported on the target backend can be individually configured through
+    :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of:
+
+    (1) The supported input/output activation, weight, and bias data types
+
+    (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
+
+    (3) (Optionally) Fusion, QAT, and reference module mappings.
+
+    The format of the patterns is described in:
+    https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
+
+    Example usage::
+
+        import torch
+        from torch.ao.quantization.backend_config import (
+            BackendConfig,
+            BackendPatternConfig,
+            DTypeConfig,
+            ObservationType,
+        )
+
+        weighted_int8_dtype_config = DTypeConfig(
+            input_dtype=torch.quint8,
+            output_dtype=torch.quint8,
+            weight_dtype=torch.qint8,
+            bias_dtype=torch.float)
+
+        def fuse_conv2d_relu(is_qat, conv, relu):
+            return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu)
+
+        # For quantizing Linear
+        linear_config = BackendPatternConfig(torch.nn.Linear) \
+            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
+            .add_dtype_config(weighted_int8_dtype_config) \
+            .set_root_module(torch.nn.Linear) \
+            .set_qat_module(torch.ao.nn.qat.Linear) \
+            .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear)
+
+        # For fusing Conv2d + ReLU into ConvReLU2d
+        conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \
+            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
+            .add_dtype_config(weighted_int8_dtype_config) \
+            .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \
+            .set_fuser_method(fuse_conv2d_relu)
+
+        # For quantizing ConvReLU2d
+        fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \
+            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
+            .add_dtype_config(weighted_int8_dtype_config) \
+            .set_root_module(torch.nn.Conv2d) \
+            .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \
+            .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d)
+
+        backend_config = BackendConfig("my_backend") \
+            .set_backend_pattern_config(linear_config) \
+            .set_backend_pattern_config(conv_relu_config) \
+            .set_backend_pattern_config(fused_conv_relu_config)
+
+    """
+
+    def __init__(self, name: str = ""):
+        self.name = name
+        # Store all BackendPatternConfigs in a map to handle duplicates
+        # Note: the key in this map uses the complex reversed tuple format.
+        # This is intended only for internal use; users who wish to access
+        # the original patterns should go through `self.configs` instead.
+        self._pattern_complex_format_to_config: dict[Pattern, BackendPatternConfig] = {}
+
+    def __repr__(self):
+        return f"BackendConfig({self.__dict__})"
+
+    def set_name(self, name: str) -> BackendConfig:
+        """
+        Set the name of the target backend.
+        """
+        self.name = name
+        return self
+
+    def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig:
+        """
+        Set the config for an pattern that can be run on the target backend.
+        This overrides any existing config for the given pattern.
+        """
+        # Avoid circular dependencies
+        pattern_complex_format = torch.ao.quantization.backend_config.utils._get_pattern_in_reversed_nested_tuple_format(
+            config
+        )  # type: ignore[attr-defined]
+        self._pattern_complex_format_to_config[pattern_complex_format] = config
+        return self
+
+    def set_backend_pattern_configs(
+        self, configs: list[BackendPatternConfig]
+    ) -> BackendConfig:
+        """
+        Set the configs for patterns that can be run on the target backend.
+        This overrides any existing config for a given pattern if it was previously registered already.
+        """
+        for conf in configs:
+            self.set_backend_pattern_config(conf)
+        return self
+
+    @property
+    def configs(self) -> list[BackendPatternConfig]:
+        """
+        Return a copy of the list of configs set in this `BackendConfig`.
+        """
+        return list(self._pattern_complex_format_to_config.values())
+
+    @classmethod
+    def from_dict(cls, backend_config_dict: dict[str, Any]) -> BackendConfig:
+        """
+        Create a ``BackendConfig`` from a dictionary with the following items:
+
+            "name": the name of the target backend
+
+            "configs": a list of dictionaries that each represents a `BackendPatternConfig`
+
+        """
+        conf = cls(backend_config_dict.get(NAME_DICT_KEY, ""))
+        for d in backend_config_dict.get(CONFIGS_DICT_KEY, []):
+            if isinstance(d, BackendPatternConfig):
+                conf.set_backend_pattern_config(d)
+            elif isinstance(d, dict):
+                conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d))
+            else:
+                raise ValueError(
+                    f"Expected backend_config_dict['{CONFIGS_DICT_KEY}'] to be a dictionary"
+                )
+        return conf
+
+    def to_dict(self) -> dict[str, Any]:
+        """
+        Convert this ``BackendConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`.
+        """
+        return {
+            NAME_DICT_KEY: self.name,
+            CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs],
+        }
+
+
+class BackendPatternConfig:
+    """
+    Config object that specifies quantization behavior for a given operator pattern.
+    For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
+    """
+
+    def __init__(self, pattern: Optional[Pattern] = None):
+        self.pattern: Optional[Pattern] = pattern
+        self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+        self.dtype_configs: list[DTypeConfig] = []
+        self.root_module: Optional[type[torch.nn.Module]] = None
+        self.qat_module: Optional[type[torch.nn.Module]] = None
+        self.reference_quantized_module: Optional[type[torch.nn.Module]] = None
+        self.fused_module: Optional[type[torch.nn.Module]] = None
+        self.fuser_method: Optional[Callable] = None
+
+        # Temporary/internal configs
+        self._root_node_getter: Optional[Callable] = None
+        self._extra_inputs_getter: Optional[Callable] = None
+        self._num_tensor_args_to_observation_type: dict[int, ObservationType] = {}
+        self._input_type_to_index: dict[str, int] = {}
+        self._pattern_complex_format: Optional[Pattern] = None
+
+    def __repr__(self):
+        dict_nonempty = {
+            k: v
+            for k, v in self.__dict__.items()
+            if (
+                (not isinstance(v, (list, dict)) and v is not None)
+                or (isinstance(v, (list, dict)) and len(v) > 0)
+            )
+        }
+        return f"BackendPatternConfig({dict_nonempty})"
+
+    def set_pattern(self, pattern: Pattern) -> BackendPatternConfig:
+        """
+        Set the pattern to configure.
+
+        The pattern can be a float module, functional operator, pytorch operator, or a tuple
+        combination of the above. Tuple patterns are treated as sequential patterns, and
+        currently only tuples of 2 or 3 elements are supported.
+        """
+        if self._pattern_complex_format is not None:
+            raise ValueError(
+                "Only one of 'pattern' or 'pattern_complex_format' can be set"
+            )
+        self.pattern = pattern
+        return self
+
+    def set_observation_type(
+        self, observation_type: ObservationType
+    ) -> BackendPatternConfig:
+        """
+        Set how observers should be inserted in the graph for this pattern.
+
+        Observation type here refers to how observers (or quant-dequant ops) will be placed
+        in the graph. This is used to produce the desired reference patterns understood by
+        the backend. Weighted ops such as linear and conv require different observers
+        (or quantization parameters passed to quantize ops in the reference model) for the
+        input and the output.
+
+        There are two observation types:
+
+            `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance
+            will be different from the input. This is the most common observation type.
+
+            `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the
+            same as the input. This is useful for operators like `cat`.
+
+        Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs
+        with observers (and fake quantizes) attached instead of observers themselves.
+        """
+        self.observation_type = observation_type
+        return self
+
+    def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
+        """
+        Add a set of supported data types passed as arguments to quantize ops in the
+        reference model spec.
+        """
+        self.dtype_configs.append(dtype_config)
+        return self
+
+    def set_dtype_configs(
+        self, dtype_configs: list[DTypeConfig]
+    ) -> BackendPatternConfig:
+        """
+        Set the supported data types passed as arguments to quantize ops in the
+        reference model spec, overriding all previously registered data types.
+        """
+        self.dtype_configs = dtype_configs
+        return self
+
+    def set_root_module(
+        self, root_module: type[torch.nn.Module]
+    ) -> BackendPatternConfig:
+        """
+        Set the module that represents the root for this pattern.
+
+        When we construct the reference quantized model during the convert phase,
+        the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU)
+        will be swapped to the corresponding reference quantized modules (e.g.
+        torch.ao.nn.reference.quantized.Linear). This allows custom backends to
+        specify custom reference quantized module implementations to match the
+        numerics of their lowered operators. Since this is a one-to-one mapping,
+        both the root module and the reference quantized module must be specified
+        in the same BackendPatternConfig in order for the conversion to take place.
+        """
+        self.root_module = root_module
+        return self
+
+    def set_qat_module(self, qat_module: type[torch.nn.Module]) -> BackendPatternConfig:
+        """
+        Set the module that represents the QAT implementation for this pattern.
+        """
+        self.qat_module = qat_module
+        return self
+
+    def set_reference_quantized_module(
+        self, reference_quantized_module: type[torch.nn.Module]
+    ) -> BackendPatternConfig:
+        """
+        Set the module that represents the reference quantized implementation for
+        this pattern's root module.
+
+        For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`.
+        """
+        self.reference_quantized_module = reference_quantized_module
+        return self
+
+    def set_fused_module(
+        self, fused_module: type[torch.nn.Module]
+    ) -> BackendPatternConfig:
+        """
+        Set the module that represents the fused implementation for this pattern.
+        """
+        self.fused_module = fused_module
+        return self
+
+    def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig:
+        """
+        Set the function that specifies how to fuse this BackendPatternConfig's pattern.
+
+        The first argument of this function should be `is_qat`, and the rest of the arguments
+        should be the items in the tuple pattern. The return value of this function should be
+        the resulting fused module.
+
+        For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be:
+
+            def fuse_linear_relu(is_qat, linear, relu):
+                return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
+
+        For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6.
+        """
+        self.fuser_method = fuser_method
+        return self
+
+    def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig:
+        self._root_node_getter = root_node_getter
+        return self
+
+    def _set_extra_inputs_getter(
+        self, extra_inputs_getter: Callable
+    ) -> BackendPatternConfig:
+        self._extra_inputs_getter = extra_inputs_getter
+        return self
+
+    def _set_num_tensor_args_to_observation_type(
+        self, num_tensor_args_to_observation_type: dict[int, ObservationType]
+    ) -> BackendPatternConfig:
+        self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type
+        return self
+
+    def _set_input_type_to_index(
+        self, input_type_to_index: dict[str, int]
+    ) -> BackendPatternConfig:
+        self._input_type_to_index = input_type_to_index
+        return self
+
+    def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig:
+        """
+        Set the pattern to configure, using the reversed nested tuple format.
+
+        See the BackendConfig README for more detail:
+        https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification
+        """
+        if self.pattern is not None:
+            raise ValueError(
+                "Only one of 'pattern' or 'pattern_complex_format' can be set"
+            )
+        self._pattern_complex_format = pattern
+        return self
+
+    @classmethod
+    def from_dict(
+        cls, backend_pattern_config_dict: dict[str, Any]
+    ) -> BackendPatternConfig:
+        """
+        Create a ``BackendPatternConfig`` from a dictionary with the following items:
+
+            "pattern": the pattern being configured
+            "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how
+            observers should be inserted for this pattern
+            "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s
+            "root_module": a :class:`torch.nn.Module` that represents the root for this pattern
+            "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern
+            "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized
+            implementation for this pattern's root module.
+            "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
+            "fuser_method": a function that specifies how to fuse the pattern for this pattern
+            "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated)
+
+        """
+
+        def _get_dtype_config(obj: Any) -> DTypeConfig:
+            """
+            Convert the given object into a ``DTypeConfig`` if possible, else throw an exception.
+            """
+            if isinstance(obj, DTypeConfig):
+                return obj
+            if isinstance(obj, dict):
+                return DTypeConfig.from_dict(obj)
+            raise ValueError(
+                f"Expected a list of DTypeConfigs in "
+                f"backend_pattern_config_dict[\"{DTYPE_CONFIGS_DICT_KEY}\"], got '{type(obj)}'"
+            )
+
+        conf = cls()
+        if PATTERN_DICT_KEY in backend_pattern_config_dict:
+            conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY])
+        if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict:
+            conf.set_observation_type(
+                backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY]
+            )
+        for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
+            conf.add_dtype_config(_get_dtype_config(d))
+        conf.set_root_module(
+            backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None)  # type: ignore[arg-type]
+        )
+        conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None))  # type: ignore[arg-type]
+        conf.set_reference_quantized_module(
+            backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None)  # type: ignore[arg-type]
+        )
+        conf.set_fused_module(
+            backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None)  # type: ignore[arg-type]
+        )
+        conf.set_fuser_method(
+            backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None)  # type: ignore[arg-type]
+        )
+        conf._set_root_node_getter(
+            backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None)  # type: ignore[arg-type]
+        )
+        conf._set_extra_inputs_getter(
+            backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None)  # type: ignore[arg-type]
+        )
+        conf._set_num_tensor_args_to_observation_type(
+            backend_pattern_config_dict.get(
+                NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {}
+            )
+        )
+        conf._set_input_type_to_index(
+            backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {})
+        )
+        if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict:
+            conf._set_pattern_complex_format(
+                backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY]
+            )
+        return conf
+
+    def to_dict(self) -> dict[str, Any]:
+        """
+        Convert this ``BackendPatternConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
+        """
+        backend_pattern_config_dict: dict[str, Any] = {
+            OBSERVATION_TYPE_DICT_KEY: self.observation_type,
+            DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs],
+        }
+        if self.pattern is not None:
+            backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern
+        if self.root_module is not None:
+            backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module
+        if self.qat_module is not None:
+            backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module
+        if self.reference_quantized_module is not None:
+            backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = (
+                self.reference_quantized_module
+            )
+        if self.fused_module is not None:
+            backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module
+        if self.fuser_method is not None:
+            backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method
+        if self._root_node_getter is not None:
+            backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = (
+                self._root_node_getter
+            )
+        if self._extra_inputs_getter is not None:
+            backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = (
+                self._extra_inputs_getter
+            )
+        if len(self._num_tensor_args_to_observation_type) > 0:
+            backend_pattern_config_dict[
+                NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY
+            ] = self._num_tensor_args_to_observation_type
+        if len(self._input_type_to_index) > 0:
+            backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = (
+                self._input_type_to_index
+            )
+        if self._pattern_complex_format is not None:
+            backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = (
+                self._pattern_complex_format
+            )
+        return backend_pattern_config_dict
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b9b16492821b73dba1ff3ce6e2617d844d94229
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/executorch.py
@@ -0,0 +1,498 @@
+# TODO: rename executorch to qnnpack_executorch since executorch is a general runtime
+# not a specific backend
+
+import operator
+
+import torch
+import torch.ao.nn.qat as nnqat
+import torch.ao.nn.quantized.reference as nnqr
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.quantization.fuser_method_mappings import (
+    _sequential_wrapper2,
+    fuse_conv_bn,
+    fuse_conv_bn_relu,
+)
+
+from ._common_operator_config_utils import _Conv2dMetadata
+from .backend_config import (
+    BackendConfig,
+    BackendPatternConfig,
+    DTypeConfig,
+    DTypeWithConstraints,
+    ObservationType,
+)
+from .qnnpack import (
+    qnnpack_default_op_qint8_symmetric_dtype_config,
+    qnnpack_weighted_op_qint8_symmetric_dtype_config,
+)
+
+
+__all__ = [
+    "get_executorch_backend_config",
+]
+
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+executorch_weighted_op_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+executorch_default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+executorch_default_dynamic_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+executorch_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
+    dtype=torch.qint8,
+    scale_min_lower_bound=2**-12,
+)
+
+executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
+    dtype=torch.qint8,
+    quant_min_lower_bound=-127,
+    quant_max_upper_bound=127,
+    scale_min_lower_bound=2**-12,
+)
+
+executorch_default_dynamic_qint8_dtype_config = DTypeConfig(
+    input_dtype=executorch_act_qint8_scale_min_2_neg_12,
+    output_dtype=torch.float,
+    weight_dtype=executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+executorch_default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+executorch_weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+
+# =============================
+# |  BACKEND PATTERN CONFIGS  |
+# =============================
+
+
+def _get_linear_configs() -> list[BackendPatternConfig]:
+    """
+    Return all configs related to linear modules and ops.
+    """
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [
+        qnnpack_weighted_op_qint8_symmetric_dtype_config,
+        executorch_weighted_op_int8_dtype_config,
+        executorch_default_dynamic_quint8_dtype_config,
+        executorch_default_dynamic_qint8_dtype_config,
+        executorch_default_dynamic_float16_dtype_config,
+    ]
+    linear_configs: list[BackendPatternConfig] = []
+    # linear module
+    linear_configs.append(
+        BackendPatternConfig(torch.nn.Linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+        .set_qat_module(nnqat.Linear)
+    )
+    # linear qat module
+    linear_configs.append(
+        BackendPatternConfig(nnqat.Linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(torch.nn.Linear)
+        .set_reference_quantized_module(nnqr.Linear)
+    )
+    # functional linear
+    linear_configs.append(
+        BackendPatternConfig(torch.nn.functional.linear)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        ._set_input_type_to_index({"weight": 1, "bias": 2})
+    )
+    return linear_configs
+
+
+def _get_conv_configs() -> list[BackendPatternConfig]:
+    """
+    Return all configs related to conv modules and ops.
+    """
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [
+        qnnpack_weighted_op_qint8_symmetric_dtype_config,
+        executorch_weighted_op_int8_dtype_config,
+    ]
+    conv_configs = []
+    for convs in [_Conv2dMetadata]:
+        # (1) Single conv modules/functions
+        # -----------------------------------
+        # conv module
+        conv_configs.append(
+            BackendPatternConfig(convs.root)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+            .set_qat_module(convs.qat)
+        )
+        # conv qat module
+        conv_configs.append(
+            BackendPatternConfig(convs.qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # functional conv
+        conv_configs.append(
+            BackendPatternConfig(convs.func)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            ._set_input_type_to_index({"weight": 1, "bias": 2})
+        )
+
+        # (2) Conv + relu
+        # -----------------------------------
+        # conv module + relu module
+        conv_configs.append(
+            BackendPatternConfig((convs.root, nn.ReLU))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
+            .set_fused_module(convs.fused_conv_relu)
+        )
+        # conv module + functional relu
+        conv_configs.append(
+            BackendPatternConfig((convs.root, F.relu))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
+            .set_fused_module(convs.fused_conv_relu)
+        )
+        # fused conv relu module
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_relu)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+            .set_qat_module(convs.relu_qat)
+        )
+        # conv relu, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.relu_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # functional conv + relu module
+        conv_configs.append(
+            BackendPatternConfig((convs.func, nn.ReLU))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+        # functional conv + functional relu
+        conv_configs.append(
+            BackendPatternConfig((convs.func, F.relu))
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+        )
+        # fused conv relu
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_relu)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.relu_qat)
+        )
+
+        conv_configs.append(
+            BackendPatternConfig(convs.relu_qat)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+
+        # (3) Conv + batchnorm (+ relu)
+        # -------------------------------
+        # conv + batchnorm (+ relu)
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuse_conv_bn)
+            .set_fused_module(convs.fused_conv_bn)
+        )
+        # conv + bn + relu module fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_fuser_method(fuse_conv_bn_relu)
+            .set_fused_module(convs.fused_conv_bn_relu)
+        )
+        # conv + bn + relu functional fusion
+        conv_configs.append(
+            BackendPatternConfig((convs.root, convs.bn, F.relu))
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_root_module(convs.root)
+            .set_fuser_method(fuse_conv_bn_relu)
+            .set_fused_module(convs.fused_conv_bn_relu)
+        )
+        # TODO: we can add fusion for torch.relu as well
+        # 3.2 conv + bn (+ relu) fused module configs
+        # fused conv bn
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_bn)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.bn_qat)
+        )
+
+        # fused conv bn relu
+        conv_configs.append(
+            BackendPatternConfig(convs.fused_conv_bn_relu)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            .set_qat_module(convs.bn_relu_qat)
+        )
+
+        # conv bn, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.bn_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+        # conv bn relu, qat fused module
+        conv_configs.append(
+            BackendPatternConfig(convs.bn_relu_qat)
+            .set_observation_type(observation_type)  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(convs.root)
+            .set_reference_quantized_module(convs.reference)
+        )
+    return conv_configs
+
+
+def _get_binary_ops_configs() -> list[BackendPatternConfig]:
+    """
+    Return all configs related to binary ops.
+    """
+    dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        executorch_weighted_op_int8_dtype_config,
+    ]
+    num_tensor_args_to_observation_type_mapping = {
+        # TODO: this is not used right now since we have extra check in prepare
+        # will need to change this to NO_OBSERVER later after we implemented
+        # Tensor dtype inference properly
+        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
+        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
+    }
+    binary_op_configs: list[BackendPatternConfig] = []
+    for op in [
+        operator.add,
+        torch.add,
+        operator.sub,
+        torch.sub,
+        operator.mul,
+        torch.mul,
+    ]:
+        bop_patterns = [
+            (op, torch.nn.ReLU),
+            (op, torch.nn.functional.relu),
+            (op, torch.relu),
+            op,
+        ]
+        binary_op_configs.extend(
+            BackendPatternConfig(bop_pattern)
+            .set_dtype_configs(dtype_configs)  # noqa: E131
+            ._set_num_tensor_args_to_observation_type(
+                num_tensor_args_to_observation_type_mapping
+            )
+            for bop_pattern in bop_patterns
+        )
+    return binary_op_configs
+
+
+def _get_share_qparams_ops_configs() -> list[BackendPatternConfig]:
+    """
+    Return the operator configs for the operators that works for both float and quantized
+    input if input is quantized, the output Tensor shares the same quantization parameter
+    with input.
+
+    Example operator: avgpool2d, reshape, transpose, maxpool2d
+    Example observed operator:
+    observer_0 - avgpool2d - observer_0 (same observer instance as input)
+    """
+    observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
+    dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        executorch_default_op_quint8_dtype_config,
+    ]
+    share_qparams_ops = [
+        torch.nn.Flatten,
+        F.adaptive_avg_pool2d,
+        F.elu,
+        F.hardtanh,
+        F.max_pool2d,
+        F.pad,
+        F.relu,
+        F.relu6,
+        F.leaky_relu,
+        F.leaky_relu_,
+        torch.nn.AdaptiveAvgPool2d,
+        torch.nn.ConstantPad2d,
+        torch.nn.ELU,
+        torch.nn.MaxPool2d,
+        torch.nn.ReLU6,
+        torch.nn.Hardtanh,
+        torch.nn.LeakyReLU,
+        torch.clamp,
+        torch.flatten,
+        torch.mean,
+        torch.permute,
+        torch.permute_copy,
+        torch.squeeze,
+        "clamp",
+        "mean",
+        "permute",
+        "reshape",
+        "relu",
+        "relu_",
+        "squeeze",
+        "squeeze_",
+        "leaky_relu",
+    ]
+    share_qparams_op_configs: list[BackendPatternConfig] = [
+        BackendPatternConfig(op)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        for op in share_qparams_ops
+    ]
+    return share_qparams_op_configs
+
+
+def _get_bn_configs() -> list[BackendPatternConfig]:
+    """
+    Return all configs related to batchnorm.
+    """
+    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+    dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        executorch_default_op_quint8_dtype_config,
+    ]
+    bn_configs = []
+    bn_configs.append(
+        BackendPatternConfig(nn.BatchNorm2d)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+    return bn_configs
+
+
+def _get_cat_configs() -> list[BackendPatternConfig]:
+    dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        executorch_default_op_quint8_dtype_config,
+    ]
+    cat_configs = []
+    cat_configs.append(
+        BackendPatternConfig(torch.cat)
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+        .set_dtype_configs(dtype_configs)
+    )
+    cat_configs.append(
+        BackendPatternConfig(torch.concat)
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+        .set_dtype_configs(dtype_configs)
+    )
+    cat_configs.append(
+        BackendPatternConfig(torch.concatenate)
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+        .set_dtype_configs(dtype_configs)
+    )
+    return cat_configs
+
+
+def _get_embedding_op_configs() -> list[BackendPatternConfig]:
+    dtype_configs = [
+        executorch_weight_only_quint8_dtype_config,
+    ]
+    embedding_op_configs = []
+    for embedding_op, qat_embedding_op, ref_embedding_op in [
+        (nn.Embedding, nnqat.Embedding, nnqr.Embedding),
+        (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag),
+    ]:
+        embedding_op_configs.append(
+            BackendPatternConfig(embedding_op)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_qat_module(qat_embedding_op)
+            .set_root_module(embedding_op)
+            .set_reference_quantized_module(ref_embedding_op)
+        )
+        # config for qat op
+        embedding_op_configs.append(
+            BackendPatternConfig(qat_embedding_op)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            .set_root_module(embedding_op)
+            .set_reference_quantized_module(ref_embedding_op)
+        )
+
+        # config for functional embedding
+        embedding_op_configs.append(
+            BackendPatternConfig(torch.nn.functional.embedding)
+            .set_observation_type(
+                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+            )  # noqa: E131
+            .set_dtype_configs(dtype_configs)
+            ._set_input_type_to_index({"weight": 1})
+        )
+    return embedding_op_configs
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+
+def get_executorch_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack.
+    """
+    return (
+        BackendConfig("executorch")
+        .set_backend_pattern_configs(_get_linear_configs())
+        .set_backend_pattern_configs(_get_conv_configs())
+        .set_backend_pattern_configs(_get_binary_ops_configs())
+        .set_backend_pattern_configs(_get_share_qparams_ops_configs())
+        .set_backend_pattern_configs(_get_bn_configs())
+        .set_backend_pattern_configs(_get_cat_configs())
+        .set_backend_pattern_configs(_get_embedding_op_configs())
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d665f4fd030aba47c98ee692f0d9e7eca41cbc6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/fbgemm.py
@@ -0,0 +1,129 @@
+import torch
+
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+    _get_tensor_info_op_configs,
+)
+from .backend_config import BackendConfig, DTypeConfig
+
+
+__all__ = [
+    "get_fbgemm_backend_config",
+]
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+# TODO: For now, these DTypeConfigs are identical to the ones defined in native.py
+# In the future, once we support specifying quant_min/quant_max and scale_min/scale_max,
+# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized
+# values to within [0, 127].
+
+fbgemm_weighted_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+fbgemm_default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+fbgemm_default_op_fp16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float16,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float16,
+)
+
+fbgemm_default_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+fbgemm_default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+fbgemm_weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint4x2,
+)
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+
+def get_fbgemm_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch's native FBGEMM backend.
+    """
+    conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config]
+    linear_dtype_configs = [
+        fbgemm_weighted_op_quint8_dtype_config,
+        fbgemm_default_dynamic_int8_dtype_config,
+        fbgemm_default_dynamic_float16_dtype_config,
+    ]
+    binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
+    rnn_op_dtype_configs = [
+        fbgemm_default_dynamic_int8_dtype_config,
+        fbgemm_default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        fbgemm_weight_only_quint8_dtype_config,
+        fbgemm_weight_only_quint4x2_dtype_config,
+    ]
+    return (
+        BackendConfig("fbgemm")
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_embedding_op_configs(embedding_op_dtype_configs)
+        )
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py
new file mode 100644
index 0000000000000000000000000000000000000000..a98d1a9a3d41b43b1c0ce55a2471d3342af71a55
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/native.py
@@ -0,0 +1,231 @@
+# mypy: allow-untyped-defs
+import torch
+
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_ln_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+    _get_tensor_info_op_configs,
+)
+from .backend_config import BackendConfig, DTypeConfig
+
+
+__all__ = [
+    "get_test_only_legacy_native_backend_config",
+    "default_op_quint8_dtype_config",
+    "default_op_fp16_dtype_config",
+    "default_dynamic_int8_dtype_config",
+    "default_dynamic_float16_dtype_config",
+    "input_output_only_quint8_dtype_config",
+    "weight_only_quint8_dtype_config",
+    "weight_only_quint4x2_dtype_config",
+    "get_native_backend_config",
+    "get_native_backend_config_dict",
+    "get_test_only_legacy_native_backend_config_dict",
+]
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+# weighted op int8 dtype config
+# this is config for ops that has quantized weights, like linear, conv
+weighted_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+default_op_fp16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float16,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float16,
+)
+
+default_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    # currently the dtype check is not yet enabled, so we provided the dtype_configs but
+    # it is not really used yet,
+    # we will enable it a bit later after we moved everything to backend_config_dict
+    is_dynamic=True,
+)
+
+default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    # currently the dtype check is not yet enabled, so we provided the dtype_configs but
+    # it is not really used yet,
+    # we will enable it a bit later after we moved everything to backend_config_dict
+    is_dynamic=True,
+)
+
+# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights
+input_output_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.float,
+    bias_dtype=torch.float,
+)
+
+weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+weight_only_quint4x2_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint4x2,
+)
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+
+def get_test_only_legacy_native_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops.
+    """
+    conv_dtype_configs = [weighted_op_quint8_dtype_config]
+    linear_dtype_configs = [
+        weighted_op_quint8_dtype_config,
+        default_dynamic_int8_dtype_config,
+        default_dynamic_float16_dtype_config,
+        default_op_fp16_dtype_config,
+    ]
+    binary_op_dtype_configs = [
+        default_op_quint8_dtype_config,
+        default_op_fp16_dtype_config,
+    ]
+    default_op_dtype_configs = [default_op_quint8_dtype_config]
+    fixed_qparams_op_dtype_configs = [
+        default_op_quint8_dtype_config,
+        default_op_fp16_dtype_config,
+    ]
+    share_qparams_op_dtype_configs = [
+        default_op_quint8_dtype_config,
+        default_op_fp16_dtype_config,
+    ]
+    tensor_info_op_dtype_configs = [
+        default_op_quint8_dtype_config,
+    ]
+    rnn_op_dtype_configs = [
+        default_dynamic_int8_dtype_config,
+        default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        weight_only_quint8_dtype_config,
+        weight_only_quint4x2_dtype_config,
+    ]
+    layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
+    return (
+        BackendConfig("_native_and_fp16")
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs))
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_embedding_op_configs(embedding_op_dtype_configs)
+        )
+    )
+
+
+def get_native_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
+    """
+    # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs
+    conv_dtype_configs = [weighted_op_quint8_dtype_config]
+    linear_dtype_configs = [
+        weighted_op_quint8_dtype_config,
+        default_dynamic_int8_dtype_config,
+        default_dynamic_float16_dtype_config,
+    ]
+    binary_op_dtype_configs = [default_op_quint8_dtype_config]
+    default_op_dtype_configs = [default_op_quint8_dtype_config]
+    fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
+    share_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
+    tensor_info_op_dtype_configs = [default_op_quint8_dtype_config]
+    rnn_op_dtype_configs = [
+        default_dynamic_int8_dtype_config,
+        default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        weight_only_quint8_dtype_config,
+        weight_only_quint4x2_dtype_config,
+    ]
+    layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
+    return (
+        BackendConfig("native")
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs))
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_embedding_op_configs(embedding_op_dtype_configs)
+        )
+    )
+
+
+def get_native_backend_config_dict():
+    """
+    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form.
+    """
+    return get_native_backend_config().to_dict()
+
+
+def get_test_only_legacy_native_backend_config_dict():
+    """
+    Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional
+    fp16 ops in dictionary form.
+    """
+    return get_test_only_legacy_native_backend_config().to_dict()
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/observation_type.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/observation_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py
new file mode 100644
index 0000000000000000000000000000000000000000..348cec62ea18a80d7153b564ef3f0343fc4d17eb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/onednn.py
@@ -0,0 +1,640 @@
+# mypy: allow-untyped-defs
+import itertools
+import operator
+
+import torch
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.quantized.reference as nnqr
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2
+from torch.ao.quantization.utils import MatchAllNode
+
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_ln_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+)
+from .backend_config import (
+    BackendConfig,
+    BackendPatternConfig,
+    DTypeConfig,
+    ObservationType,
+)
+
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+onednn_weighted_op_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+onednn_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+onednn_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+onednn_weight_only_qint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+)
+
+onednn_input_output_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.float,
+    bias_dtype=torch.float,
+)
+
+# ===================
+# |  FUSER METHODS  |
+# ===================
+
+
+def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
+    r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module
+    Args:
+        is_qat: a flag for whether we are using quantization aware training fusion
+                or post training quantization fusion
+        linear: Module instance of type Linear
+        bn: BatchNorm1d instance that needs to be fused with the linear layer
+        leaky_relu: LeakyReLU instance that needs to be fused with the linear layer
+    Examples::
+        >>> # xdoctest: +SKIP(failing)
+        >>> m1 = nn.Linear(20, 10)
+        >>> b1 = nn.BatchNorm1d(10)
+        >>> lr = nn.LeakyReLU(0.01)
+        >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
+    """
+    assert linear.training == bn.training and bn.training == leaky_relu.training, (
+        "Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
+    )
+
+    if is_qat:
+        raise NotImplementedError(
+            f"Cannot fuse train modules: {(linear, bn, leaky_relu)}"
+        )
+    else:
+        map_to_fused_module_eval = {
+            nn.Linear: nni.LinearLeakyReLU,
+        }
+        fused_module = map_to_fused_module_eval.get(type(linear), None)
+        if fused_module is not None:
+            fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
+            fm = fused_module(fused_linear, leaky_relu)
+            return fm
+        else:
+            raise NotImplementedError(
+                f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}"
+            )
+
+
+# ======================
+# |  CONFIGS FOR CONV  |
+# ======================
+observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
+
+conv_dtype_configs = [onednn_weighted_op_int8_dtype_config]
+conv_configs = _get_conv_configs(conv_dtype_configs)
+
+# (1) Conv2d + Add
+
+# conv2d   Y
+#   \   /
+#    add
+
+# include:
+# conv2d conv2d
+#   \   /
+#    add
+
+
+def _fuse_conv_add_left(is_qat, add, conv, _):
+    return nni.ConvAdd2d(conv, add)
+
+
+def _conv_add_root_node_getter_left(pattern):
+    _, conv, _ = pattern
+    return conv
+
+
+def _conv_add_extra_inputs_getter_left(pattern):
+    """get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _, _conv, extra_input = pattern
+    return [extra_input]
+
+
+# conv2d
+#  \
+#  bn   Y
+#   \   /
+#    add
+
+
+def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _):
+    bn, conv = bn_conv
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}")
+    else:
+        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+        return nni.ConvAdd2d(fused_conv, add)
+
+
+def _conv_bn_add_root_node_getter_left(add_pattern):
+    _, bn_conv, _ = add_pattern
+    _bn, conv = bn_conv
+    return conv
+
+
+def _conv_bn_add_extra_inputs_getter_left(add_pattern):
+    """get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _, _bn_conv, extra_input = add_pattern
+    return [extra_input]
+
+
+conv_add_left_optioins = itertools.product(
+    [True, False],  # with_bn
+    [torch.add, operator.add],  # add_op
+)
+
+for with_bn, add_op in conv_add_left_optioins:
+    if with_bn:
+        conv_configs.append(
+            BackendPatternConfig()
+            ._set_pattern_complex_format(
+                (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)
+            )  # noqa: E131
+            .set_observation_type(observation_type)
+            .set_dtype_configs(conv_dtype_configs)
+            .set_fuser_method(_fuse_conv_bn_add_left)
+            ._set_root_node_getter(_conv_bn_add_root_node_getter_left)
+            ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left)
+            .set_fused_module(nni.ConvAdd2d)
+        )
+    else:
+        conv_configs.append(
+            BackendPatternConfig()
+            ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode))  # noqa: E131
+            .set_observation_type(observation_type)
+            .set_dtype_configs(conv_dtype_configs)
+            .set_fuser_method(_fuse_conv_add_left)
+            ._set_root_node_getter(_conv_add_root_node_getter_left)
+            ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left)
+            .set_fused_module(nni.ConvAdd2d)
+        )
+
+#  Y   conv2d
+#   \   /
+#    add
+
+
+def _fuse_conv_add_right(is_qat, add, _, conv):
+    return nni.ConvAdd2d(conv, add)
+
+
+def _conv_add_root_node_getter_right(pattern):
+    _add, _, conv = pattern
+    return conv
+
+
+def _conv_add_extra_inputs_getter_right(pattern):
+    """get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _, extra_input, _conv = pattern
+    return [extra_input]
+
+
+#      conv2d
+#        /
+#  Y    bn
+#   \   /
+#    add
+
+
+def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv):
+    bn, conv = bn_conv
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}")
+    else:
+        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+        return nni.ConvAdd2d(fused_conv, add)
+
+
+def _conv_bn_add_root_node_getter_right(pattern):
+    _add, _, bn_conv = pattern
+    _bn, conv = bn_conv
+    return conv
+
+
+def _conv_bn_add_extra_inputs_getter_right(pattern):
+    """get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _, extra_input, _bn_conv = pattern
+    return [extra_input]
+
+
+conv_add_optioins = itertools.product(
+    [True, False],  # with_bn
+    [torch.add, operator.add],  # add_op
+)
+
+for with_bn, add_op in conv_add_optioins:
+    if with_bn:
+        conv_configs.append(
+            BackendPatternConfig()
+            ._set_pattern_complex_format(
+                (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))
+            )  # noqa: E131
+            .set_observation_type(observation_type)
+            .set_dtype_configs(conv_dtype_configs)
+            .set_fuser_method(_fuse_conv_bn_add_right)
+            ._set_root_node_getter(_conv_bn_add_root_node_getter_right)
+            ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right)
+            .set_fused_module(nni.ConvAdd2d)
+        )
+    else:
+        conv_configs.append(
+            BackendPatternConfig()
+            ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d))  # noqa: E131
+            .set_observation_type(observation_type)
+            .set_dtype_configs(conv_dtype_configs)
+            .set_fuser_method(_fuse_conv_add_right)
+            ._set_root_node_getter(_conv_add_root_node_getter_right)
+            ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right)
+            .set_fused_module(nni.ConvAdd2d)
+        )
+
+conv_configs.append(
+    BackendPatternConfig(nni.ConvAdd2d)
+    .set_observation_type(observation_type)  # noqa: E131
+    .set_dtype_configs(conv_dtype_configs)
+    .set_root_module(nn.Conv2d)
+    .set_reference_quantized_module(nnqr.Conv2d)
+)
+
+# (2) Conv2d + Add + Relu
+
+# conv2d Y
+#   \   /
+#    add
+#     \
+#     relu
+
+
+def _fuse_conv_add_relu_left(is_qat, relu, add_pattern):
+    add, conv, _ = add_pattern
+    return nni.ConvAddReLU2d(conv, add, relu)
+
+
+def _conv_add_relu_root_node_getter_left(pattern):
+    _relu, add_pattern = pattern
+    _, conv, _ = add_pattern
+    return conv
+
+
+def _conv_add_relu_extra_inputs_getter_left(pattern):
+    """get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _relu, add_pattern = pattern
+    _, _conv, extra_input = add_pattern
+    return [extra_input]
+
+
+# conv2d
+#  \
+#  bn   Y
+#   \   /
+#    add
+#     \
+#     relu
+
+
+def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern):
+    add, bn_conv, _ = add_pattern
+    bn, conv = bn_conv
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}")
+    else:
+        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+        return nni.ConvAddReLU2d(fused_conv, add, relu)
+
+
+def _conv_bn_add_relu_root_node_getter_left(pattern):
+    _relu, add_pattern = pattern
+    _, bn_conv, _ = add_pattern
+    _bn, conv = bn_conv
+    return conv
+
+
+def _conv_bn_add_relu_extra_inputs_getter_left(pattern):
+    """get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _relu, add_pattern = pattern
+    _, _bn_conv, extra_input = add_pattern
+    return [extra_input]
+
+
+conv_add_relu_left_optioins = itertools.product(
+    [True, False],  # with_bn
+    [torch.add, operator.add],  # add_op
+)
+
+for with_bn, add_op in conv_add_relu_left_optioins:
+    if with_bn:
+        conv_configs.append(
+            BackendPatternConfig()
+            ._set_pattern_complex_format(
+                (nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))
+            )  # noqa: E131
+            .set_observation_type(observation_type)
+            .set_dtype_configs(conv_dtype_configs)
+            .set_fuser_method(_fuse_conv_bn_add_relu_left)
+            ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left)
+            ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left)
+            .set_fused_module(nni.ConvAddReLU2d)
+        )
+    else:
+        conv_configs.append(
+            BackendPatternConfig()
+            ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode)))  # noqa: E131
+            .set_observation_type(observation_type)
+            .set_dtype_configs(conv_dtype_configs)
+            .set_fuser_method(_fuse_conv_add_relu_left)
+            ._set_root_node_getter(_conv_add_relu_root_node_getter_left)
+            ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left)
+            .set_fused_module(nni.ConvAddReLU2d)
+        )
+
+#  Y   conv2d
+#   \   /
+#    add
+#     \
+#     relu
+
+
+def _fuse_conv_add_relu_right(is_qat, relu, add_pattern):
+    add, _, conv = add_pattern
+    return nni.ConvAddReLU2d(conv, add, relu)
+
+
+def _conv_add_relu_root_node_getter_right(pattern):
+    _relu, add_pattern = pattern
+    _, _extra_input, conv = add_pattern
+    return conv
+
+
+def _conv_add_relu_extra_inputs_getter_right(pattern):
+    """get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _relu, add_pattern = pattern
+    _, extra_input, _conv = add_pattern
+    return [extra_input]
+
+
+#      conv2d
+#        /
+#  Y    bn
+#   \   /
+#    add
+#     \
+#     relu
+
+
+def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern):
+    add, _, bn_conv = add_pattern
+    bn, conv = bn_conv
+    if is_qat:
+        raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}")
+    else:
+        fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
+        return nni.ConvAddReLU2d(fused_conv, add, relu)
+
+
+def _conv_bn_add_relu_root_node_getter_right(pattern):
+    _relu, add_pattern = pattern
+    _, _, bn_conv = add_pattern
+    _bn, conv = bn_conv
+    return conv
+
+
+def _conv_bn_add_relu_extra_inputs_getter_right(pattern):
+    """get inputs pattern for extra inputs, inputs for root node
+    are assumed to be copied over from root node to the fused node
+    """
+    _relu, add_pattern = pattern
+    _, extra_input, _bn_conv = add_pattern
+    return [extra_input]
+
+
+conv_add_relu_left_optioins = itertools.product(
+    [True, False],  # with_bn
+    [torch.add, operator.add],  # add_op
+)
+
+for with_bn, add_op in conv_add_relu_left_optioins:
+    if with_bn:
+        conv_configs.append(
+            BackendPatternConfig()
+            ._set_pattern_complex_format(
+                (nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))
+            )  # noqa: E131
+            .set_observation_type(observation_type)
+            .set_dtype_configs(conv_dtype_configs)
+            .set_fuser_method(_fuse_conv_bn_add_relu_right)
+            ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right)
+            ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right)
+            .set_fused_module(nni.ConvAddReLU2d)
+        )
+    else:
+        conv_configs.append(
+            BackendPatternConfig()
+            ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d)))  # noqa: E131
+            .set_observation_type(observation_type)
+            .set_dtype_configs(conv_dtype_configs)
+            .set_fuser_method(_fuse_conv_add_relu_right)
+            ._set_root_node_getter(_conv_add_relu_root_node_getter_right)
+            ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right)
+            .set_fused_module(nni.ConvAddReLU2d)
+        )
+
+conv_configs.append(
+    BackendPatternConfig(nni.ConvAddReLU2d)
+    .set_observation_type(observation_type)  # noqa: E131
+    .set_dtype_configs(conv_dtype_configs)
+    .set_root_module(nn.Conv2d)
+    .set_reference_quantized_module(nnqr.Conv2d)
+)
+
+# ========================
+# |  CONFIGS FOR LINEAR  |
+# ========================
+
+linear_dtype_configs = [
+    onednn_weighted_op_int8_dtype_config,
+    onednn_dynamic_int8_dtype_config,
+]
+linear_configs = _get_linear_configs(linear_dtype_configs)
+
+
+def _add_eltwise_fusion_configs(
+    configs,
+    root_module,
+    root_op,
+    post_module,
+    post_op,
+    dtype_configs,
+    fuser_method,
+    fused_module,
+    observation_type,
+    ref_quant_module,
+):
+    # 1 base module + op module fusion config
+    configs.append(
+        BackendPatternConfig((root_module, post_module))
+        .set_dtype_configs(dtype_configs)  # noqa: E131
+        .set_fuser_method(fuser_method)
+        .set_fused_module(fused_module)
+    )
+    # base module + functional post op
+    configs.append(
+        BackendPatternConfig((root_module, post_op))
+        .set_dtype_configs(dtype_configs)  # noqa: E131
+        .set_fuser_method(fuser_method)
+        .set_fused_module(fused_module)
+    )
+
+    # 2 fused module configs
+    configs.append(
+        BackendPatternConfig(fused_module)
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+        .set_root_module(root_module)
+        .set_reference_quantized_module(ref_quant_module)
+    )
+
+    # 3 functional base op + post op configs
+    configs.append(
+        BackendPatternConfig((root_op, post_module))
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+    configs.append(
+        BackendPatternConfig((root_op, post_op))
+        .set_observation_type(observation_type)  # noqa: E131
+        .set_dtype_configs(dtype_configs)
+    )
+
+
+# Configs for linear + leaky_relu fusion
+_add_eltwise_fusion_configs(
+    linear_configs,
+    nn.Linear,
+    F.linear,
+    nn.LeakyReLU,
+    F.leaky_relu,
+    linear_dtype_configs,
+    _sequential_wrapper2(nni.LinearLeakyReLU),
+    nni.LinearLeakyReLU,
+    observation_type,
+    nnqr.Linear,
+)
+
+# Configs for linear module + batchnorm + leaky_relu
+linear_configs.append(
+    BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU))
+    .set_dtype_configs(linear_dtype_configs)  # noqa: E131
+    .set_fuser_method(_fuse_linear_bn_leaky_relu)
+    .set_fused_module(nni.LinearLeakyReLU)
+)
+
+# Configs for linear + tanh fusion
+_add_eltwise_fusion_configs(
+    linear_configs,
+    nn.Linear,
+    F.linear,
+    nn.Tanh,
+    torch.tanh,
+    linear_dtype_configs,
+    _sequential_wrapper2(nni.LinearTanh),
+    nni.LinearTanh,
+    observation_type,
+    nnqr.Linear,
+)
+
+# ===========================
+# |  CONFIGS FOR OTHER OPS  |
+# ===========================
+
+binary_op_dtype_configs = [onednn_op_quint8_dtype_config]
+default_op_dtype_configs = [onednn_op_quint8_dtype_config]
+fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
+share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
+rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config]
+embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config]
+layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config]
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+
+def get_onednn_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch's native ONEDNN backend.
+    """
+    return (
+        BackendConfig("onednn")
+        .set_backend_pattern_configs(conv_configs)
+        .set_backend_pattern_configs(linear_configs)
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs))
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_embedding_op_configs(embedding_op_dtype_configs)
+        )
+    )
+
+
+__all__ = [
+    "get_onednn_backend_config",
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py
new file mode 100644
index 0000000000000000000000000000000000000000..841bac512a6549f39f757b9531591f1e47e72a83
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/qnnpack.py
@@ -0,0 +1,171 @@
+import torch
+
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+)
+from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints
+
+
+__all__ = [
+    "get_qnnpack_backend_config",
+]
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+qnnpack_weighted_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+qnnpack_default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+qnnpack_default_op_fp16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float16,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float16,
+)
+
+qnnpack_default_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+qnnpack_default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+qnnpack_weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint4x2,
+)
+
+# xnnpack compatible dtype configs
+
+# We restrict scale values to be 2 ** -12 to ensure the
+# requantization scale never falls below the xnnpack lower
+# threshold. Additionally, for qint8 weight, we restrict
+# the quantization values to [-127, +127], excluding -128.
+# For more detail, refer to the description of
+# `default_symmetric_qnnpack_qconfig`.
+
+# TODO: add additional restriction on qscheme to ensure it
+# is either per_tensor_symmetric or per_channel_symmetric
+
+qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
+    dtype=torch.qint8,
+    scale_min_lower_bound=2**-12,
+)
+
+qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
+    dtype=torch.qint8,
+    quant_min_lower_bound=-127,
+    quant_max_upper_bound=127,
+    scale_min_lower_bound=2**-12,
+)
+
+qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig(
+    input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
+    output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
+    weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
+    bias_dtype=torch.float,
+)
+
+qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig(
+    input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
+    output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
+)
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+
+def get_qnnpack_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch's native QNNPACK backend.
+    """
+    conv_dtype_configs = [
+        qnnpack_weighted_op_qint8_symmetric_dtype_config,
+        qnnpack_weighted_op_quint8_dtype_config,
+    ]
+    linear_dtype_configs = [
+        qnnpack_weighted_op_qint8_symmetric_dtype_config,
+        qnnpack_weighted_op_quint8_dtype_config,
+        qnnpack_default_dynamic_int8_dtype_config,
+        qnnpack_default_dynamic_float16_dtype_config,
+    ]
+    binary_op_dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        qnnpack_default_op_quint8_dtype_config,
+    ]
+    default_op_dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        qnnpack_default_op_quint8_dtype_config,
+    ]
+    fixed_qparams_op_dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        qnnpack_default_op_quint8_dtype_config,
+    ]
+    share_qparams_op_dtype_configs = [
+        qnnpack_default_op_qint8_symmetric_dtype_config,
+        qnnpack_default_op_quint8_dtype_config,
+    ]
+    rnn_op_dtype_configs = [
+        qnnpack_default_dynamic_int8_dtype_config,
+        qnnpack_default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        qnnpack_weight_only_quint8_dtype_config,
+        qnnpack_weight_only_quint4x2_dtype_config,
+    ]
+    return (
+        BackendConfig("qnnpack")
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_embedding_op_configs(embedding_op_dtype_configs)
+        )
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0490e2071f4f2df59b4bb6eb2a1d7885b4aa036
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/tensorrt.py
@@ -0,0 +1,98 @@
+# mypy: allow-untyped-defs
+import torch
+
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_conv_configs,
+    _get_linear_configs,
+    _get_share_qparams_op_configs,
+    _get_tensor_info_op_configs,
+)
+from .backend_config import (
+    BackendConfig,
+    BackendPatternConfig,
+    DTypeConfig,
+    ObservationType,
+)
+
+
+__all__ = [
+    "get_tensorrt_backend_config",
+    "get_tensorrt_backend_config_dict",
+]
+
+
+def get_tensorrt_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for the TensorRT backend.
+    NOTE: Current api will change in the future, it's just to unblock experimentation for
+    new backends, please don't use it right now.
+    TODO: add a README when it's more stable
+    """
+    # dtype configs
+    weighted_op_qint8_dtype_config = DTypeConfig(
+        input_dtype=torch.qint8,
+        output_dtype=torch.qint8,
+        weight_dtype=torch.qint8,
+        bias_dtype=torch.float,
+    )
+    non_weighted_op_qint8_dtype_config = DTypeConfig(
+        input_dtype=torch.qint8,
+        output_dtype=torch.qint8,
+    )
+
+    addmm_config = (
+        BackendPatternConfig(torch.addmm)
+        .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
+        .add_dtype_config(weighted_op_qint8_dtype_config)
+        ._set_input_type_to_index(
+            {
+                "bias": 0,
+                "input": 1,
+                "weight": 2,
+            }
+        )
+    )
+    cat_config = (
+        BackendPatternConfig(torch.cat)
+        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
+        .add_dtype_config(non_weighted_op_qint8_dtype_config)
+    )
+    conv_dtype_configs = [
+        weighted_op_qint8_dtype_config,
+    ]
+    linear_dtype_configs = [
+        weighted_op_qint8_dtype_config,
+    ]
+    binary_op_dtype_configs = [
+        weighted_op_qint8_dtype_config,
+    ]
+    share_qparams_op_dtype_configs = [
+        non_weighted_op_qint8_dtype_config,
+    ]
+    tensor_info_op_dtype_configs = [
+        non_weighted_op_qint8_dtype_config,
+    ]
+    # there might be things not supported in fx2trt, but it will error out
+    # during fx2trt conversion and can support them after that
+    return (
+        BackendConfig("tensorrt")
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
+        .set_backend_pattern_config(addmm_config)
+        .set_backend_pattern_config(cat_config)
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
+        )
+    )
+
+
+def get_tensorrt_backend_config_dict():
+    """
+    Return the `BackendConfig` for the TensorRT backend in dictionary form.
+    """
+    return get_tensorrt_backend_config().to_dict()
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..97dd6007c7fe04ae3e0959dc6cfb5da5602f1782
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/utils.py
@@ -0,0 +1,314 @@
+# mypy: allow-untyped-defs
+from typing import Any, Callable, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.quantization.fuser_method_mappings import _reverse2, _reverse3
+from torch.ao.quantization.utils import Pattern
+
+from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig
+
+
+__all__ = [
+    "get_pattern_to_dtype_configs",
+    "get_qat_module_classes",
+    "get_fused_module_classes",
+    "get_pattern_to_input_type_to_index",
+    "get_root_module_to_quantized_reference_module",
+    "get_fuser_method_mapping",
+    "get_module_to_qat_module",
+    "get_fusion_pattern_to_root_node_getter",
+    "get_fusion_pattern_to_extra_inputs_getter",
+    "remove_boolean_dispatch_from_name",
+    "pattern_to_human_readable",
+    "entry_to_pretty_str",
+]
+
+
+def get_pattern_to_dtype_configs(
+    backend_config: BackendConfig,
+) -> dict[Pattern, list[DTypeConfig]]:
+    pattern_to_dtype_configs: dict[Pattern, list[DTypeConfig]] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        pattern_to_dtype_configs[pattern] = config.dtype_configs
+    return pattern_to_dtype_configs
+
+
+def get_qat_module_classes(backend_config: BackendConfig) -> tuple[type, ...]:
+    qat_module_classes = [
+        config.qat_module
+        for config in backend_config.configs
+        if config.qat_module is not None
+    ]
+    return tuple(set(qat_module_classes))
+
+
+def get_fused_module_classes(backend_config: BackendConfig) -> tuple[type, ...]:
+    fused_module_classes = [
+        config.fused_module
+        for config in backend_config.configs
+        if config.fused_module is not None
+    ]
+    return tuple(set(fused_module_classes))
+
+
+def get_pattern_to_input_type_to_index(
+    backend_config: BackendConfig,
+) -> dict[Pattern, dict[str, int]]:
+    pattern_to_input_type_to_index: dict[Pattern, dict[str, int]] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        pattern_to_input_type_to_index[pattern] = config._input_type_to_index
+    return pattern_to_input_type_to_index
+
+
+def get_root_module_to_quantized_reference_module(
+    backend_config: BackendConfig,
+) -> dict[type[torch.nn.Module], type[torch.nn.Module]]:
+    mapping: dict[type[torch.nn.Module], type[torch.nn.Module]] = {}
+    for config in backend_config.configs:
+        if (
+            config.root_module is not None
+            and config.reference_quantized_module is not None
+        ):
+            mapping[config.root_module] = config.reference_quantized_module
+    return mapping
+
+
+def get_fuser_method_mapping(
+    backend_config: BackendConfig,
+) -> dict[Pattern, Union[nn.Sequential, Callable]]:
+    fuser_method_mapping: dict[Pattern, Union[nn.Sequential, Callable]] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config.fuser_method is not None:
+            # Note: both the fuser method and the pattern are specified in forward order in the
+            # BackendConfig, but the internal pattern matching code uses the reversed nested tuple
+            # format, so we need to convert both to the internal format
+            fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config)
+            fuser_method_mapping[pattern] = fuser_method
+    return fuser_method_mapping
+
+
+def get_module_to_qat_module(
+    backend_config: BackendConfig,
+) -> dict[Pattern, type[torch.nn.Module]]:
+    module_to_qat_module: dict[Pattern, type[torch.nn.Module]] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config.qat_module is not None:
+            module_to_qat_module[pattern] = config.qat_module
+    return module_to_qat_module
+
+
+def get_fusion_pattern_to_root_node_getter(
+    backend_config: BackendConfig,
+) -> dict[Pattern, Callable]:
+    """Get a map from fusion pattern to a function that returns the root node
+    from the fusion pattern, e.g. the most common one is:
+    def get_root_node(node_pattern):
+        while not isinstance(node_pattern[-1], Node):
+            node_pattern = node_pattern[-1]
+        return node_pattern[-1]
+    This can work for all patterns whose root node is the "last node" in the pattern,
+    e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d))
+    """
+    root_node_getter_mapping: dict[Pattern, Callable] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config._root_node_getter is not None:
+            root_node_getter_mapping[pattern] = config._root_node_getter
+    return root_node_getter_mapping
+
+
+def get_fusion_pattern_to_extra_inputs_getter(
+    backend_config: BackendConfig,
+) -> dict[Pattern, Callable]:
+    """Get a map from fusion pattern to a function that returns extra input nodes
+    from the fusion pattern, in the order required by the root node. This is optional,
+    if not specified, we will not copy over any extra inputs for the root node.
+    Example:
+    # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d))
+    # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra
+    # argument to the fused module, we can unpack the pattern and return the node at
+    # MatchAllNode here
+    # we can implement extra_inputs_getter as follows:
+    def extra_inputs_getter(pattern) -> List[Any]:
+        add, extra_input, conv_pattern = pattern
+        return [extra_input]
+    """
+    extra_inputs_getter_mapping: dict[Pattern, Callable] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config._extra_inputs_getter is not None:
+            extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter
+    return extra_inputs_getter_mapping
+
+
+def remove_boolean_dispatch_from_name(p) -> Any:
+    """
+    Some ops have a default string representation such as
+    '.fn at 0x7ff1106bf280>',
+    this function replaces them with the hardcoded function names.
+    """
+    if p is F.fractional_max_pool2d:
+        return "torch.nn.functional.fractional_max_pool2d"
+    elif p is F.fractional_max_pool3d:
+        return "torch.nn.functional.fractional_max_pool3d"
+    elif p is F.max_pool1d:
+        return "torch.nn.functional.max_pool1d"
+    elif p is F.max_pool2d:
+        return "torch.nn.functional.max_pool2d"
+    elif p is F.max_pool3d:
+        return "torch.nn.functional.max_pool3d"
+    elif p is F.adaptive_max_pool1d:
+        return "torch.nn.functional.adaptive_max_pool1d"
+    elif p is F.adaptive_max_pool2d:
+        return "torch.nn.functional.adaptive_max_pool2d"
+    elif p is F.adaptive_max_pool3d:
+        return "torch.nn.functional.adaptive_max_pool3d"
+    assert "boolean_dispatch" not in str(p), (
+        f"{p} does not have a human readable representation in "
+        + "quantization documentation"
+    )
+    return p
+
+
+def pattern_to_human_readable(p) -> Any:
+    if isinstance(p, tuple):
+        # nested patterns, recurse
+        return tuple(pattern_to_human_readable(inner_p) for inner_p in p)
+    elif isinstance(p, str):
+        # method names are already human readable
+        return p
+    else:
+        p = remove_boolean_dispatch_from_name(p)
+        return p
+
+
+# TODO(future PR): move backend_config_dict to use dataclass and move this logic to
+# the corresponding __str__ function
+def entry_to_pretty_str(entry) -> str:
+    """
+    Given a backend_config_dict entry, returns a string with the human readable
+    representation of it.
+    """
+    s = "{\n"
+
+    # always output the pattern first
+    if "pattern" in entry:
+        pattern_str = pattern_to_human_readable(entry["pattern"])
+
+        s += f"  'pattern': {pattern_str},\n"
+
+    # custom output for dtype_configs to make it look nice
+    if "dtype_configs" in entry:
+        s += "  'dtype_configs': [\n"
+        for dtype_config in entry["dtype_configs"]:
+            s += "    {\n"
+            for k, v in dtype_config.items():
+                s += f"      '{k}': {v},\n"
+            s += "    },\n"
+        s += "  ],\n"
+
+    # custom output for num_tensor_args_to_observation_type to make it look nice
+    if "num_tensor_args_to_observation_type" in entry:
+        s += "  'num_tensor_args_to_observation_type': {\n"
+        for k, v in entry["num_tensor_args_to_observation_type"].items():
+            s += f"    {k}: {v},\n"
+        s += "  },\n"
+
+    # output all the other fields
+    custom_handled_fields = [
+        "pattern",
+        "dtype_configs",
+        "num_tensor_args_to_observation_type",
+    ]
+    for field_name in entry:
+        if field_name in custom_handled_fields:
+            continue
+        s += f"  '{field_name}': {entry[field_name]},\n"
+
+    s += "}"
+    return s
+
+
+def _get_pattern_in_reversed_nested_tuple_format(
+    config: BackendPatternConfig,
+) -> Pattern:
+    """
+    Return the pattern specified in the given config in the reversed nested tuple format
+    used internally in the quantization pattern matching code.
+
+    If the pattern is not a tuple, or the pattern is already specified in the reversed
+    nested tuple format, return the pattern as is. Otherwise:
+
+    For 2-tuples (a, b), return (b, a).
+    For 3-tuples (a, b, c), return (c, (b, a)).
+
+    For example:
+        * Given nn.Linear, return nn.Linear
+        * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear)
+        * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return
+          (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
+
+    For context, the reason why this is needed is the user-facing BackendConfig
+    API accepts the flat 2-or-3-tuple format in forward order. While this simple
+    format handles the vast majority of use cases, it does not handle the more
+    complex ones, and so the internal pattern matching code for quantization uses
+    the following, more general reversed nested tuple format instead:
+
+        operator = module_type | functional | torch op | native op | MatchAllNode
+        Pattern = (operator, Pattern, Pattern, ...) | operator
+
+    In the future, we expect to replace the above complex format with the one used
+    by the subgraph rewriter in torch.fx, so we don't have to maintain our own
+    complex pattern matching code. Then we won't need this helper function anymore.
+    """
+    if config._pattern_complex_format is not None:
+        return config._pattern_complex_format
+    if config.pattern is None:
+        raise ValueError(
+            "Either 'pattern' or 'pattern_complex_format' must be specified"
+        )
+    if not isinstance(config.pattern, tuple):
+        return config.pattern
+
+    # Pattern is specified in the simple tuple format, need to convert
+    if len(config.pattern) == 2:
+        (a, b) = config.pattern
+        return (b, a)
+    elif len(config.pattern) == 3:
+        (a, b, c) = config.pattern
+        return (c, (b, a))
+    else:
+        raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
+
+
+def _get_fuser_method_in_reversed_nested_tuple_format(
+    config: BackendPatternConfig,
+) -> Callable:
+    """
+    Return the fuser method specified in the given config in the reversed nested
+    tuple format used internally in the quantization pattern matching code.
+
+    If pattern is specified in the reversed nested tuple format, we assume the
+    fuser method is also specified in this format and simply return it as is.
+    Otherwise, we convert the fuser method as follows:
+
+        * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv)
+        * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv),
+          where bn_conv is a 2-tuple (bn, conv)
+
+    The first argument of a fuser method is always `is_qat` and is not affected
+    in the conversion. We currently only support functions with 3 or 4 arguments.
+    """
+    assert config.fuser_method is not None
+    if config._pattern_complex_format is not None:
+        return config.fuser_method
+    if not isinstance(config.pattern, tuple):
+        raise ValueError("Expected pattern to be a tuple, got: ", config.pattern)
+
+    # Pattern is specified in the simple tuple format, need to convert
+    if len(config.pattern) == 2:
+        return _reverse2(config.fuser_method)
+    elif len(config.pattern) == 3:
+        return _reverse3(config.fuser_method)
+    else:
+        raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py
new file mode 100644
index 0000000000000000000000000000000000000000..c64b56c981b391140f63038ac507b0708ee876f4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/backend_config/x86.py
@@ -0,0 +1,126 @@
+import torch
+
+from ._common_operator_config_utils import (
+    _get_binary_op_configs,
+    _get_bn_configs,
+    _get_cat_config,
+    _get_conv_configs,
+    _get_default_op_configs,
+    _get_embedding_op_configs,
+    _get_fixed_qparams_op_configs,
+    _get_linear_configs,
+    _get_rnn_op_configs,
+    _get_share_qparams_op_configs,
+    _get_tensor_info_op_configs,
+)
+from .backend_config import BackendConfig, DTypeConfig
+
+
+__all__ = [
+    "get_x86_backend_config",
+]
+
+# ===================
+# |  DTYPE CONFIGS  |
+# ===================
+
+# X86 aligns with FBGEMM for now
+
+x86_weighted_op_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+)
+
+x86_default_op_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.quint8,
+)
+
+x86_default_op_fp16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float16,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float16,
+)
+
+x86_default_dynamic_int8_dtype_config = DTypeConfig(
+    input_dtype=torch.quint8,
+    output_dtype=torch.float,
+    weight_dtype=torch.qint8,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+x86_default_dynamic_float16_dtype_config = DTypeConfig(
+    input_dtype=torch.float16,
+    output_dtype=torch.float,
+    weight_dtype=torch.float16,
+    bias_dtype=torch.float,
+    is_dynamic=True,
+)
+
+x86_weight_only_quint8_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint8,
+)
+
+x86_weight_only_quint4x2_dtype_config = DTypeConfig(
+    input_dtype=torch.float,
+    output_dtype=torch.float,
+    weight_dtype=torch.quint4x2,
+)
+
+
+# =====================
+# |  BACKEND CONFIGS  |
+# =====================
+
+
+def get_x86_backend_config() -> BackendConfig:
+    """
+    Return the `BackendConfig` for PyTorch's native x86 backend.
+    """
+    conv_dtype_configs = [x86_weighted_op_int8_dtype_config]
+    linear_dtype_configs = [
+        x86_weighted_op_int8_dtype_config,
+        x86_default_dynamic_int8_dtype_config,
+        x86_default_dynamic_float16_dtype_config,
+    ]
+    binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
+    default_op_dtype_configs = [x86_default_op_quint8_dtype_config]
+    fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
+    share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config]
+    tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config]
+    rnn_op_dtype_configs = [
+        x86_default_dynamic_int8_dtype_config,
+        x86_default_dynamic_float16_dtype_config,
+    ]
+    embedding_op_dtype_configs = [
+        x86_weight_only_quint8_dtype_config,
+        x86_weight_only_quint4x2_dtype_config,
+    ]
+    return (
+        BackendConfig("x86")
+        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
+        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
+        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
+        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(
+            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
+        )
+        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
+        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
+        .set_backend_pattern_configs(
+            _get_embedding_op_configs(embedding_op_dtype_configs)
+        )
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72d624ad7d6a3926c5d34afab3b7066928f9933d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/__init__.py
@@ -0,0 +1,3 @@
+from .convert import convert
+from .fuse import fuse
+from .prepare import prepare
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c4517b93c7fa34883d013f0f9d2afb0def16603
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_decomposed.py
@@ -0,0 +1,1221 @@
+# mypy: allow-untyped-defs
+import math
+from typing import Optional
+
+import torch
+from torch._refs import _unsqueeze_multiple
+from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
+from torch.library import impl, Library
+
+
+# Note: decomposed means decomposed quantized tensor, using decomposed so that the
+# name is not too long
+quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
+
+_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.uint16, torch.int16, torch.int32]
+_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn]
+
+_DTYPE_TO_QVALUE_BOUNDS = {
+    k: (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES
+}
+_DTYPE_TO_QVALUE_BOUNDS.update(
+    {k: (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES}
+)
+
+
+# Helper to check the passed in quant min and max are valid for the dtype
+def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
+    if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
+        raise ValueError(f"Unsupported dtype: {dtype}")
+    quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
+
+    assert quant_min >= quant_min_lower_bound, (
+        "quant_min out of bound for dtype, "
+        f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
+    )
+
+    assert quant_max <= quant_max_upper_bound, (
+        "quant_max out of bound for dtype, "
+        f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
+    )
+
+
+quantized_decomposed_lib.define(
+    "quantize_per_tensor(Tensor input, float scale, int zero_point, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
+)
+
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd")
+def quantize_per_tensor(
+    input: torch.Tensor,
+    scale: float,
+    zero_point: int,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    """Affine quantization for the Tensor using the same quantization parameters to map
+    from floating point to quantized values
+
+    Args:
+       input (torch.Tensor): original float32 or bfloat16 Tensor
+       scale (float): quantization parameter for affine quantization
+       zero_point (int): quantization parameter for affine quantization
+       quant_min (int): minimum quantized value for output Tensor
+       quant_max (int): maximum quantized value for output Tensor
+       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+
+    Returns:
+       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
+       are not stored in the Tensor, we are storing them in function arguments instead
+    """
+    if input.dtype in [torch.float16, torch.bfloat16]:
+        input = input.to(torch.float32)
+    assert input.dtype == torch.float32, (
+        f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    )
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+
+    inv_scale = 1.0 / scale
+    return torch.clamp(
+        torch.round(input * inv_scale) + zero_point, quant_min, quant_max
+    ).to(dtype)
+
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta")
+def quantize_per_tensor_meta(
+    input: torch.Tensor,
+    scale: float,
+    zero_point: int,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    if input.dtype in [torch.float16, torch.bfloat16]:
+        input = input.to(torch.float32)
+    assert input.dtype == torch.float32, (
+        f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    )
+    return torch.empty_like(input, dtype=dtype)
+
+
+quantized_decomposed_lib.define(
+    "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
+)
+
+
+@impl(
+    quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd"
+)
+def quantize_per_tensor_tensor(
+    input: torch.Tensor,
+    scale: torch.Tensor,
+    zero_point: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    """Affine quantization for the Tensor using the same quantization parameters to map
+    from floating point to quantized values
+    Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
+    scalar values
+    """
+    assert zero_point.numel() == 1, (
+        f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    )
+    assert scale.numel() == 1, (
+        f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    )
+    return quantize_per_tensor(
+        input,
+        scale.item(),
+        zero_point.item(),  # type: ignore[arg-type]
+        quant_min,  # type: ignore[arg-type]
+        quant_max,  # type: ignore[arg-type]
+        dtype,
+    )
+
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
+def quantize_per_tensor_tensor_meta(
+    input: torch.Tensor,
+    scale: torch.Tensor,
+    zero_point: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    if input.dtype in [torch.float16, torch.bfloat16]:
+        input = input.to(torch.float32)
+    assert zero_point.numel() == 1, (
+        f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    )
+    assert scale.numel() == 1, (
+        f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    )
+    assert input.dtype == torch.float32, (
+        f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    )
+    return torch.empty_like(input, dtype=dtype)
+
+
+# TODO: remove other variants and keep this one
+quantized_decomposed_lib.define(
+    "quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
+    "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor"
+)
+
+
+@impl(
+    quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd"
+)
+def quantize_per_tensor_tensor2(
+    input: torch.Tensor,
+    scale: torch.Tensor,
+    zero_point: torch.Tensor,
+    quant_min: torch.Tensor,
+    quant_max: torch.Tensor,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    """Affine quantization for the Tensor using the same quantization parameters to map
+    from floating point to quantized values
+    Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
+    scalar values
+    """
+    assert zero_point.numel() == 1, (
+        f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    )
+    assert scale.numel() == 1, (
+        f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    )
+    return quantize_per_tensor(
+        input,
+        scale.item(),
+        zero_point.item(),  # type: ignore[arg-type]
+        quant_min.item(),  # type: ignore[arg-type]
+        quant_max.item(),  # type: ignore[arg-type]
+        dtype,
+    )
+
+
+@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta")
+def quantize_per_tensor_tensor2_meta(
+    input: torch.Tensor,
+    scale: torch.Tensor,
+    zero_point: torch.Tensor,
+    quant_min: torch.Tensor,
+    quant_max: torch.Tensor,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    return quantize_per_tensor_tensor_meta(
+        input,
+        scale,
+        zero_point,  # type: ignore[arg-type]
+        quant_min,  # type: ignore[arg-type]
+        quant_max,  # type: ignore[arg-type]
+        dtype,
+    )
+
+
+# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
+# the signature as metadata for the input Tensor, this might be useful for pattern
+# matching in the future
+# We will revisit this later if we found there are no use cases for it
+quantized_decomposed_lib.define(
+    "dequantize_per_tensor(Tensor input, float scale, int zero_point, "
+    "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
+)
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
+def dequantize_per_tensor(
+    input: torch.Tensor,
+    scale: float,
+    zero_point: int,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    *,
+    out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    """Affine dequantization for the Tensor using the same quantization parameters to map
+    from quantized values to floating point values
+
+    Args:
+       input (torch.Tensor): Tensor with dtype matching `dtype` argument,
+       e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with
+       quantization parameters in the argument of this function (scale/zero_point)
+
+       scale (float): quantization parameter for affine quantization
+
+       zero_point (int): quantization parameter for affine quantization
+
+       quant_min (int): minimum quantized value for input Tensor (not used in computation,
+       reserved for pattern matching)
+
+       quant_max (int): maximum quantized value for input Tensor (not used in computation,
+       reserved for pattern matching)
+
+       dtype (torch.dtype): dtype for input Tensor (not used in computation,
+       reserved for pattern matching)
+
+       out_dtype (torch.dtype?): optional dtype for output Tensor
+
+    Returns:
+       dequantized float32 Tensor
+    """
+    assert input.dtype == dtype, (
+        f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
+    )
+    if out_dtype is None:
+        out_dtype = torch.float32
+    if dtype in _DTYPE_TO_QVALUE_BOUNDS:
+        # TODO: investigate why
+        # (input - zero_point).to(torch.float32) * scale
+        # failed the test
+        return (input.to(out_dtype) - zero_point) * scale
+    else:
+        raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta")
+def dequantize_per_tensor_meta(
+    input: torch.Tensor,
+    scale: torch.Tensor,
+    zero_point: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    *,
+    out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    if out_dtype is None:
+        out_dtype = torch.float32
+    return torch.empty_like(input, dtype=out_dtype)
+
+
+quantized_decomposed_lib.define(
+    "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
+    "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "dequantize_per_tensor.tensor",
+    "CompositeExplicitAutograd",
+)
+def dequantize_per_tensor_tensor(
+    input: torch.Tensor,
+    scale: torch.Tensor,
+    zero_point: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    *,
+    out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    """Affine dequantization for the Tensor using the same quantization parameters to map
+    from quantized values to floating point values
+    Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
+    scalar values
+    """
+    assert zero_point.numel() == 1, (
+        f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    )
+    assert scale.numel() == 1, (
+        f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    )
+    return dequantize_per_tensor(
+        input,
+        scale.item(),
+        zero_point.item(),  # type: ignore[arg-type]
+        quant_min,
+        quant_max,
+        dtype,
+        out_dtype=out_dtype,
+    )
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
+def dequantize_per_tensor_tensor_meta(
+    input: torch.Tensor,
+    scale: torch.Tensor,
+    zero_point: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    *,
+    out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    if out_dtype is None:
+        out_dtype = torch.float32
+    assert zero_point.numel() == 1, (
+        f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    )
+    assert scale.numel() == 1, (
+        f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    )
+    assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
+    if dtype in _DTYPE_TO_QVALUE_BOUNDS:
+        return torch.empty_like(input, dtype=out_dtype)
+    else:
+        raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
+
+
+# TODO: remove other variants and keep this one
+quantized_decomposed_lib.define(
+    "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
+    "Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "dequantize_per_tensor.tensor2",
+    "CompositeExplicitAutograd",
+)
+def dequantize_per_tensor_tensor2(
+    input: torch.Tensor,
+    scale: torch.Tensor,
+    zero_point: torch.Tensor,
+    quant_min: torch.Tensor,
+    quant_max: torch.Tensor,
+    dtype: torch.dtype,
+    *,
+    out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    """Affine dequantization for the Tensor using the same quantization parameters to map
+    from quantized values to floating point values
+    Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
+    scalar values
+    """
+    assert zero_point.numel() == 1, (
+        f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
+    )
+    assert scale.numel() == 1, (
+        f"Expecting scale tensor to be one element, but received : {scale.numel()}"
+    )
+    return dequantize_per_tensor(
+        input,
+        scale.item(),
+        zero_point.item(),  # type: ignore[arg-type]
+        quant_min.item(),  # type: ignore[arg-type]
+        quant_max.item(),  # type: ignore[arg-type]
+        dtype,
+        out_dtype=out_dtype,
+    )
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta")
+def dequantize_per_tensor_tensor2_meta(
+    input,
+    scale,
+    zero_point,
+    quant_min,
+    quant_max,
+    dtype,
+    *,
+    out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    return dequantize_per_tensor_tensor_meta(
+        input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype
+    )
+
+
+quantized_decomposed_lib.define(
+    "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
+    "float eps, ScalarType dtype) -> (Tensor, Tensor)"
+)
+
+
+@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
+def choose_qparams_tensor(
+    input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Given an input Tensor, derive the per tensor affine quantization parameter
+    (scale and zero_point) for target quantized Tensor from the Tensor
+
+    Args:
+       input (torch.Tensor): floating point input Tensor
+       quant_min (int): minimum quantized value for target quantized Tensor
+       quant_max (int): maximum quantized value for target quantized Tensor
+       dtype (torch.dtype): dtype for target quantized Tensor
+
+    Returns:
+       scale (float): quantization parameter for the target quantized Tensor
+       zero_point (int): quantization parameter for the target quantized Tensor
+    """
+    assert input.dtype in [
+        torch.float32,
+        torch.float16,
+        torch.bfloat16,
+    ], (
+        f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
+    )
+    assert dtype in _DTYPE_TO_QVALUE_BOUNDS, (
+        f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
+    )
+    validate_qmin_qmax(qmin, qmax)
+
+    min_val, max_val = torch.aminmax(input)
+
+    return determine_qparams(
+        min_val,
+        max_val,
+        qmin,
+        qmax,
+        dtype,
+        torch.Tensor([eps]),
+        has_customized_qrange=False,
+    )
+
+
+quantized_decomposed_lib.define(
+    "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, "
+    "float eps, ScalarType dtype) -> (Tensor, Tensor)"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_symmetric.tensor",
+    "CompositeExplicitAutograd",
+)
+def choose_qparams_symmetric_tensor(
+    input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Given an input Tensor, derive the per tensor affine quantization parameter
+    (scale and zero_point) for target quantized Tensor from the Tensor
+
+    Args:
+       input (torch.Tensor): floating point input Tensor
+       quant_min (int): minimum quantized value for target quantized Tensor
+       quant_max (int): maximum quantized value for target quantized Tensor
+       dtype (torch.dtype): dtype for target quantized Tensor
+
+    Returns:
+       scale (float): quantization parameter for the target quantized Tensor
+       zero_point (int): quantization parameter for the target quantized Tensor
+    """
+    assert input.dtype in [
+        torch.float32,
+        torch.float16,
+        torch.bfloat16,
+    ], (
+        f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
+    )
+    assert dtype in _DTYPE_TO_QVALUE_BOUNDS, (
+        f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
+    )
+    validate_qmin_qmax(qmin, qmax)
+
+    min_val, max_val = torch.aminmax(input)
+    return determine_qparams(
+        min_val,
+        max_val,
+        qmin,
+        qmax,
+        dtype,
+        torch.Tensor([eps]),
+        has_customized_qrange=False,
+        qscheme=torch.per_tensor_symmetric,
+    )
+
+
+@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
+def choose_qparams_tensor_meta(
+    input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
+) -> tuple[torch.Tensor, torch.Tensor]:
+    assert input.dtype in [
+        torch.float32,
+        torch.float16,
+        torch.bfloat16,
+    ], (
+        f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
+    )
+    assert quant_min < quant_max, (
+        f"Expecting quant_min to be smaller than quant_max but received min: \
+        {quant_min} max: {quant_max}"
+    )
+    return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
+        1, dtype=torch.int64, device=input.device
+    )
+
+
+@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta")
+def choose_qparams_symmetric_tensor_meta(
+    input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
+) -> tuple[torch.Tensor, torch.Tensor]:
+    return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
+        1, dtype=torch.int64, device=input.device
+    )
+
+
+# Helper function used to implement per-channel quantization against any axis
+def _permute_to_axis_zero(x, axis):
+    new_axis_list = list(range(x.dim()))
+    new_axis_list[axis] = 0
+    new_axis_list[0] = axis
+    y = x.permute(tuple(new_axis_list))
+    return y, new_axis_list
+
+
+quantized_decomposed_lib.define(
+    "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
+)
+
+
+@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd")
+def quantize_per_channel(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    axis: int,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    """Affine per channel quantization for the Tensor using the same quantization
+    parameters for each channel/axis to map from floating point to quantized values
+
+    Args:
+       input (torch.Tensor): original float32 or bfloat16 Tensor
+       scales (torch.Tensor): a list of scale quantization parameter for
+       affine quantization, one per channel
+       zero_point (torch.Tensor): a list of zero_point quantization parameter for
+       affine quantization, one per channel
+       quant_min (int): minimum quantized value for output Tensor
+       quant_max (int): maximum quantized value for output Tensor
+       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+
+    Returns:
+       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
+       are not stored in the Tensor, we are storing them in function arguments instead
+    """
+    if input.dtype in [torch.float16, torch.bfloat16]:
+        input = input.to(torch.float32)
+    assert input.dtype == torch.float32, (
+        f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    )
+    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    input, permute_axis_list = _permute_to_axis_zero(input, axis)
+
+    new_shape = [1] * input.dim()
+    new_shape[0] = scales.shape[0]
+    scales = scales.view(new_shape)
+    zero_points = zero_points.view(new_shape)
+
+    res = torch.clamp(
+        torch.round(input * (1.0 / scales)) + zero_points, quant_min, quant_max
+    )
+    out = res.permute(tuple(permute_axis_list))
+    return out.to(dtype)
+
+
+@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta")
+def quantize_per_channel_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    axis: int,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    if input.dtype in [torch.float16, torch.bfloat16]:
+        input = input.to(torch.float32)
+    assert input.dtype == torch.float32, (
+        f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+    )
+    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    return torch.empty_like(input, dtype=dtype)
+
+
+# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
+# the signature as metadata for the input Tensor, this might be useful for pattern
+# matching in the future
+# We will revisit this later if we found there are no use cases for it
+quantized_decomposed_lib.define(
+    "dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, "
+    "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
+)
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
+def dequantize_per_channel(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: Optional[torch.Tensor],
+    axis: int,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    *,
+    out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    """Affine per channel dequantization for the Tensor using the same quantization
+    parameters for each channel/axis to map from quantized values to floating point values
+
+    Args:
+       input (torch.Tensor): Tensor with dtype matching `dtype` argument,
+       e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with
+       quantization parameter in the argument of this function (scales/zero_points/axis)
+
+       scales (torch.Tensor): a list of scale quantization parameter for
+       affine quantization, one per channel
+
+       zero_points (torch.Tensor): a list of zero_point quantization parameter for
+       affine quantization, one per channel
+
+       quant_min (int): minimum quantized value for output Tensor (not used in computation,
+       reserved for pattern matching)
+
+       quant_max (int): maximum quantized value for output Tensor (not used in computation,
+       reserved for pattern matching)
+
+       dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
+       reserved for pattern matching)
+
+       out_dtype (torch.dtype?): optional dtype for output Tensor
+
+    Returns:
+       dequantized float32 Tensor
+    """
+    assert input.dtype == dtype, (
+        f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
+    )
+    if out_dtype is None:
+        out_dtype = torch.float32
+    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    input, permute_axis_list = _permute_to_axis_zero(input, axis)
+
+    new_shape = [1] * input.dim()
+    new_shape[0] = scales.shape[0]
+    scales = scales.view(new_shape)
+    if zero_points is not None:
+        res = (input - zero_points.view(new_shape)) * scales
+    else:
+        res = input * scales
+
+    res = res.to(out_dtype)
+
+    out = res.permute(tuple(permute_axis_list))
+    return out
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta")
+def dequantize_per_channel_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: Optional[torch.Tensor],
+    axis: int,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    *,
+    out_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+    assert input.dtype == dtype, (
+        f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
+    )
+    if out_dtype is None:
+        out_dtype = torch.float32
+    assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    return torch.empty_like(input, dtype=out_dtype)
+
+
+quantized_decomposed_lib.define(
+    "choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_per_token",
+    "CompositeExplicitAutograd",
+)
+def choose_qparams_per_token(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Choose quantization parameters for per token quantization. This means for a N dimension Tensor
+    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
+    every N elements with the same quantization parameter. The dimension for scales/zero_points
+    will be (M1 * M2 ... * Mn)
+
+    Args:
+       input (torch.Tensor): original float32/float16 Tensor
+       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
+
+    Returns:
+        scales and zero_points, both float32 Tensors
+    """
+
+    scales = input.abs().amax(dim=-1, keepdim=True)
+    if scales.dtype == torch.float16:
+        scales = (
+            scales.float()
+        )  # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
+    if dtype == torch.int8:
+        n_bits = 8
+        quant_max = 2 ** (n_bits - 1) - 1
+    else:
+        raise Exception(  # noqa: TRY002
+            f"unsupported dtype in choose_qparams_per_token: {dtype}"
+        )
+
+    scales = scales.clamp(min=1e-5).div(quant_max)
+    zero_points = torch.zeros_like(scales)
+    return scales, zero_points
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_per_token",
+    "Meta",
+)
+def choose_qparams_per_token_meta(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    size = list(input.shape[:-1]) + [1]
+    return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
+        size, dtype=torch.int64, device=input.device
+    )
+
+
+quantized_decomposed_lib.define(
+    "_choose_qparams_per_token_asymmetric_impl(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "_choose_qparams_per_token_asymmetric_impl",
+    "CompositeImplicitAutograd",
+)
+def _choose_qparams_per_token_asymmetric_impl(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """Choose quantization parameters for per token quantization. This means for a N dimension Tensor
+    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
+    every N elements with the same quantization parameter. The dimension for scales/zero_points
+    will be (M1 * M2 ... * Mn)
+
+    Args:
+       input (torch.Tensor): original float32/float16 Tensor
+       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
+
+    Returns:
+        scales and zero_points, both float32 Tensors
+    """
+    # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
+    qmin, qmax = -128, 127
+    min_val = torch.amin(input, dim=-1, keepdim=True)
+    max_val = torch.amax(input, dim=-1, keepdim=True)
+    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+    eps = torch.finfo(torch.float32).eps  # use xnnpack eps?
+
+    # scale
+    scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
+    scale = scale.clamp(min=eps)
+
+    # zero point
+    descaled_min = min_val_neg / scale
+    descaled_max = max_val_pos / scale
+    zero_point_from_min_error = qmin + descaled_min
+    zero_point_from_max_error = qmax + descaled_max
+    zero_point = torch.where(
+        zero_point_from_min_error + zero_point_from_max_error > 0,
+        qmin - descaled_min,
+        qmax - descaled_max,
+    )
+    zero_point = torch.clamp(zero_point, qmin, qmax).round()
+
+    return scale.to(torch.float64), zero_point.to(torch.int64)
+
+
+quantized_decomposed_lib.define(
+    "choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_per_token_asymmetric",
+    "CompositeExplicitAutograd",
+)
+def choose_qparams_per_token_asymmetric(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    return _choose_qparams_per_token_asymmetric_impl(input, dtype)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "choose_qparams_per_token_asymmetric",
+    "Meta",
+)
+def choose_qparams_per_token_asymmetric_meta(
+    input: torch.Tensor,
+    dtype: torch.dtype,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    size = list(input.shape[:-1]) + [1]
+    return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
+        size, dtype=torch.int64, device=input.device
+    )
+
+
+def _per_token_quant_qparam_dim_check(input, scales, zero_points):
+    num_tokens = math.prod(list(input.size())[:-1])
+    assert num_tokens == scales.numel(), (
+        f"num_tokens: {num_tokens} scales: {scales.size()}"
+    )
+    assert num_tokens == zero_points.numel(), (
+        f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
+    )
+
+
+quantized_decomposed_lib.define(
+    "quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
+    "int quant_min, int quant_max, ScalarType dtype) -> Tensor"
+)
+
+
+@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd")
+def quantize_per_token(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+):
+    """Per token quantization for the Tensor using the quantization parameters to map
+    from floating point to quantized values. This means for a N dimension Tensor
+    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
+    every N elements with the same quantization parameter. The dimension for scales/zero_points
+    will be (M1 * M2 ... * Mn)
+
+    Args:
+       input (torch.Tensor): original float32 or bfloat16 Tensor
+       scales (float32 torch.Tensor): quantization parameter for per token affine quantization
+       zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization
+       quant_min (int): minimum quantized value for output Tensor
+       quant_max (int): maximum quantized value for output Tensor
+       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+
+    Returns:
+       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
+       are not stored in the Tensor, we are storing them in function arguments instead
+    """
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    _per_token_quant_qparam_dim_check(input, scales, zero_points)
+    input = (
+        input.mul(1.0 / scales)
+        .add(zero_points)
+        .round()
+        .clamp(quant_min, quant_max)
+        .to(dtype)
+    )
+    return input
+
+
+@impl(quantized_decomposed_lib, "quantize_per_token", "Meta")
+def quantize_per_token_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+):
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    return torch.empty_like(input, dtype=dtype)
+
+
+quantized_decomposed_lib.define(
+    "dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
+    "int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor"
+)
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd")
+def dequantize_per_token(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    output_dtype: torch.dtype = torch.float32,
+):
+    """Per token dequantization for the Tensor using the quantization parameters to map
+    from floating point to quantized values. This means for a N dimension Tensor
+    (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
+    every N elements with the same quantization parameter. The dimension for scales/zero_points
+    will be (M1 * M2 ... * Mn)
+
+    Args:
+       input (torch.Tensor): quantized Tensor (uint8, int8 etc.)
+       scales (float64 torch.Tensor): quantization parameter for per token affine quantization
+       zero_points (int64 torch.Tensor): quantization parameter for per token affine quantization
+       quant_min (int): minimum quantized value for input Tensor
+       quant_max (int): maximum quantized value for input Tensor
+       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
+       output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
+
+    Returns:
+       dequantized Tensor with dtype `output_dtype`
+    """
+    input = input - zero_points
+    input = input * scales
+    # Since scales are of float64 type, we need to cast it to output dtype requested
+    return input.to(output_dtype)
+
+
+@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta")
+def dequantize_per_token_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    output_dtype: torch.dtype = torch.float32,
+):
+    _quant_min_max_bounds_check(quant_min, quant_max, dtype)
+    # TODO: support fp16
+    return torch.empty_like(input, dtype=output_dtype)
+
+
+quantized_decomposed_lib.define(
+    "quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, "
+    "int quant_max, ScalarType dtype, int group_size) -> Tensor"
+)
+
+
+# TODO: dtype is ignored for now
+@impl(
+    quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd"
+)
+def quantize_per_channel_group(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    group_size=128,
+):
+    assert group_size > 1
+    # needed for GPTQ single column quantize
+    if group_size > input.shape[-1] and scales.shape[-1] == 1:
+        group_size = input.shape[-1]
+
+    assert input.shape[-1] % group_size == 0
+    assert input.dim() == 2
+
+    # TODO: check for dtype, currently we can't express torch.int4 so it's omitted
+    to_quant = input.reshape(-1, group_size)
+    assert torch.isnan(to_quant).sum() == 0
+
+    scales = scales.reshape(-1, 1)
+    zero_points = zero_points.reshape(-1, 1)
+
+    input_int8 = (
+        to_quant.mul(1.0 / scales)
+        .add(zero_points)
+        .round()
+        .clamp_(quant_min, quant_max)
+        .to(dtype)
+        .reshape_as(input)
+    )
+
+    return input_int8
+
+
+@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta")
+def quantize_per_channel_group_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    group_size=128,
+):
+    """Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters
+    to map from floating point to quantized values. This means for each row of a 2-d Tensor
+    (M, N), we calculate scales/zero_points for each `group_size` elements
+    and quantize every `group_size` elements with the same quantization parameter.
+    The dimension for scales/zero_points will be (M * ceil(N, group_size),)
+
+    Args:
+       input (torch.Tensor): original float32 or bfloat16 Tensor
+       scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
+       zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
+       quant_min (int): minimum quantized value for output Tensor
+       quant_max (int): maximum quantized value for output Tensor
+       dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+
+    Returns:
+       Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
+       are not stored in the Tensor, we are storing them in function arguments instead
+    """
+    assert group_size > 1
+    # needed for GPTQ single column quantize
+    if group_size > input.shape[-1] and scales.shape[-1] == 1:
+        group_size = input.shape[-1]
+
+    assert input.shape[-1] % group_size == 0
+    assert input.dim() == 2
+    return torch.empty_like(input, dtype=dtype)
+
+
+quantized_decomposed_lib.define(
+    "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, "
+    "int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "dequantize_per_channel_group",
+    "CompositeExplicitAutograd",
+)
+def dequantize_per_channel_group(
+    w_int8: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: Optional[torch.Tensor],
+    quant_min: int,
+    quant_max: int,
+    dtype: torch.dtype,
+    group_size: int = 128,
+    output_dtype: torch.dtype = torch.float32,
+):
+    """Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters
+    to map from floating point to quantized values. This means for each row of a 2-d Tensor
+    (M, N), we calculate scales/zero_points for each `group_size` elements
+    and quantize every `group_size` elements with the same quantization parameter.
+    The dimension for scales/zero_points will be (M * ceil(N, group_size),)
+
+    Args:
+       input (torch.Tensor): quantized Tensor (uint8/int8 etc.)
+       scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
+       zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
+       quant_min (int): minimum quantized value for input Tensor
+       quant_max (int): maximum quantized value for input Tensor
+       dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
+       output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
+
+    Returns:
+       dequantized Tensor with dtype `output_dtype`
+    """
+
+    assert group_size > 1
+    # needed for GPTQ single column dequantize
+    if group_size > w_int8.shape[-1] and scales.shape[-1] == 1:
+        group_size = w_int8.shape[-1]
+    assert w_int8.shape[-1] % group_size == 0
+    assert w_int8.dim() == 2
+
+    w_int8_grouped = w_int8.reshape(-1, group_size)
+    scales = scales.reshape(-1, 1)
+    if zero_points is not None:
+        zp = zero_points.reshape(-1, 1)
+    else:
+        zp = torch.zeros([], dtype=torch.int32, device=scales.device)
+    w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype)
+    return w_dq
+
+
+quantized_decomposed_lib.define(
+    "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
+    "int quant_min, int quant_max) -> Tensor"
+)
+
+
+class FakeQuantPerChannel(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
+        if scales.dtype != torch.float32:
+            scales = scales.to(torch.float32)
+        if zero_points.dtype != torch.int32:
+            zero_points = zero_points.to(torch.int32)
+        assert input.dtype == torch.float32, (
+            f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+        )
+        assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+        broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim))
+        unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
+        unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)
+        temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points
+        out = (
+            torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points
+        ) * unsqueeze_scales
+        mask = torch.logical_and((temp >= quant_min), (temp <= quant_max))
+
+        ctx.save_for_backward(mask)
+        return out
+
+    @staticmethod
+    def backward(ctx, gy):
+        (mask,) = ctx.saved_tensors
+        return gy * mask, None, None, None, None, None
+
+
+@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Autograd")
+def fake_quant_per_channel(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    axis: int,
+    quant_min: int,
+    quant_max: int,
+) -> torch.Tensor:
+    return FakeQuantPerChannel.apply(
+        input, scales, zero_points, axis, quant_min, quant_max
+    )
+
+
+@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta")
+def fake_quant_per_channel_meta(
+    input: torch.Tensor,
+    scales: torch.Tensor,
+    zero_points: torch.Tensor,
+    axis: int,
+    quant_min: int,
+    quant_max: int,
+) -> torch.Tensor:
+    return torch.empty_like(input)
+
+
+quantized_decomposed_lib.define(
+    "convert_element_type.no_fuse(Tensor input, ScalarType dtype) -> Tensor"
+)
+
+
+@impl(
+    quantized_decomposed_lib,
+    "convert_element_type.no_fuse",
+    "CompositeExplicitAutograd",
+)
+def convert_element_type(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+    return torch.ops.prims.convert_element_type.default(input, dtype)
+
+
+@impl(quantized_decomposed_lib, "convert_element_type.no_fuse", "Meta")
+def convert_element_type_meta(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+    return torch.empty_like(input, dtype=dtype)
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..822d261ffc3282b3d4299d5e8f67f007712b3df6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_equalize.py
@@ -0,0 +1,955 @@
+# mypy: allow-untyped-defs
+import operator
+import warnings
+from collections import namedtuple
+from typing import Any, Optional
+
+import torch
+import torch.ao.nn.intrinsic as nni
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
+from torch.ao.quantization.observer import (
+    _with_args,
+    ObserverBase,
+    PerChannelMinMaxObserver,
+)
+from torch.ao.quantization.utils import _parent_name, check_min_max_valid
+from torch.fx import GraphModule
+from torch.fx.graph import Node
+
+from .utils import (
+    get_new_attr_name_with_prefix,
+    maybe_get_next_module,
+    node_arg_is_weight,
+)
+
+
+CUSTOM_MODULE_SUPP_LIST: list[Any] = []
+
+
+def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
+    """Reshapes the scale so that we can multiply it to the input by the given axis."""
+    new_shape = [1] * input.ndim
+    new_shape[axis] = input.size(axis)
+    return scale.view(new_shape)
+
+
+qsheme_mapping_per_tensor_to_per_channel = {
+    torch.per_tensor_affine: torch.per_channel_affine,
+    torch.per_tensor_symmetric: torch.per_channel_symmetric,
+}
+
+
+class _InputEqualizationObserver(nn.Module):
+    r"""Observer for tracking the running min/max values of input columns, and
+    computing the quantization parameters for the overall min/max input values.
+
+    Args:
+        dtype: Quantized data type
+        qscheme: Quantization scheme
+        quant_min: Minimum quantization value. If unspecified, it will
+            follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will
+            follow the 8-bit setup.
+
+    The running minimum/maximum :math:`x_\text{min/max}` are computed in the
+    same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`,
+    with the difference that the running min/max values are stored per column.
+    This observer is intended to be used along with a WeightEqualizationObserver
+    to calculate the equalization scale.
+    """
+
+    def __init__(
+        self,
+        dtype=torch.quint8,
+        qscheme=torch.per_tensor_affine,
+        quant_min=None,
+        quant_max=None,
+        factory_kwargs=None,
+    ) -> None:
+        super().__init__()
+
+        if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
+            raise TypeError("Input qscheme must be per-tensor")
+
+        self.dtype = dtype
+        self.qscheme = qscheme
+
+        per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
+        self.input_obs = PerChannelMinMaxObserver(
+            ch_axis=1,
+            dtype=dtype,
+            qscheme=per_channel_qscheme,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            factory_kwargs=factory_kwargs,
+        )
+
+        self.equalization_scale = torch.tensor(1)
+        self.equalization_shape: list[int] = []
+
+    def forward(self, x_orig):
+        if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
+            raise ValueError(
+                "InputEqualizationObserver only supports Linear and Conv layers"
+            )
+
+        # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers)
+        self.equalization_shape = [1] * x_orig.ndim
+        self.equalization_shape[1] = x_orig.size(1)
+
+        return self.input_obs(x_orig)
+
+    def get_input_minmax(self):
+        return (self.input_obs.min_val, self.input_obs.max_val)
+
+    def set_equalization_scale(self, equalization_scale):
+        # Reshape the equalization scale along axis=1 so that it can be
+        # multiplied with the input along axis=1
+        if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
+            return
+        self.equalization_scale = torch.reshape(
+            equalization_scale, self.equalization_shape
+        )
+
+    def calculate_scaled_minmax(self):
+        r"""Returns the scaled min/max inputs"""
+        if (
+            self.equalization_scale.nelement() == 1
+            and self.equalization_scale == torch.tensor(1)
+        ):
+            warnings.warn(
+                "Must call calculate_equalization_scale before calling calculate_scaled_minmax. "
+                + "Will not scale the next quantization observer."
+            )
+            return None, None
+
+        # Calculate qparams for the scaled min/max inputs
+        # Scale the input by the equalization scale located at the same column
+        # index
+        (min_inputs, max_inputs) = self.get_input_minmax()
+        equalization_scale_reshaped = reshape_scale(
+            self.equalization_scale, 0, min_inputs
+        )
+        min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped))
+        max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped))
+
+        return min_input_scaled, max_input_scaled
+
+    with_args = classmethod(_with_args)
+
+
+class _WeightEqualizationObserver(nn.Module):
+    r"""Observer for tracking the running min/max values of weight columns and
+    rows, and computing the quantization parameters for the weight rows.
+
+    Args:
+        dtype: Quantized data type
+        qscheme: Quantization scheme
+        quant_min: Minimum quantization value. If unspecified, it will
+            follow the 8-bit setup.
+        quant_max: Maximum quantization value. If unspecified, it will
+            follow the 8-bit setup.
+
+    This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
+    to record the running minimum and maximum of columns of incoming weight
+    tensors. This observer is intended to be used along with an
+    InputEqualizationObserver to calculate the equalization scale.
+
+    The running minimum/maximum :math:`w_\text{min/max}` are computed in the
+    same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.
+    """
+
+    def __init__(
+        self,
+        dtype=torch.qint8,
+        qscheme=torch.per_tensor_affine,
+        quant_min=None,
+        quant_max=None,
+        factory_kwargs=None,
+    ) -> None:
+        super().__init__()
+
+        self.dtype = dtype
+        self.qscheme = qscheme
+        self.ch_axis = 1
+
+        per_channel_qscheme = qscheme
+        if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
+            per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
+        self.weight_col_obs = PerChannelMinMaxObserver(
+            ch_axis=1,
+            dtype=dtype,
+            qscheme=per_channel_qscheme,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            factory_kwargs=factory_kwargs,
+        )
+
+        self.equalization_scale = torch.tensor(1)
+
+    def forward(self, w_orig):
+        if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
+            raise ValueError(
+                "InputEqualizationObserver only supports Linear and Conv layers"
+            )
+
+        return self.weight_col_obs(w_orig)
+
+    def get_weight_col_minmax(self):
+        return (self.weight_col_obs.min_val, self.weight_col_obs.max_val)
+
+    def set_equalization_scale(self, equalization_scale):
+        self.equalization_scale = equalization_scale
+
+    with_args = classmethod(_with_args)
+
+
+def calculate_equalization_scale(
+    input_obs: _InputEqualizationObserver, weight_obs: _WeightEqualizationObserver
+) -> torch.Tensor:
+    r"""Calculates the equalization scale and sets the equalization_scale value
+    in the observers.
+
+    Args:
+        input_obs: Observer that tracks the ranges for the input columns
+        weight_obs: Observer that tracks the ranges for the weight columns
+    """
+
+    (min_inputs, max_inputs) = input_obs.get_input_minmax()
+    (min_weights, max_weights) = weight_obs.get_weight_col_minmax()
+
+    if not (
+        check_min_max_valid(min_inputs, max_inputs)
+        and check_min_max_valid(min_weights, max_weights)
+    ):
+        warnings.warn(
+            "Must run observer before calling calculate_equalization_scale. "
+            + "Returning default equalization scale torch.tensor(1)."
+        )
+        return torch.tensor(1)
+
+    if not (min_inputs.shape == min_weights.shape):
+        raise ValueError(
+            "Input and Weight must have the same column dimension. "
+            + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
+        )
+
+    equalization_scale = torch.sqrt(
+        (max_weights - min_weights) / (max_inputs - min_inputs)
+    )
+    # Replace all 'inf', 'nan', 0's with 1s to prevent errors
+    equalization_scale[equalization_scale == 0.0] = 1
+    equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1)
+    return equalization_scale
+
+
+class EqualizationQConfig(
+    namedtuple("EqualizationQConfig", ["input_activation", "weight"])
+):
+    """
+    Describes how to quantize a layer or a part of the network specifically for
+    input-weight equalization by providing settings (observer classes) for
+    inputs, outputs, and weights.
+
+    Note that EqualizationQConfig needs to contain observer **classes** (like
+    MinMaxObserver) or a callable that returns instances on invocation, not the
+    concrete observer instances themselves.
+    Quantization function will instantiate observers multiple times for each of
+    the layers.
+
+    Observer classes have usually reasonable default arguments, but they can be
+    overwritten with `with_args` method (that behaves like functools.partial):
+
+    my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8),
+                                    weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
+    """
+
+    __slots__ = ()
+
+    def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
+        if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
+            raise ValueError(
+                "EqualizationQConfig received observer instance, please pass observer class instead. "
+                + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed"
+            )
+        self = super().__new__(cls, input_activation, weight)
+        return self
+
+
+input_equalization_observer = _InputEqualizationObserver.with_args(
+    dtype=torch.quint8, qscheme=torch.per_tensor_symmetric
+)
+weight_equalization_observer = _WeightEqualizationObserver.with_args(
+    dtype=torch.qint8, qscheme=torch.per_channel_symmetric
+)
+default_equalization_qconfig = EqualizationQConfig(
+    input_activation=input_equalization_observer, weight=weight_equalization_observer
+)
+
+
+def fused_module_supports_equalization(module) -> bool:
+    """Checks if the fused node supports equalization."""
+    return type(module) in [
+        nni.LinearReLU,
+        nni.ConvReLU1d,
+        nni.ConvReLU2d,
+        nni.ConvReLU3d,
+    ]
+
+
+def nn_module_supports_equalization(module) -> bool:
+    """Checks if the torch.nn node supports equalization."""
+    return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
+
+
+def custom_module_supports_equalization(module) -> bool:
+    """Checks if the custom node supports equalization."""
+    return type(module) in CUSTOM_MODULE_SUPP_LIST
+
+
+def node_supports_equalization(node: Node, modules) -> bool:
+    """Checks if the current node supports equalization
+    Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
+    """
+    if node.op == "call_module":
+        return (
+            nn_module_supports_equalization(modules[str(node.target)])
+            or fused_module_supports_equalization(modules[str(node.target)])
+            or custom_module_supports_equalization(modules[str(node.target)])
+        )
+    elif node.op == "call_function":
+        return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
+    return False
+
+
+def is_equalization_observer(observer: nn.Module) -> bool:
+    return isinstance(
+        observer, (_InputEqualizationObserver, _WeightEqualizationObserver)
+    )
+
+
+###############################################################################
+# Functions for equalization during convert                                   #
+###############################################################################
+
+
+def get_op_node_and_weight_eq_obs(
+    input_eq_obs_node: Node, model: GraphModule, modules: dict[str, nn.Module]
+) -> tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
+    """Gets the following weight equalization observer. There should always
+    exist a weight equalization observer after an input equalization observer.
+
+    Returns the operation node that follows the input equalization observer node
+    and the weight equalization observer
+    """
+
+    # Find the op node that comes directly after the input equalization observer
+    op_node = None
+    for user in input_eq_obs_node.users.keys():
+        if node_supports_equalization(user, modules):
+            op_node = user
+            break
+
+    assert op_node is not None
+    if op_node.op == "call_module":
+        # If the op_node is a nn.Linear layer, then it must have a
+        # WeightEqualizationObserver configuration
+        maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(
+            model, "equalization_node_name_to_qconfig"
+        )
+        assert maybe_equalization_node_name_to_config is not None
+        equalization_node_name_to_qconfig: dict[str, Any] = (
+            maybe_equalization_node_name_to_config  # type: ignore[assignment]
+        )
+        assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
+        weight_eq_obs = equalization_node_name_to_qconfig.get(  # type: ignore[union-attr]
+            op_node.name, None
+        ).weight()
+
+        assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
+        return op_node, weight_eq_obs
+
+    elif op_node.op == "call_function":
+        weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
+        if weight_node is not None:
+            weight_eq_obs = modules[str(weight_node.target)]
+            assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
+            return op_node, weight_eq_obs
+
+    return None, None
+
+
+def maybe_get_weight_eq_obs_node(
+    op_node: Node, modules: dict[str, nn.Module]
+) -> Optional[Node]:
+    """Gets the weight equalization observer node if it exists."""
+    assert op_node.op == "call_function"
+    for node_arg in op_node.args:
+        if node_arg_is_weight(op_node, node_arg):
+            assert (
+                isinstance(node_arg, Node)
+                and node_arg.op == "call_module"
+                and isinstance(
+                    modules[str(node_arg.target)], _WeightEqualizationObserver
+                )
+            )
+            return node_arg
+    return None
+
+
+def maybe_get_next_input_eq_obs(
+    node: Node, modules: dict[str, nn.Module]
+) -> Optional[_InputEqualizationObserver]:
+    """Gets the following input equalization observer if it exists.
+
+    For example, in the case of connecting linear layers:
+        x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
+    If the node being passed in is the linear1 node, then we want to return eq_obs2,
+    the following equalization observer for linear2.
+
+    However, if there are no connecting layers:
+        x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
+    Then we want to return None.
+
+    In the case of an unfused linear-relu layer with a connecting linear layer:
+        linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
+    Since it is unfused, we want to skip over the relu layer and return eq_obs2,
+    the following equalization observer for linear2.
+    """
+
+    assert node_supports_equalization(node, modules)
+
+    # Locate the following nn.ReLU or F.relu node if it exists
+    maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
+    if maybe_relu_node is None:
+        maybe_relu_node = maybe_get_next_module(
+            node, modules, target_functional_type=F.relu
+        )
+
+    # Locate the following output observer if it exists.
+    # We will skip the relu node if it exists.
+    maybe_obs_node = (
+        maybe_get_next_module(node, modules, ObserverBase)
+        if maybe_relu_node is None
+        else maybe_get_next_module(maybe_relu_node, modules, ObserverBase)
+    )
+    if maybe_obs_node is None:
+        return None
+
+    maybe_eq_obs_node = maybe_get_next_module(
+        maybe_obs_node, modules, _InputEqualizationObserver
+    )
+    if maybe_eq_obs_node is None:
+        return None
+
+    maybe_eq_obs = modules[str(maybe_eq_obs_node)]
+    assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
+    return maybe_eq_obs
+
+
+def maybe_get_next_equalization_scale(
+    node: Node, modules: dict[str, nn.Module]
+) -> Optional[torch.Tensor]:
+    """If the next next node is an InputEqualizationObserver then we want to
+    return its equalization scale, else we return 1
+
+    This is used in the case where there are two connecting linear layers:
+        linear1 -> LinearOutObs -> InputEqObs -> linear2
+    In this case, the node given is linear1 and we want to locate the InputEqObs.
+    """
+    next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
+    if next_inp_eq_obs:
+        if (
+            next_inp_eq_obs.equalization_scale.nelement() == 1
+            and next_inp_eq_obs.equalization_scale == torch.tensor(1)
+        ):
+            return None
+        return next_inp_eq_obs.equalization_scale
+    return None
+
+
+def scale_input_observer(node: Node, modules: dict[str, nn.Module]) -> None:
+    """Scales the following input quantization observer's min/max values by
+    updating the values with the scaled min/max values calculated by the input
+    equalization observer
+    """
+    input_eq_obs = modules[str(node.target)]
+    assert isinstance(input_eq_obs, _InputEqualizationObserver)
+
+    input_quant_obs_node = node.args[0]
+    assert isinstance(input_quant_obs_node, Node)
+
+    input_quant_obs = modules[str(input_quant_obs_node.target)]
+    if not isinstance(input_quant_obs, ObserverBase):
+        return
+
+    min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
+    if min_input_scaled is None and max_input_scaled is None:
+        return
+    input_quant_obs.min_val = min_input_scaled
+    input_quant_obs.max_val = max_input_scaled
+
+
+def scale_weight_node(
+    node: Node,
+    modules: dict[str, nn.Module],
+    equalization_scale: torch.Tensor,
+    next_equalization_scale: Optional[torch.Tensor],
+) -> None:
+    """Scale the weights for input-weight equalization by multiplying the
+    weight by 1/equalization_scale and next_equalization_scale
+
+    Args:
+        node: Current node whose weights we want to scale
+        equalization_scale: Current node's calculated equalization scale
+        next_equalization_scale: Next node's calculated equalization scale if
+           the following node needs to be equalized, 1 otherwise
+    """
+    if equalization_scale is None:
+        return
+
+    if fused_module_supports_equalization(modules[str(node.target)]):
+        op_module = modules[str(node.target)][0]  # type: ignore[index]
+    else:
+        op_module = modules[str(node.target)]
+    assert nn_module_supports_equalization(
+        op_module
+    ) or custom_module_supports_equalization(op_module)
+
+    # Scale the weights for input-weight equalization
+    # If the following layer needs to be equalized then we will multiply its scale
+    weight = op_module.weight
+    assert isinstance(weight, torch.Tensor)
+
+    # Scale the weights by the reciprocal of the equalization scale
+    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
+    equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
+    scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
+
+    if next_equalization_scale is None:
+        op_module.weight = nn.Parameter(scaled_weight)
+        return
+
+    # Multiply the weights row wise by the next equalization scale
+    # Reshape the equalization scale so that we can multiply it to the weight along axis=0
+    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight)
+    scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
+
+    op_module.weight = nn.Parameter(scaled_weight)
+
+    # Multiply the bias element wise by the next equalization scale
+    bias = op_module.bias
+    if bias is None:
+        return
+    assert isinstance(bias, torch.Tensor)
+
+    # Reshape the equalization scale so that we can multiply it element-wise to the bias
+    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
+    scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
+    op_module.bias = nn.Parameter(scaled_bias)
+
+
+def scale_weight_functional(
+    op_node: Node,
+    model: GraphModule,
+    modules: dict[str, nn.Module],
+    equalization_scale: torch.Tensor,
+    next_equalization_scale: Optional[torch.Tensor],
+) -> None:
+    """Scales the weight value for functional layers"""
+    if equalization_scale is None:
+        return
+
+    # From the given op_node, the path looks like:
+    #   get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
+    # So we want to trace back from the op_node to get the equalization observer
+    # node, then the quantization observer node, and then finally the weight
+    # node which contains the weight values.
+
+    # Get the equalization observer node
+    weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
+    if weight_eq_obs_node is None:
+        return
+
+    # Get the quantization observer node
+    weight_quant_obs_node = weight_eq_obs_node.args[0]
+    if weight_quant_obs_node is None:
+        return
+    assert isinstance(weight_quant_obs_node, Node) and isinstance(
+        modules[str(weight_quant_obs_node.target)], ObserverBase
+    )
+
+    # Get the get_attr(weight) node
+    weight_node = weight_quant_obs_node.args[0]
+    if weight_node is None:
+        return
+    assert isinstance(weight_node, Node) and weight_node.op == "get_attr"
+
+    weight_parent_name, weight_name = _parent_name(weight_node.target)
+    weight = getattr(modules[weight_parent_name], weight_name)
+
+    # Scale the weights for input-weight equalization
+    # If the following layer needs to be equalized then we will multiply its scale
+    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
+    equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
+    scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
+
+    if next_equalization_scale is None:
+        setattr(modules[weight_parent_name], weight_name, scaled_weight)
+        return
+
+    # Multiply the weights row wise by the next equalization scale
+    # Reshape the equalization scale so that we can multiply it to the weight along axis=1
+    next_equalization_scale_reshaped = reshape_scale(
+        next_equalization_scale, 0, scaled_weight
+    )
+    scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
+
+    setattr(modules[weight_parent_name], weight_name, scaled_weight)
+    assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
+
+    # Multiply the bias element wise by the next equalization scale
+    bias_node = None
+    for node in op_node.args:
+        # Find the node containing the weight values
+        if isinstance(node, Node) and node.op == "get_attr" and "bias" in node.name:
+            bias_node = node
+            break
+    if bias_node is None:
+        return
+
+    bias_parent_name, bias_name = _parent_name(bias_node.target)
+    bias = getattr(modules[bias_parent_name], bias_name)
+
+    # Reshape the equalization scale so that we can multiply it element-wise to the bias
+    next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
+    scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
+    setattr(modules[bias_parent_name], bias_name, scaled_bias)
+
+
+def clear_weight_quant_obs_node(op_node: Node, modules: dict[str, nn.Module]) -> None:
+    """Given the operation node, we want find the corresponding quantization
+    observer and reset its min/max values
+    """
+    weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
+    if weight_eq_obs_node is None:
+        return
+
+    weight_quant_obs_node = weight_eq_obs_node.args[0]
+    if weight_quant_obs_node is None:
+        return
+    assert isinstance(weight_quant_obs_node, Node)
+
+    weight_quant_obs = modules[str(weight_quant_obs_node.target)]
+    assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
+    weight_quant_obs.reset_min_max_vals()  # type: ignore[operator]
+
+
+def remove_node(model: GraphModule, node: Node, prev_node: Node):
+    """Removes the given node from the model by replacing all of its users with
+    the given previous node
+    """
+    # For all of the current node's users, replace the current node with
+    # the input quantization observer node
+    orig_users = list(node.users.keys())
+    for user_node in orig_users:
+        user_node.replace_input_with(node, prev_node)
+
+    # Erase the InputEqualizationObserver node
+    model.graph.erase_node(node)
+
+
+def update_obs_for_equalization(
+    model: GraphModule, modules: dict[str, nn.Module]
+) -> dict[str, _WeightEqualizationObserver]:
+    """Update all of the observer's equalization scale. For each
+    InputEqualizationObserver, we will find the location of the next
+    WeightEqualizationObserver, create it, and calculate the equalization scale
+    based on the two observers.
+
+    We will then return a dictionary mapping operation node names to
+    the corresponding WeightEqualizationObservers for that operation.
+    """
+    weight_eq_obs_dict = {}
+    for node in model.graph.nodes:
+        if node.op == "call_module" and isinstance(
+            modules[node.target], _InputEqualizationObserver
+        ):
+            input_eq_obs = modules[node.target]
+            assert isinstance(input_eq_obs, _InputEqualizationObserver)
+            op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
+
+            if op_node is None or weight_eq_obs is None:
+                continue
+
+            if op_node.op == "call_module":
+                # Calibrate the weight equalization observer since it has just
+                # been created
+                if fused_module_supports_equalization(modules[str(op_node.target)]):
+                    module = modules[str(op_node.target)][0]  # type: ignore[index]
+                    assert nn_module_supports_equalization(module)
+                    weight_eq_obs(module.weight)
+                else:
+                    weight_eq_obs(modules[str(op_node.target)].weight)
+
+            # Calculate and set the equalization scale values
+            equalization_scale = calculate_equalization_scale(
+                input_eq_obs, weight_eq_obs
+            )
+            input_eq_obs.set_equalization_scale(equalization_scale)
+            weight_eq_obs.set_equalization_scale(equalization_scale)
+
+            weight_eq_obs_dict[op_node.name] = weight_eq_obs
+
+    return weight_eq_obs_dict
+
+
+def convert_eq_obs(
+    model: GraphModule,
+    modules: dict[str, nn.Module],
+    weight_eq_obs_dict: dict[str, _WeightEqualizationObserver],
+) -> None:
+    """Converts the equalization operations and updates the other nodes in the
+    following way:
+        - Removes the input equalization observers and inserts a mul operator
+          along with an equalization scale node wherever applicable (we do not
+          want to insert a mul operator between connecting linear layers).
+        - Updates the input quantization observers with the scaled input min/max
+          values.
+        - Scales the weights by the current and next equalization scales.
+        - Removes the weight equalization observer node if it exists.
+
+    Before (after prepare):
+                                    weight values
+                                          |
+                                    WeightQuantObs
+                                          |
+                                      WeightEqObs
+                                          |
+        x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
+
+    After this function:
+                                              scaled weight values
+                                                      |
+       equalization scale                       WeightQuantObs
+              |                                       |
+        x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
+
+    After convert:
+       equalization scale                 scaled weight values
+              |                                    |
+        x -> mul -> quantize_per_tensor -> quantized::linear
+
+    Note that although the equalization observer appeared after the quantization
+    observer after prepare_fx, the mul node appears before the quantization node
+    after convert_fx. This is because placing the equalization observer after
+    the quantization observer in prepare_fx would allow us to keep the invariant
+    that the graph before the current node inserts its observers is not
+    modified.
+
+    Having the equalization observer before the quantization observer would also
+    cause some inconsistences between the ordering of the quantization and
+    equalization observers.
+    For example, a single linear layer would look like:
+        x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
+    But between two connected linear layers, it would look like:
+        linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
+    """
+    for node in model.graph.nodes:
+        if node.op == "call_module" and isinstance(
+            modules[node.target], _InputEqualizationObserver
+        ):
+            inp_quant_obs_node = node.args[0]
+            prev_node = inp_quant_obs_node.args[0]
+
+            # If the previous node is a layer that needs to be equalized, then
+            # we will remove the current node because we do not need to add any
+            # equalization nodes between two layers that need to be equalized
+
+            # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2
+            # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2
+            if (
+                node_supports_equalization(prev_node, modules)
+                or "relu" in prev_node.name
+            ):
+                remove_node(model, node, inp_quant_obs_node)
+                continue
+
+            # Update the following input quantization observer's min/max values
+            scale_input_observer(node, modules)
+
+            # Remove the InputEqualization node and add a mul operator before
+            # the quantization observer node that appears before the equalization node
+            # Before: x -> input_quant_obs -> input_eq_obs -> linear
+            # After: x -> mul -> input_quant_obs -> linear
+
+            # Create a node containing the equalization scale
+            with model.graph.inserting_before(inp_quant_obs_node):
+                get_new_eq_scale_name = get_new_attr_name_with_prefix(
+                    prev_node.name + "_equalization_scale"
+                )
+                name = get_new_eq_scale_name(modules)
+                setattr(model, name, modules[node.target].equalization_scale)
+                eq_scale_node = model.graph.create_node("get_attr", name)
+
+            # Create a node multiplying the input with the equalization scale
+            with model.graph.inserting_after(eq_scale_node):
+                inputs = (prev_node, eq_scale_node)
+                mul_node = model.graph.create_node("call_function", torch.mul, inputs)
+
+            # Set the mul nod to be the input_quant_obs_node's input instead of
+            # the previous node
+            inp_quant_obs_node.replace_input_with(prev_node, mul_node)
+            remove_node(model, node, inp_quant_obs_node)
+
+        elif weight_eq_obs_dict.get(node.name, None) is not None:
+            weight_eq_obs = weight_eq_obs_dict.get(node.name)
+            assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
+            equalization_scale = weight_eq_obs.equalization_scale
+
+            if (
+                equalization_scale.nelement() == 1
+                and equalization_scale == torch.tensor(1)
+            ):
+                equalization_scale = None  # type: ignore[assignment]
+            maybe_next_equalization_scale = maybe_get_next_equalization_scale(
+                node, modules
+            )
+
+            # Scale the weight nodes
+            if node.op == "call_module":
+                scale_weight_node(
+                    node, modules, equalization_scale, maybe_next_equalization_scale
+                )
+            elif node.op == "call_function":
+                scale_weight_functional(
+                    node,
+                    model,
+                    modules,
+                    equalization_scale,
+                    maybe_next_equalization_scale,
+                )
+
+                weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
+                if weight_eq_obs_node is None:
+                    return
+                assert isinstance(
+                    modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver
+                )
+
+                # Clear the quantization observer's min/max values so that they
+                # can get updated later based on the new scale values
+                clear_weight_quant_obs_node(node, modules)
+
+                # Erase the weight equalization observer node
+                prev_node = weight_eq_obs_node.args[0]
+                remove_node(model, weight_eq_obs_node, prev_node)  # type: ignore[arg-type]
+            else:
+                raise ValueError(
+                    "Expected operation node to be 'call_module' or 'call_function"
+                    + f"Instead got node {node.name} as '{node.op}'."
+                )
+
+
+def _convert_equalization_ref(model: GraphModule):
+    """Reference function which applies changes needed for equalization, but
+    does not quantize the nodes
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+
+    # Calculate the equalization scale, update the observers with the scaled
+    # inputs, and scale the weight
+    weight_eq_obs_dict = update_obs_for_equalization(model, modules)
+    convert_eq_obs(model, modules, weight_eq_obs_dict)
+
+    return GraphModule(model, model.graph)
+
+
+###############################################################################
+# Functions for running the equalized model on the Numeric Suite              #
+###############################################################################
+
+
+def get_layer_sqnr_dict(
+    model_a: nn.Module, model_b: nn.Module, x: torch.Tensor
+) -> dict[str, float]:
+    """Runs the Numeric Suite on model_a and model_b and returns a dictionary
+    containing the SQNR between layers in model_a and model_b.
+
+    Note: In order to support equalized models, this function has a hacky fix in
+    which we do not match any torch.mul operators. This is because equalized
+    models contain extra mul operators to scale the input by the equalization
+    scale, but this edge case has not been resolved yet within the numeric suite code.
+
+    Args:
+        model_a: A float model
+        model_b: A quantized model
+        x: Inputs to use during calibration
+    """
+    import torch.ao.ns._numeric_suite_fx as ns
+    from torch.ao.ns.fx.mappings import get_unmatchable_types_map
+
+    unmatchable_types_map = get_unmatchable_types_map()
+    unmatchable_types_map["funs_unmatchable"].add(torch.mul)
+
+    model_a_ns, model_b_ns = ns.add_loggers(
+        "fp32",
+        model_a,
+        "int8",
+        model_b,
+        ns.OutputLogger,
+        unmatchable_types_map=unmatchable_types_map,
+    )
+
+    model_a_ns(x)
+    model_b_ns(x)
+
+    activation_comparison_dict = ns.extract_logger_info(
+        model_a_ns, model_b_ns, ns.OutputLogger, "int8"
+    )
+    ns.extend_logger_results_with_comparison(
+        activation_comparison_dict,
+        "fp32",
+        "int8",
+        torch.ao.ns.fx.utils.compute_sqnr,
+        "sqnr",
+    )
+
+    # Construct a dictionary mapping layer names to the SQNR values
+    layer_sqnr_dict = {}
+    for key in activation_comparison_dict:
+        layer = activation_comparison_dict[key]["node_output"]["int8"][0]["fqn"]
+        sqnr = activation_comparison_dict[key]["node_output"]["int8"][0]["sqnr"][0]
+        layer_sqnr_dict[layer] = sqnr
+
+    return layer_sqnr_dict
+
+
+def get_equalization_qconfig_dict(
+    layer_sqnr_dict: dict[str, float], num_layers_to_equalize: int
+) -> Any:
+    """Given the layer to SQNR dictionary, find the layers with the highest
+    quantization errors, and return an equalization_qconfig_dict
+    specifying to only equalize those top layers.
+
+    Args:
+        layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found
+            when comparing an equalized model against a float model)
+        num_layers_to_equalize: Number of layers with the highest quantization
+           errors to equalize
+    """
+
+    # Sort the layer_sqnr_dictionary values and get the layers with the lowest
+    # SQNR values (aka highest quantization errors)
+    layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1))
+    layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize]
+
+    # Constructs an equalization_qconfig_dict that specifies to only equalize
+    # the layers with the highest quantization errors
+    module_to_qconfig_list = [
+        (item[0], default_equalization_qconfig) for item in layers_to_equalize
+    ]
+    equalization_qconfig_dict = {"module_name": module_to_qconfig_list}
+    return equalization_qconfig_dict
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeaad6b8afccc2ae43827972d2754bc1462058b2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py
@@ -0,0 +1,1364 @@
+# mypy: allow-untyped-defs
+import operator
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.ao.nn.intrinsic as nni
+import torch.ao.nn.intrinsic.quantized as nniq
+import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.ao.nn.quantized.reference as nnqr
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.ao.quantization.quantization_mappings import get_quantized_operator
+from torch.ao.quantization.utils import _parent_name
+from torch.fx import GraphModule, map_arg, Node
+from torch.fx.graph import Graph
+
+from .utils import (
+    collect_producer_nodes,
+    create_node_from_old_node_preserve_meta,
+    get_linear_prepack_op_for_dtype,
+    get_new_attr_name_with_prefix,
+    get_qconv_prepack_op,
+    graph_module_from_producer_nodes,
+)
+
+
+QOP_TO_ARG_NAMES_TO_SKIP: dict[Callable[..., Any], list[str]] = {
+    torch._ops.ops.quantized.hardswish: ["inplace"],
+    torch._ops.ops.quantized.elu: ["inplace"],
+    torch._ops.ops.quantized.dropout: ["inplace"],
+    torch._ops.ops.quantized.instance_norm: [
+        "running_mean",
+        "running_var",
+        "use_input_stats",
+        "momentum",
+    ],
+}
+
+
+def _is_node_in_list(node, modules, func_list, method_list, module_type_list):
+    is_call_function = node.op == "call_function" and node.target in func_list
+    is_call_method = node.op == "call_method" and node.target in method_list
+    is_call_module = (
+        node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
+    )
+    return is_call_function, is_call_method, is_call_module
+
+
+def is_fixed_qparams_node(node, modules):
+    func_list = [
+        torch.nn.functional.hardsigmoid,
+        torch.nn.functional.sigmoid,
+        torch.sigmoid,
+        torch.tanh,
+    ]
+    method_list = [
+        "hardsigmoid",
+        "hardsigmoid_",
+        "sigmoid",
+        "sigmoid_",
+        "tanh",
+        "tanh_",
+    ]
+    module_type_list = [
+        torch.nn.Hardsigmoid,
+        torch.nn.Sigmoid,
+        torch.nn.Tanh,
+        torch.nn.Softmax,
+    ]
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+
+def is_default_node(node, modules):
+    func_list = [
+        torch.nn.functional.elu,
+        torch.nn.functional.hardswish,
+        torch.nn.functional.instance_norm,
+        torch.nn.functional.layer_norm,
+        torch.nn.functional.leaky_relu,
+        torch.nn.functional.dropout,
+    ]
+    method_list: list[Any] = []
+    module_type_list = [
+        nnqr.ConvTranspose1d,
+        nnqr.ConvTranspose2d,
+        nnqr.ConvTranspose3d,
+        torch.nn.ELU,
+        torch.nn.LeakyReLU,
+        torch.nn.Hardswish,
+        torch.nn.InstanceNorm1d,
+        torch.nn.InstanceNorm2d,
+        torch.nn.InstanceNorm3d,
+        torch.nn.LayerNorm,
+        torch.nn.Dropout,
+        torch.nn.PReLU,
+        torch.nn.BatchNorm2d,
+        torch.nn.BatchNorm3d,
+        torch.ao.nn.intrinsic.BNReLU2d,
+        torch.ao.nn.intrinsic.BNReLU3d,
+    ]
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+
+def is_copy_node(node, modules):
+    func_list = [
+        torch.adaptive_avg_pool1d,
+        torch.nn.functional.adaptive_avg_pool2d,
+        torch.nn.functional.adaptive_avg_pool3d,
+        torch.nn.functional.hardtanh,
+        torch.nn.functional.hardtanh_,
+        torch.nn.functional.interpolate,
+        torch.nn.functional.max_pool1d,
+        torch.nn.functional.max_pool2d,
+        torch.nn.functional.max_pool3d,
+        torch.nn.functional.relu,
+        torch.nn.functional.relu6,
+        torch.avg_pool1d,
+        torch._C._nn.avg_pool2d,
+        torch._C._nn.avg_pool3d,
+        torch.clamp,
+        torch.flatten,
+        torch.mean,
+        operator.floordiv,
+        # F.channel_shuffle and torch.channel_shuffle are essentially the same thing
+        # so we only need to put one of them here
+        torch.channel_shuffle,
+    ]
+    method_list = [
+        "clamp",
+        "mean",
+        "relu",
+        "relu_",
+    ]
+    module_type_list = [
+        torch.nn.AdaptiveAvgPool1d,
+        torch.nn.AdaptiveAvgPool2d,
+        torch.nn.AdaptiveAvgPool3d,
+        torch.nn.AvgPool1d,
+        torch.nn.AvgPool2d,
+        torch.nn.AvgPool3d,
+        torch.nn.Hardtanh,
+        torch.nn.MaxPool1d,
+        torch.nn.MaxPool2d,
+        torch.nn.MaxPool3d,
+        torch.nn.ReLU,
+        torch.nn.ReLU6,
+        torch.nn.ChannelShuffle,
+    ]
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+
+def is_general_tensor_shape_node(node, modules):
+    func_list = [
+        torch.narrow,
+        torch.transpose,
+        torch.repeat_interleave,
+        torch.squeeze,
+        torch.stack,
+        torch.unsqueeze,
+        torch.nn.functional.pixel_shuffle,
+        torch.nn.functional.pixel_unshuffle,
+    ]
+    method_list = [
+        "contiguous",
+        "detach",
+        "detach_",
+        "permute",
+        "repeat",
+        "repeat_interleave",
+        "reshape",
+        "resize_",
+        "shape",
+        "size",
+        "squeeze",
+        "squeeze_",
+        "transpose",
+        "unsqueeze",
+        "unsqueeze_",
+        "view",
+    ]
+    module_type_list = [
+        torch.nn.Identity,
+        torch.nn.PixelShuffle,
+        torch.nn.PixelUnshuffle,
+    ]
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+
+def is_other_node(node, modules):
+    func_list = [
+        torch.cat,
+    ]
+    method_list: list[Any] = []
+    module_type_list: list[Any] = []
+    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
+
+
+def is_special_pattern_node(node, modules):
+    res_function, res_method, res_module = False, False, False
+    for checker in [
+        is_fixed_qparams_node,
+        is_default_node,
+        is_copy_node,
+        is_general_tensor_shape_node,
+        is_other_node,
+    ]:
+        is_call_function, is_call_method, is_call_module = checker(node, modules)
+        res_function = res_function or is_call_function
+        res_method = res_method or is_call_method
+        res_module = res_module or is_call_module
+    return res_function, res_method, res_module
+
+
+def is_dequantize_node(node):
+    return (
+        isinstance(node, Node)
+        and node.op == "call_method"
+        and node.target == "dequantize"
+    )
+
+
+def is_getattr_tensor_metadata_node(node):
+    return (
+        node.op == "call_function"
+        and node.target == getattr
+        and node.args[1] in ["shape"]
+    )
+
+
+def is_get_tensor_info_node(node):
+    return node.op == "call_method" and node.target in ["shape", "size"]
+
+
+def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: dict[str, QConfigAny]):
+    """
+    Return True if the op is configured with a None qconfig, False otherwise.
+    Note: maybe need to generalize this to also check for the dtype, and we
+    only lower when dtype matches, but right now fbgemm/qnnpack only support
+    a single dtype, so it is OK for now.
+    """
+    return op.name in qconfig_map and qconfig_map[op.name] is None
+
+
+# Mapping from reference module class to the replacement static quantized module class for lowering
+STATIC_LOWER_MODULE_MAP: dict[type[nn.Module], type[WeightedQuantizedModule]] = {
+    nnqr.Linear: nnq.Linear,
+    nnqr.Conv1d: nnq.Conv1d,
+    nnqr.Conv2d: nnq.Conv2d,
+    nnqr.Conv3d: nnq.Conv3d,
+}
+
+# Mapping from reference module class to the replacement dynamic quantized module class for lowering
+DYNAMIC_LOWER_MODULE_MAP: dict[type[nn.Module], type[nn.Module]] = {
+    nnqr.Linear: nnqd.Linear,
+    nnqr.GRUCell: nnqd.GRUCell,
+    nnqr.LSTMCell: nnqd.LSTMCell,
+    nnqr.RNNCell: nnqd.RNNCell,
+    nnqr.LSTM: nnqd.LSTM,
+    nnqr.GRU: nnqd.GRU,
+}
+
+# Mapping from reference module class to the replacement weight only quantized module class for lowering
+# TODO: correct the namespace for these modules
+WEIGHT_ONLY_LOWER_MODULE_MAP: dict[type[nn.Module], type[nn.Module]] = {
+    nnqr.Embedding: nnq.Embedding,
+    nnqr.EmbeddingBag: nnq.EmbeddingBag,
+}
+
+# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge
+# _lower_static_weighted_ref_module and special_pattern_replacement
+SPECIAL_PATTERN_LOWER_MODULE_MAP = {
+    nn.BatchNorm2d: nnq.BatchNorm2d,
+    nn.BatchNorm3d: nnq.BatchNorm3d,
+    nnqr.ConvTranspose1d: nnq.ConvTranspose1d,
+    nnqr.ConvTranspose2d: nnq.ConvTranspose2d,
+    nnqr.ConvTranspose3d: nnq.ConvTranspose3d,
+    nn.ELU: nnq.ELU,
+    nn.LeakyReLU: nnq.LeakyReLU,
+    nn.Hardswish: nnq.Hardswish,
+    nn.InstanceNorm1d: nnq.InstanceNorm1d,
+    nn.InstanceNorm2d: nnq.InstanceNorm2d,
+    nn.InstanceNorm3d: nnq.InstanceNorm3d,
+    nn.LayerNorm: nnq.LayerNorm,
+    nn.Dropout: nnq.Dropout,
+    nn.Softmax: nnq.Softmax,
+    nn.PReLU: nnq.PReLU,
+    nni.BNReLU2d: nniq.BNReLU2d,
+    nni.BNReLU3d: nniq.BNReLU3d,
+}
+
+# Mapping from fused module class to a 2-tuple of:
+#   1) The inner reference module class
+#   2) The replacement static quantized module class for lowering
+STATIC_LOWER_FUSED_MODULE_MAP: dict[
+    type[nn.Module], tuple[type[nn.Module], type[WeightedQuantizedModule]]
+] = {
+    nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU),
+    # TODO: LinearLeakyReLU is registered as global but it is only fused and
+    # lowered when ondnn's backend config is used. Maybe need to separate
+    # registration and lowering functions for different backends in the future.
+    nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU),
+    nni.LinearTanh: (nnqr.Linear, nniq.LinearTanh),
+    nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d),
+    nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d),
+    nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d),
+}
+
+# The difference between STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP and STATIC_LOWER_FUSED_MODULE_MAP:
+# The refer node inside STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP has 2 inputs.
+# Mapping from fused module class to a 2-tuple of:
+#   1) The inner reference module class
+#   2) The replacement static quantized module class for lowering
+STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: dict[
+    type[nn.Module], tuple[type[nn.Module], type[WeightedQuantizedModule]]
+] = {
+    nni.ConvAdd2d: (nnqr.Conv2d, nniq.ConvAdd2d),
+    nni.ConvAddReLU2d: (nnqr.Conv2d, nniq.ConvAddReLU2d),
+}
+
+# Mapping from fused module class to a 2-tuple of:
+#   1) The inner reference module class
+#   2) The replacement dynamic quantized module class for lowering
+DYNAMIC_LOWER_FUSED_MODULE_MAP: dict[
+    type[nn.Module], tuple[type[nn.Module], type[nn.Module]]
+] = {
+    nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU),
+}
+
+# Mapping from a functional to lower to a 2-tuple of
+#   1) The quantized version of the op
+#   2) The quantized version of the op fused with relu, if it exists, else None
+STATIC_LOWER_FUNCTIONAL_MAP: dict[Callable, tuple[Callable, Optional[Callable]]] = {
+    F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu),
+    F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu),
+    F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu),
+    F.conv3d: (torch.ops.quantized.conv3d, torch.ops.quantized.conv3d_relu),
+    F.conv_transpose1d: (torch.ops.quantized.conv_transpose1d, None),
+    F.conv_transpose2d: (torch.ops.quantized.conv_transpose2d, None),
+    F.conv_transpose3d: (torch.ops.quantized.conv_transpose3d, None),
+}
+
+WEIGHT_PREPACK_OPS: set[Callable] = {
+    torch._ops.ops.quantized.linear_prepack,
+    torch._ops.ops.quantized.linear_prepack_fp16,
+    torch._ops.ops.quantized.conv1d_prepack,
+    torch._ops.ops.quantized.conv2d_prepack,
+    torch._ops.ops.quantized.conv3d_prepack,
+    torch.ops.quantized.conv_transpose1d_prepack,
+    torch.ops.quantized.conv_transpose2d_prepack,
+    torch.ops.quantized.conv_transpose3d_prepack,
+}
+
+# Mapping from a functional to a dictionary, where the key is a 2-tuple of
+# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of
+#   1) The dynamically quantized version of the op
+#   2) The dynamically quantized version of the op fused with relu, if it exists, else None
+DYNAMIC_LOWER_FUNCTIONAL_MAP: dict[
+    Callable, dict[tuple[torch.dtype, torch.dtype], tuple[Callable, Optional[Callable]]]
+] = {
+    F.linear: {
+        (torch.quint8, torch.qint8): (
+            torch.ops.quantized.linear_dynamic,
+            torch.ops.quantized.linear_relu_dynamic,
+        ),
+        (torch.float16, torch.float16): (
+            torch.ops.quantized.linear_dynamic_fp16,
+            torch.ops.quantized.linear_relu_dynamic_fp16,
+        ),
+    },
+    # dynamic conv + relu is not available yet
+    F.conv1d: {
+        (torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None),
+    },
+    F.conv2d: {
+        (torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None),
+    },
+    F.conv3d: {
+        (torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None),
+    },
+}
+
+CONV_FUNCTIONAL_OPS: set[Callable] = {
+    F.conv1d,
+    F.conv2d,
+    F.conv3d,
+}
+
+CONV_TRANSPOSE_FUNCTIONAL_OPS: set[Callable] = {
+    F.conv_transpose1d,
+    F.conv_transpose2d,
+    F.conv_transpose3d,
+}
+
+# TODO: add tests for lowering these ops
+QBIN_OP_MAPPING: dict[Union[Callable, str], Callable] = {
+    operator.add: torch.ops.quantized.add,
+    torch.add: torch.ops.quantized.add,
+    operator.mul: torch.ops.quantized.mul,
+    operator.matmul: torch.ops.quantized.matmul,
+    torch.mul: torch.ops.quantized.mul,
+    torch.matmul: torch.ops.quantized.matmul,
+}
+QBIN_RELU_OP_MAPPING: dict[Union[Callable, str], Callable] = {
+    operator.add: torch.ops.quantized.add_relu,
+    torch.add: torch.ops.quantized.add_relu,
+    operator.mul: torch.ops.quantized.mul_relu,
+    torch.mul: torch.ops.quantized.mul_relu,
+}
+
+ORIGINAL_WEIGHTS_LOOKUP = "original_weights_lookup"
+
+
+def _save_packed_weight(self, destination, prefix, keep_vars):
+    for attr_name in dir(self):
+        if "_packed_weight" in attr_name and isinstance(
+            getattr(self, attr_name), torch._C.ScriptObject
+        ):  # type: ignore[attr-defined]
+            packed_weight = getattr(self, attr_name)
+            destination[prefix + attr_name] = packed_weight
+
+
+def _load_packed_weight(
+    self,
+    state_dict,
+    prefix,
+    local_metadata,
+    strict,
+    missing_keys,
+    unexpected_keys,
+    error_msgs,
+):
+    attrs_to_pop = []
+    for attr_name in state_dict:
+        if attr_name.startswith("_packed_weight") and isinstance(
+            state_dict[attr_name], torch._C.ScriptObject
+        ):  # type: ignore[attr-defined] # noqa: B950
+            setattr(self, attr_name, state_dict[attr_name])
+            attrs_to_pop.append(attr_name)
+
+    # pop the packed param attributesn
+    for attr_name in attrs_to_pop:
+        state_dict.pop(attr_name)
+
+
+def fold_weight(
+    quantized_model: GraphModule,
+    node_name_to_scope: dict[str, tuple[str, type]],
+    keep_original_weights: bool = False,
+) -> GraphModule:
+    """
+    Trace back from the weight node util we hit getattr, reconstruct the
+    graph module with the traced nodes and run the graph module to pack the
+    weight. then replace the original chain of ops with the packed weight.
+    """
+    packed_weights = {}
+    # map from folded node name to the prepacked weight name
+    folded_nodes = {}
+    original_weights_lookup: dict[str, list] = {}
+    lookup_counter = 0
+    # get packed weights
+    for node in quantized_model.graph.nodes:
+        if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS:
+            nodes_to_fold = collect_producer_nodes(node)
+            if nodes_to_fold is not None:
+                for node_to_fold in nodes_to_fold:
+                    folded_nodes[node_to_fold.name] = node
+
+                prepacking_module = graph_module_from_producer_nodes(
+                    quantized_model, nodes_to_fold
+                )
+                packed_weight = prepacking_module()
+                packed_weights[node.name] = packed_weight
+                if keep_original_weights:
+                    original_weights = list(prepacking_module.state_dict().values())
+                    original_weights_lookup[str(lookup_counter)] = sorted(
+                        original_weights, key=lambda x: x.numel(), reverse=True
+                    )
+                    if len(original_weights_lookup[str(lookup_counter)]) == 1:
+                        # bias is None
+                        original_weights_lookup[str(lookup_counter)].append(None)
+                    lookup_counter += 1
+    lookup_counter = 0
+
+    # remove folded nodes and replace the prepacking node with getattr
+    folded_graph = Graph()
+    env: dict[Any, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node.name])
+
+    for node in quantized_model.graph.nodes:
+        prepack_node = folded_nodes.get(node.name, None)
+        if prepack_node is node:
+            packed_weight = packed_weights[node.name]
+            # add a prepacked attribute to root
+            op_node = next(iter(prepack_node.users))
+            module_path, _ = node_name_to_scope[op_node.name]
+            get_new_packed_weight_name = get_new_attr_name_with_prefix(
+                module_path + "_packed_weight_"
+            )
+            packed_weight_name = get_new_packed_weight_name(quantized_model)
+            setattr(quantized_model, packed_weight_name, packed_weight)
+            # replace prepack node with a getattr node
+            env[node.name] = folded_graph.create_node(
+                "get_attr", packed_weight_name, (), {}
+            )
+            if keep_original_weights:
+                key_name = (
+                    packed_weight_name.replace(":", "_")
+                    .replace("/", "_")
+                    .replace("|", "_")
+                    .replace(" ", "")
+                    .lower()
+                )
+                original_weights_lookup[key_name] = original_weights_lookup[
+                    str(lookup_counter)
+                ]
+                del original_weights_lookup[str(lookup_counter)]
+                lookup_counter += 1
+        elif prepack_node is not None:
+            # remove the foled node
+            continue
+        else:
+            # copy other nodes
+            env[node.name] = folded_graph.node_copy(node, load_arg)
+
+    quantized_model = GraphModule(quantized_model, folded_graph)
+    quantized_model._register_state_dict_hook(_save_packed_weight)
+    quantized_model.register_load_state_dict_pre_hook(_load_packed_weight)
+
+    if keep_original_weights:
+        setattr(  # noqa: B010
+            quantized_model, ORIGINAL_WEIGHTS_LOOKUP, original_weights_lookup
+        )
+
+    return quantized_model
+
+
+def _get_module(node: Node, modules: dict[str, nn.Module]) -> Optional[nn.Module]:
+    """
+    Return the `torch.nn.Module` that corresponds to the specified node's target.
+    If no such node exists, return None.
+    """
+    if node.op == "call_module" and str(node.target) in modules:
+        return modules[str(node.target)]
+    else:
+        return None
+
+
+def _match_static_pattern(
+    node: Node,
+    modules: dict[str, nn.Module],
+    qconfig_map: dict[str, QConfigAny],
+    matching_modules_or_ops: list[Callable],
+    dequantize_node_arg_indices: list[int],
+) -> Union[tuple[Node, Node, Node], tuple[None, None, None]]:
+    """
+    Match the pattern (dequantize - ref node - quantize) against the node provided.
+
+    If there is a match, return a 3-tuple of:
+      1) q_node: the quantize node,
+      2) relu_node: a relu node wrapping the ref_node, and
+      3) ref_node: a reference module or functional node to replace with its quantized counterpart
+    Otherwise, if there is no match, return a 3-tuple of (None, None, None).
+
+    Parameters:
+      node: The `torch.fx.Node` to match against.
+      modules: A mapping from node names to modules in the model graph, used for module lookup.
+      qconfig_map: A mapping from node names to the qconfigs associated with the nodes.
+          If the corresponding qconfig for the reference node is None, then return no match.
+      matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s.
+          If the reference node is not in this list, then return no match.
+      dequantize_node_arg_indices: A list of indices in the reference node args where dequantize
+          nodes may be present. An empty list means skipping the check for dequantize nodes.
+    """
+    SKIP_LOWERING_VALUE = (None, None, None)
+
+    # Match quantize node
+    if node.op != "call_function" or node.target != torch.quantize_per_tensor:
+        return SKIP_LOWERING_VALUE
+    q_node = node
+    ref_node = q_node.args[0]
+    assert isinstance(ref_node, Node)
+
+    # Handle cases where the node is wrapped in a ReLU
+    if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or (
+        ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU
+    ):
+        relu_node = ref_node
+        ref_node = relu_node.args[0]
+        assert isinstance(ref_node, Node)
+    else:
+        relu_node = None
+    if should_skip_lowering(ref_node, qconfig_map):
+        return SKIP_LOWERING_VALUE
+
+    # Match reference module or functional
+    if isinstance(matching_modules_or_ops[0], type) and issubclass(
+        matching_modules_or_ops[0], nn.Module
+    ):
+        expected_op = "call_module"
+        match_key = type(_get_module(ref_node, modules))
+    else:
+        expected_op = "call_function"
+        match_key = ref_node.target  # type: ignore[assignment]
+    if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
+        return SKIP_LOWERING_VALUE
+
+    # Match dequantize node(s). Both of the following conditions must pass:
+    # (1) All `torch.fx.Node`s at the matching indices must be a dequantize node
+    # (2) There must be at least one dequantize node
+    matched_dequantize = False
+    for i in dequantize_node_arg_indices:
+        assert i < len(ref_node.args), (
+            f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
+        )
+        arg = ref_node.args[i]
+        if is_dequantize_node(arg):
+            matched_dequantize = True
+        elif isinstance(arg, Node):
+            return SKIP_LOWERING_VALUE
+    if not matched_dequantize:
+        return SKIP_LOWERING_VALUE
+
+    return (q_node, relu_node, ref_node)  # type: ignore[return-value]
+
+
+def _match_static_pattern_with_two_inputs(
+    node: Node,
+    modules: dict[str, nn.Module],
+    qconfig_map: dict[str, QConfigAny],
+    matching_modules_or_ops: list[Callable],
+) -> Union[tuple[Node, Node], tuple[None, None]]:
+    """
+                      (dequantize \
+    Match the pattern (dequantize - ref node - quantize) against the node provided.
+
+    If there is a match, return a 2-tuple of:
+      1) q_node: the quantize node,
+      2) ref_node: a reference module or functional node to replace with its quantized counterpart
+    Otherwise, if there is no match, return a 2-tuple of (None, None).
+
+    Parameters:
+      node: The `torch.fx.Node` to match against.
+      modules: A mapping from node names to modules in the model graph, used for module lookup.
+      qconfig_map: A mapping from node names to the qconfigs associated with the nodes.
+          If the corresponding qconfig for the reference node is None, then return no match.
+      matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s.
+          If the reference node is not in this list, then return no match.
+    """
+    SKIP_LOWERING_VALUE = (None, None)
+
+    # Match quantize node
+    if node.op != "call_function" or node.target != torch.quantize_per_tensor:
+        return SKIP_LOWERING_VALUE
+    q_node = node
+    ref_node = q_node.args[0]
+    assert isinstance(ref_node, Node)
+
+    if should_skip_lowering(ref_node, qconfig_map):
+        return SKIP_LOWERING_VALUE
+
+    # Match reference module or functional
+    if isinstance(matching_modules_or_ops[0], type) and issubclass(
+        matching_modules_or_ops[0], nn.Module
+    ):
+        expected_op = "call_module"
+        match_key = type(_get_module(ref_node, modules))
+    else:
+        # This pass only support op of "call_module"
+        return SKIP_LOWERING_VALUE
+
+    if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
+        return SKIP_LOWERING_VALUE
+
+    # Check ref_node has 2 input nodes, both are dq node.
+    if len(ref_node.args) != 2:
+        return SKIP_LOWERING_VALUE
+    for i in range(len(ref_node.args)):
+        arg = ref_node.args[i]
+        if not is_dequantize_node(arg):
+            return SKIP_LOWERING_VALUE
+
+    return (q_node, ref_node)
+
+
+def _lower_static_weighted_ref_module(
+    model: GraphModule, qconfig_map: dict[str, QConfigAny]
+):
+    """
+    Traverse the graph and find dequantize - ref module - quantize patterns
+    and replace them with the quantized version of the ref module.
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
+        matching_modules = list(STATIC_LOWER_MODULE_MAP.keys()) + list(
+            STATIC_LOWER_FUSED_MODULE_MAP.keys()
+        )
+        q_node, _relu_node, ref_node = _match_static_pattern(
+            n,
+            modules,
+            qconfig_map,
+            matching_modules,  # type: ignore[arg-type]
+            dequantize_node_arg_indices=[0],
+        )
+        if q_node is None:
+            continue
+        assert ref_node is not None
+        (_, scale_node, zero_point_node, _) = q_node.args
+        ref_module = _get_module(ref_node, modules)
+        ref_class = type(ref_module)
+        assert isinstance(scale_node, Node)
+        assert isinstance(zero_point_node, Node)
+        assert issubclass(ref_class, nn.Module)
+
+        # Step 1: Change this pattern to use the corresponding quantized module
+        # For fused modules, we also check whether the inner module is a reference module
+        # If so, we replace the entire fused module with the corresponding quantized module
+        if ref_class in STATIC_LOWER_FUSED_MODULE_MAP:
+            inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class]
+            if type(ref_module[0]) != inner_ref_class:  # type: ignore[index]
+                continue
+        else:
+            q_class = STATIC_LOWER_MODULE_MAP[ref_class]
+        output_scale = getattr(model, scale_node.target)  # type: ignore[arg-type]
+        output_zero_point = getattr(model, zero_point_node.target)  # type: ignore[arg-type]
+        q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
+        # replace reference module with quantized module
+        parent_name, module_name = _parent_name(ref_node.target)
+        setattr(modules[parent_name], module_name, q_module)
+
+        # Step 2: Reroute around dq_node, and remove q_node and its args
+        assert len(ref_node.args) == 1
+        dq_node = ref_node.args[0]
+        assert isinstance(dq_node, Node)
+        ref_node.replace_input_with(dq_node, dq_node.args[0])  # type: ignore[arg-type]
+        q_node.replace_all_uses_with(ref_node)
+        model.graph.erase_node(q_node)
+        model.graph.erase_node(scale_node)
+        model.graph.erase_node(zero_point_node)
+
+
+def _lower_static_weighted_ref_module_with_two_inputs(
+    model: GraphModule, qconfig_map: dict[str, QConfigAny]
+):
+    """
+    Traverse the graph and find patterns
+    dequantize   dequantize
+       \\         //
+        ref module
+            \\
+          quantize
+    and replace them with the quantized version of the ref module.
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        #                                            (dequantize \
+        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
+        matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys())
+        (q_node, ref_node) = _match_static_pattern_with_two_inputs(
+            n,
+            modules,
+            qconfig_map,
+            matching_modules,  # type: ignore[arg-type]
+        )
+        if q_node is None:
+            continue
+        assert ref_node is not None
+        (_, scale_node, zero_point_node, _) = q_node.args
+        ref_module = _get_module(ref_node, modules)
+        ref_class = type(ref_module)
+        assert isinstance(scale_node, Node)
+        assert isinstance(zero_point_node, Node)
+        assert issubclass(ref_class, nn.Module)
+
+        # Step 1: Change this pattern to use the corresponding quantized module
+        # For fused modules, we also check whether the inner module is a reference module
+        # If so, we replace the entire fused module with the corresponding quantized module
+        if ref_class in STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP:
+            inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[
+                ref_class
+            ]
+            if type(ref_module[0]) != inner_ref_class:  # type: ignore[index]
+                continue
+        else:
+            continue
+        output_scale = getattr(model, scale_node.target)  # type: ignore[arg-type]
+        output_zero_point = getattr(model, zero_point_node.target)  # type: ignore[arg-type]
+        q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
+        # replace reference module with quantized module
+        parent_name, module_name = _parent_name(ref_node.target)
+        setattr(modules[parent_name], module_name, q_module)
+
+        # Step 2: Reroute around dq_node, and remove q_node and its args
+        assert len(ref_node.args) == 2
+        for arg in ref_node.args:
+            if not is_dequantize_node(arg):
+                continue
+            dq_node = arg
+            assert isinstance(dq_node, Node)
+            ref_node.replace_input_with(dq_node, dq_node.args[0])  # type: ignore[arg-type]
+
+        q_node.replace_all_uses_with(ref_node)
+        model.graph.erase_node(q_node)
+        model.graph.erase_node(scale_node)
+        model.graph.erase_node(zero_point_node)
+
+
+def _lower_dynamic_weighted_ref_module(model: GraphModule):
+    """
+    Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns
+    and replace them with the dynamically quantized version of the ref module.
+    """
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        if n.op != "call_module" or type(named_modules[str(n.target)]) not in set(
+            DYNAMIC_LOWER_MODULE_MAP.keys()
+        ).union(set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())):
+            continue
+        ref_node = n
+        dq_node = ref_node.args[0]
+        if dq_node.op != "call_method" or dq_node.target != "dequantize":
+            continue
+
+        input_dynamic_q_node = dq_node.args[0]
+
+        if (
+            input_dynamic_q_node.op != "call_function"
+            or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic
+        ):
+            continue
+
+        activation_dtype = input_dynamic_q_node.args[1]
+        is_fp16 = activation_dtype == torch.float16
+        is_int8 = activation_dtype in [torch.quint8, torch.qint8]
+        if not is_int8 and not is_fp16:
+            continue
+
+        ref_module = named_modules[str(ref_node.target)]
+        ref_class = type(ref_module)
+        if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP:
+            inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class]
+            if type(ref_module[0]) != inner_ref_class:
+                continue
+        else:
+            q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class)  # type: ignore[assignment]
+        # TODO: maybe define a WeightedDynamicallyQuantizedModule
+        q_module = q_class.from_reference(ref_module)  # type: ignore[attr-defined]
+
+        # replace reference module with dynamically quantized module
+        parent_name, module_name = _parent_name(ref_node.target)
+        setattr(named_modules[parent_name], module_name, q_module)
+        ref_node.replace_input_with(dq_node, input_dynamic_q_node.args[0])
+
+
+def _lower_weight_only_weighted_ref_module(model: GraphModule):
+    """
+    Traverse the graph and find ref_module patterns
+    and replace them with the weight only quantized version of the ref module.
+    """
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        if n.op != "call_module" or type(named_modules[str(n.target)]) not in set(
+            WEIGHT_ONLY_LOWER_MODULE_MAP.keys()
+        ):
+            continue
+        ref_node = n
+        ref_module = named_modules[str(ref_node.target)]
+        ref_class = type(ref_module)
+        q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class)
+        # TODO: WeightedQuantizedModule is currently assuming static quant apis
+        # with output_scale, output_zero_point in from_reference, we may want to
+        # relax that, or rename this
+        # TODO: maybe define a WeightedWeightOnlyQuantizedModule
+        q_module = q_class.from_reference(ref_module)  # type: ignore[union-attr]
+
+        # replace reference module with dynamically quantized module
+        parent_name, module_name = _parent_name(ref_node.target)
+        setattr(named_modules[parent_name], module_name, q_module)
+
+
+def _lower_static_weighted_ref_functional(
+    model: GraphModule, qconfig_map: dict[str, QConfigAny]
+):
+    """
+    Traverse the graph and replace functional reference patterns with their quantized versions.
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        # Step 0: Find nodes that match this pattern (dequantize - functional op - quantize)
+        matching_ops = list(STATIC_LOWER_FUNCTIONAL_MAP.keys())
+        (q_node, relu_node, func_node) = _match_static_pattern(
+            n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1]
+        )
+        if q_node is None:
+            continue
+        assert func_node is not None
+        (_, output_scale_node, output_zp_node, _) = q_node.args
+        (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
+        assert isinstance(output_zp_node, Node)
+        assert isinstance(input_dq_node, Node)
+        assert isinstance(weight_dq_node, Node)
+        quantized_weight = weight_dq_node.args[0]
+        assert isinstance(quantized_weight, Node)
+        if quantized_weight.op != "call_function" or quantized_weight.target not in (
+            torch.quantize_per_tensor,
+            torch.quantize_per_channel,
+        ):
+            continue
+
+        # Step 1: Replace quantized weights with packed weights, which will be folded later
+        # Use the right prepack op and prepare the corresponding args
+        # Linear prepack args: (quantized weights[, bias])
+        # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups])
+        prepack_args = [quantized_weight] + remaining_func_args
+        if func_node.target == F.linear:
+            weight_dtype = quantized_weight.args[-1]
+            prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
+        elif func_node.target in CONV_FUNCTIONAL_OPS:
+            prepack_op = get_qconv_prepack_op(func_node.target)  # type: ignore[arg-type]
+            # For conv1d, the stride, padding, and dilation args may be ints,
+            # in which case we need to convert them to tuples
+            if func_node.target == F.conv1d:
+                for i in [2, 3, 4]:
+                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
+                        prepack_args[i] = (prepack_args[i],)
+        elif func_node.target in CONV_TRANSPOSE_FUNCTIONAL_OPS:
+            prepack_op = get_qconv_prepack_op(func_node.target)  # type: ignore[arg-type]
+            # For conv_transpose1d, the stride, padding, and dilation args may be ints,
+            # in which case we need to convert them to tuples
+            if func_node.target == F.conv_transpose1d:
+                # Note prepack_args[5] is groups.
+                for i in [2, 3, 4, 6]:
+                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
+                        prepack_args[i] = (prepack_args[i],)
+            # swap dilation and groups
+            # prepack op has arguments: {w, b, stride, padding, output_padding, dilation, groups}
+            # transposed conv op has arguments: {x, w, b, stride, padding, output_padding, groups, dilation}
+            if len(prepack_args) > 6:
+                prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5]
+        else:
+            raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
+        with model.graph.inserting_before(output_scale_node):  # type: ignore[arg-type]
+            # kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack)
+            # They are not needed for compute op (i.e., quantized::linear)
+            kwargs = func_node.kwargs
+            # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias
+            if func_node.target == F.linear and "bias" in kwargs:
+                kwargs = kwargs.copy()
+                kwargs["B"] = kwargs["bias"]
+                del kwargs["bias"]
+            packed_weight = model.graph.create_node(
+                "call_function", prepack_op, tuple(prepack_args), kwargs
+            )
+
+        # Step 2: Replace reference pattern with the corresponding quantized op
+        (q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target]  # type: ignore[index]
+        # conv_transpose does not support fusion with relu yet. q_relu_func is None in such cases
+        if q_relu_func is not None:
+            func_node.target = q_relu_func if relu_node is not None else q_func
+        else:
+            func_node.target = q_func
+        func_node.args = (
+            input_dq_node.args[0],
+            packed_weight,
+            output_scale_node,
+            output_zp_node,
+        )
+        # kwargs for func_node has been moved to kwargs for prepack op
+        func_node.kwargs = {}
+        q_node.replace_all_uses_with(func_node)
+        # Move func_node after output_zp_node in the graph
+        output_zp_node.append(func_node)
+
+        # Clean up: Remove quantize node, and the relu node if it exists
+        model.graph.erase_node(q_node)
+        if relu_node is not None and q_relu_func is not None:
+            model.graph.erase_node(relu_node)
+
+
+def _lower_dynamic_weighted_ref_functional(
+    model: GraphModule, qconfig_map: dict[str, QConfigAny]
+):
+    """
+    Traverse the graph and replace functional reference patterns with their dynamically
+    quantized versions.
+    Examples:
+    quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic
+    to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16
+    """
+    modules = dict(model.named_modules(remove_duplicate=False))
+    # we want to search in reserved order so that we can match the larger patterns first
+    # e.g. we want to match linear - relu before linear.
+    for n in reversed(model.graph.nodes):
+        # Step 0: Find nodes that match this pattern
+        # (quantize_per_tensor_dynamic - dequantize - dynamically quantized op)
+        # We search for the pattern backwards, starting with the quantize node
+        # Quantize node args: (func, scale, zp, dtype)
+        func_node = n
+        # Handle cases where the functional op is wrapped in a ReLU
+        if (
+            func_node.op == "call_function"
+            and func_node.target == F.relu
+            or func_node.op == "call_module"
+            and type(modules[str(func_node.target)]) == torch.nn.ReLU
+        ):
+            relu_node = func_node
+            func_node = relu_node.args[0]
+        else:
+            relu_node = None
+        if should_skip_lowering(func_node, qconfig_map):
+            continue
+        # Linear args: (dequantized inputs, dequantized weights[, bias])
+        # Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups])
+        if (
+            func_node.op != "call_function"
+            or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP
+        ):
+            continue
+        (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
+        if (
+            input_dq_node.op != "call_method"
+            or input_dq_node.target != "dequantize"
+            or weight_dq_node.op != "call_method"
+            or weight_dq_node.target != "dequantize"
+        ):
+            continue
+
+        input_dynamic_q_node = input_dq_node.args[0]
+
+        if (
+            input_dynamic_q_node.op != "call_function"
+            or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic
+        ):
+            continue
+
+        reduce_range_node = None
+        (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args
+        is_fp16 = activation_dtype == torch.float16
+        is_int8 = activation_dtype in [torch.quint8, torch.qint8]
+        if not is_int8 and not is_fp16:
+            continue
+
+        quantized_weight = weight_dq_node.args[0]
+        weight_dtype = quantized_weight.args[-1]
+
+        # Step 1: Try to select reference pattern with the corresponding quantized op
+        dynamic_quant_dtype_key = (activation_dtype, weight_dtype)
+        if (
+            dynamic_quant_dtype_key
+            not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target]
+        ):
+            print(
+                f"Didn't find dtype combination {dynamic_quant_dtype_key} during "
+                f"dynamic quantized op lowering for {func_node.target}"
+            )
+            continue
+        (q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][
+            dynamic_quant_dtype_key
+        ]
+
+        if q_func is None or q_relu_func is None:
+            print(
+                "Didn't find corresponding quantized function or quantized relu function "
+                f"for {func_node.target}, {dynamic_quant_dtype_key}"
+            )
+            continue
+
+        # Step 2: Replace quantized weights with packed weights, which will be folded later
+        # Use the right prepack op and prepare the corresponding args
+        # Linear prepack args: (quantized weights[, bias])
+        # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups])
+        prepack_args = [quantized_weight] + remaining_func_args
+        prepack_kwargs = {}
+        if func_node.target == F.linear:
+            prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
+            kwargs = func_node.kwargs.copy()
+            if "bias" in kwargs:
+                prepack_kwargs["B"] = kwargs["bias"]
+                del kwargs["bias"]
+                func_node.kwargs = kwargs
+        elif func_node.target in CONV_FUNCTIONAL_OPS:
+            prepack_op = get_qconv_prepack_op(func_node.target)
+            # For conv1d, the stride, padding, and dilation args may be ints,
+            # in which case we need to convert them to tuples
+            if func_node.target == F.conv1d:
+                for i in [2, 3, 4]:
+                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
+                        prepack_args[i] = (prepack_args[i],)
+        else:
+            raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
+        with model.graph.inserting_before(func_node):
+            packed_weight = model.graph.create_node(
+                "call_function", prepack_op, tuple(prepack_args), prepack_kwargs
+            )
+
+        # Step 3: Replace reference pattern with the corresponding quantized op
+        func_node.target = q_relu_func if relu_node is not None else q_func
+        if is_int8:
+            func_node.args = (pattern_input, packed_weight, reduce_range_node)
+        else:
+            func_node.args = (pattern_input, packed_weight)
+
+        if relu_node is not None:
+            relu_node.replace_all_uses_with(func_node)
+
+        # Step 4: Remove the relu node if it exists
+        if relu_node is not None:
+            model.graph.erase_node(relu_node)
+
+
+def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfigAny]):
+    binary_ops_to_lower: list[Callable] = [
+        operator.add,
+        torch.add,
+        operator.mul,
+        torch.mul,
+        torch.matmul,
+    ]
+    modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
+        (q_node, relu_node, bop_node) = _match_static_pattern(
+            n,
+            modules,
+            qconfig_map,
+            binary_ops_to_lower,
+            dequantize_node_arg_indices=[0, 1],
+        )
+        if q_node is None:
+            continue
+        assert bop_node is not None
+        (_, scale_node, zero_point_node, _) = q_node.args
+
+        # Step 1: Remove dequant nodes
+        num_dq_nodes = 0
+        for arg in bop_node.args:
+            if not is_dequantize_node(arg):
+                continue
+            dq_node = arg
+            assert isinstance(dq_node, Node)
+            dn_input = dq_node.args[0]
+            bop_node.replace_input_with(dq_node, dn_input)  # type: ignore[arg-type]
+            num_dq_nodes += 1
+        assert num_dq_nodes > 0
+
+        # Step 2: Swap binary op to quantized binary op
+        assert bop_node.target in QBIN_OP_MAPPING
+        binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING
+        qbin_op = binop_to_qbinop[bop_node.target]
+        # prepare the args for quantized binary op
+        # (x, y)
+        qop_node_args = list(bop_node.args)
+        # (x, y, scale, zero_point)
+        # add scale and zero_point arguments for Tensor - Tensor operation
+        if num_dq_nodes == 2:
+            qop_node_args.extend([scale_node, zero_point_node])
+        # insert a call to quantized binary op and remove the original binary op
+        with model.graph.inserting_after(q_node):
+            qop_node = create_node_from_old_node_preserve_meta(
+                model.graph,
+                ("call_function", qbin_op, tuple(qop_node_args), {}),
+                bop_node,
+            )
+            q_node.replace_all_uses_with(qop_node)
+
+        # Step 3: Remove quantize node, binary op node, and relu node if any
+        model.graph.erase_node(q_node)
+        if relu_node is not None:
+            model.graph.erase_node(relu_node)
+        model.graph.erase_node(bop_node)
+
+
+def special_pattern_replacement(model: GraphModule):
+    modules = dict(model.named_modules(remove_duplicate=False))
+    for n in model.graph.nodes:
+        q_node = n
+        is_quantize = q_node.target == torch.quantize_per_tensor
+        is_to_fp16 = (
+            q_node.op == "call_method"
+            and q_node.target == "to"
+            and len(q_node.args) == 2
+            and q_node.args[1] == torch.float16
+        )
+        if not (is_quantize or is_to_fp16):
+            continue
+        ref_node = q_node.args[0]
+        # get output scale/zero_point/dtype from the quantize node
+        # ref_node, scale_node, zero_point_node, dtype = q_node.args
+        # TODO: add safety checks that users for the ref_node and dq_node needs to be one
+        is_call_function, is_call_method, is_call_module = is_fixed_qparams_node(
+            ref_node, modules
+        )
+        if is_to_fp16 and (is_call_function or is_call_method or is_call_module):
+            # TODO: add a warning or error out here? (bc-breaking if error out)
+            # warnings.warn(
+            #     "Only reference patterns are currently supported for {dtype} dtype with {op} op"
+            #     "".format(dtype=dtypes, op=ref_node))
+            continue
+
+        is_call_function, is_call_method, is_call_module = is_default_node(
+            ref_node, modules
+        )
+        if is_to_fp16 and (is_call_function or is_call_method or is_call_module):
+            # TODO: add a warning or error out here? (bc-breaking if error out)
+            continue
+
+        # This check includes all supported ops
+        is_call_function, is_call_method, is_call_module = is_special_pattern_node(
+            ref_node, modules
+        )
+        if not (is_call_module or is_call_function or is_call_method):
+            continue
+        assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0
+        dq_node_or_nodes = (
+            ref_node.args[0]
+            if len(ref_node.args) > 0
+            else next(iter(ref_node.kwargs.values()))
+        )
+        assert isinstance(dq_node_or_nodes, (Node, tuple, list))
+        is_dequantize = False
+        if isinstance(dq_node_or_nodes, Node):
+            is_dequantize = (
+                dq_node_or_nodes.op == "call_method"
+                and dq_node_or_nodes.target == "dequantize"
+            )
+        elif isinstance(dq_node_or_nodes, (tuple, list)):
+            is_dequantize = all(
+                x.op == "call_method" and x.target == "dequantize"
+                for x in dq_node_or_nodes
+            )
+
+        if not is_dequantize:
+            continue
+
+        # TODO: enable we have patterns that needs to swap the modules
+        if is_call_module:
+            ref_module = modules[ref_node.target]
+            if type(ref_module) in SPECIAL_PATTERN_LOWER_MODULE_MAP and is_quantize:
+                qmodule_cls = SPECIAL_PATTERN_LOWER_MODULE_MAP.get(type(ref_module))
+                scale_node = q_node.args[1]
+                zero_point_node = q_node.args[2]
+                output_scale = getattr(model, scale_node.target)
+                output_zero_point = getattr(model, zero_point_node.target)
+
+                qmodule = qmodule_cls.from_reference(  # type:ignore[union-attr]
+                    ref_module, output_scale, output_zero_point
+                )
+                # replace reference module with quantized module
+                parent_name, module_name = _parent_name(ref_node.target)
+                setattr(modules[parent_name], module_name, qmodule)
+
+        # reroute around dq node:
+        dq_nodes: list[Node] = []
+        if isinstance(dq_node_or_nodes, Node):
+            dq_nodes = [dq_node_or_nodes]
+        elif isinstance(dq_node_or_nodes, (tuple, list)):
+            dq_nodes = list(dq_node_or_nodes)
+
+        for dq_node in dq_nodes:
+            dn_input = dq_node.args[0]
+            ref_node.replace_input_with(dq_node, dn_input)
+
+        # store q node args
+        qnode_qparams = list(q_node.args)[1:]
+        # replace uses of q node with input and remove q node
+        q_node_input = q_node.args[0]
+        q_node.replace_all_uses_with(q_node_input)
+        model.graph.erase_node(q_node)
+
+        is_call_function, is_call_method, is_call_module = is_default_node(
+            ref_node, modules
+        )
+        if is_call_function:
+            # pass scale/zer_point arguments from quantize_per_tensor to the default node operator
+            # insert an op after the zero_point node so that the scale/zero_point
+            # nodes are is available
+            qop = get_quantized_operator(ref_node.target)
+            args = list(ref_node.args)
+            kwargs = dict(ref_node.kwargs)
+            if qop in QOP_TO_ARG_NAMES_TO_SKIP:
+                args_to_skip = QOP_TO_ARG_NAMES_TO_SKIP[qop]
+                for arg in args_to_skip:
+                    if arg in kwargs:
+                        kwargs.pop(arg)
+            kwargs["output_scale"] = qnode_qparams[0]
+            kwargs["output_zero_point"] = qnode_qparams[1]
+            with model.graph.inserting_after(qnode_qparams[1]):
+                qop_node = create_node_from_old_node_preserve_meta(
+                    model.graph, ("call_function", qop, tuple(args), kwargs), ref_node
+                )
+                ref_node.replace_all_uses_with(qop_node)
+                model.graph.erase_node(ref_node)
+        else:
+            # remove scale/zero_point node for quantize node
+            for n in qnode_qparams:
+                if isinstance(n, Node):
+                    model.graph.erase_node(n)
+
+    return model
+
+
+def _lower_getattr_tensor_metadta_op(model: GraphModule):
+    """Modified the graph of the model inplace, to skip extra dequantize op before
+    the general tensor shape ops when possible
+    """
+    for n in model.graph.nodes:
+        if is_getattr_tensor_metadata_node(n):
+            maybe_dq = n.args[0]
+            if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize":
+                continue
+            # skip the dequantize node
+            args = list(n.args)
+            args[0] = n.args[0].args[0]
+            n.args = tuple(args)
+
+
+def _lower_get_tensor_info_op(model: GraphModule):
+    """Modified the graph of the model inplace, to skip extra dequantize op before
+    the general tensor shape ops when possible
+    """
+    for n in model.graph.nodes:
+        if not is_get_tensor_info_node(n):
+            continue
+        maybe_dq = n.args[0]
+        if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize":
+            continue
+        # skip the dequantize node
+        args = list(n.args)
+        args[0] = n.args[0].args[0]
+        n.args = tuple(args)
+
+
+def _lower_to_native_backend(
+    model: GraphModule,
+    qconfig_map: dict[str, QConfigAny],
+    node_name_to_scope: dict[str, tuple[str, type]],
+    keep_original_weights: bool = False,
+) -> GraphModule:
+    """Lower a quantized reference model (with reference quantized operator patterns)
+    to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same
+    operator signature so they can be lowered with the same function
+    """
+    _lower_static_weighted_ref_module(model, qconfig_map)
+    _lower_static_weighted_ref_module_with_two_inputs(model, qconfig_map)
+    _lower_dynamic_weighted_ref_module(model)
+    _lower_weight_only_weighted_ref_module(model)
+    _lower_static_weighted_ref_functional(model, qconfig_map)
+    _lower_dynamic_weighted_ref_functional(model, qconfig_map)
+    _lower_quantized_binary_op(model, qconfig_map)
+    _lower_getattr_tensor_metadta_op(model)
+    _lower_get_tensor_info_op(model)
+    special_pattern_replacement(model)
+    model.graph.eliminate_dead_code()
+    model = fold_weight(model, node_name_to_scope, keep_original_weights)
+    model.graph.eliminate_dead_code()
+    model.recompile()
+    model.graph.lint()
+    return model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..4625e287011c4f57e7669914b0da1bdf86d76fba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/detector.py
@@ -0,0 +1,1733 @@
+# mypy: allow-untyped-defs
+from abc import ABC, abstractmethod
+from typing import Any, Callable
+
+import torch
+import torch.ao.nn.qat as nnqat
+import torch.nn as nn
+from torch.ao.quantization.fake_quantize import FakeQuantize
+from torch.ao.quantization.fx._equalize import (
+    default_equalization_qconfig,
+    EqualizationQConfig,
+)
+from torch.ao.quantization.fx._model_report.model_report_observer import (
+    ModelReportObserver,
+)
+from torch.ao.quantization.fx.graph_module import GraphModule
+from torch.ao.quantization.observer import (
+    _is_activation_post_process,
+    default_dynamic_quant_observer,
+    default_observer,
+    default_per_channel_weight_observer,
+    default_weight_observer,
+    ObserverBase,
+)
+from torch.ao.quantization.qconfig import (
+    _assert_valid_qconfig,
+    default_qconfig,
+    QConfig,
+)
+
+
+# Names for observer insert keys
+DETECTOR_TARGET_NODE_KEY = "target_node"
+DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert"
+DETECTOR_IS_POST_OBS_KEY = "is_post_observer"
+DETECTOR_OBS_ARGS_KEY = "observer_args"
+
+
+# Mapping related code
+class DetectorQConfigInfo:
+    r"""
+    This class contains the QConfig information for a single module.
+    The list of variables / values this contains can grow depending on the
+    extensibility of the qconfig mapping feature set but this currently includes:
+    - if activation observer is dynamic
+    - if weight observer is per channel
+
+
+    Args:
+        module_fqn (str): The fully qualified name (fqn) of the module that this
+            information contains info relevant to qconfig for
+    """
+
+    def __init__(self, module_fqn: str):
+        super().__init__()
+        self.module_fqn = module_fqn
+
+        # populate this section with all the variables we might find important
+        # change from none if your detector is actually using this
+        self.is_activation_dynamic = False
+        self.is_weight_per_channel = False
+
+        # equalization related options
+        self.is_equalization_recommended = False
+
+    def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig:
+        r"""
+        Args:
+            module (torch.nn.Module) The module we are generating
+            the qconfig for
+
+        Returns the generated quantization QConfig according to what a valid configuration is
+        """
+        # Apply suggestions to new qconfig
+        module_qconfig = default_qconfig
+
+        # keep track of dynamic and per_channel recommendations
+        recommendations_list = []
+        # append as if a list of combinations
+        recommendations_list.append(
+            (self.is_activation_dynamic, self.is_weight_per_channel)
+        )
+        recommendations_list.append(
+            (self.is_activation_dynamic, False)
+        )  # only trying dynamic rec
+        recommendations_list.append(
+            (False, self.is_weight_per_channel)
+        )  # only trying dynamic
+
+        # now we try each of the combinations
+        for rec in recommendations_list:
+            # rec[0] -> dynamic recommended
+            # rec[1] -> per channel recommended
+            activation = default_dynamic_quant_observer if rec[0] else default_observer
+            weight = (
+                default_per_channel_weight_observer
+                if rec[1]
+                else default_weight_observer
+            )
+            test_config = QConfig(activation, weight)
+            try:
+                _assert_valid_qconfig(test_config, module)
+                module_qconfig = test_config
+                break
+            except AssertionError:
+                # if not a valid configuration, we move on to the next one in priority
+                continue
+
+        # return the QConfig chosen
+        return module_qconfig
+
+    def generate_equalization_qconfig(self) -> EqualizationQConfig:
+        r"""
+        This returns the equalization configuration for a module.
+
+        For now, it just returns the default, but as more equalization options become
+        possible, this method can get more fleshed out with more nuanced granularity.
+
+
+        Returns the generated equalization QConfig according to what a valid configuration is
+        """
+        # in this case, we just return default equalization config
+        # we know this is valid because only valid modules would even
+        # have this option
+        return default_equalization_qconfig
+
+
+# Adding base class for detectors
+class DetectorBase(ABC):
+    r"""Base Detector Module
+    Any detector class should derive from this class.
+
+    Concrete detectors should follow the same general API, which includes:
+    - A method to calculate and return observer insertion points
+        - Should return both the fqns and the Observer class to insert
+    - A method to return a report based on the detector
+        - Should return a str-based report and dict info in Tuple[str,Dict] format
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.detector_config_info = None
+
+    @abstractmethod
+    def determine_observer_insert_points(self, model) -> dict:
+        r"""
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict.
+            This dict maps string keys to detector specific information
+        """
+
+    @abstractmethod
+    def get_detector_name(self) -> str:
+        r"""Returns the name of the current detector"""
+
+    @abstractmethod
+    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
+        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+
+    def _get_targeting_node(
+        self, prepared_fx_model: GraphModule, target_fqn: str
+    ) -> torch.fx.node.Node:
+        r"""
+        Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn.
+
+        If it's not found, it means it is most likely inside a fused layer
+            We just go one layer up in terms of the fqn we are searching for until we find parent node
+            If we get to empty string, then we know that it doesn't exist
+
+        The reason for the recursion is that if the model that we are looking for got fused,
+        we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module,
+        which would have fqn as x.linear so they will not match.
+        To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear,
+        or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module
+        even in cases with fusion
+
+        Args:
+            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
+            target_fqn (str): The fqn of the layer we are trying to target
+
+        Returns the node object we are trying to add observers around
+        """
+        for node in prepared_fx_model.graph.nodes:
+            # if the node's target is our target, return it
+            if node.target == target_fqn:
+                return node
+
+        # getting here means node not found
+        # if no "." we are already at base and failed
+        parent_fqn_sep_index = target_fqn.rfind(".")
+        if parent_fqn_sep_index == -1:
+            raise ValueError("passed in target_fqn not found in graph's targets.")
+        else:
+            # recursively call it with parent fqn
+            return self._get_targeting_node(
+                prepared_fx_model, target_fqn[:parent_fqn_sep_index]
+            )
+
+    @abstractmethod
+    def generate_detector_report(self, model) -> tuple[str, dict[str, Any]]:
+        r"""
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Tuple of two elements:
+            Str: string report of the suggested improvements
+            Dict: contains useful data collected by the observer pertinent to this report
+        """
+
+
+class PerChannelDetector(DetectorBase):
+    r"""This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization.
+    Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
+
+    per_channel quantization can lead to major benefits in the form of accuracy.
+    Therefore, if the backend used by the user supports it, it is recommended to use
+
+    Args:
+        backend (str, optional): the backend the user wishes to use in production
+            Default value is current torch.backends.quantized.engine
+    """
+
+    # Keys for return dictionary
+    BACKEND_KEY = "backend"
+    PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported"
+    PER_CHAN_USED_KEY = "per_channel_quantization_used"
+
+    # Default map for representing supported per channel quantization modules for different backends
+    DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: dict[str, set[Any]] = {
+        "fbgemm": {
+            nn.Linear,
+            nn.Conv1d,
+            nn.Conv2d,
+            nn.Conv3d,
+            nnqat.Linear,
+            nnqat.Conv1d,
+            nnqat.Conv2d,
+            nnqat.Conv3d,
+        },
+        "qnnpack": {
+            nn.Linear,
+            nn.Conv1d,
+            nn.Conv2d,
+            nn.Conv3d,
+            nnqat.Linear,
+            nnqat.Conv1d,
+            nnqat.Conv2d,
+            nnqat.Conv3d,
+        },
+        "onednn": {
+            nn.Linear,
+            nn.Conv1d,
+            nn.Conv2d,
+            nn.Conv3d,
+            nnqat.Linear,
+            nnqat.Conv1d,
+            nnqat.Conv2d,
+            nnqat.Conv3d,
+        },
+        "x86": {
+            nn.Linear,
+            nn.Conv1d,
+            nn.Conv2d,
+            nn.Conv3d,
+            nnqat.Linear,
+            nnqat.Conv1d,
+            nnqat.Conv2d,
+            nnqat.Conv3d,
+        },
+    }
+
+    def __init__(self, backend: str = torch.backends.quantized.engine):
+        super().__init__()
+
+        # store the backend information
+        self.backend_chosen = backend
+        self.supported_modules = set()
+        if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
+            self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[
+                self.backend_chosen
+            ]
+        else:
+            raise ValueError(
+                f"Not configured to work with {self.backend_chosen}. Try a different default backend"
+            )
+
+    def get_detector_name(self) -> str:
+        r"""returns the string name of this detector"""
+        return "per_channel_detector"
+
+    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
+        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        # run the helper function to populate the dictionary
+        per_channel_info = self._detect_per_channel_helper(model)
+
+        # we actually have a qconfig info object we are populating
+        module_fqn_to_detector_qconfig_info = {}
+
+        for module_fqn in per_channel_info:
+            # create a detector info instance
+            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
+
+            # see if per channel quantization is supported
+            per_chan_supported: bool = per_channel_info[module_fqn][
+                self.PER_CHAN_SUPPORTED_KEY
+            ]
+            detector_qconfig_info.is_weight_per_channel = per_chan_supported
+            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
+
+        return module_fqn_to_detector_qconfig_info
+
+    def determine_observer_insert_points(self, model: nn.Module) -> dict:
+        r"""
+        There is no observers inserted for the PerChannelDetector.
+
+        Returns an empty dictionary since no observers are added or needed
+        """
+        return {}
+
+    def _detect_per_channel_helper(self, model: nn.Module):
+        r"""
+        determines if per_channel quantization is supported in modules and submodules.
+
+        Returns a dictionary in the higher level _detect_per_channel function.
+        Each entry maps the fully-qualified-name to information on whether per_channel quantization.
+
+        Args:
+            model: The current module that is being checked to see if it is per_channel quantizable
+
+        Returns dictionary mapping fqns to if per_channel quantization is possible
+        """
+        # create dict we will return
+        per_channel_info: dict = {}
+
+        # get the fully qualified name and check if in list of modules to include and list of modules to ignore
+        for fqn, module in model.named_modules():
+            is_in_include_list = any(
+                isinstance(module, x) for x in self.supported_modules
+            )
+
+            # check if the module per_channel is supported
+            # based on backend
+            per_channel_supported = False
+
+            if is_in_include_list:
+                per_channel_supported = True
+
+                # assert statement for MyPy
+                q_config_file = module.qconfig
+                assert isinstance(q_config_file, QConfig)
+
+                # this object should either be fake quant or observer
+                q_or_s_obj = module.qconfig.weight.p.func()
+                assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))
+
+                per_channel_used = False  # will be true if found in qconfig
+
+                if hasattr(
+                    q_or_s_obj, "ch_axis"
+                ):  # then we know that per_channel quantization used
+                    # all fake quants have channel axis so need to check is_per_channel
+                    if isinstance(q_or_s_obj, FakeQuantize):
+                        if (
+                            hasattr(q_or_s_obj, "is_per_channel")
+                            and q_or_s_obj.is_per_channel
+                        ):
+                            per_channel_used = True
+                    elif isinstance(q_or_s_obj, ObserverBase):
+                        # should be an observer otherwise
+                        per_channel_used = True
+                    else:
+                        raise ValueError("Should be either observer or fake quant")
+
+                per_channel_info[fqn] = {
+                    self.PER_CHAN_SUPPORTED_KEY: per_channel_supported,
+                    self.PER_CHAN_USED_KEY: per_channel_used,
+                    self.BACKEND_KEY: self.backend_chosen,
+                }
+
+        return per_channel_info
+
+    def generate_detector_report(self, model: nn.Module) -> tuple[str, dict[str, Any]]:
+        r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization.
+        Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
+
+        Looks at q_config format and backend to determine if per_channel can be utilized.
+        Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support
+
+        Args:
+            model: The prepared and calibrated model we want to check if using per_channel
+
+        Returns a tuple with two elements:
+            String report of potential actions to improve model (if per_channel quantization is available in backend)
+            Dictionary mapping per_channel quantizable elements to:
+                whether per_channel quantization is supported by the backend
+                if it is being utilized in the current model
+        """
+
+        # run the helper function to populate the dictionary
+        per_channel_info = self._detect_per_channel_helper(model)
+
+        # String to let the user know of further optimizations
+        further_optims_str = (
+            f"Further Optimizations for backend {self.backend_chosen}: \n"
+        )
+
+        optimizations_possible = False
+        for fqn in per_channel_info:
+            fqn_dict = per_channel_info[fqn]
+            if (
+                fqn_dict[self.PER_CHAN_SUPPORTED_KEY]
+                and not fqn_dict[self.PER_CHAN_USED_KEY]
+            ):
+                optimizations_possible = True
+                further_optims_str += (
+                    f"Module {fqn} can be configured to use per_channel quantization.\n"
+                )
+
+        if optimizations_possible:
+            further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
+        else:
+            further_optims_str += "No further per_channel optimizations possible."
+
+        # return the string and the dictionary form of same information
+        return (further_optims_str, per_channel_info)
+
+
+class DynamicStaticDetector(DetectorBase):
+    r"""
+    Determines whether dynamic or static quantization is more appropriate for a given module.
+
+    Takes advantage of the ModelReportObserver that records range information.
+    Stationary distribution of data are strictly above tolerance level for the comparison statistic:
+
+        S = average_batch_activation_range/epoch_activation_range
+
+    Nonstationary distributions are below or at the tolerance level for this metric.
+
+    If the distribution of data right after the module is non-stationary, recommend dynamic quantization
+        Otherwise recommend static quantization
+
+    Args:
+        tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
+    """
+
+    # names for the pre and post observers that are inserted
+    DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
+    DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"
+
+    # naming conventions for stationary vs non-stationary data
+    STATIONARY_STR = "stationary"
+    NON_STATIONARY_STR = "non-stationary"
+
+    # naming for activation
+    INPUT_ACTIVATION_PREFIX = "input_activation_"
+    OUTPUT_ACTIVATION_PREFIX = "output_activation_"
+
+    # naming conventions for the keys of the return module info
+    TOLERANCE_KEY = "dynamic_static_tolerance"
+    DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended"
+    PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
+    POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
+    PRE_OBS_DATA_DIST_KEY = (
+        INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
+    )
+    POST_OBS_DATA_DIST_KEY = (
+        OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
+    )
+    IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"
+
+    # modules that are supported both dynamic and static for this report function
+    DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}
+
+    # modules that will be supported soon for both
+    DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}
+
+    def __init__(self, tolerance=0.5):
+        super().__init__()
+
+        # set tolerance level and initialize a set to keep track of useful fqn locations
+        self.tolerance = tolerance
+        self.useful_observer_fqns: set[str] = set()
+
+    def determine_observer_insert_points(
+        self, prepared_fx_model: GraphModule
+    ) -> dict[str, dict[str, Any]]:
+        r"""
+        Determines where observers need to be inserted for the Dynamic vs Static detector.
+        For this detector, we want to place observers on either side of linear layers in the model.
+
+        Currently inserts observers for:
+            linear layers
+
+        Args:
+            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
+            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
+            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
+            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
+            key "observer_args" -> The arguments that are meant to be passed into the observer
+        """
+
+        # observer for this detector is ModelReportObserver
+        obs_ctr = ModelReportObserver
+
+        # return dict
+        obs_fqn_to_info: dict[str, dict[str, Any]] = {}
+
+        for fqn, module in prepared_fx_model.named_modules():
+            # make sure module is supported
+            if self._is_supported(module, insert=True):
+                # if it's a supported type, we want to get node and add observer insert locations
+                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
+
+                # add entry for pre-observer
+                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
+
+                obs_fqn_to_info[pre_obs_fqn] = {
+                    DETECTOR_TARGET_NODE_KEY: targeted_node,
+                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
+                    DETECTOR_IS_POST_OBS_KEY: False,
+                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
+                }
+
+                # add entry for post-observer
+                post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME
+
+                obs_fqn_to_info[post_obs_fqn] = {
+                    DETECTOR_TARGET_NODE_KEY: targeted_node,
+                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
+                    DETECTOR_IS_POST_OBS_KEY: True,
+                    DETECTOR_OBS_ARGS_KEY: (targeted_node,),
+                }
+
+        return obs_fqn_to_info
+
+    def get_detector_name(self) -> str:
+        r"""returns the string name of this detector"""
+        return "dynamic_vs_static_detector"
+
+    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
+        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        # run the helper function to populate the dictionary
+        dynamic_static_info = self._generate_dict_info(model)
+
+        # we actually have a qconfig info object we are populating
+        module_fqn_to_detector_qconfig_info = {}
+
+        for module_fqn in dynamic_static_info:
+            # create a detector info instance
+            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
+
+            # see if per channel quantization is supported
+            dynamic_static_recommended: bool = dynamic_static_info[module_fqn][
+                self.DEFAULT_DYNAMIC_REC_KEY
+            ]
+            detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended
+            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
+
+        return module_fqn_to_detector_qconfig_info
+
+    def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
+        r"""Returns whether the given module is supported for observers
+
+        Args
+            module: The module to check and ensure is supported
+            insert: True if this is check for observer insertion, false if for report gen
+
+        Returns True if the module is supported by observer, False otherwise
+        """
+        # check to see if module is of a supported type
+        is_supported_type = any(
+            isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED
+        )
+
+        # check if it will be supported
+        future_supported_type = any(
+            isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED
+        )
+
+        # supported
+        supported = is_supported_type or future_supported_type
+
+        # this is check for observer insertion
+        if insert:
+            return supported
+        else:
+            # this is for report gen and we also need to check if it contains observers
+            has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(
+                module, self.DEFAULT_POST_OBSERVER_NAME
+            )
+            return supported and has_obs
+
+    def _generate_dict_info(self, model: GraphModule) -> dict[str, Any]:
+        r"""
+        Helper function for generate_detector_report that does the generation of the dictionary.
+        This process is done as specified in generate_detector_report documentation
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a Dictionary mapping modules with ModelReportObservers around them to:
+                whether dynamic quantization is recommended
+                their S metric of input to module
+                whether input to module is stationary or non-stationary
+                their S metric of output of module
+                whether output of module is stationary or non-stationary
+                the tolerance level to decided whether input/output is stationary or non-stationary
+                whether it is currently supported or planned for the future
+        """
+        # store modules dynamic vs static information
+        module_dynamic_static_info = {}
+
+        # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info
+        #   This information primary includes whether the data distributions around a supported module is stationary or not
+        #   Based on this, it is recorded whether dynamic or static quantization is recommended
+
+        # loop through all submodules included nested ones
+        for fqn, module in model.named_modules():
+            # if module is Linear has the ModelReportObserver attached to it
+            if self._is_supported(module):
+                # get pre and post observers for the module
+                pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+                post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME)
+
+                # get the statistics for each module
+                pre_stat = pre_obs.get_batch_to_epoch_ratio()
+                post_stat = post_obs.get_batch_to_epoch_ratio()
+
+                # record module, pre and post stat, and whether to do dynamic or static based off it
+                # true if post observer data distribution is non-stationary, false if it's stationary
+                dynamic_recommended = post_stat <= self.tolerance
+
+                # specify the classifications for whether data distributions considered stationary or non-stationary
+                pre_obs_dist_classif = (
+                    self.STATIONARY_STR
+                    if pre_stat > self.tolerance
+                    else self.NON_STATIONARY_STR
+                )
+                post_obs_dist_classif = (
+                    self.STATIONARY_STR
+                    if post_stat > self.tolerance
+                    else self.NON_STATIONARY_STR
+                )
+
+                # check if current support or future support
+                is_supported_type = any(
+                    isinstance(module, x)
+                    for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED
+                )
+
+                # store the set of important information for this module
+                module_info = {
+                    self.TOLERANCE_KEY: self.tolerance,
+                    self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended,
+                    self.PRE_OBS_COMP_STAT_KEY: pre_stat,
+                    self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif,
+                    self.POST_OBS_COMP_STAT_KEY: post_stat,
+                    self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif,
+                    self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type,
+                }
+
+                module_dynamic_static_info[fqn] = module_info
+
+        return module_dynamic_static_info
+
+    def generate_detector_report(
+        self, model: GraphModule
+    ) -> tuple[str, dict[str, Any]]:
+        r"""
+        Determines whether dynamic or static quantization is more appropriate for a given module.
+
+        Takes advantage of the ModelReportObserver that records range information.
+        Stationary distribution of data are strictly above tolerance level for the comparison statistic:
+
+            S = average_batch_activation_range/epoch_activation_range
+
+        Nonstationary distributions are below or at the tolerance level for this metric.
+
+        If the distribution of data right after the module is non-stationary, recommend dynamic quantization
+            Otherwise recommend static quantization
+
+        This will then generate suggestions for dynamic vs static quantization focused around Linear.
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a tuple with two elements:
+            String report of of whether dynamic or static quantization is recommended for certain modules
+            Dictionary mapping modules with ModelReportObservers around them to:
+                whether dynamic quantization is recommended
+                their S metric of input to module
+                whether input to module is stationary or non-stationary
+                their S metric of output of module
+                whether output of module is stationary or non-stationary
+                the tolerance level to decided whether input/output is stationary or non-stationary
+                whether it is currently supported or planned for the future
+        """
+
+        # get the dictionary of the information to format the string report
+        module_dynamic_static_info = self._generate_dict_info(model)
+
+        dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"
+
+        modules_added: bool = False  # check to make sure at least 1 module added.
+
+        dynamic_benefit = (
+            " You will get more accurate results if you use dynamic quantization"
+        )
+        static_benefit = (
+            " You can increase model efficiency if you use static quantization"
+        )
+        future_support_str = (
+            ". This layer is not yet supported for dynamic quantization"
+        )
+        # This for loop goes through the information collected in module_dynamic_static_info and:
+        #   Populates the string based report with the information from module_dynamic_static_info
+        #   Compiles the complete report by appending relevant formatted strings
+
+        for module_fqn in module_dynamic_static_info.keys():
+            # there is at least 1 module for suggestion
+            modules_added = True
+            module_info = module_dynamic_static_info[module_fqn]
+            suggestion_string_template = (
+                "For module {} it is suggested to use {} quantization because {}.\n"
+            )
+
+            # decide what string formatting values will be
+            quantization_type = ""
+            quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."
+
+            benefit_str = ""
+
+            # strings for if dynamic quantized per tensor is needed
+            recommend_per_tensor = (
+                ". We recommend to add a {} before this module if it is static."
+            )
+            rec_lay_to_add = "dynamic quantize per tensor layer"
+            dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
+            dynamic_per_tensor_reasoning_string = " This is because the input to this module has a non-stationary distribution"
+
+            # start composing explanation
+            if module_info[self.DEFAULT_DYNAMIC_REC_KEY]:
+                quantization_type = "dynamic"
+                # check if currently supported or future supported
+                benefit_str = dynamic_benefit
+                if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]:
+                    benefit_str += future_support_str
+            else:
+                quantization_type = "static"
+                benefit_str = static_benefit
+
+            # now set the quantization explanation string
+            quantization_reasoning = (
+                quantization_reasoning.format(
+                    module_fqn,
+                    module_info[self.PRE_OBS_DATA_DIST_KEY],
+                    module_info[self.POST_OBS_DATA_DIST_KEY],
+                )
+                + benefit_str
+            )
+
+            # if we have a non-stationary input -> linear -> stationary we suggested static
+            # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
+            if (
+                module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR
+                and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR
+            ):
+                quantization_reasoning = (
+                    quantization_reasoning
+                    + dynamic_per_tensor_string
+                    + dynamic_per_tensor_reasoning_string
+                )
+
+            # format the overall suggestion string with the specific inputs
+            module_suggestion_string = suggestion_string_template.format(
+                module_fqn, quantization_type, quantization_reasoning
+            )
+
+            # append to overall suggestion
+            dynamic_vs_static_string += module_suggestion_string
+
+        if not modules_added:
+            dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n"
+
+        # return the string as well as the dictionary of information
+        return (dynamic_vs_static_string, module_dynamic_static_info)
+
+
+class InputWeightEqualizationDetector(DetectorBase):
+    r"""
+    Determines whether input-weight equalization can help improve quantization for certain modules.
+
+    Specifically, this list of modules includes:
+        linear
+        conv
+
+    Determines whether input-weight equalization is recommended based on the comp stat:
+        s_c = sqrt(w_c/W)/sqrt(i_c/I)
+        where:
+            w_c is range of weight for channel c, W is range of weight over all channels
+            i_c is range of input for channel c, I is range of input over all channels
+
+        if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization
+
+    Args:
+        ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested
+            Should be between 0 and 1 (both non-inclusive)
+        ch_axis (int, optional): The channel axis being observed to determine input weight equalization
+            Default: 1
+
+    * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested
+        Should be between 0 and 1
+
+    * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization
+
+    * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization
+
+    * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
+    """
+
+    SUPPORTED_MODULES: set[Callable] = {
+        nn.Linear,
+        nn.Conv1d,
+        nn.Conv2d,
+        nn.Conv3d,
+        nnqat.Linear,
+        nnqat.Conv1d,
+        nnqat.Conv2d,
+        nnqat.Conv3d,
+    }
+
+    # names for the pre and post observers that are inserted
+    DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
+
+    # weight / activation prefix for each of the below info
+    WEIGHT_PREFIX = "weight_"
+    ACTIVATION_PREFIX = "input_activation_"
+
+    # string names for keys of info dictionaries
+    PER_CHANNEL_MAX_KEY = "per_channel_max"
+    PER_CHANNEL_MIN_KEY = "per_channel_min"
+    GLOBAL_MAX_KEY = "global_max"
+    GLOBAL_MIN_KEY = "global_min"
+
+    # keys for return dict of recommendations
+    RECOMMENDED_KEY = "input_weight_equalization_recommended"
+    COMP_METRIC_KEY = "input_weight_channel_comparison_metrics"
+    THRESHOLD_KEY = "input_weight_threshold"
+    CHANNEL_KEY = "input_weight_channel_axis"
+
+    # default weight and info strings
+    WEIGHT_STR = "weight"
+    INPUT_STR = "input"
+
+    # default for what ratio we recommend input weight
+    DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4
+
+    def __init__(self, ratio_threshold: float, ch_axis: int = 1):
+        # ensure passed in inputs are valid
+        if ratio_threshold <= 0 or ratio_threshold >= 1:
+            raise ValueError("Make sure threshold is > 0 and < 1")
+
+        # initialize attributes based on args
+        self.ratio_threshold: float = ratio_threshold
+        self.ch_axis: int = ch_axis
+
+    def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
+        r"""Returns whether the given module is supported for observers
+
+        Args
+            module: The module to check and ensure is supported
+            insert: True if this is check for observer insertion, false if for report gen
+
+        Returns True if the module is supported by observer, False otherwise
+        """
+        # check to see if module is of a supported type
+        is_supported_type = any(type(module) is x for x in self.SUPPORTED_MODULES)
+
+        # this is check for observer insertion
+        if insert:
+            return is_supported_type
+        else:
+            # this is for report gen and we also need to check if it contains observers
+            has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+            return is_supported_type and has_obs
+
+    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
+        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        # run the helper function to populate the dictionary
+        # find the range of inputs
+        input_values: dict[str, dict] = self._extract_input_info(model)
+
+        # find the range of weights
+        weight_values: dict[str, dict] = self._extract_weight_info(model)
+
+        # calculate per_channel comparison statistic s_c
+        comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values(
+            input_values, weight_values
+        )
+
+        # generate the return dictionary
+        input_weight_equalization_info: dict[str, dict] = self._generate_dict_info(
+            input_values, weight_values, comp_stats
+        )
+
+        # we actually have a qconfig info object we are populating
+        module_fqn_to_detector_qconfig_info = {}
+
+        for module_fqn in input_weight_equalization_info:
+            # create a detector info instance
+            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
+
+            # see if per channel quantization is supported
+            input_weight_recommended: bool = input_weight_equalization_info[module_fqn][
+                self.RECOMMENDED_KEY
+            ]
+            detector_qconfig_info.is_equalization_recommended = input_weight_recommended
+            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
+
+        return module_fqn_to_detector_qconfig_info
+
+    def determine_observer_insert_points(
+        self, prepared_fx_model: GraphModule
+    ) -> dict[str, dict[str, Any]]:
+        r"""Determines where observers need to be inserted for the Input Weight Equalization Detector.
+        For this detector, we want to place observers in front of supported layers.
+
+        Currently inserts observers for:
+            linear layers
+            conv layers
+
+        Args:
+            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
+            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
+            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
+            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
+            key "observer_args" -> The arguments that are meant to be passed into the observer
+        """
+
+        # observer for this detector is ModelReportObserver
+        obs_ctr = ModelReportObserver
+
+        # return dict
+        obs_fqn_to_info: dict[str, dict[str, Any]] = {}
+
+        for fqn, module in prepared_fx_model.named_modules():
+            # check to see if module is of a supported type
+            if self._is_supported(module, insert=True):
+                # if it's a supported type, we want to get node and add observer insert locations
+                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
+
+                # add entry for pre-observer
+                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
+
+                obs_fqn_to_info[pre_obs_fqn] = {
+                    DETECTOR_TARGET_NODE_KEY: targeted_node,
+                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis),
+                    DETECTOR_IS_POST_OBS_KEY: False,
+                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
+                }
+
+        return obs_fqn_to_info
+
+    def get_detector_name(self) -> str:
+        r"""Returns the name of this detector"""
+        return "input_weight_equalization_detector"
+
+    def _extract_input_info(self, model: GraphModule) -> dict[str, dict]:
+        r"""
+        Takes in a calibrated GraphModule and then finds the relevant observers.
+        It then extracts the input information for each observer returns it
+
+        Args
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a dict mapping relevant module fqns (str) to a dict with keys:
+            "input_activation_per_channel_max" : maps to the per_channel max values
+            "input_activation_per_channel_min" : maps to the per_channel min values
+            "input_activation_global_max" : maps to the global max recorded
+            "input_activation_global_min" : maps to the global min recorded
+        """
+
+        # return dictionary mapping observer fqns to desired info
+        input_info: dict[str, dict] = {}
+
+        for fqn, module in model.named_modules():
+            # if module is supported and it has a pre-observer
+            if self._is_supported(module):
+                # get pre observer for the module
+                pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+
+                input_info[fqn] = {
+                    self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val,
+                    self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val,
+                    self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val),
+                    self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val),
+                }
+
+        return input_info
+
+    def _extract_weight_info(self, model: GraphModule) -> dict[str, dict]:
+        r"""
+        Takes in a calibrated GraphModule and then finds the relevant observers.
+        It then extracts the weight information for each layer an observer is attached to.
+
+        Args
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a dict mapping module fqns (str) to a dict with keys:
+            "per_channel_max" : maps to the per_channel max values
+            "per_channel_min" : maps to the per_channel min values
+            "global_max" : maps to the global max recorded
+            "global_min" : maps to the global min recorded
+        """
+        # return dictionary mapping observer fqns to desired info
+        weight_info: dict[str, dict] = {}
+
+        for fqn, module in model.named_modules():
+            # if module is supported and it has a pre-observer
+            if self._is_supported(module):
+                # we don't need actual observer, just the module weights
+                # calculate min and max vals
+                device = module.weight.device
+                min_val: torch.Tensor = torch.tensor([float("inf")], device=device)
+                max_val: torch.Tensor = torch.tensor([float("-inf")], device=device)
+                x_copy = module.weight
+                x_dim = x_copy.size()
+
+                new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
+                new_axis_list[self.ch_axis] = 0
+                new_axis_list[0] = self.ch_axis
+                y = x_copy.permute(new_axis_list)
+
+                # Need to match dtype of min/max because the updates to buffers
+                # are done in place and types need to match for comparisons
+                y = y.to(min_val.dtype)
+                y = torch.flatten(y, start_dim=1)
+                if min_val.numel() == 0 or max_val.numel() == 0:
+                    min_val, max_val = torch.aminmax(y, dim=1)
+                else:
+                    min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
+                    min_val = torch.min(min_val_cur, min_val)
+                    max_val = torch.max(max_val_cur, max_val)
+
+                weight_info[fqn] = {
+                    self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val,
+                    self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val,
+                    self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val),
+                    self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val),
+                }
+
+        return weight_info
+
+    def _calculate_range_ratio(
+        self, info_dict: dict, info_str: str, module_fqn: str
+    ) -> torch.Tensor:
+        r"""
+        Takes in an info dict and calculates the s_c matrix.
+
+        Args:
+            info_dict (dict): A dictionary of either input or weight range info
+            info_str (str): A str describing whether currently looking at weight or input info
+                Either "weight" or "input"
+            module_fqn (str): The fqn of the module we are looking at
+
+        Returns a tensor of values, where each value is the s_c stat for a different channel
+        """
+        # calculate the ratios of the info
+        # get the prefix str
+        prefix_str = (
+            self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX
+        )
+
+        per_channel_range = (
+            info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY]
+            - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY]
+        )
+        global_range = (
+            info_dict[prefix_str + self.GLOBAL_MAX_KEY]
+            - info_dict[prefix_str + self.GLOBAL_MIN_KEY]
+        )
+
+        if global_range == 0:
+            range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information."
+            raise ValueError(
+                f"The range of the {info_str} data for module {module_fqn} is 0, "
+                f"which means you have a constant value channel. {range_zero_explanation}"
+            )
+
+        ratio = per_channel_range / global_range
+
+        return ratio
+
+    def _generate_comparison_values(
+        self, input_info: dict, weight_info: dict
+    ) -> dict[str, torch.Tensor]:
+        r"""
+        Takes in the information on the min and max values of the inputs and weights and:
+            Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I)
+
+        Args:
+            input_info (dict): A dict mapping each observer to input range information
+            weight_info (dict): A dict mapping each observer to weight range information
+
+        Returns a dict mapping relevant observer fqns (str) to a 1-D tensor.
+            Each value is a different s_c value for a different channel
+        """
+        # create return dictionary for each observer
+        module_fqn_to_channel: dict[str, torch.Tensor] = {}
+
+        # for each module (both passed in dicts should have same keys)
+        for module_fqn in input_info:
+            # raise error if not in weight info
+            if module_fqn not in weight_info:
+                raise KeyError(
+                    f"Unable to find weight range stats for module {module_fqn}"
+                )
+
+            # calculate the ratios of the weight info and input info
+            weight_ratio = self._calculate_range_ratio(
+                weight_info[module_fqn], self.WEIGHT_STR, module_fqn
+            )
+            input_ratio = self._calculate_range_ratio(
+                input_info[module_fqn], self.INPUT_STR, module_fqn
+            )
+
+            # if mismatched size, because of grouping, we want to replicate weight enough times
+            weight_channels = len(weight_ratio)
+            input_channels = len(input_ratio)
+            if weight_channels != input_channels:
+                # we try to replicate
+                assert input_channels % weight_channels == 0, (
+                    "input channels should be divisible by weight channels."
+                )
+                # get replication factor
+                rep_factor: int = input_channels // weight_channels
+
+                # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n
+                weight_ratio = weight_ratio.repeat(rep_factor)
+
+            # calculate the s metric per channel
+            s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio)
+            module_fqn_to_channel[module_fqn] = s
+
+        # return compiled observer ratios
+        return module_fqn_to_channel
+
+    def _generate_dict_info(
+        self, input_info: dict, weight_info: dict, comp_stats: dict
+    ) -> dict[str, dict]:
+        r"""
+        Helper function for generate_detector_report that does the generation of the dictionary.
+        This process is done as specified in generate_detector_report documentation
+
+        Args:
+            input_info (dict): A dict mapping each module to input range information
+            weight_info (dict): A dict mapping each module to weight range information
+            comp_stats (dict): A dict mapping each module to its corresponding comp stat
+
+        Returns a dictionary mapping each module with relevant ModelReportObservers around them to:
+            whether input weight equalization is recommended
+            their s_c metric compared to the threshold
+            the threshold used to make the recommendation
+            the channel used for recording data
+            the input channel range info
+            the weight channel range info
+        """
+        # store modules input weight equalization info
+        input_weight_equalization_info: dict[str, dict] = {}
+
+        # for each module we add separate set of suggestions
+        for module_fqn in input_info:
+            # get relevant info for this module
+            mod_input_info: dict = input_info[module_fqn]
+            mod_weight_info: dict = weight_info[module_fqn]
+            mod_comp_stat: dict = comp_stats[module_fqn]
+
+            # decide if each channel should have input weight equalization or not
+            channel_rec_vals: list = []
+
+            for val in mod_comp_stat:
+                float_rep: float = val.item()
+
+                # decide if recommending input weight equalization
+                recommended: bool = (
+                    float_rep >= self.ratio_threshold
+                    and float_rep <= 1 / self.ratio_threshold
+                )
+                channel_rec_vals.append(recommended)
+
+            # build the return dict input
+            # also unpack input and weight dicts into it
+            input_weight_equalization_info[module_fqn] = {
+                self.RECOMMENDED_KEY: channel_rec_vals,
+                self.COMP_METRIC_KEY: mod_comp_stat,
+                self.THRESHOLD_KEY: self.ratio_threshold,
+                self.CHANNEL_KEY: self.ch_axis,
+                **mod_input_info,
+                **mod_weight_info,
+            }
+
+        # return our compiled info for each module
+        return input_weight_equalization_info
+
+    def generate_detector_report(
+        self, model: GraphModule
+    ) -> tuple[str, dict[str, Any]]:
+        r"""
+        Determines whether input weight equalization is appropriate for a given module.
+
+        Takes advantage of the ModelReport Observer which records per channel information of input range
+        It then uses the passed in weight info inconjunction to compute the desired ratio
+        Finally, it gives suggestions based on this information for each module of interest
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a tuple with two elements:
+            String report of of whether input weight equalization is recommended for certain modules
+            Dictionary mapping modules of interest to:
+                whether input weight equalization is recommended
+                their s_c metric compared to the threshold
+                the threshold used to make the recommendation
+                the channel used for recording data
+                the input channel range info
+                the weight channel range info
+        """
+
+        # find the range of inputs
+        input_values: dict[str, dict] = self._extract_input_info(model)
+
+        # find the range of weights
+        weight_values: dict[str, dict] = self._extract_weight_info(model)
+
+        # calculate per_channel comparison statistic s_c
+        comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values(
+            input_values, weight_values
+        )
+
+        # generate the return dictionary
+        input_weight_equalization_info: dict[str, dict] = self._generate_dict_info(
+            input_values, weight_values, comp_stats
+        )
+
+        # now we can generate report based on this information
+        input_weight_string = "Input-Weight Equalization suggestions: \n"
+
+        # some strings to be formatted depending on module we are adding
+        module_suggestion_str = "For Module {} looked at with axis {}: \n"
+        channel_suggestion_str = (
+            "\tWe suggest {} input weight equalization because {}\n"
+        )
+        use_str = "to use"
+        no_use_str = "to not use"
+        input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error."
+        input_weight_non_benefit_reasoning = (
+            "{}/{} channels benefitting from input-weight equalization being applied."
+        )
+        input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}"
+
+        # added module check
+        added_module: bool = False
+
+        # compile the suggestion string
+        for module_fqn in input_weight_equalization_info:
+            # we added at least 1 module
+            added_module = True
+            # add the module level description
+            input_weight_string += module_suggestion_str.format(
+                module_fqn, self.ch_axis
+            )
+
+            mod_info: dict[str, Any] = input_weight_equalization_info[module_fqn]
+
+            # gather info on how many channels would benefit from input weight and
+            recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY]
+            num_recs = sum(recommendation_per_channel)
+
+            if (
+                num_recs / len(recommendation_per_channel)
+                >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO
+            ):
+                input_benefit_formatted = input_weight_benefit_str.format(
+                    num_recs, len(recommendation_per_channel)
+                )
+                channel_str = channel_suggestion_str.format(
+                    use_str, input_benefit_formatted
+                )
+                input_weight_string += channel_str
+            else:
+                non_benefit_reason_formatted = (
+                    input_weight_non_benefit_reasoning.format(
+                        num_recs, len(recommendation_per_channel)
+                    )
+                )
+                non_benefit_str = input_weight_non_benefit_str.format(
+                    non_benefit_reason_formatted
+                )
+                channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str)
+                input_weight_string += channel_str
+
+        # if no modules looked at, amend return string
+        if not added_module:
+            input_weight_string += (
+                "No applicable layers for suggestions. Only linear and conv valid.\n"
+            )
+
+        # return a tuple with the string explanation and the compiled dict info
+        return (input_weight_string, input_weight_equalization_info)
+
+
+class OutlierDetector(DetectorBase):
+    r"""
+    Determines whether there are significant outliers in activation data around a certain layer.
+
+    This is ideally used in conjunction with information on stationary vs. non-stationary distribution:
+        If the data is stationary, and there are significant outliers, then we want to flag them
+        We want to do this on a per channel basis for detecting outliers
+
+    Determines whether activation data is flagged as outlier based on if data is stationary and:
+        p_r = avg(100th percentile / "reference_percentile"th percentile)
+        where:
+            p_r is average percentile ratio across all batches in the epoch
+            reference_percentile is a percentile values between 0 and 100 exclusive
+
+        if p_r is above some threshold, then we consider the activations to have significant outliers
+
+    Args:
+        ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations
+            Should be >= 1
+            Default: 3.5
+        reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
+            Should be between 0 and 1
+            Default: 0.975
+        fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
+            If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
+            regardless of whether we detected outliers or not in channel to take a closer look at channel results
+            Should be between 0 and 1
+            Default: 0.95
+        ch_axis (int, optional): The channel axis being observed to determine input weight equalization
+            Default: 1
+
+    * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations
+        The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold
+        If it is significantly greater, then we consider it an outlier
+        This threshold was calculated based on the ratio of the percentiles in a normal distribution
+        The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
+
+    * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
+        Should be between 0 and 1
+        The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
+
+    * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
+        Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
+        Should be between 0 and 1
+
+    * :attr:`ch_axis`: The channel axis being observed to determine outliers
+
+    * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
+    """
+
+    # names for the pre observers that are inserted
+    DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
+
+    # pre activation prefix
+    INPUT_ACTIVATION_PREFIX = "input_activation_"
+
+    # names for dict keys
+    OUTLIER_KEY = "outliers_detected"
+    NUM_BATCHES_KEY = "outlier_detection_batches_used"
+    IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches"
+    COMP_METRIC_KEY = "outlier_detection_percentile_ratios"
+    RATIO_THRES_KEY = "outlier_detection_ratio_threshold"
+    REF_PERCENTILE_KEY = "outlier_detection_reference_percentile"
+    CHANNEL_AXIS_KEY = "outlier_detection_channel_axis"
+    MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max"
+    CONSTANT_COUNTS_KEY = "constant_batch_counts"
+
+    def __init__(
+        self,
+        ratio_threshold: float = 3.5,
+        reference_percentile: float = 0.975,
+        fraction_batches_used_threshold: float = 0.95,
+        ch_axis: int = 1,
+    ):
+        # initialize the variables of interest
+        self.ratio_threshold = ratio_threshold
+
+        # make sure passed in percentile is valid
+        assert reference_percentile >= 0 and reference_percentile <= 1
+        assert (
+            fraction_batches_used_threshold >= 0
+            and fraction_batches_used_threshold <= 1
+        )
+        self.reference_percentile = reference_percentile
+        self.fraction_batches_used_threshold = fraction_batches_used_threshold
+        self.ch_axis = ch_axis
+
+    def get_detector_name(self) -> str:
+        r"""Returns the name of this detector"""
+        return "outlier_detector"
+
+    def _supports_insertion(self, module: nn.Module) -> bool:
+        r"""Returns whether the given module is supported for observers insertion
+
+        Any module that doesn't have children and isn't an observer itself is supported
+
+        Args
+            module: The module to check and ensure is supported
+
+        Returns True if the module is supported by observer, False otherwise
+        """
+        # case for insertion of module
+        # check if the module has any children and isn't observer
+        num_children = len(list(module.children()))
+        return num_children == 0 and not _is_activation_post_process(module)
+
+    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
+        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
+        Args
+            model (nn.Module or subclass): model to find observer insertion points
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
+            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
+        """
+        # currently doesn't do anything for outlier detector
+        return {}
+
+    def _supports_report_gen(self, module: nn.Module) -> bool:
+        r"""Returns whether the given module is supported for report generation
+
+        Any module that has a model report pre-observer is supported
+
+        Args
+            module: The module to check and ensure is supported
+
+        Returns True if the module is supported by observer, False otherwise
+        """
+        return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
+
+    def determine_observer_insert_points(
+        self, prepared_fx_model: GraphModule
+    ) -> dict[str, dict[str, Any]]:
+        r"""Determines where observers need to be inserted for the Outlier Detector.
+
+        For this detector, we want to place observers in front of supported layers.
+
+        Currently inserts observers for:
+            all layers that do not have children (leaf level layers)
+
+        Args:
+            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
+
+        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
+            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
+            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
+            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
+            key "observer_args" -> The arguments that are meant to be passed into the observer
+        """
+        # observer for this detector is ModelReportObserver
+        obs_ctr = ModelReportObserver
+
+        # return dict
+        obs_fqn_to_info: dict[str, dict[str, Any]] = {}
+
+        for fqn, module in prepared_fx_model.named_modules():
+            # check to see if module is of a supported type
+            if self._supports_insertion(module):
+                # if it's a supported type, we want to get node and add observer insert locations
+                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
+
+                # add entry for pre-observer
+                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
+
+                obs_fqn_to_info[pre_obs_fqn] = {
+                    DETECTOR_TARGET_NODE_KEY: targeted_node,
+                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(
+                        ch_axis=self.ch_axis, comp_percentile=self.reference_percentile
+                    ),
+                    DETECTOR_IS_POST_OBS_KEY: False,
+                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
+                }
+
+        return obs_fqn_to_info
+
+    def _calculate_outlier_info(
+        self,
+        percentile_ratios: torch.Tensor,
+        counted_batches: torch.Tensor,
+        total_batches: int,
+    ) -> dict[str, list[bool]]:
+        r"""
+        Gives info on whether the percentile ratios calculated would be considered outliers
+        Also gives information on whether the collected data is statistically significant to make this claim
+
+        Args:
+            percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
+            counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
+            total_batches (int): The total number of batches that passed through observer in this epoch
+
+        Returns a dictionary mapping:
+            "outliers_detected" : list of bools per channel that are true if it is considered an outlier
+            "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
+                where o_r = counted_batches / total_batches
+        """
+        outlier_dict: dict[str, list[bool]] = {
+            self.OUTLIER_KEY: [],
+            self.IS_SUFFICIENT_BATCHES_KEY: [],
+        }
+
+        # get both as flattened lists for easy mapping
+        ratios_list: list = percentile_ratios.tolist()
+        num_batches_list: list = counted_batches.tolist()
+
+        # calculate whether channels were statistically significant
+        significant_size = [
+            batch_size / total_batches >= self.fraction_batches_used_threshold
+            for batch_size in num_batches_list
+        ]
+        outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size
+
+        # calculate for each channel whether it's an outlier or not based on ratio
+        outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
+        outlier_dict[self.OUTLIER_KEY] = outlier_detected
+
+        # return the dictionary with the two lists
+        return outlier_dict
+
+    def _generate_info_dict(self, model: GraphModule) -> dict[str, dict]:
+        r"""
+        Helper function for generate_detector_report that does the generation of the dictionary.
+        This process is done as specified in generate_detector_report documentation
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a dict mapping relevant module fqns to:
+            whether there were outliers found in activation before
+            the number of batches used for each channel
+            whether fraction of applicable batches used is above fraction_batches_used_threshold
+            their p_r metric compared to the threshold
+            the threshold used to make the recommendation
+            the reference_percentile used to make the recommendation
+            the channel axis used to determine individual channels
+            the constant batch counts per channel
+            the per channel max values
+        """
+        # return dictionary mapping observer fqns to desired info
+        info_dict: dict[str, dict] = {}
+
+        for fqn, module in model.named_modules():
+            # if module is supported and it has a pre-observer
+            if self._supports_report_gen(module):
+                # get pre observer for the module
+                pre_obs: ModelReportObserver = getattr(
+                    module, self.DEFAULT_PRE_OBSERVER_NAME
+                )
+
+                # get the number of batches and calculated ratio thresholds
+                num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
+                average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
+                channel_batch_cnts: torch.Tensor = pre_obs.constant_channels
+                total_batches: int = pre_obs.num_batches_tracked
+
+                # also get the max values
+                max_vals: torch.Tensor = pre_obs.max_val
+
+                # we have to specifically modify how we are recording negative ratio for pre-relu layers
+                for index, ratio_val in enumerate(average_ratios):
+                    # check if we have a negative ratio
+                    # a ratio might be negative if we have a situation where the 100th percentile is
+                    # > 0 while the nth percentile is < 0, in which case this would not be detected
+                    # as an outlier. Since we care more about magnitude, we make it positive.
+                    if ratio_val.item() < 0:
+                        # first make it positive
+                        average_ratios[index] = -ratio_val
+
+                    if ratio_val.item() < 1:
+                        # if it's less than 1 we have the flip it as well
+                        average_ratios[index] = 1 / ratio_val
+
+                outlier_calcs = self._calculate_outlier_info(
+                    average_ratios, num_batches, total_batches
+                )
+
+                # calculate whether ratios were outliers
+                info_dict[fqn] = {
+                    self.CHANNEL_AXIS_KEY: self.ch_axis,
+                    self.REF_PERCENTILE_KEY: self.reference_percentile,
+                    self.RATIO_THRES_KEY: self.ratio_threshold,
+                    self.COMP_METRIC_KEY: average_ratios,
+                    self.NUM_BATCHES_KEY: num_batches,
+                    self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
+                    self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[
+                        self.IS_SUFFICIENT_BATCHES_KEY
+                    ],
+                    self.CONSTANT_COUNTS_KEY: channel_batch_cnts,
+                    self.MAX_VALS_KEY: max_vals,
+                }
+
+        return info_dict
+
+    def generate_detector_report(
+        self, model: GraphModule
+    ) -> tuple[str, dict[str, Any]]:
+        r"""
+        Determines whether input weight equalization is appropriate for a given module.
+
+        Takes advantage of the ModelReport Observer which records the relevant percentile information
+
+        Args:
+            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
+
+        Returns a tuple with two elements:
+            String report of of whether there are outliers in the activations around certain modules
+            Dictionary mapping modules of interest to:
+                whether there were outliers found in activation before
+                the number of batches used for each channel
+                whether fraction of applicable batches used is above fraction_batches_used_threshold
+                their p_r metric compared to the threshold
+                the threshold used to make the recommendation
+                the reference_percentile used to make the recommendation
+                the channel axis used to determine individual channels
+                the constant batch counts per channel
+                the per channel max values
+        """
+        # generate the information dictionary of outlier information
+        info_dict = self._generate_info_dict(model)
+
+        # now we can generate report based on this information
+        outlier_string = "Outlier detection report: \n"
+
+        # added module check
+        added_module: bool = False
+
+        # some strings to be formatted depending on module we are adding
+        module_suggestion_str = "For Module {} looked at with axis {}: \n"
+        channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n"
+        channel_max_value_str = "a max value across all batches of {}"
+        note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results."
+        note_distribution = "stationary distributions"
+        note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary"
+
+        # suggestion for constant batch check since that can make it no outliers
+        constant_str = "\tFor channel {}, we found {} constant value batches. {}\n"
+        constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why."
+
+        # compile the suggestion string
+        for module_fqn in info_dict:
+            # get module specific info
+            mod_info: dict[str, Any] = info_dict[module_fqn]
+            # check to see if we already added high level model desc
+            added_model_desc = False
+            # look at each individual channel and add a suggestion
+            for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]):
+                if outlier_detected:
+                    # we found at least 1 outlier
+                    if not added_model_desc:
+                        # add the module level description
+                        outlier_string += module_suggestion_str.format(
+                            module_fqn, self.ch_axis
+                        )
+                        added_model_desc = True
+
+                    # we mark that we found at least one outlier
+                    added_module = True
+                    max_value_found_str = channel_max_value_str.format(
+                        mod_info[self.MAX_VALS_KEY][index]
+                    )
+                    channel_str = channel_suggestion_str.format(
+                        index, max_value_found_str
+                    )
+                    outlier_string += channel_str
+
+                # also check if we found constant batch
+                if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0:
+                    # make sure we add a module level highlight.
+                    if not added_model_desc:
+                        # add the module level description
+                        outlier_string += module_suggestion_str.format(
+                            module_fqn, self.ch_axis
+                        )
+                        added_model_desc = True
+
+                    constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][
+                        index
+                    ]
+                    formatted_str = constant_str.format(
+                        index, constant_values_for_channel, constant_suggestion
+                    )
+                    outlier_string += formatted_str
+                    # we also added at least one thing to description
+                    added_module = True
+
+        # if found outlier, give suggestion, else give default response
+        if added_module:
+            # compose the note string
+            note_composed = note_string.format(note_distribution, note_rec)
+            outlier_string += note_composed
+        else:
+            outlier_string += "There were no outliers found in the activations.\n"
+
+        return (outlier_string, info_dict)
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6d2ce755123ea80ac69ed41b575e420e5a5619e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report.py
@@ -0,0 +1,664 @@
+# mypy: allow-untyped-defs
+from collections import OrderedDict
+from typing import Any, Callable
+
+import torch
+from torch.ao.quantization.fx._equalize import EqualizationQConfig
+from torch.ao.quantization.fx._model_report.detector import (
+    DETECTOR_IS_POST_OBS_KEY,
+    DETECTOR_OBS_ARGS_KEY,
+    DETECTOR_OBS_TO_INSERT_KEY,
+    DETECTOR_TARGET_NODE_KEY,
+    DetectorBase,
+    DetectorQConfigInfo,
+)
+from torch.ao.quantization.fx._model_report.model_report_visualizer import (
+    ModelReportVisualizer,
+)
+from torch.ao.quantization.fx.graph_module import GraphModule
+from torch.ao.quantization.observer import ObserverBase
+from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping
+
+
+class ModelReport:
+    r"""
+    The ModelReport class aims to provide users an easy way to diagnose issues that they run into
+    with their models. The class works with all traceable GraphModules to help diagnose issues,
+    though the requirements on the type of model more-so depends on the specific report the user
+    is trying to generate. With respect to the reports, the ModelReport class is initialized with
+    a set of Detector classes, each of which generate reports on quantization configuration
+    issues a use might have.
+
+    Currently supports generating reports on:
+    - Suggestions for per-channel vs. per-tensor quantization (nn.Module)
+    - Suggestions for dynamic vs static quantization for linear layers (Graph Modules)
+    - Suggestions for input-weight equalization for linear and conv layers (Graph Modules)
+    - Suggestions for outlier detection for all layers (Graph Modules)
+
+    The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver)
+    where needed for each detector to gather the information it needs, and then after callibration, the ModelReport
+    class compiles the report generated by each Detector class into a single report to return to the user. It also
+    has the capability to remove all the observers it inserted as well.
+
+    * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule
+
+    * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class
+        Make sure that these are all unique types of detectors [do not have more than 1 of the same class]
+
+    * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors.
+        This set is generated by calling the get_detector_name() of each detector
+
+    * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest
+        The purpose of this is to keep track of what observers were inserted for each detector, so that they
+        can be removed at the end if desired
+
+    * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not
+        This is to ensure we only insert observers once with the ModelReport instance
+
+    * :attr:`_removed_observers` A boolean to track if we have removed observers already
+        The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport
+        instance. This also allows the functionality where we can generate the report multiple times
+        as long as we haven't removed the observers yet.
+
+    Note:
+        This class was initially designed to work with the Fx Graph Mode workflow in mind. However,
+        full functionality is available as long as there is a traceable GraphModule that is being used.
+        One method to get a traceable GraphModule without going through the Fx workflow is to use
+        the QuantizationTracer class.
+
+    General Flow for Fx workflow:
+    1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model
+    2.) Prepare your model with prepare_fx
+    3.) Call model_report.prepare_detailed_calibration to add relevant observers
+    4.) Callibrate your model with data
+    5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
+    Optional
+        6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance
+        7.) To help in parsing report information and debugging, view report info as a:
+            - Table
+            - Histogram
+            - Line plot
+    8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions
+
+    Example (with QuantizationTracer):
+        >>> # xdoctest: +SKIP
+        >>> # get the necessary qconfig
+        >>> config = PrepareCustomConfig()
+        >>> skipped_module_names, skipped_module_classes = (
+        ...     get_skipped_module_name_and_classes(config, False)
+        ... )
+
+        >>> # initialize our model and get GraphModule
+        >>> model = SomeModel()
+        >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
+        >>> graph_module = GraphModule(model, tracer.trace(model))
+
+        >>> # get our set of detectors and ModelReport instance
+        >>> detector_set = set(
+        ...     [
+        ...         DynamicStaticDetector(tolerance=0.5),
+        ...         InputWeightEqualizationDetector(ratio_threshold=0.7),
+        ...     ]
+        ... )
+        >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set)
+
+        >>> # now we insert the observers and callibrate the model
+        >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration()
+        >>> for i in range(num_callibration_batches):
+        >>>     example_input = get_callibration_input()
+        >>>     tracer_model_with_observers(example_input)
+
+        >>> # finally we generate the reports and optionally remove the observers we inserted
+        >>> reports = tracer_reporter.generate_model_report(
+        ...     remove_inserted_observers=True
+        ... )
+
+        >>> # Optional: we can generate the qconfig mapping based on the suggestions
+        >>> qconfigs = model_report.generate_qconfig_mapping()
+
+        >>> # Optional: we can generate the equalization mapping based on the suggestions
+        >>> qconfigs = model_report.generate_equalization_mapping()
+
+        >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired
+        >>> model_report_visualizer = tracer_reporter.generate_visualizer()
+
+    """
+
+    def __init__(self, model: GraphModule, desired_report_detectors: set[DetectorBase]):
+        if len(desired_report_detectors) == 0:
+            raise ValueError("Should include at least 1 desired report")
+
+        # keep track of the model we wish to generate report for
+        self._model: GraphModule = model
+
+        # keep the reports private so they can't be modified
+        self._desired_report_detectors = desired_report_detectors
+        self._desired_detector_names = {
+            detector.get_detector_name() for detector in desired_report_detectors
+        }
+
+        # keep a mapping of desired reports to observers of interest
+        # this is to get the readings, and to remove them, can create a large set
+        # this set can then be used to traverse the graph and remove added observers
+        self._detector_name_to_observer_fqns: dict[str, set[str]] = {}
+
+        # initialize each report to have empty set of observers of interest
+        for desired_report in self._desired_detector_names:
+            self._detector_name_to_observer_fqns[desired_report] = set()
+
+        # flags to ensure that we can only prepare and remove observers once
+        self._prepared_flag = False
+        self._removed_observers = False
+
+        # store the reports that we generated for visualization purposes
+        # initially empty since no reports generated
+        self._generated_reports: dict[str, dict] = {}
+
+    def get_desired_reports_names(self) -> set[str]:
+        """Returns a copy of the desired reports for viewing"""
+        return self._desired_detector_names.copy()
+
+    def get_observers_of_interest(self) -> dict[str, set[str]]:
+        """Returns a copy of the observers of interest for viewing"""
+        return self._detector_name_to_observer_fqns.copy()
+
+    def prepare_detailed_calibration(self) -> GraphModule:
+        r"""
+        Takes in a graph model and inserts the following observers:
+        - ModelReportObserver
+
+        Each observer is inserted based on the desired_reports into the relevant locations
+
+        Right now, each report in self._desired_detector_names has independent insertions
+            However, if a module already has a Observer of the same type, the insertion will not occur
+            This is because all of the same type of Observer collect same information, so redundant
+
+        Returns the same GraphModule with the observers inserted
+        """
+
+        # if already prepared once, cannot prepare again
+        if self._prepared_flag:
+            raise ValueError(
+                "Already ran preparing detailed callibration. Run the report generation next after callibration."
+            )
+
+        # loop through each detector, find where placements should be, and keep track
+        insert_observers_fqns: dict[str, Any] = {}
+
+        for detector in self._desired_report_detectors:
+            # determine observer points for each detector
+            obs_fqn_to_info = detector.determine_observer_insert_points(self._model)
+            # map each insert point to the observer to use
+            insert_observers_fqns.update(obs_fqn_to_info)
+            # update the set of observers this report cares about
+            self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(
+                obs_fqn_to_info.keys()
+            )
+
+        # now insert all the observers at their desired locations
+        for observer_fqn in insert_observers_fqns:
+            target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY]
+            insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY]
+            insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY]
+            observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY]
+            self._insert_observer_around_module(
+                observer_fqn, target_node, insert_obs, observer_args, insert_post
+            )
+
+        self._prepared_flag = True
+
+        return self._model
+
+    def _insert_observer_around_module(
+        self,
+        obs_fqn: str,
+        target_node: torch.fx.node.Node,
+        obs_to_insert: ObserverBase,
+        observer_args: tuple,
+        insert_post: bool,
+    ):
+        r"""
+        Helper function that inserts the observer into both the graph structure and the module of the model
+
+        Args
+            node_fqn (str): The fully qualified name of the observer we want to insert
+            target_node (torch.fx.node.Node): The node in model we are inserting observers around
+            obs_to_insert (ObserverBase): The observer we are inserting around target_node
+            observer_args (Tuple): The arguments we want to pass into the observer
+            insert_post (bool): whether this is meant to be a post observer for this node
+        """
+        # if we are inserting post, then our target node is the next node
+        if insert_post:
+            target_node = target_node.next
+
+        with self._model.graph.inserting_before(target_node):
+            self._model.add_submodule(obs_fqn, obs_to_insert)
+            self._model.graph.create_node(
+                op="call_module", target=obs_fqn, args=observer_args
+            )
+
+        # recompile model after inserts are made
+        self._model.recompile()
+
+    def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node:
+        r"""
+        Takes in a node fqn and returns the node based on the fqn
+
+        Args
+            node_fqn (str): The fully qualified name of the node we want to find in model
+
+        Returns the Node object of the given node_fqn otherwise returns None
+        """
+        node_to_return = None
+        for node in self._model.graph.nodes:
+            # if the target matches the fqn, it's the node we are looking for
+            if node.target == node_fqn:
+                node_to_return = node
+                break
+
+        if node_to_return is None:
+            raise ValueError("The node_fqn is was not found within the module.")
+
+        # assert for MyPy
+        assert isinstance(node_to_return, torch.fx.node.Node)
+
+        return node_to_return
+
+    def generate_model_report(
+        self, remove_inserted_observers: bool
+    ) -> dict[str, tuple[str, dict]]:
+        r"""
+        Generates all the requested reports.
+
+        Note:
+            You should have callibrated the model with relevant data before calling this
+
+        The reports generated are specified by the desired_reports specified in desired_reports
+
+        Can optionally remove all the observers inserted by the ModelReport instance
+
+        Args:
+            remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance
+
+        Returns a mapping of each desired report name to a tuple with:
+            The textual summary of that report information
+            A dictionary containing relevant statistics or information for that report
+
+        Note:
+            Throws exception if we try to generate report on model we already removed observers from
+            Throws exception if we try to generate report without preparing for callibration
+        """
+        # if we haven't prepped model for callibration, then we shouldn't generate report yet
+        if not self._prepared_flag:
+            raise Exception(  # noqa: TRY002
+                "Cannot generate report without preparing model for callibration"
+            )
+
+        # if we already removed the observers, we cannot generate report
+        if self._removed_observers:
+            raise Exception(  # noqa: TRY002
+                "Cannot generate report on model you already removed observers from"
+            )
+
+        # keep track of all the reports of interest and their outputs
+        reports_of_interest = {}
+
+        for detector in self._desired_report_detectors:
+            # generate the individual report for the detector
+            report_output = detector.generate_detector_report(self._model)
+            reports_of_interest[detector.get_detector_name()] = report_output
+
+        # if user wishes to remove inserted observers, go ahead and remove
+        if remove_inserted_observers:
+            self._removed_observers = True
+            # get the set of all Observers inserted by this instance of ModelReport
+            all_observers_of_interest: set[str] = set()
+            for desired_report in self._detector_name_to_observer_fqns:
+                observers_of_interest = self._detector_name_to_observer_fqns[
+                    desired_report
+                ]
+                all_observers_of_interest.update(observers_of_interest)
+
+            # go through all_observers_of_interest and remove them from the graph and model
+            for observer_fqn in all_observers_of_interest:
+                # remove the observer from the model
+                self._model.delete_submodule(observer_fqn)
+
+                # remove the observer from the graph structure
+                node_obj = self._get_node_from_fqn(observer_fqn)
+
+                if node_obj:
+                    self._model.graph.erase_node(node_obj)
+                else:
+                    raise ValueError("Node no longer exists in GraphModule structure")
+
+            # remember to recompile the model
+            self._model.recompile()
+
+        # save the generated reports for visualization purposes
+        saved_reports: dict[str, dict] = {
+            report_name: report_tuple[1]
+            for report_name, report_tuple in reports_of_interest.items()
+        }
+
+        self._generated_reports = saved_reports
+
+        # return the reports of interest
+        return reports_of_interest
+
+    def _is_same_info_for_same_key(self, info_dict_a: dict, info_dict_b: dict) -> bool:
+        r"""
+        Takes in two dictionaries and ensures that any common keys between the two have the same
+        values.
+
+        Args:
+            info_dict_a (Dict): First dictionary we wish to compare
+            info_dict_b (Dict): Second dictionary we wish to compare
+
+        Returns True if all shared keys have same values, false otherwise
+        """
+        # get the set of keys for both
+        dict_a_keys: set = set(info_dict_a.keys())
+        dict_b_keys: set = set(info_dict_b.keys())
+
+        # get the insersection keys and check if same value for both dicts
+        intersecting_keys: set = dict_a_keys.intersection(dict_b_keys)
+
+        for key in intersecting_keys:
+            dict_a_val = info_dict_a[key]
+            dict_b_val = info_dict_b[key]
+
+            # if it's a tensor we have to handle separately
+            if type(dict_a_val) == torch.Tensor:
+                # if dict_b_val not tensor, automatically false
+                if (
+                    type(dict_b_val) != torch.Tensor
+                    or sum(dict_a_val != dict_b_val) != 0
+                ):
+                    return False
+            else:
+                # for non-tensor vals
+                if dict_a_val != dict_b_val:
+                    return False
+
+        # if no non matching shared keys found, return true
+        return True
+
+    def _reformat_reports_for_visualizer(self) -> OrderedDict:
+        r"""
+        Takes the generated reports and reformats them into the format that is desired by the
+        ModelReportVisualizer
+
+        Returns an OrderedDict mapping module_fqns to their features
+        """
+        # we want to reorder and reformat the information so it is ordered in terms of order
+        # found in the model
+
+        # first create new dict with all modules as keys and features under respective module
+        module_fqns_to_features: dict[str, dict] = {}
+
+        for report_name in self._generated_reports:
+            # get mod -> feature dict and go through
+            module_info = self._generated_reports[report_name]
+
+            for module_fqn in module_info:
+                # check if already in our accumulation dict
+                if module_fqn in module_fqns_to_features:
+                    # we merge all the features together
+                    new_info: dict = module_info[module_fqn]
+                    present_info: dict = module_fqns_to_features[module_fqn]
+
+                    # merge them together into the new unioned dict
+                    # same features keys -> same info, so okay if override
+
+                    # do safety check to make sure shared keys have same info
+                    if self._is_same_info_for_same_key(new_info, present_info):
+                        module_fqns_to_features[module_fqn] = {
+                            **new_info,
+                            **present_info,
+                        }
+                    else:
+                        error_str = "You have the same key with different values across detectors. "
+                        error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors."
+                        raise ValueError(error_str)
+                else:
+                    # we just set it
+                    module_fqns_to_features[module_fqn] = module_info[module_fqn]
+
+        # our ordered dict so that modules can be ordered in order of how they appear in model
+        features_by_module: OrderedDict[str, dict] = OrderedDict()
+
+        # we loop through modules in graph in order
+        for fqn, _module in self._model.named_modules():
+            # find that fqn in fqns_to_features
+            if fqn in module_fqns_to_features:
+                # add it to our ordered dict
+                features_by_module[fqn] = module_fqns_to_features[fqn]
+
+        # return the ordered dict of info we created
+        return features_by_module
+
+    def generate_visualizer(self) -> ModelReportVisualizer:
+        r"""
+        Generates a ModelReportVisualizer instance using the reports generated
+        by the generate_model_report() method.
+
+        Returns the generated ModelReportVisualizer instance initialized
+
+        Note:
+            Throws exception if attempt to get visualizers without generating report
+        """
+        # check if user has generated reports at least once
+        if len(self._generated_reports) == 0:
+            raise Exception(  # noqa: TRY002
+                "Unable to generate visualizers without first generating reports"
+            )
+
+        # get the ordered dict mapping modules to their full set of collected features / stats
+        module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer()
+
+        # create and return ModelReportVisualizer instance
+        visualizer: ModelReportVisualizer = ModelReportVisualizer(
+            module_fqns_to_features
+        )
+
+        return visualizer
+
+    def _generate_qconfig_mapping_helper(
+        self,
+        detector_qconfig_info_combined: dict[str, DetectorQConfigInfo],
+        generation_function: Callable,
+    ) -> QConfigMapping:
+        r"""
+        This helper takes in the compiled detector qconfig info that
+        has been compiled together and merges it into a QConfigMapping
+        """
+        # keep track of the qconfigmapping
+        qconfig_mapping = QConfigMapping()
+
+        # loop through each module / fqn and attempt to create QConfigMapping
+        for fqn, module in self._model.named_modules():
+            # if we have a qconfig info for this module
+            if fqn in detector_qconfig_info_combined:
+                qconfig_info_compiled = detector_qconfig_info_combined[fqn]
+
+                # now generate the qconfig and add it to the mapping
+                generated_qconfig = generation_function(qconfig_info_compiled, module)
+
+                # add to our config
+                qconfig_mapping.set_module_name(fqn, generated_qconfig)
+
+        # return compiled mapping
+        return qconfig_mapping
+
+    def _update_detector_quantizaiton_qconfig_info(
+        self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo
+    ):
+        r"""
+        Takes in the old and new information and updates the combined information.
+
+        Args:
+            combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
+            new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
+                into it
+        """
+        combined_info.is_activation_dynamic = (
+            combined_info.is_activation_dynamic or new_info.is_activation_dynamic
+        )
+        combined_info.is_weight_per_channel = (
+            combined_info.is_weight_per_channel or new_info.is_weight_per_channel
+        )
+
+    def _update_detector_equalization_qconfig_info(
+        self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo
+    ):
+        r"""
+        Takes in the old and new information and updates the combined information.
+
+        Args:
+            combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
+            new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
+                into it
+        """
+        is_equalization_recommended = (
+            combined_info.is_equalization_recommended
+            or new_info.is_equalization_recommended
+        )
+        combined_info.is_equalization_recommended = is_equalization_recommended
+
+    def _generate_module_fqn_to_detector_info_mapping(
+        self, update_qconfig_info_function: Callable
+    ) -> dict[str, DetectorQConfigInfo]:
+        r"""
+        Generates a QConfigMapping based on the suggestions of the
+        ModelReport API. The generated mapping encompasses all the
+        different types of feedback from the different detectors
+        all into one place.
+
+        These configs are based on the suggestions provided by the ModelReport API
+        and can only be generated once the reports have been generated.
+
+        Args:
+            update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo
+            and updates the one that is being compiled
+
+        Returns a Dict mapping module_fqns to DetectorQConfigInfo objects
+
+        Note:
+            Throws exception if we try to generate mapping on model we already removed observers from
+            Throws exception if we try to generate mapping without preparing for callibration
+        """
+        # if we haven't prepped model for callibration, then we shouldn't generate mapping yet
+        if not self._prepared_flag:
+            raise Exception(  # noqa: TRY002
+                "Cannot generate report without preparing model for callibration"
+            )
+
+        # if we already removed the observers, we cannot mapping
+        if self._removed_observers:
+            raise Exception(  # noqa: TRY002
+                "Cannot generate report on model you already removed observers from"
+            )
+
+        # keep track of qconfig info for each module across detectors
+        detector_qconfig_info_combined: dict[str, DetectorQConfigInfo] = {}
+
+        for detector in self._desired_report_detectors:
+            # get the info from the detector
+            detector_info: dict[str, DetectorQConfigInfo] = detector.get_qconfig_info(
+                self._model
+            )
+
+            # we go through the modules
+            for module_fqn in detector_info:
+                # see if we already have info on it
+                if module_fqn in detector_qconfig_info_combined:
+                    # we combine the current options with what is there
+                    current_options = detector_qconfig_info_combined[module_fqn]
+                    detector_options = detector_info[module_fqn]
+
+                    update_qconfig_info_function(current_options, detector_options)
+                else:
+                    # we just use this for now
+                    detector_qconfig_info_combined[module_fqn] = detector_info[
+                        module_fqn
+                    ]
+
+        return detector_qconfig_info_combined
+
+    def generate_qconfig_mapping(self) -> QConfigMapping:
+        r"""
+        Generates a QConfigMapping based on the suggestions of the
+        ModelReport API. The generated mapping encompasses all the
+        different types of feedback from the different detectors
+        all into one place.
+
+        These configs are based on the suggestions provided by the ModelReport API
+        and can only be generated once the reports have been generated.
+
+        Returns a QConfigMapping for the quantization configuration
+
+        Note:
+            Throws exception if we try to generate mapping on model we already removed observers from
+            Throws exception if we try to generate mapping without preparing for callibration
+        """
+        # get the mapping info
+        detector_qconfig_info_combined = (
+            self._generate_module_fqn_to_detector_info_mapping(
+                self._update_detector_quantizaiton_qconfig_info
+            )
+        )
+
+        # we will do a bit of processing and remove fqns that don't have input weight recommended
+
+        # now we generate the QConfig for each of the options
+        mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
+            detector_qconfig_info_combined, self._quantization_config_generator
+        )
+
+        # return the generated mapping
+        return mapping
+
+    def _quantization_config_generator(
+        self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module
+    ) -> QConfig:
+        r"""
+        Returns the quantization configuration generated by the DetectorQConfigInfo object
+        """
+        return detector_qconfig_info.generate_quantization_qconfig(module)
+
+    def _equalization_config_generator(
+        self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module
+    ) -> EqualizationQConfig:
+        r"""
+        We ignore the module argument here, and only focus on thedetector_qconfig_info
+
+        Returns the equalization configuration generated by the DetectorQConfigInfo object
+        """
+        return detector_qconfig_info.generate_equalization_qconfig()
+
+    def generate_equalization_mapping(self) -> QConfigMapping:
+        r"""
+        Generates a QConfigMapping based on the suggestions of the
+        ModelReport API for equalization. The generated mapping encompasses all the
+        different types of feedback from the input-weight equalization detector.
+
+        These configs are based on the suggestions provided by the ModelReport API
+        and can only be generated once the reports have been generated.
+
+        Returns a QConfigMapping for the equalization configuration
+        """
+        # get the mapping info
+        detector_qconfig_info_combined = (
+            self._generate_module_fqn_to_detector_info_mapping(
+                self._update_detector_equalization_qconfig_info
+            )
+        )
+
+        # now we generate the QConfig for each of the options
+        mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
+            detector_qconfig_info_combined, self._equalization_config_generator
+        )
+
+        # return the generated mapping
+        return mapping
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a809dc60838e574e0bd484ee9698e9d1a0a5ee47
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py
@@ -0,0 +1,285 @@
+# mypy: allow-untyped-defs
+import torch
+from torch.ao.quantization.observer import ObserverBase
+
+
+class ModelReportObserver(ObserverBase):
+    r"""This observer is used to record additional information regarding keeping track
+    of S = average_batch_activation_range/epoch_activation_range.
+
+    The purpose of this information is to prepare a report to present to users on whether
+    Dynamic or Static Quantization is more appropriate for their model given the general
+    distributions of their data.
+
+    Args:
+        ch_axis (int, optional): The channel axis for which the range and outlier stats are computed
+            Default: 1
+        comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers
+            Should be between 0 and 1 exclusive
+            Default: 0.9
+
+    * :attr:`num_batches_tracked` specifies number of batches passed through the observer
+
+    * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through
+
+    * :attr:`epoch_activation_min` defines the minimum value passed through the observer
+
+    * :attr:`epoch_activation_max` defines the maximum value passed through the observer
+
+    * :attr:`ch_axis` defines the channel being used to compute per channel min max stats
+
+    * :attr:`min_val` defines the per channel minimum values passed through
+
+    * :attr:`max_val` defines the per channel maximum values passed through
+
+    * :attr:`comp_percentile` defines comparison percentile to find outliers
+
+    * :attr:`average_percentile_ratio` defines the per channel average percentile ratios
+
+    * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel
+
+    * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel
+
+    Note: this tool is meant for FX Graph Mode Quantization
+    """
+
+    epoch_activation_min: torch.Tensor
+    epoch_activation_max: torch.Tensor
+    min_val: torch.Tensor
+    max_val: torch.Tensor
+    comp_percentile: torch.Tensor
+    average_percentile_ratio: torch.Tensor
+    percentile_batches_tracked: torch.Tensor
+    constant_channels: torch.Tensor
+
+    def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9):
+        super().__init__(torch.qint8)
+        self.num_batches_tracked = 0
+
+        # keep track of the min and mix of the range for average batch and epoch as a whole
+        self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0))
+        self.register_buffer("epoch_activation_min", torch.tensor(float("inf")))
+        self.register_buffer("epoch_activation_max", torch.tensor(float("-inf")))
+
+        # keep track of per channel min max information using the given channel
+        self.ch_axis: int = ch_axis
+        self.register_buffer("min_val", torch.tensor([]))
+        self.register_buffer("max_val", torch.tensor([]))
+
+        # keep track of percentile ratio information per channel
+        self.register_buffer("comp_percentile", torch.tensor([comp_percentile]))
+        self.register_buffer("average_percentile_ratio", torch.tensor([]))
+        self.register_buffer("percentile_batches_tracked", torch.tensor([]))
+        self.register_buffer("constant_channels", torch.tensor([]))
+
+    def forward(self, x):
+        x_copy = x.detach()  # avoid keeping autograd tape
+        x_copy = x_copy.to(self.epoch_activation_min.dtype)
+
+        x_copy = self._calculate_range_stats(x_copy)
+        x_copy = self._calculate_min_max_stats(x_copy)
+        x_copy = self._calculate_percentile_stats(x_copy)
+
+        # return the passed in the value
+        return x
+
+    def _calculate_range_stats(self, x_copy):
+        r"""Calculates and stores range stats with forward values.
+
+        Args
+            x_copy: A copy of the forward data
+
+        Returns the passed in x_copy
+        """
+        # get the min, max values of the data
+        min_val_cur, max_val_cur = torch.aminmax(x_copy)
+
+        # calculate new epoch range values
+        epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur)
+        epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur)
+
+        self.epoch_activation_min.copy_(epoch_min_val)
+        self.epoch_activation_max.copy_(epoch_max_val)
+
+        # calculate the average batch activation range
+        current_batch_range = max_val_cur - min_val_cur
+        new_range = (
+            self.average_batch_activation_range * self.num_batches_tracked
+            + current_batch_range
+        ) / (self.num_batches_tracked + 1)
+
+        self.average_batch_activation_range = new_range
+        self.num_batches_tracked += 1  # new batch was processed
+
+        return x_copy
+
+    def _calculate_min_max_stats(self, x_copy):
+        r"""Calculates and stores the per_channel min, max stats with forward values.
+        Does calculation based on channel axis: self.ch_axis
+
+        Args
+            x_copy: A copy of the forward data
+
+        Returns the passed in x_copy
+        """
+        # get the current min and max vals
+        min_val = self.min_val
+        max_val = self.max_val
+        x_dim = x_copy.size()
+
+        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
+        new_axis_list[self.ch_axis] = 0
+        new_axis_list[0] = self.ch_axis
+        y = x_copy.permute(new_axis_list)
+        # Need to match dtype of min/max because the updates to buffers
+        # are done in place and types need to match for comparisons
+        y = y.to(self.min_val.dtype)
+        y = torch.flatten(y, start_dim=1)
+        if min_val.numel() == 0 or max_val.numel() == 0:
+            min_val, max_val = torch.aminmax(y, dim=1)
+        else:
+            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
+            min_val = torch.min(min_val_cur, min_val)
+            max_val = torch.max(max_val_cur, max_val)
+
+        self.min_val.resize_(min_val.shape)
+        self.max_val.resize_(max_val.shape)
+        self.min_val.copy_(min_val)
+        self.max_val.copy_(max_val)
+
+        return x_copy
+
+    def _calculate_percentile_stats(self, x_copy):
+        r"""Calculates and stores the per_channel percentile stats with forward values.
+        Does calculation based on channel axis: self.ch_axis
+
+        Args
+            x_copy: A copy of the forward data
+
+        Returns the passed in x_copy
+        """
+        # get the dimension of the copy
+        x_dim = x_copy.size()
+
+        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
+        new_axis_list[self.ch_axis] = 0
+        new_axis_list[0] = self.ch_axis
+        y = x_copy.permute(new_axis_list)
+        # Need to match dtype of min/max because the updates to buffers
+        # are done in place and types need to match for comparisons
+        y = y.to(self.min_val.dtype)
+        y = torch.flatten(y, start_dim=1)
+        y = y.to(dtype=self.min_val.dtype, device="cpu")
+
+        # find the percentile values along the axis
+        # we want both 100th percentile and comp_percentile
+        # we also want to find 0th quartile to see if we have constant channel
+        quantiles_list = [0, self.comp_percentile, 1.00]
+        quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype)
+
+        # find the quantiles
+        desired_quantiles = torch.quantile(
+            y, quantiles_to_find, dim=self.ch_axis, interpolation="lower"
+        )
+        zero_quantile = desired_quantiles[0]
+        comp_quantile = desired_quantiles[1]
+        hundreth_quartile = desired_quantiles[2]
+
+        # if any of the channels have 0s, we ignore that channel for this calculation
+        any_non_zero_quantile_value: torch.Tensor = (
+            comp_quantile != torch.tensor([0])
+        ) | (hundreth_quartile != torch.tensor([0]))
+        any_non_zero_quantile_value = (
+            any_non_zero_quantile_value.int()
+        )  # transform boolean values to int values
+
+        # we also check if we have a constant channel
+        any_constant_channels: torch.Tensor = (
+            hundreth_quartile - zero_quantile
+        ) == torch.tensor([0])
+        any_constant_channels = (
+            any_constant_channels.int()
+        )  # transform boolean values to int values
+
+        # possibilities to get nan as an answer
+        #   will ignore any of these three cases with 0s and just not deal with them for now
+        # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative
+        # case (2) 0 in denominator: is possible unless case 3, we just ignore
+        # case (3) 0 in both: not outlier, channel just kinda useless, ignore
+
+        # get the ratio and get rid of nan values
+        quantile_ratios = hundreth_quartile / comp_quantile
+        quantile_ratios = torch.nan_to_num(quantile_ratios)
+        # update averages, remembering to only update if didn't have zeros
+        ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios
+
+        # if num_batches and average_ratio are not initialized, we want to initialize them
+        if (
+            self.percentile_batches_tracked.shape[0] == 0
+            or self.average_percentile_ratio.shape[0] == 0
+        ):
+            self.percentile_batches_tracked = torch.zeros_like(
+                any_non_zero_quantile_value
+            )
+            self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero)
+
+        # also initialize the constant channel var if that is not initialized separately
+        if self.constant_channels.shape[0] == 0:
+            self.constant_channels = torch.zeros_like(any_constant_channels)
+
+        # get current num batches and average ratio
+        num_batches = self.percentile_batches_tracked
+        average_ratio = self.average_percentile_ratio
+
+        # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches
+        new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value
+        new_ratios: torch.Tensor = (
+            (average_ratio * num_batches) + ratio_if_not_zero
+        ) / new_number_of_batches
+        new_ratios = torch.nan_to_num(new_ratios)
+
+        # update the number of non-constant channels
+        new_constant_count: torch.Tensor = (
+            self.constant_channels + any_constant_channels
+        )
+
+        # update the values locally
+        self.percentile_batches_tracked.copy_(new_number_of_batches)
+        self.average_percentile_ratio.copy_(new_ratios)
+        self.constant_channels.copy_(new_constant_count)
+
+        return x_copy
+
+    @torch.jit.export
+    def get_batch_to_epoch_ratio(self):
+        epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min
+
+        if epoch_activation_range == torch.tensor(float(0)):
+            raise ValueError("Range for Epoch is 0")
+        elif epoch_activation_range == torch.tensor(float("inf")):
+            raise ValueError(
+                "No data has been run through observer or infinity value present"
+            )
+        else:
+            return self.average_batch_activation_range / epoch_activation_range
+
+    @torch.jit.export
+    def reset_batch_and_epoch_values(self):
+        # set all the values back to their original defaults for a new epoch
+        # keep device
+        device = self.max_val.device
+        self.num_batches_tracked = 0
+        self.average_batch_activation_range = torch.tensor(float(0), device=device)
+        self.epoch_activation_min = torch.tensor(float("inf"), device=device)
+        self.epoch_activation_max = torch.tensor(float("-inf"), device=device)
+        self.min_val = torch.tensor([], device=device)
+        self.max_val = torch.tensor([], device=device)
+        self.average_percentile_ratio = torch.tensor([], device=device)
+        self.percentile_batches_tracked = torch.tensor([], device=device)
+        self.constant_channels = torch.tensor([], device=device)
+
+    @torch.jit.export
+    def calculate_qparams(self):  # type: ignore[override]
+        raise Exception(  # noqa: TRY002
+            "calculate_qparams should not be called for ModelReportObserver"
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..20f08c9dbc0f8a432849494fdbdd70028c36f64c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py
@@ -0,0 +1,710 @@
+# mypy: allow-untyped-defs
+from collections import OrderedDict, OrderedDict as OrdDict
+from typing import Any
+
+import torch
+
+
+# try to import tablate
+got_tabulate = True
+try:
+    from tabulate import tabulate
+except ImportError:
+    got_tabulate = False
+
+
+# var to see if we could import matplotlib
+got_matplotlib = True
+try:
+    import matplotlib.pyplot as plt
+except ImportError:
+    got_matplotlib = False
+
+
+class ModelReportVisualizer:
+    r"""
+    The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics
+    that were generated by the ModelReport API. However, at a higher level, the class aims to provide
+    some level of visualization of statistics to PyTorch in order to make it easier to parse data and
+    diagnose any potential issues with data or a specific model. With respect to the visualizations,
+    the ModelReportVisualizer class currently supports several methods of visualizing data.
+
+    Supported Visualization Methods Include:
+    - Table format
+    - Plot format (line graph)
+    - Histogram format
+
+    For all of the existing visualization methods, there is the option to filter data based on:
+    - A module fqn prefix
+    - Feature [required for the plot and histogram]
+
+    * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below
+        Ensure sure that features that are the same across different report contain the same name
+        Ensure that objects representing the same features are the same type / dimension (where applicable)
+
+    Note:
+        Currently, the ModelReportVisualizer class supports visualization of data generated by the
+        ModelReport class. However, this structure is extensible and should allow the visualization of
+        other information as long as the information is structured in the following general format:
+
+        Report Structure
+        -- module_fqn [module with attached detectors]
+            |
+            -- feature keys [not every detector extracts same information]
+                                    [same collected info has same keys, unless can be specific to detector]
+
+
+    The goal behind the class is that the generated visualizations can be used in conjunction with the generated
+    report for people to get a better understanding of issues and what the fix might be. It is also just to provide
+    a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as
+    that grows in size.
+
+    General Use Flow Expected
+    1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects
+    2.) Prepare your model with prepare_fx
+    3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers
+    4.) Callibrate your model with data
+    5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
+    6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance
+    7.) Use instance to view different views of data as desired, applying filters as needed
+        8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram
+
+    """
+
+    # keys for table dict
+    TABLE_TENSOR_KEY = "tensor_level_info"
+    TABLE_CHANNEL_KEY = "channel_level_info"
+
+    # Constants for header vals
+    NUM_NON_FEATURE_TENSOR_HEADERS = 2
+    NUM_NON_FEATURE_CHANNEL_HEADERS = 3
+
+    # Constants for row index in header
+    CHANNEL_NUM_INDEX = 2
+
+    def __init__(self, generated_reports: OrderedDict[str, Any]):
+        r"""
+        Initializes the ModelReportVisualizer instance with the necessary reports.
+
+        Args:
+            generated_reports (Dict[str, Any]): The reports generated by the ModelReport class
+                can also be a dictionary generated in another manner, as long as format is same
+        """
+        self.generated_reports = generated_reports
+
+    def get_all_unique_module_fqns(self) -> set[str]:
+        r"""
+        The purpose of this method is to provide a user the set of all module_fqns so that if
+        they wish to use some of the filtering capabilities of the ModelReportVisualizer class,
+        they don't need to manually parse the generated_reports dictionary to get this information.
+
+        Returns all the unique module fqns present in the reports the ModelReportVisualizer
+        instance was initialized with.
+        """
+        # returns the keys of the ordered dict
+        return set(self.generated_reports.keys())
+
+    def get_all_unique_feature_names(
+        self, plottable_features_only: bool = True
+    ) -> set[str]:
+        r"""
+        The purpose of this method is to provide a user the set of all feature names so that if
+        they wish to use the filtering capabilities of the generate_table_view(), or use either of
+        the generate_plot_view() or generate_histogram_view(), they don't need to manually parse
+        the generated_reports dictionary to get this information.
+
+        Args:
+            plottable_features_only (bool): True if the user is only looking for plottable features,
+                False otherwise
+                plottable features are those that are tensor values
+                Default: True (only return those feature names that are plottable)
+
+        Returns all the unique module fqns present in the reports the ModelReportVisualizer
+        instance was initialized with.
+        """
+        unique_feature_names = set()
+        for module_fqn in self.generated_reports:
+            # get dict of the features
+            feature_dict: dict[str, Any] = self.generated_reports[module_fqn]
+
+            # loop through features
+            for feature_name in feature_dict:
+                # if we need plottable, ensure type of val is tensor
+                if (
+                    not plottable_features_only
+                    or type(feature_dict[feature_name]) == torch.Tensor
+                ):
+                    unique_feature_names.add(feature_name)
+
+        # return our compiled set of unique feature names
+        return unique_feature_names
+
+    def _get_filtered_data(
+        self, feature_filter: str, module_fqn_filter: str
+    ) -> OrderedDict[str, Any]:
+        r"""
+        Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed.
+
+        Args:
+            feature_filter (str): The feature filter, if we want to filter the set of data to only include
+                a certain set of features that include feature_filter
+                If feature = "", then we do not filter based on any features
+            module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with
+                this prefix will be included
+                If module_fqn_filter = "" we do not filter based on module fqn, and include all modules
+
+        First, the data is filtered based on module_fqn, and then filtered based on feature
+        Returns an OrderedDict (sorted in order of model) mapping:
+            module_fqns -> feature_names -> values
+        """
+        # create return dict
+        filtered_dict: OrderedDict[str, Any] = OrdDict()
+
+        for module_fqn in self.generated_reports:
+            # first filter based on module
+            if module_fqn_filter == "" or module_fqn_filter in module_fqn:
+                # create entry for module and loop through features
+                filtered_dict[module_fqn] = {}
+                module_reports = self.generated_reports[module_fqn]
+                for feature_name in module_reports:
+                    # check if filtering on features and do so if desired
+                    if feature_filter == "" or feature_filter in feature_name:
+                        filtered_dict[module_fqn][feature_name] = module_reports[
+                            feature_name
+                        ]
+
+        # we have populated the filtered dict, and must return it
+
+        return filtered_dict
+
+    def _generate_tensor_table(
+        self,
+        filtered_data: OrderedDict[str, dict[str, Any]],
+        tensor_features: list[str],
+    ) -> tuple[list, list]:
+        r"""
+        Takes in the filtered data and features list and generates the tensor headers and table
+
+        Currently meant to generate the headers and table for both the tensor information.
+
+        Args:
+            filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping:
+                module_fqns -> feature_names -> values
+            tensor_features (List[str]): A list of the tensor level features
+
+        Returns a tuple with:
+            A list of the headers of the tensor table
+            A list of lists containing the table information row by row
+            The 0th index row will contain the headers of the columns
+            The rest of the rows will contain data
+        """
+        # now we compose the tensor information table
+        tensor_table: list[list[Any]] = []
+        tensor_headers: list[str] = []
+
+        # append the table row to the table only if we have features
+        if len(tensor_features) > 0:
+            # now we add all the data
+            for index, module_fqn in enumerate(filtered_data):
+                # we make a new row for the tensor table
+                tensor_table_row = [index, module_fqn]
+                for feature in tensor_features:
+                    # we iterate in same order of added features
+
+                    if feature in filtered_data[module_fqn]:
+                        # add value if applicable to module
+                        feature_val = filtered_data[module_fqn][feature]
+                    else:
+                        # add that it is not applicable
+                        feature_val = "Not Applicable"
+
+                    # if it's a tensor we want to extract val
+                    if isinstance(feature_val, torch.Tensor):
+                        feature_val = feature_val.item()
+
+                    # we add to our list of values
+                    tensor_table_row.append(feature_val)
+
+                tensor_table.append(tensor_table_row)
+
+        # add row of headers of we actually have something, otherwise just empty
+        if len(tensor_table) != 0:
+            tensor_headers = ["idx", "layer_fqn"] + tensor_features
+
+        return (tensor_headers, tensor_table)
+
+    def _generate_channels_table(
+        self,
+        filtered_data: OrderedDict[str, Any],
+        channel_features: list[str],
+        num_channels: int,
+    ) -> tuple[list, list]:
+        r"""
+        Takes in the filtered data and features list and generates the channels headers and table
+
+        Currently meant to generate the headers and table for both the channels information.
+
+        Args:
+            filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping:
+                module_fqns -> feature_names -> values
+            channel_features (List[str]): A list of the channel level features
+            num_channels (int): Number of channels in the channel data
+
+        Returns a tuple with:
+            A list of the headers of the channel table
+            A list of lists containing the table information row by row
+            The 0th index row will contain the headers of the columns
+            The rest of the rows will contain data
+        """
+        # now we compose the table for the channel information table
+        channel_table: list[list[Any]] = []
+        channel_headers: list[str] = []
+
+        # counter to keep track of number of entries in
+        channel_table_entry_counter: int = 0
+
+        if len(channel_features) > 0:
+            # now we add all channel data
+            for module_fqn in filtered_data:
+                # we iterate over all channels
+                for channel in range(num_channels):
+                    # we make a new row for the channel
+                    new_channel_row = [channel_table_entry_counter, module_fqn, channel]
+                    for feature in channel_features:
+                        if feature in filtered_data[module_fqn]:
+                            # add value if applicable to module
+                            feature_val = filtered_data[module_fqn][feature][channel]
+                        else:
+                            # add that it is not applicable
+                            feature_val = "Not Applicable"
+
+                        # if it's a tensor we want to extract val
+                        if type(feature_val) is torch.Tensor:
+                            feature_val = feature_val.item()
+
+                        # add value to channel specific row
+                        new_channel_row.append(feature_val)
+
+                    # add to table and increment row index counter
+                    channel_table.append(new_channel_row)
+                    channel_table_entry_counter += 1
+
+        # add row of headers of we actually have something, otherwise just empty
+        if len(channel_table) != 0:
+            channel_headers = ["idx", "layer_fqn", "channel"] + channel_features
+
+        return (channel_headers, channel_table)
+
+    def generate_filtered_tables(
+        self, feature_filter: str = "", module_fqn_filter: str = ""
+    ) -> dict[str, tuple[list, list]]:
+        r"""
+        Takes in optional filter values and generates two tables with desired information.
+
+        The generated tables are presented in both a list-of-lists format
+
+        The reason for the two tables are that they handle different things:
+        1.) the first table handles all tensor level information
+        2.) the second table handles and displays all channel based information
+
+        The reasoning for this is that having all the info in one table can make it ambiguous which collected
+            statistics are global, and which are actually per-channel, so it's better to split it up into two
+            tables. This also makes the information much easier to digest given the plethora of statistics collected
+
+        Tensor table columns:
+            idx  layer_fqn  feature_1   feature_2   feature_3   .... feature_n
+            ----  ---------  ---------   ---------   ---------        ---------
+
+        Per-Channel table columns:
+            idx  layer_fqn  channel  feature_1   feature_2   feature_3   .... feature_n
+            ----  ---------  -------  ---------   ---------   ---------        ---------
+
+        Args:
+            feature_filter (str, optional): Filters the features presented to only those that
+                contain this filter substring
+                Default = "", results in all the features being printed
+            module_fqn_filter (str, optional): Only includes modules that contains this string
+                Default = "", results in all the modules in the reports to be visible in the table
+
+        Returns a dictionary with two keys:
+            (Dict[str, Tuple[List, List]]) A dict containing two keys:
+            "tensor_level_info", "channel_level_info"
+                Each key maps to a tuple with:
+                    A list of the headers of each table
+                    A list of lists containing the table information row by row
+                    The 0th index row will contain the headers of the columns
+                    The rest of the rows will contain data
+
+        Example Use:
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> mod_report_visualizer.generate_filtered_tables(
+            ...     feature_filter="per_channel_min", module_fqn_filter="block1"
+            ... )  # generates table with per_channel_min info for all modules in block 1 of the model
+        """
+        # first get the filtered data
+        filtered_data: OrderedDict[str, Any] = self._get_filtered_data(
+            feature_filter, module_fqn_filter
+        )
+
+        # now we split into tensor and per-channel data
+        tensor_features: set[str] = set()
+        channel_features: set[str] = set()
+
+        # keep track of the number of channels we have
+        num_channels: int = 0
+
+        for module_fqn in filtered_data:
+            for feature_name in filtered_data[module_fqn]:
+                # get the data for that specific feature
+                feature_data = filtered_data[module_fqn][feature_name]
+
+                # check if not zero dim tensor
+                is_tensor: bool = isinstance(feature_data, torch.Tensor)
+                is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0
+
+                if is_not_zero_dim or isinstance(feature_data, list):
+                    # works means per channel
+                    channel_features.add(feature_name)
+                    num_channels = len(feature_data)
+                else:
+                    # means is per-tensor
+                    tensor_features.add(feature_name)
+
+        # we make them lists for iteration purposes
+        tensor_features_list: list[str] = sorted(tensor_features)
+        channel_features_list: list[str] = sorted(channel_features)
+
+        # get the tensor info
+        tensor_headers, tensor_table = self._generate_tensor_table(
+            filtered_data, tensor_features_list
+        )
+
+        # get the channel info
+        channel_headers, channel_table = self._generate_channels_table(
+            filtered_data, channel_features_list, num_channels
+        )
+
+        # let's now create the dictionary to return
+        table_dict = {
+            self.TABLE_TENSOR_KEY: (tensor_headers, tensor_table),
+            self.TABLE_CHANNEL_KEY: (channel_headers, channel_table),
+        }
+
+        # return the two tables
+        return table_dict
+
+    def generate_table_visualization(
+        self, feature_filter: str = "", module_fqn_filter: str = ""
+    ):
+        r"""
+        Takes in optional filter values and prints out formatted tables of the information.
+
+        The reason for the two tables printed out instead of one large one are that they handle different things:
+        1.) the first table handles all tensor level information
+        2.) the second table handles and displays all channel based information
+
+        The reasoning for this is that having all the info in one table can make it ambiguous which collected
+            statistics are global, and which are actually per-channel, so it's better to split it up into two
+            tables. This also makes the information much easier to digest given the plethora of statistics collected
+
+        Tensor table columns:
+         idx  layer_fqn  feature_1   feature_2   feature_3   .... feature_n
+        ----  ---------  ---------   ---------   ---------        ---------
+
+        Per-Channel table columns:
+
+         idx  layer_fqn  channel  feature_1   feature_2   feature_3   .... feature_n
+        ----  ---------  -------  ---------   ---------   ---------        ---------
+
+        Args:
+            feature_filter (str, optional): Filters the features presented to only those that
+                contain this filter substring
+                Default = "", results in all the features being printed
+            module_fqn_filter (str, optional): Only includes modules that contains this string
+                Default = "", results in all the modules in the reports to be visible in the table
+
+        Example Use:
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> mod_report_visualizer.generate_table_visualization(
+            ...     feature_filter="per_channel_min", module_fqn_filter="block1"
+            ... )
+            >>> # prints out neatly formatted table with per_channel_min info
+            >>> # for all modules in block 1 of the model
+        """
+        # see if we got tabulate
+        if not got_tabulate:
+            print("Make sure to install tabulate and try again.")
+            return None
+
+        # get the table dict and the specific tables of interest
+        table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
+        tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
+        channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
+
+        # get the table string and print it out
+        # now we have populated the tables for each one
+        # let's create the strings to be returned
+        table_str = ""
+        # the tables will have some headers columns that are non-feature
+        # ex. table index, module name, channel index, etc.
+        # we want to look at header columns for features, that come after those headers
+        if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS:
+            # if we have at least one tensor level feature to be added we add tensor table
+            table_str += "Tensor Level Information \n"
+            table_str += tabulate(tensor_table, headers=tensor_headers)
+        if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS:
+            # if we have at least one channel level feature to be added we add tensor table
+            table_str += "\n\n Channel Level Information \n"
+            table_str += tabulate(channel_table, headers=channel_headers)
+
+        # if no features at all, let user know
+        if table_str == "":
+            table_str = "No data points to generate table with."
+
+        print(table_str)
+
+    def _get_plottable_data(
+        self, feature_filter: str, module_fqn_filter: str
+    ) -> tuple[list, list[list], bool]:
+        r"""
+        Takes in the feature filters and module filters and outputs the x and y data for plotting
+
+        Args:
+            feature_filter (str): Filters the features presented to only those that
+                contain this filter substring
+            module_fqn_filter (str): Only includes modules that contains this string
+
+        Returns a tuple of three elements
+            The first is a list containing relevant x-axis data
+            The second is a list containing the corresponding y-axis data
+            If the data is per channel
+        """
+        # get the table dict and the specific tables of interest
+        table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
+        tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
+        channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
+
+        # make sure it is only 1 feature that is being plotted
+        # get the number of features in each of these
+        tensor_info_features_count = (
+            len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
+        )
+        channel_info_features_count = (
+            len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
+        )
+
+        # see if valid tensor or channel plot
+        is_valid_per_tensor_plot: bool = tensor_info_features_count == 1
+        is_valid_per_channel_plot: bool = channel_info_features_count == 1
+
+        # offset should either be one of tensor or channel table or neither
+        feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
+        table = tensor_table
+
+        # if a per_channel plot, we have different offset and table
+        if is_valid_per_channel_plot:
+            feature_column_offset = (
+                ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
+            )
+            table = channel_table
+
+        x_data: list = []
+        y_data: list[list] = []
+        # the feature will either be a tensor feature or channel feature
+        if is_valid_per_tensor_plot:
+            for table_row_num, row in enumerate(table):
+                # get x_value to append
+                x_val_to_append = table_row_num
+                # the index of the feature will the 0 + num non feature columns
+                tensor_feature_index = feature_column_offset
+                row_value = row[tensor_feature_index]
+                if not type(row_value) == str:
+                    x_data.append(x_val_to_append)
+                    y_data.append(row_value)
+        elif is_valid_per_channel_plot:
+            # gather the x_data and multiple y_data
+            # calculate the number of channels
+            num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
+
+            # separate data list per channel
+            y_data.extend([] for _ in range(num_channels))
+
+            for table_row_num, row in enumerate(table):
+                # get x_value to append
+                x_val_to_append = table_row_num
+                current_channel = row[
+                    self.CHANNEL_NUM_INDEX
+                ]  # initially chose current channel
+                new_module_index: int = table_row_num // num_channels
+                x_val_to_append = new_module_index
+
+                # the index of the feature will the 0 + num non feature columns
+                tensor_feature_index = feature_column_offset
+                row_value = row[tensor_feature_index]
+                if not type(row_value) == str:
+                    # only append if new index we are appending
+                    if len(x_data) == 0 or x_data[-1] != x_val_to_append:
+                        x_data.append(x_val_to_append)
+
+                    # append value for that channel
+                    y_data[current_channel].append(row_value)
+        else:
+            # more than one feature was chosen
+            error_str = "Make sure to pick only a single feature with your filter to plot a graph."
+            error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names."
+            error_str += " Pick one of those features to plot."
+            raise ValueError(error_str)
+
+        # return x, y values, and if data is per-channel
+        return (x_data, y_data, is_valid_per_channel_plot)
+
+    def generate_plot_visualization(
+        self, feature_filter: str, module_fqn_filter: str = ""
+    ):
+        r"""
+        Takes in a feature and optional module_filter and plots of the desired data.
+
+        For per channel features, it averages the value across the channels and plots a point
+        per module. The reason for this is that for models with hundreds of channels, it can
+        be hard to differentiate one channel line from another, and so the point of generating
+        a single average point per module is to give a sense of general trends that encourage
+        further deep dives.
+
+        Note:
+            Only features in the report that have tensor value data are plottable by this class
+            When the tensor information is plotted, it will plot:
+                idx as the x val, feature value as the y_val
+            When the channel information is plotted, it will plot:
+                the first idx of each module as the x val, feature value as the y_val [for each channel]
+                The reason for this is that we want to be able to compare values across the
+                channels for same layer, and it will be hard if values are staggered by idx
+                This means each module is represented by only 1 x value
+        Args:
+            feature_filter (str): Filters the features presented to only those that
+                contain this filter substring
+            module_fqn_filter (str, optional): Only includes modules that contains this string
+                Default = "", results in all the modules in the reports to be visible in the table
+
+        Example Use:
+            >>> # xdoctest: +SKIP("undefined variables")
+            >>> mod_report_visualizer.generate_plot_visualization(
+            ...     feature_filter="per_channel_min", module_fqn_filter="block1"
+            ... )
+            >>> # outputs line plot of per_channel_min information for all
+            >>> # modules in block1 of model each channel gets it's own line,
+            >>> # and it's plotted across the in-order modules on the x-axis
+        """
+        # checks if we have matplotlib and let's user know to install it if don't
+        if not got_matplotlib:
+            print("make sure to install matplotlib and try again.")
+            return None
+
+        # get the x and y data and if per channel
+        x_data, y_data, data_per_channel = self._get_plottable_data(
+            feature_filter, module_fqn_filter
+        )
+
+        # plot based on whether data is per channel or not
+        ax = plt.subplot()
+        ax.set_ylabel(feature_filter)
+        ax.set_title(feature_filter + " Plot")
+        plt.xticks(x_data)  # only show ticks for actual points
+
+        if data_per_channel:
+            ax.set_xlabel("First idx of module")
+            # set the legend as well
+            # plot a single line that is average of the channel values
+            num_modules = len(
+                y_data[0]
+            )  # all y_data have same length, so get num modules
+            num_channels = len(
+                y_data
+            )  # we want num channels to be able to calculate average later
+
+            avg_vals = [
+                sum(y_data[:][index]) / num_channels for index in range(num_modules)
+            ]
+
+            # plot the three things we measured
+            ax.plot(
+                x_data, avg_vals, label=f"Average Value Across {num_channels} Channels"
+            )
+            ax.legend(loc="upper right")
+        else:
+            ax.set_xlabel("idx")
+            ax.plot(x_data, y_data)
+
+        # actually show the plot
+        plt.show()
+
+    def generate_histogram_visualization(
+        self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10
+    ):
+        r"""
+        Takes in a feature and optional module_filter and plots the histogram of desired data.
+
+        Note:
+            Only features in the report that have tensor value data can be viewed as a histogram
+            If you want to plot a histogram from all the channel values of a specific feature for
+                a specific model, make sure to specify both the model and the feature properly
+                in the filters and you should be able to see a distribution of the channel data
+
+        Args:
+            feature_filter (str, optional): Filters the features presented to only those that
+                contain this filter substring
+                Default = "", results in all the features being printed
+            module_fqn_filter (str, optional): Only includes modules that contains this string
+                Default = "", results in all the modules in the reports to be visible in the table
+            num_bins (int, optional): The number of bins to create the histogram with
+                Default = 10, the values will be split into 10 equal sized bins
+
+        Example Use:
+            >>> # xdoctest: +SKIP
+            >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization(
+            ...     feature_filter="per_channel_min", module_fqn_filter="block1"
+            ... )
+            # outputs histogram of per_channel_min information for all modules in block1 of model
+                information is gathered across all channels for all modules in block 1 for the
+                per_channel_min and is displayed in a histogram of equally sized bins
+        """
+        # checks if we have matplotlib and let's user know to install it if don't
+        if not got_matplotlib:
+            print("make sure to install matplotlib and try again.")
+            return None
+
+        # get the x and y data and if per channel
+        _x_data, y_data, data_per_channel = self._get_plottable_data(
+            feature_filter, module_fqn_filter
+        )
+
+        # for histogram, we just care about plotting the y data
+        # plot based on whether data is per channel or not
+        ax = plt.subplot()
+        ax.set_xlabel(feature_filter)
+        ax.set_ylabel("Frequency")
+        ax.set_title(feature_filter + " Histogram")
+
+        if data_per_channel:
+            # set the legend as well
+            # combine all the data
+            all_data = []
+            for channel_info in y_data:
+                all_data.extend(channel_info)
+
+            _val, bins, _ = plt.hist(
+                all_data,
+                bins=num_bins,
+                stacked=True,
+                rwidth=0.8,
+            )
+            plt.xticks(bins)
+        else:
+            _val, bins, _ = plt.hist(
+                y_data,
+                bins=num_bins,
+                stacked=False,
+                rwidth=0.8,
+            )
+            plt.xticks(bins)
+
+        plt.show()
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..9513fb288850b515b6e3c75660eccc0336a257ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/convert.py
@@ -0,0 +1,1264 @@
+# mypy: ignore-errors
+
+import copy
+import operator
+import warnings
+from typing import Any, Callable, Optional, Union
+
+import torch
+from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    get_native_backend_config,
+)
+from torch.ao.quantization.backend_config.utils import (
+    get_fused_module_classes,
+    get_pattern_to_dtype_configs,
+    get_qat_module_classes,
+    get_root_module_to_quantized_reference_module,
+)
+from torch.ao.quantization.observer import _is_activation_post_process
+from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny
+from torch.ao.quantization.qconfig_mapping import QConfigMapping
+from torch.ao.quantization.quant_type import QuantType
+from torch.ao.quantization.quantize import _remove_qconfig
+from torch.ao.quantization.stubs import DeQuantStub
+from torch.ao.quantization.utils import (
+    _parent_name,
+    activation_is_statically_quantized,
+    get_qparam_dict,
+    get_swapped_custom_module_class,
+    is_per_channel,
+    to_underlying_dtype,
+    weight_is_quantized,
+)
+from torch.fx import GraphModule
+from torch.fx.graph import Argument, Graph, Node
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+# importing the lib so that the quantized_decomposed ops are registered
+from ._decomposed import quantized_decomposed_lib  # noqa: F401
+from ._equalize import convert_eq_obs, update_obs_for_equalization
+from .custom_config import ConvertCustomConfig, PrepareCustomConfig
+from .graph_module import _is_observed_module, _is_observed_standalone_module
+from .lower_to_fbgemm import lower_to_fbgemm
+from .qconfig_mapping_utils import (
+    _compare_prepare_convert_qconfig_mappings,
+    _generate_node_name_to_qconfig,
+    _is_qconfig_supported_by_dtype_configs,
+    _update_qconfig_for_fusion,
+    _update_qconfig_for_qat,
+)
+from .utils import (
+    _get_module,
+    _is_custom_module_lstm,
+    _is_custom_module_mha,
+    assert_and_get_unique_device,
+    collect_producer_nodes,
+    create_getattr_from_value,
+    get_custom_module_class_keys,
+    graph_module_from_producer_nodes,
+    node_arg_is_weight,
+)
+
+
+__all__ = [
+    "convert",
+    "convert_custom_module",
+    "convert_standalone_module",
+    "convert_weighted_module",
+]
+
+SUPPORTED_QDTYPES = [
+    torch.quint8,
+    torch.qint8,
+    torch.qint32,
+    torch.uint8,
+    torch.int8,
+    torch.uint16,
+    torch.int16,
+    torch.int32,
+    torch.float8_e5m2,
+    torch.float8_e4m3fn,
+]
+
+_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
+    torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
+    torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
+}
+
+
+def _replace_observer_with_quantize_dequantize_node_decomposed(
+    model: torch.fx.GraphModule,
+    node: Node,
+    modules: dict[str, torch.nn.Module],
+    node_name_to_scope: dict[str, tuple[str, type]],
+    node_name_to_qconfig: dict[str, QConfigAny],
+) -> None:
+    """Replace activation_post_process module call node with quantize and
+    dequantize node working with decomposed Tensor
+
+    Before:
+    ... -> observer_0(x) -> ...
+    After:
+    ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
+    torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
+
+    or quantize_per_channel and dequantize_per_channel
+    """
+    graph = model.graph
+    assert modules is not None
+    assert isinstance(node.target, str)
+    module_path, prefix = _get_module_path_and_prefix(
+        node, node_name_to_scope, node_name_to_qconfig
+    )
+    activation_post_process = modules[node.target]
+    if hasattr(activation_post_process, "convert"):
+        activation_post_process.convert(model, node)
+        return
+    # skip replacing observers to quant/dequant nodes if the qconfigs of all
+    # consumers and producers of this observer are None
+    skip_replacement = all(
+        _has_none_qconfig(n, node_name_to_qconfig)
+        for n in list(node.args) + list(node.users.keys())
+    )
+    if skip_replacement or not _is_conversion_supported(activation_post_process):
+        # didn't find corresponding quantize op and info for the activation_post_process
+        # so we just remove the observer
+        with graph.inserting_before(node):
+            node.replace_all_uses_with(node.args[0])
+            graph.erase_node(node)
+        return
+
+    # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
+
+    # 1. extract the information from activation_post_process module for generating
+    # the quantize and dequantize operator
+    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
+
+    is_dynamic = False
+    if hasattr(activation_post_process, "is_dynamic"):
+        is_dynamic = activation_post_process.is_dynamic  # type: ignore[assignment]
+
+    def add_dequantize_op_kwargs(dequantize_op, input_node):
+        dequantize_op_kwargs = {}
+        if "val" in input_node.meta:
+            dq_out_dtype = input_node.meta["val"].dtype
+            if dq_out_dtype != torch.float32:
+                dequantize_op_kwargs = {"out_dtype": dq_out_dtype}
+        return dequantize_op_kwargs
+
+    if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
+        # TODO: probably should cleanup this condition check, it's hard
+        # to reason about this if and the following elif
+
+        # uint8/int8/int32 static quantization branch
+
+        # 1. extract information for inserting q/dq node from activation_post_process
+        node_type = "call_function"
+        quantize_op: Optional[Callable] = None
+        scale, zero_point = activation_post_process.calculate_qparams()  # type: ignore[attr-defined, operator]
+        if is_per_channel(activation_post_process.qscheme):  # type: ignore[attr-defined]
+            ch_axis = int(activation_post_process.ch_axis)  # type: ignore[attr-defined, arg-type]
+            quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
+            dequantize_op = (
+                torch.ops.quantized_decomposed.dequantize_per_channel.default
+            )
+            quant_min = activation_post_process.quant_min
+            quant_max = activation_post_process.quant_max
+            dtype_ = to_underlying_dtype(dtype)
+            qparams = {
+                "_scale_": scale,
+                "_zero_point_": zero_point,
+                "_axis_": ch_axis,
+                "_quant_min_": quant_min,
+                "_quant_max_": quant_max,
+                "_dtype_": dtype_,
+            }
+        else:
+            quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
+            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
+            scale = float(scale)
+            zero_point = int(zero_point)
+            quant_min = activation_post_process.quant_min  # type: ignore[attr-defined]
+            quant_max = activation_post_process.quant_max  # type: ignore[attr-defined]
+            dtype_ = to_underlying_dtype(dtype)
+            qparams = {
+                "_scale_": scale,
+                "_zero_point_": zero_point,
+                "_quant_min_": quant_min,
+                "_quant_max_": quant_max,
+                "_dtype_": dtype_,
+            }
+
+        # 2. replace activation_post_process node with quantize and dequantize
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value_or_node in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                if key in ["_scale_", "_zero_point_"] and (
+                    not isinstance(value_or_node, (float, int))
+                ):
+                    # For scale and zero_point values we register them as buffers in the root module.
+                    # However, note that when the values are not tensors, as in the case of
+                    # per_tensor quantization, they will be treated as literals.
+                    # However, registering them as a node seems to cause issue with dynamo
+                    # tracing where it may consider tensor overload as opposed to default.
+                    # With extra check of scale and zero_point being scalar, it makes
+                    # sure that the default overload can be used.
+                    # TODO: maybe need more complex attr name here
+                    qparam_node = create_getattr_from_value(
+                        model, graph, module_path + prefix + key, value_or_node
+                    )
+                    quantize_op_inputs.append(qparam_node)
+                else:
+                    # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
+                    quantize_op_inputs.append(value_or_node)
+
+            quantized_node = graph.create_node(
+                node_type, quantize_op, tuple(quantize_op_inputs), {}
+            )
+            # use the same qparams from quantize op
+            dq_inputs = [quantized_node] + quantize_op_inputs[1:]
+            dequantized_node = graph.call_function(
+                dequantize_op,
+                tuple(dq_inputs),
+                add_dequantize_op_kwargs(dequantize_op, input_node),
+            )
+
+            node.replace_all_uses_with(dequantized_node)
+            # propagate numeric debug handle from observer/fake_quant node to dequantize node
+            if (
+                CUSTOM_KEY in node.meta
+                and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
+            ):
+                if CUSTOM_KEY not in dequantized_node.meta:
+                    dequantized_node.meta[CUSTOM_KEY] = {}
+                dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
+                    CUSTOM_KEY
+                ][NUMERIC_DEBUG_HANDLE_KEY]
+            graph.erase_node(node)
+    elif is_dynamic:
+        # uint8/int8/fp16 dynamic quantization
+
+        # 1. extract information for inserting q/dq node from activation_post_process
+        node_type = "call_function"
+        quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
+        # we only use choose_qparams for is_decomposed now,
+        # but we should probably align the non-decomposed path with this as well,
+        # and that can be done after we remove reduce_range flag
+        # 1. extract qparams from activation_post_process module
+        dtype_ = to_underlying_dtype(dtype)
+        assert dtype_ in [torch.uint8, torch.int8], (
+            "only uint8 and int8 are supported in reference flow for "
+            "dynamic quantization right now"
+        )
+        quant_min = activation_post_process.quant_min  # type: ignore[attr-defined]
+        quant_max = activation_post_process.quant_max  # type: ignore[attr-defined]
+        qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine)  # type: ignore[attr-defined]
+        eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps)  # type: ignore[attr-defined]
+        # note: scale and zero_point are missing for quantize_per_tensor op
+        # we'll need to get this from choose_qparams op, which we'll add after
+        # this step
+        qparams = {
+            "_quant_min_": quant_min,
+            "_quant_max_": quant_max,
+            "_eps_": eps,
+            "_dtype_": dtype_,
+        }
+
+        choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme]
+        # 2. insert choose_qparams op and update the qparams list
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            choose_qparams_op_inputs = [node.args[0]]
+            for key, value in qparams.items():
+                # we have quant_min, quant_max and dtype, all should be stored
+                # as literals
+                choose_qparams_op_inputs.append(value)
+            choose_qparams_node = graph.create_node(
+                "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {}
+            )
+            # choose_qparms returns (scale, zero_point)
+            scale_node = graph.create_node(
+                "call_function", operator.getitem, (choose_qparams_node, 0), {}
+            )
+            zero_point_node = graph.create_node(
+                "call_function", operator.getitem, (choose_qparams_node, 1), {}
+            )
+            quant_min = qparams["_quant_min_"]
+            quant_max = qparams["_quant_max_"]
+            dtype = qparams["_dtype_"]
+            qparams = {
+                "_scale_": scale_node,
+                "_zero_point_": zero_point_node,
+                "_quant_min_": quant_min,
+                "_quant_max_": quant_max,
+                "_dtype_": dtype,
+            }
+
+        # 3. replace activation_post_process node to quantize and dequantize node
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value_or_node in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                if key in ["_scale_", "_zero_point_"]:
+                    # in this case we have a node in the graph since it's dynamically
+                    # computed from the input, with choose_qparams op
+                    qparam_node = value_or_node
+                    quantize_op_inputs.append(qparam_node)
+                else:
+                    # for qparams that are not scale/zero_point (like axis, dtype) we
+                    # store them as literals in the graph.
+                    quantize_op_inputs.append(value_or_node)
+
+            quantized_node = graph.create_node(
+                node_type, quantize_op, tuple(quantize_op_inputs), {}
+            )
+            # use the same qparams from quantize op
+            dq_inputs = [quantized_node] + quantize_op_inputs[1:]
+            # need to use the tensor variant of this op, since scale and zero_point
+            # from choose_qparam are Tensors, instead of float/int, this is to
+            # prevent these nodes being traced away by downstream systems
+            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
+            dequantized_node = graph.call_function(
+                dequantize_op,
+                tuple(dq_inputs),
+                add_dequantize_op_kwargs(dequantize_op, input_node),
+            )
+
+            node.replace_all_uses_with(dequantized_node)
+            # propagate numeric debug handle from observer/fake_quant node to dequantize node
+            if NUMERIC_DEBUG_HANDLE_KEY in node.meta:
+                dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
+                    NUMERIC_DEBUG_HANDLE_KEY
+                ]
+            graph.erase_node(node)
+    elif dtype == torch.float16:
+        # Insert to_fp16 -> to_fp32 node
+        dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            convert_fp16_node = graph.create_node(
+                "call_function", dtype_convert_op, (input_node, torch.float16), {}
+            )
+            convert_fp32_node = graph.create_node(
+                "call_function", dtype_convert_op, (convert_fp16_node, torch.float), {}
+            )
+            node.replace_all_uses_with(convert_fp32_node)
+            graph.erase_node(node)
+
+    # should not reach since we have checks in the beginning to make sure the
+    # activation_post_process is supported
+
+
+def _replace_observer_with_quantize_dequantize_node(
+    model: torch.fx.GraphModule,
+    node: Node,
+    modules: dict[str, torch.nn.Module],
+    node_name_to_scope: dict[str, tuple[str, type]],
+    node_name_to_qconfig: dict[str, QConfigAny],
+) -> None:
+    """Replace activation_post_process module call node with quantize and
+    dequantize node
+
+    Before:
+    ... -> observer_0(x) -> ...
+    After:
+    ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
+    """
+    assert modules is not None
+    assert isinstance(node.target, str)
+    graph = model.graph
+    module_path, prefix = _get_module_path_and_prefix(
+        node, node_name_to_scope, node_name_to_qconfig
+    )
+    activation_post_process = modules[node.target]
+    # skip replacing observers to quant/dequant nodes if the qconfigs of all
+    # consumers and producers of this observer are None
+    skip_replacement = all(
+        _has_none_qconfig(n, node_name_to_qconfig)
+        for n in list(node.args) + list(node.users.keys())
+    )
+    if skip_replacement or not _is_conversion_supported(activation_post_process):
+        # didn't find corresponding quantize op and info for the activation_post_process
+        # so we just remove the observer
+        with graph.inserting_before(node):
+            node.replace_all_uses_with(node.args[0])
+            graph.erase_node(node)
+        return
+
+    # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
+    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
+
+    is_dynamic = False
+    if hasattr(activation_post_process, "is_dynamic"):
+        is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]
+
+    if dtype in [
+        torch.quint8,
+        torch.qint8,
+        torch.qint32,
+        torch.float8_e5m2,
+        torch.float8_e4m3fn,
+    ] and (not is_dynamic):
+        # TODO: probably should cleanup this condition check, it's hard
+        # to reason about this if and the following elif
+
+        # uint8/int8/int32 static quantization branch
+
+        # 1. extract the information from activation_post_process module for generating
+        # the quantize and dequantize operator
+        node_type = "call_function"
+        quantize_op: Optional[Callable] = None
+        scale, zero_point = activation_post_process.calculate_qparams()  # type: ignore[attr-defined, operator]
+        if is_per_channel(activation_post_process.qscheme):  # type: ignore[attr-defined]
+            ch_axis = int(activation_post_process.ch_axis)  # type: ignore[attr-defined, arg-type]
+            qparams = {
+                "_scale_": scale,
+                "_zero_point_": zero_point,
+                "_axis_": ch_axis,
+                "_dtype_": dtype,
+            }
+            quantize_op = torch.quantize_per_channel
+        else:
+            scale = float(scale)
+            zero_point = int(zero_point)
+            qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
+            quantize_op = torch.quantize_per_tensor
+
+        # 2. replace activation_post_process node with quantize and dequantize
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value_or_node in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                if key in ["_scale_", "_zero_point_"]:
+                    # For scale and zero_point values we register them as buffers in the root module.
+                    # TODO: maybe need more complex attr name here
+                    qparam_node = create_getattr_from_value(
+                        model, graph, module_path + prefix + key, value_or_node
+                    )
+                    quantize_op_inputs.append(qparam_node)
+                else:
+                    # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
+                    quantize_op_inputs.append(value_or_node)
+
+            quantized_node = graph.create_node(
+                node_type, quantize_op, tuple(quantize_op_inputs), {}
+            )
+            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+    elif is_dynamic:
+        # uint8/int8/fp16 dynamic quantization branch
+
+        node_type = "call_function"
+        quantize_op = torch.quantize_per_tensor_dynamic
+        # TODO: get reduce range from observer
+        # reduce_range = activation_post_process.reduce_range
+        reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
+        qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range}
+
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value in qparams.items():
+                quantize_op_inputs.append(value)
+
+            quantized_node = graph.create_node(
+                node_type, quantize_op, tuple(quantize_op_inputs), {}
+            )
+            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+    elif dtype == torch.float16:
+        node_type = "call_method"
+        quantize_op = "to"  # type: ignore[assignment]
+        qparams = {"_dtype_": dtype}
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                quantize_op_inputs.append(value)
+
+            quantized_node = graph.create_node(
+                node_type, quantize_op, tuple(quantize_op_inputs), {}
+            )
+            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+
+    # should not reach since we have checks in the beginning to make sure the
+    # activation_post_process is supported
+
+
+# this is a temporary hack for custom module, we may want to implement
+# this properly after the custom module class design is finalized
+# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
+# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
+# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
+def _replace_observer_or_dequant_stub_with_dequantize_node(
+    node: Node, graph: Graph
+) -> None:
+    call_custom_module_node = node.args[0]
+    assert isinstance(call_custom_module_node, Node), (
+        f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
+    )
+    node.replace_all_uses_with(call_custom_module_node)
+    graph.erase_node(node)
+    _insert_dequantize_node(call_custom_module_node, graph)
+
+
+def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
+    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
+
+    is_dynamic = False
+    if hasattr(activation_post_process, "is_dynamic"):
+        is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]
+
+    return (
+        (dtype in SUPPORTED_QDTYPES and (not is_dynamic))
+        or is_dynamic  # type: ignore[return-value]
+        or dtype == torch.float16
+    )
+
+
+def _has_none_qconfig(
+    node: Argument, node_name_to_qconfig: dict[str, QConfigAny]
+) -> bool:
+    """Check if a node has a qconfig of None, i.e. user requested to not quantize
+    the node
+    """
+    return (
+        isinstance(node, Node)
+        and node.name in node_name_to_qconfig
+        and node_name_to_qconfig[node.name] is None
+    )
+
+
+def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
+    """Extract the subgraph that produces the weight for dynamic quant
+    or weight only quant node and run the subgraph to observe the weight.
+    Note that the observers of dynamic quant or weight only quant ops are
+    run during the convert step.
+    """
+    for node in observed.graph.nodes:
+        if node.op != "call_function":
+            continue
+        for node_arg in node.args:
+            # node_arg is weight
+            if node_arg and node_arg_is_weight(node, node_arg):
+                weight_observer_nodes = collect_producer_nodes(node_arg)
+                if weight_observer_nodes is None:
+                    continue
+                weight_observer_module = graph_module_from_producer_nodes(
+                    observed, weight_observer_nodes
+                )
+                # run the weight observer
+                weight_observer_module()
+
+
+def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
+    """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
+    we'll recursively remove the dequantize Node
+    """
+    if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize":
+        quantize_node = arg.args[0]
+        # we only replace the specific use since dequantize could be used by other nodes
+        # as well
+        node.replace_input_with(arg, quantize_node)
+    elif isinstance(arg, (list, tuple)):
+        for arg_element in arg:
+            _maybe_recursive_remove_dequantize(arg_element, node, graph)
+    elif isinstance(arg, dict):
+        for arg_element in arg.values():
+            _maybe_recursive_remove_dequantize(arg_element, node, graph)
+    else:
+        warnings.warn(
+            f"Unsupported node type in recursive remove dequantize: {type(arg)}"
+        )
+
+
+def _get_module_path_and_prefix(
+    obs_node: Node,
+    node_name_to_scope: dict[str, tuple[str, type]],
+    node_name_to_qconfig: dict[str, QConfigAny],
+) -> tuple[str, str]:
+    """Given and observer node, get the `Scope` or the fully qualified name for
+    the submodule containing the observed node, also return a prefix of "_input"
+    when the observed node is an input of a F.linear op, and not the output of another
+    quantized op.
+    TODO: this logic is hacky, we should think about how to remove it or make it more
+    general
+    """
+    observed_node = obs_node.args[0]
+    # an observer can be inserted for both input of the next operator or output of the previous
+    # operator (they can be the same)
+    # this flag identifies if the observer is inserted only because the observed node is
+    # the input of the next operator
+    assert isinstance(observed_node, Node), (
+        f"Expecting observed node to be a Node, but got {observed_node}"
+    )
+    is_input_observer_only = (
+        node_name_to_qconfig[observed_node.name] is None
+        if observed_node.name in node_name_to_qconfig
+        else None
+    )
+    if is_input_observer_only:
+        # if the quantize function is at the input of op, then we find the first user of the observer_node
+        # to get the path. If a linear call_function is in the user list, we return the first instance
+        # of linear node to get the FQN.
+        users = list(obs_node.users)
+        first_linear_use_or_first_use = users[0] if users else None
+        linear_node = None
+        for n in users:
+            if n.op == "call_function" and n.target == torch.nn.functional.linear:
+                linear_node = n
+                break
+        if linear_node:
+            first_linear_use_or_first_use = linear_node
+        prefix = "_input"
+    else:
+        # if the quantize function is at the output of the op, we use the observer input node to get the path
+        first_linear_use_or_first_use = observed_node
+        prefix = ""
+
+    if (
+        first_linear_use_or_first_use
+        and first_linear_use_or_first_use.name in node_name_to_scope
+    ):
+        module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
+    else:
+        # TODO: it's not used, so actually we can skip quantization
+        # but this requires changing return type of quantize_node
+        # we can fix it later if needed
+        module_path = ""
+    return module_path, prefix
+
+
+def _insert_dequantize_node(node: Node, graph: Graph) -> None:
+    """Inserts dequantize node for `node` in `graph`"""
+    with graph.inserting_after(node):
+        dequantize_node = graph.call_method("dequantize", (node,))
+        for user_node in dict(node.users):
+            if user_node is not dequantize_node:
+                user_node.replace_input_with(node, dequantize_node)
+
+
+def _maybe_get_observer_for_node(
+    node: Node, modules: dict[str, torch.nn.Module]
+) -> Optional[torch.nn.Module]:
+    """
+    If the node is observed, return the observer
+    instance. Otherwise, return None.
+    """
+    for maybe_obs_node in node.users.keys():
+        if maybe_obs_node.op == "call_module":
+            maybe_obs = modules[str(maybe_obs_node.target)]
+            if _is_activation_post_process(maybe_obs):
+                return maybe_obs
+    return None
+
+
+def convert_standalone_module(
+    node: Node,
+    modules: dict[str, torch.nn.Module],
+    model: torch.fx.GraphModule,
+    is_reference: bool,
+    backend_config: Optional[BackendConfig],
+) -> None:
+    """Converts a observed standalone module to a quantized standalone module by calling
+    the fx convert api, currently using the same `is_reference` flag as parent, but we may
+    changing this behavior in the future (e.g. separating quantization and lowering for
+    standalone module as well)
+
+    Args:
+      - node: The call_module node of the observed standalone module
+      - modules: named_module of original model
+      - model: original model
+      - is_reference: a flag from parent provided by user to decide if we want to
+        produce a reference model or a fbgemm/qnnpack model
+      - backend_config: backend configuration of the target backend of quantization
+    """
+    # TODO: remove is_reference flag
+    if is_reference:
+        convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
+    else:
+        convert_fn = torch.ao.quantization.quantize_fx.convert_fx  # type: ignore[attr-defined]
+    # We know that observed standalone module is a GraphModule since
+    # it's produced by us
+    observed_standalone_module: GraphModule = modules[str(node.target)]  # type: ignore[assignment]
+    sm_input_quantized_idxs = observed_standalone_module.meta[
+        "_observed_graph_module_attrs"
+    ].standalone_module_input_quantized_idxs
+    # remove the dequantize nodes for inputs
+    args = list(node.args)
+    for idx in range(len(args)):
+        if idx in sm_input_quantized_idxs:
+            arg = args[idx]
+            if arg.op == "call_method" and arg.target == "dequantize":  # type: ignore[union-attr]
+                quantize_node = arg.args[0]  # type: ignore[union-attr]
+                node.replace_input_with(arg, quantize_node)
+                if len(arg.users) == 0:  # type: ignore[union-attr]
+                    model.graph.erase_node(arg)
+    # add dequantize node for output
+    sm_output_quantized_idxs = observed_standalone_module.meta[
+        "_observed_graph_module_attrs"
+    ].standalone_module_output_quantized_idxs
+    if len(sm_output_quantized_idxs) > 0:
+        assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
+        "output idxs = [0] is supported"
+
+        # if it's non-empty, then it means the output is kept in quantized form
+        # we'll just add a dequantize node after this node
+        _insert_dequantize_node(node, model.graph)
+
+    # TODO: allow convert_custom_config to override backend_config
+    # for standalone module
+    quantized_standalone_module = convert_fn(
+        observed_standalone_module, backend_config=backend_config
+    )
+    parent_name, name = _parent_name(node.target)
+    # update the modules dict
+    setattr(modules[parent_name], name, quantized_standalone_module)
+    modules[str(node.target)] = quantized_standalone_module
+
+
+def convert_weighted_module(
+    node: Node,
+    modules: dict[str, torch.nn.Module],
+    observed_node_names: set[str],
+    node_name_to_qconfig: dict[str, QConfigAny],
+    backend_config: BackendConfig,
+    is_decomposed: bool = False,
+    is_reference: bool = False,
+) -> None:
+    """Convert a weighted module to reference quantized module in the model
+    If the QConfig of a QAT module is not set, the module will still be converted to
+    a float module.
+
+    Args:
+      - node: The call_module node of the observed standalone module
+      - modules: named_module of original model
+      - observed_node_names: names for the set of observed fx node, we can skip
+        this conversion if the node is not observed
+    """
+    original_module = modules[str(node.target)]
+    qconfig: QConfigAny = original_module.qconfig  # type: ignore[assignment]
+    weight_post_process = None
+    qat_module_classes = get_qat_module_classes(backend_config)
+
+    if isinstance(original_module, qat_module_classes):
+        # Converting qat module to a float module, we need to attach
+        # weight fake_quant to the module, weight fake_quant is assumed to be run during
+        # QAT so we don't need to run it again here
+        weight_post_process = original_module.weight_fake_quant
+        original_module = original_module.to_float()  # type: ignore[operator]
+        # change qat module to float module
+        parent_name, name = _parent_name(node.target)
+        setattr(modules[parent_name], name, original_module)
+
+    is_observed = node.name in observed_node_names
+    # If a qconfig is not defined for this node, then skip converting to a reference module
+    if (
+        qconfig is None
+        or _has_none_qconfig(node, node_name_to_qconfig)
+        or not is_observed
+    ):
+        return
+
+    # skip converting to reference quantized module if the qconfig is not supported
+    pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
+    dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
+    if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
+        return
+
+    # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
+    is_weight_quantized = weight_is_quantized(qconfig)
+
+    # the condition for swapping the module to reference quantized module is:
+    # weights need to be quantized
+    if not is_weight_quantized:
+        return
+
+    fused_module = None
+    float_module = original_module
+    # extract the individual float_module and fused module
+    if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule):
+        fused_module = float_module
+        float_module = fused_module[0]  # type: ignore[index]
+
+    # TODO: move this to the reference quantized module
+    # weight_qparams or weight_qparams dict
+    wq_or_wq_dict = {"is_decomposed": is_decomposed}
+    if isinstance(float_module, torch.nn.RNNCellBase):
+        weight_post_process_ih = qconfig.weight()  # type: ignore[union-attr, operator]
+        weight_post_process_hh = qconfig.weight()  # type: ignore[union-attr, operator]
+        weight_post_process_ih(float_module.weight_ih)
+        weight_post_process_hh(float_module.weight_hh)
+        weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
+        weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
+        wq_or_wq_dict.update(
+            {
+                "weight_ih": weight_qparams_ih,
+                "weight_hh": weight_qparams_hh,
+            }
+        )
+    elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
+        # format for wq_or_wq_dict (flattened attributes):
+        # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
+        for wn in float_module._flat_weights_names:
+            if hasattr(float_module, wn) and wn.startswith("weight"):
+                weight = getattr(float_module, wn)
+                weight_post_process = qconfig.weight()  # type: ignore[union-attr, operator]
+                if weight_post_process.dtype == torch.qint8:  # type: ignore[union-attr]
+                    weight_post_process(weight)  # type: ignore[operator, misc]
+                wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
+    else:
+        # weight_post_process is None means the original module is not a QAT module
+        # we need to get weight_post_process from qconfig in this case
+        is_ptq = weight_post_process is None
+        if is_ptq:
+            weight_post_process = qconfig.weight()  # type: ignore[union-attr, operator]
+            device = assert_and_get_unique_device(float_module)
+            if device:
+                weight_post_process.to(device)
+
+        # Call weight observer/fake_quant at least once to ensure the scales and zero points
+        # have the right shapes. Note: there are two cases where we don't have to do this:
+        #
+        # (1) QAT: The model's forward method already calls the weight observer/fake_quant,
+        #     and this typically happens during training, so we don't need to do it here.
+        #
+        # (2) Non-reference (lowered) case: The quantized module's from_float method already
+        #     calls the weight observer/fake_quant, so we don't have to do it here.
+        #
+        # Currently we ignore both cases and call the weight observer/fake_quant here
+        # regardless, which is technically incorrect. For (1), this is mainly to preserve BC
+        # in test code, which may not always train before convert. In the future, we should
+        # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941.
+        #
+        # For PT2, however, we don't need to preserve BC here, so we can skip this hack
+        # for QAT. We identify this case as (is_decomposed + is_reference + is_qat).
+        # Note that we still need it for PTQ in the PT2 flow since the model's forward
+        # method doesn't call the weight observer.
+        is_qat = not is_ptq
+        if not (is_decomposed and is_reference and is_qat):
+            weight_post_process(float_module.weight)  # type: ignore[operator]
+
+        wq_or_wq_dict.update(get_qparam_dict(weight_post_process))
+
+    # We use the same reference module for all modes of quantization: static, dynamic, weight_only
+    # root_module_to_quantized_reference_module: module mapping from root (floating point) module class
+    # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
+    root_module_to_quantized_reference_module = (
+        get_root_module_to_quantized_reference_module(backend_config)
+    )
+    ref_qmodule_cls = root_module_to_quantized_reference_module.get(
+        type_before_parametrizations(float_module), None
+    )
+    assert ref_qmodule_cls is not None, (
+        f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
+    )
+    ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict)  # type: ignore[attr-defined]
+    if fused_module is not None:
+        fused_module[0] = ref_qmodule  # type: ignore[operator]
+    else:
+        parent_name, name = _parent_name(node.target)
+        setattr(modules[parent_name], name, ref_qmodule)
+
+
+def _remove_previous_dequantize_in_custom_module(
+    node: Node, prev_node: Node, graph: Graph
+) -> None:
+    """
+    Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
+
+    Before: quantize - dequantize - custom_module
+    After: quantize - custom_module
+                 \\ - dequantize
+    """
+    # expecting the input node for a custom module node to be a Node
+    assert isinstance(prev_node, Node), (
+        f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
+    )
+    if prev_node.op == "call_method" and prev_node.target == "dequantize":
+        node.replace_input_with(prev_node, prev_node.args[0])
+        # Remove the dequantize node if it doesn't have other users
+        if len(prev_node.users) == 0:
+            graph.erase_node(prev_node)
+
+
+def convert_custom_module(
+    node: Node,
+    graph: Graph,
+    modules: dict[str, torch.nn.Module],
+    custom_module_class_mapping: dict[QuantType, dict[type, type]],
+    statically_quantized_custom_module_nodes: set[Node],
+) -> None:
+    """Converts an observed custom module to a quantized custom module based on
+    `custom_module_class_mapping`
+    For static quantization, we'll also remove the previous `dequantize` node and
+    attach the observer node for output to the module, the observer for the node
+    will be converted to a dequantize node instead of quantize-dequantize pairs
+    later in the graph. In the end we would have a quantized custom module that
+    has the same interface as a default quantized module in nn.quantized namespace,
+    i.e. quantized input and quantized output.
+
+    Args:
+      - node: The call_module node of the observed standalone module
+      - graph: The graph containing the node
+      - modules: named_module of original model
+      - custom_module_class_mapping: mapping from observed custom module class to
+        quantized custom module class, used to swap custom modules
+      - statically_quantized_custom_module_nodes: we'll add the custom module node
+        if we find it is statically quantized, this will be used later when converting
+        observers to quant/dequant node pairs, if the observed node is a statically
+        quantized custom module nodes, we'll convert the observer to a dequantize node,
+        this is to keep the interface the same as the default quantized module.
+        TODO: maybe we want to redesign this part to align with reference model design
+        as well, but there has been some discussions around the interface, so we can do
+        it later.
+    """
+    observed_custom_module = modules[str(node.target)]
+    qconfig = observed_custom_module.qconfig
+    if activation_is_statically_quantized(qconfig):
+        statically_quantized_custom_module_nodes.add(node)
+        if _is_custom_module_lstm(node, modules):
+            # The inputs are tuples in the form (input, (hidden0, hidden1))
+            # Ensure all three input nodes are quantized
+            assert (
+                len(node.args) == 2
+                and isinstance(node.args[1], tuple)
+                and len(node.args[1]) == 2
+            )
+            (inputs, (hidden0, hidden1)) = node.args  # type: ignore[misc]
+            assert isinstance(inputs, Node)
+            assert isinstance(hidden0, Node)
+            assert isinstance(hidden1, Node)
+            _remove_previous_dequantize_in_custom_module(node, inputs, graph)
+            _remove_previous_dequantize_in_custom_module(node, hidden0, graph)
+            _remove_previous_dequantize_in_custom_module(node, hidden1, graph)
+        elif _is_custom_module_mha(node, modules):
+            # Inputs are in the form (query, key, value)
+            # TODO: This is the first step in enabling the full fx custom module
+            # quantization path for MultiheadAttention, and only covers the inputs
+            # to the module.
+            # Additional handling is yet to be implemented for the outputs, similar
+            # to LSTM custom module
+            assert len(node.args) == 3
+            query, key, value = node.args
+            assert isinstance(query, Node)
+            assert isinstance(key, Node)
+            assert isinstance(value, Node)
+            _remove_previous_dequantize_in_custom_module(node, query, graph)
+            _remove_previous_dequantize_in_custom_module(node, key, graph)
+            _remove_previous_dequantize_in_custom_module(node, value, graph)
+        else:
+            # remove the previous dequant node to ensure the inputs are quantized
+            arg = node.args[0]
+            assert isinstance(arg, Node)
+            _remove_previous_dequantize_in_custom_module(node, arg, graph)
+            # absorb the following observer into the module conversion
+            activation_post_process = _maybe_get_observer_for_node(node, modules)
+            assert activation_post_process is not None
+            observed_custom_module.activation_post_process = activation_post_process
+
+    # swap the observed custom module to quantized custom module
+    quantized_custom_module_class = get_swapped_custom_module_class(
+        observed_custom_module, custom_module_class_mapping, qconfig
+    )
+    quantized_custom_module = quantized_custom_module_class.from_observed(
+        observed_custom_module
+    )
+    parent_name, name = _parent_name(node.target)
+    setattr(modules[parent_name], name, quantized_custom_module)
+
+
+def convert(
+    model: GraphModule,
+    is_reference: bool = False,
+    convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None,
+    is_standalone_module: bool = False,
+    _remove_qconfig_flag: bool = True,
+    qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, dict[str, Any], None] = None,
+    is_decomposed: bool = False,
+    keep_original_weights: bool = False,
+) -> GraphModule:
+    """
+    We will convert an observed model (a module with observer calls) to a reference
+    quantized model, the rule is simple:
+    1. for each observer module call in the graph, we'll convert it to calls to
+       quantize and dequantize functions based on the observer instance
+    2. for weighted operations like linear/conv, we need to convert them to reference
+       quantized module, this requires us to know whether the dtype configured for the
+       weight is supported in the backend, this is done in prepare step and the result
+       is stored in observed_node_names, we can decide whether we need to swap the
+       module based on this set
+
+    Args:
+       * `is_standalone_module`: when this flag is True, it means we are quantizing
+       a submodule that is not inlined in parent module, and will be quantized
+       separately as one unit.
+
+       * `is_decomposed`: a boolean flag to indicate whether we want to use the
+        quantize operator for decomposed quantized tensor
+        (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
+        quantized tensor (torch.quantize_per_tensor)
+
+    Returns:
+         a quantized standalone module, whether input/output is quantized is
+         specified by prepare_custom_config, with
+         input_quantized_idxs, output_quantized_idxs, please
+         see docs for :func:`~torch.ao.quantization.prepare_fx` for details
+    """
+    if convert_custom_config is None:
+        convert_custom_config = ConvertCustomConfig()
+
+    if isinstance(convert_custom_config, dict):
+        warnings.warn(
+            "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
+            "in a future version. Please pass in a ConvertCustomConfig instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
+
+    if isinstance(qconfig_mapping, dict):
+        warnings.warn(
+            "Passing a QConfig dictionary to convert is deprecated and will not be supported "
+            "in a future version. Please pass in a QConfigMapping instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        qconfig_mapping = (
+            QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
+        )
+    qconfig_mapping = copy.deepcopy(qconfig_mapping)
+    assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
+
+    if isinstance(backend_config, dict):
+        warnings.warn(
+            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a BackendConfig instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        backend_config = BackendConfig.from_dict(backend_config)
+
+    if backend_config is None:
+        backend_config = get_native_backend_config()
+
+    assert _is_observed_module(model), "incoming model must be produced by prepare_fx"
+    observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
+    node_name_to_scope: dict[str, tuple[str, type]] = (
+        observed_graph_module_attrs.node_name_to_scope
+    )
+    prepare_custom_config: PrepareCustomConfig = (
+        observed_graph_module_attrs.prepare_custom_config
+    )
+    observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names
+    node_name_to_qconfig: dict[str, QConfigAny] = (
+        observed_graph_module_attrs.node_name_to_qconfig
+    )  # type: ignore[assignment]
+
+    # mapping from fully qualified module name to module instance
+    # for example,
+    # {
+    #   '': Model(...),
+    #   'linear': Linear(...),
+    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
+    # }
+    # We use remove_duplicate=False here because torch.cat uses
+    # the same activation_post_process module instance but different names
+    modules = dict(model.named_modules(remove_duplicate=False))
+
+    # TODO refactor this code once we update the prepare logic to have additional information on
+    # which graph nodes have been observed and share that with convert to decide which observers to ignore.
+    if qconfig_mapping:
+        prepare_qconfig_mapping: QConfigMapping = (
+            observed_graph_module_attrs.qconfig_mapping
+        )  # type: ignore[assignment]
+        modules_copy = copy.deepcopy(modules)
+
+        if observed_graph_module_attrs.is_qat:
+            _update_qconfig_for_qat(qconfig_mapping, backend_config)
+        _update_qconfig_for_fusion(model, qconfig_mapping)
+
+        _compare_prepare_convert_qconfig_mappings(
+            prepare_qconfig_mapping, qconfig_mapping
+        )  # type: ignore[arg-type]
+        convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
+            model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope
+        )
+        # check the convert_node_name_to_qconfig generated and ensure that
+        # all the values either match what was set in prepare node_name_to_qconfig
+        # or are set to None in the convert_node_name_to_qconfig.
+        for k, v in node_name_to_qconfig.items():
+            assert k in convert_node_name_to_qconfig, (
+                f"Expected key {k} in convert node_name_to_qconfig"
+            )
+            if convert_node_name_to_qconfig[k] is not None:
+                assert qconfig_equals(v, convert_node_name_to_qconfig[k]), (
+                    f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
+                    f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
+                )
+        node_name_to_qconfig = convert_node_name_to_qconfig
+
+    custom_module_classes = get_custom_module_class_keys(
+        convert_custom_config.observed_to_quantized_mapping
+    )
+    custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
+
+    if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None:
+        # If we want to do equalization then do the following:
+        # Calculate the equalization scale, update the observers with the scaled
+        # inputs, and scale the weight
+        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
+        convert_eq_obs(model, modules, weight_eq_obs_dict)
+
+    # always run weight observers in the top level forward method
+    # for dynamic quant ops or weight only quant ops
+    _run_weight_observers(model, backend_config)
+
+    # additional state to override inputs to be quantized, if specified
+    # by the user
+    placeholder_node_seen_cnt = 0
+    input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
+    output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes
+
+    root_module_to_quantized_reference_module = (
+        get_root_module_to_quantized_reference_module(backend_config)
+    )
+    # convert tuples so that it can work with isinstance(module, tuple_of_classes)
+    root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
+    qat_module_classes = get_qat_module_classes(backend_config)
+    fused_module_classes = get_fused_module_classes(backend_config)
+    statically_quantized_custom_module_nodes: set[Node] = set()
+
+    for node in list(model.graph.nodes):
+        if node.op == "placeholder":
+            cur_placeholder_node_idx = placeholder_node_seen_cnt
+            placeholder_node_seen_cnt += 1
+            if cur_placeholder_node_idx in input_quantized_idxs:
+                # Inputs are assumed to be quantized if the user specified the
+                # input_quantized_idxs override.
+                # we need to dequantize the inputs since all operators took
+                # floating point inputs in reference quantized models
+                _insert_dequantize_node(node, model.graph)
+        elif node.op == "output":
+            # If the argument is empty we don't need to do anything
+            if len(output_quantized_idxs) == 0:
+                continue
+            # Result are kept quantized if the user specified the
+            # output_quantized_idxs override.
+            # Remove the dequantize operator for the node in the end if any
+            return_node = node
+            output = node.args[0]
+            # outputs can be Node, list, tuple, dict, other cases are not supported yet
+            if isinstance(output, (list, tuple)):
+                for idx in output_quantized_idxs:
+                    _maybe_recursive_remove_dequantize(
+                        output[idx], return_node, model.graph
+                    )
+            elif isinstance(output, (Node, dict)):
+                # we treat dict as a single argument currently, but it can be extended
+                # to support {"key": dtype} after we change output_quantized_idxs to
+                # dict
+                if 0 in output_quantized_idxs:
+                    _maybe_recursive_remove_dequantize(output, return_node, model.graph)
+            else:
+                warnings.warn(
+                    f"Unsupported node type for output_quantized_idxs: {type(output)}"
+                )
+        elif node.op == "call_module":
+            mod = _get_module(node, modules)
+            assert mod is not None
+            if _is_activation_post_process(mod):
+                observed_node = node.args[0]
+                if observed_node in statically_quantized_custom_module_nodes:
+                    _replace_observer_or_dequant_stub_with_dequantize_node(
+                        node, model.graph
+                    )
+                else:
+                    if is_decomposed:
+                        _replace_observer_with_quantize_dequantize_node_decomposed(
+                            model,
+                            node,
+                            modules,
+                            node_name_to_scope,
+                            node_name_to_qconfig,
+                        )
+                    else:
+                        _replace_observer_with_quantize_dequantize_node(
+                            model,
+                            node,
+                            modules,
+                            node_name_to_scope,
+                            node_name_to_qconfig,
+                        )
+            elif isinstance(mod, DeQuantStub):
+                _replace_observer_or_dequant_stub_with_dequantize_node(
+                    node, model.graph
+                )
+            elif _is_observed_standalone_module(mod):
+                convert_standalone_module(
+                    node, modules, model, is_reference, backend_config
+                )
+            # below this point `type_before_parametrizations` is used
+            # instead of `type` to handle situations with fx quant + sparsity
+            elif type_before_parametrizations(mod) in set(root_module_classes).union(
+                qat_module_classes
+            ).union(fused_module_classes):
+                # extra check for fused module classes to make sure they are fused module classes
+                # of target modules
+                if (
+                    type_before_parametrizations(mod) in fused_module_classes
+                    and type_before_parametrizations(mod[0]) not in root_module_classes
+                ):  # type: ignore[index]
+                    continue
+                convert_weighted_module(
+                    node,
+                    modules,
+                    observed_node_names,
+                    node_name_to_qconfig,
+                    backend_config,
+                    is_decomposed,
+                    is_reference,
+                )
+            elif type_before_parametrizations(mod) in custom_module_classes:
+                convert_custom_module(
+                    node,
+                    model.graph,
+                    modules,
+                    custom_module_class_mapping,
+                    statically_quantized_custom_module_nodes,
+                )
+
+    # remove deadcode after converting observers to quant/dequant ops
+    model.graph.eliminate_dead_code()
+    model = GraphModule(model, model.graph)
+
+    # TODO: maybe move this to quantize_fx.py
+    if not is_reference:
+        model = lower_to_fbgemm(
+            model, node_name_to_qconfig, node_name_to_scope, keep_original_weights
+        )
+
+    # TODO: this looks hacky, we want to check why we need this and see if we can
+    # remove this
+    # removes qconfig and activation_post_process modules
+    if _remove_qconfig_flag:
+        _remove_qconfig(model)
+    model.delete_all_unused_submodules()
+    model.meta.pop("_observed_graph_module_attrs", None)
+    return model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..598c42ea22e3b22ad658ae505de5d430f52eb39b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/custom_config.py
@@ -0,0 +1,521 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Optional
+
+from torch.ao.quantization import QConfigMapping
+from torch.ao.quantization.backend_config import BackendConfig
+from torch.ao.quantization.quant_type import (
+    _get_quant_type_to_str,
+    _quant_type_from_str,
+    QuantType,
+)
+
+
+__all__ = [
+    "ConvertCustomConfig",
+    "FuseCustomConfig",
+    "PrepareCustomConfig",
+    "StandaloneModuleConfigEntry",
+]
+
+
+# TODO: replace all usages with these constants
+STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name"
+STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class"
+FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class"
+OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class"
+NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name"
+NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class"
+INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs"
+OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs"
+PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes"
+
+
+@dataclass
+class StandaloneModuleConfigEntry:
+    # qconfig_mapping for the prepare function called in the submodule,
+    # None means use qconfig from parent qconfig_mapping
+    qconfig_mapping: Optional[QConfigMapping]
+    example_inputs: tuple[Any, ...]
+    prepare_custom_config: Optional[PrepareCustomConfig]
+    backend_config: Optional[BackendConfig]
+
+
+class PrepareCustomConfig:
+    """
+    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and
+    :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`.
+
+    Example usage::
+
+        prepare_custom_config = PrepareCustomConfig() \
+            .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \
+                child_prepare_custom_config, backend_config) \
+            .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \
+                child_prepare_custom_config, backend_config) \
+            .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \
+            .set_non_traceable_module_names(["module2", "module3"]) \
+            .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \
+            .set_input_quantized_indexes([0]) \
+            .set_output_quantized_indexes([0]) \
+            .set_preserved_attributes(["attr1", "attr2"])
+    """
+
+    def __init__(self) -> None:
+        self.standalone_module_names: dict[str, StandaloneModuleConfigEntry] = {}
+        self.standalone_module_classes: dict[type, StandaloneModuleConfigEntry] = {}
+        self.float_to_observed_mapping: dict[QuantType, dict[type, type]] = {}
+        self.non_traceable_module_names: list[str] = []
+        self.non_traceable_module_classes: list[type] = []
+        self.input_quantized_indexes: list[int] = []
+        self.output_quantized_indexes: list[int] = []
+        self.preserved_attributes: list[str] = []
+
+    def __repr__(self):
+        dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
+        return f"PrepareCustomConfig({dict_nonempty})"
+
+    def set_standalone_module_name(
+        self,
+        module_name: str,
+        qconfig_mapping: Optional[QConfigMapping],
+        example_inputs: tuple[Any, ...],
+        prepare_custom_config: Optional[PrepareCustomConfig],
+        backend_config: Optional[BackendConfig],
+    ) -> PrepareCustomConfig:
+        """
+        Set the configuration for running a standalone module identified by ``module_name``.
+
+        If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
+        If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
+        If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
+        """
+        self.standalone_module_names[module_name] = StandaloneModuleConfigEntry(
+            qconfig_mapping, example_inputs, prepare_custom_config, backend_config
+        )
+        return self
+
+    def set_standalone_module_class(
+        self,
+        module_class: type,
+        qconfig_mapping: Optional[QConfigMapping],
+        example_inputs: tuple[Any, ...],
+        prepare_custom_config: Optional[PrepareCustomConfig],
+        backend_config: Optional[BackendConfig],
+    ) -> PrepareCustomConfig:
+        """
+        Set the configuration for running a standalone module identified by ``module_class``.
+
+        If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
+        If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
+        If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
+        """
+        self.standalone_module_classes[module_class] = StandaloneModuleConfigEntry(
+            qconfig_mapping, example_inputs, prepare_custom_config, backend_config
+        )
+        return self
+
+    def set_float_to_observed_mapping(
+        self,
+        float_class: type,
+        observed_class: type,
+        quant_type: QuantType = QuantType.STATIC,
+    ) -> PrepareCustomConfig:
+        """
+        Set the mapping from a custom float module class to a custom observed module class.
+
+        The observed module class must have a ``from_float`` class method that converts the float module class
+        to the observed module class. This is currently only supported for static quantization.
+        """
+        if quant_type != QuantType.STATIC:
+            raise ValueError(
+                "set_float_to_observed_mapping is currently only supported for static quantization"
+            )
+        if quant_type not in self.float_to_observed_mapping:
+            self.float_to_observed_mapping[quant_type] = {}
+        self.float_to_observed_mapping[quant_type][float_class] = observed_class
+        return self
+
+    def set_non_traceable_module_names(
+        self, module_names: list[str]
+    ) -> PrepareCustomConfig:
+        """
+        Set the modules that are not symbolically traceable, identified by name.
+        """
+        self.non_traceable_module_names = module_names
+        return self
+
+    def set_non_traceable_module_classes(
+        self, module_classes: list[type]
+    ) -> PrepareCustomConfig:
+        """
+        Set the modules that are not symbolically traceable, identified by class.
+        """
+        self.non_traceable_module_classes = module_classes
+        return self
+
+    def set_input_quantized_indexes(self, indexes: list[int]) -> PrepareCustomConfig:
+        """
+        Set the indexes of the inputs of the graph that should be quantized.
+        Inputs are otherwise assumed to be in fp32 by default instead.
+        """
+        self.input_quantized_indexes = indexes
+        return self
+
+    def set_output_quantized_indexes(self, indexes: list[int]) -> PrepareCustomConfig:
+        """
+        Set the indexes of the outputs of the graph that should be quantized.
+        Outputs are otherwise assumed to be in fp32 by default instead.
+        """
+        self.output_quantized_indexes = indexes
+        return self
+
+    def set_preserved_attributes(self, attributes: list[str]) -> PrepareCustomConfig:
+        """
+        Set the names of the attributes that will persist in the graph module even if they are not used in
+        the model's ``forward`` method.
+        """
+        self.preserved_attributes = attributes
+        return self
+
+    # TODO: remove this
+    @classmethod
+    def from_dict(
+        cls, prepare_custom_config_dict: dict[str, Any]
+    ) -> PrepareCustomConfig:
+        """
+        Create a ``PrepareCustomConfig`` from a dictionary with the following items:
+
+            "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs,
+            child_prepare_custom_config, backend_config) tuples
+
+            "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs,
+            child_prepare_custom_config, backend_config) tuples
+
+            "float_to_observed_custom_module_class": a nested dictionary mapping from quantization
+            mode to an inner mapping from float module classes to observed module classes, e.g.
+            {"static": {FloatCustomModule: ObservedCustomModule}}
+
+            "non_traceable_module_name": a list of modules names that are not symbolically traceable
+            "non_traceable_module_class": a list of module classes that are not symbolically traceable
+            "input_quantized_idxs": a list of indexes of graph inputs that should be quantized
+            "output_quantized_idxs": a list of indexes of graph outputs that should be quantized
+            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
+
+        This function is primarily for backward compatibility and may be removed in the future.
+        """
+
+        def _get_qconfig_mapping(obj: Any, dict_key: str) -> Optional[QConfigMapping]:
+            """
+            Convert the given object into a QConfigMapping if possible, else throw an exception.
+            """
+            if isinstance(obj, QConfigMapping) or obj is None:
+                return obj
+            if isinstance(obj, dict):
+                return QConfigMapping.from_dict(obj)
+            raise ValueError(
+                f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
+            )
+
+        def _get_prepare_custom_config(
+            obj: Any, dict_key: str
+        ) -> Optional[PrepareCustomConfig]:
+            """
+            Convert the given object into a PrepareCustomConfig if possible, else throw an exception.
+            """
+            if isinstance(obj, PrepareCustomConfig) or obj is None:
+                return obj
+            if isinstance(obj, dict):
+                return PrepareCustomConfig.from_dict(obj)
+            raise ValueError(
+                f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
+            )
+
+        def _get_backend_config(obj: Any, dict_key: str) -> Optional[BackendConfig]:
+            """
+            Convert the given object into a BackendConfig if possible, else throw an exception.
+            """
+            if isinstance(obj, BackendConfig) or obj is None:
+                return obj
+            if isinstance(obj, dict):
+                return BackendConfig.from_dict(obj)
+            raise ValueError(
+                f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
+            )
+
+        conf = cls()
+        for (
+            module_name,
+            qconfig_dict,
+            example_inputs,
+            _prepare_custom_config_dict,
+            backend_config_dict,
+        ) in prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []):
+            qconfig_mapping = _get_qconfig_mapping(
+                qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY
+            )
+            prepare_custom_config = _get_prepare_custom_config(
+                _prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY
+            )
+            backend_config = _get_backend_config(
+                backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY
+            )
+            conf.set_standalone_module_name(
+                module_name,
+                qconfig_mapping,
+                example_inputs,
+                prepare_custom_config,
+                backend_config,
+            )
+        for (
+            module_class,
+            qconfig_dict,
+            example_inputs,
+            _prepare_custom_config_dict,
+            backend_config_dict,
+        ) in prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []):
+            qconfig_mapping = _get_qconfig_mapping(
+                qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY
+            )
+            prepare_custom_config = _get_prepare_custom_config(
+                _prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY
+            )
+            backend_config = _get_backend_config(
+                backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY
+            )
+            conf.set_standalone_module_class(
+                module_class,
+                qconfig_mapping,
+                example_inputs,
+                prepare_custom_config,
+                backend_config,
+            )
+        for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get(
+            FLOAT_TO_OBSERVED_DICT_KEY, {}
+        ).items():
+            quant_type = _quant_type_from_str(quant_type_name)
+            for float_class, observed_class in custom_module_mapping.items():
+                conf.set_float_to_observed_mapping(
+                    float_class, observed_class, quant_type
+                )
+        conf.set_non_traceable_module_names(
+            prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, [])
+        )
+        conf.set_non_traceable_module_classes(
+            prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, [])
+        )
+        conf.set_input_quantized_indexes(
+            prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, [])
+        )
+        conf.set_output_quantized_indexes(
+            prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, [])
+        )
+        conf.set_preserved_attributes(
+            prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
+        )
+        return conf
+
+    def to_dict(self) -> dict[str, Any]:
+        """
+        Convert this ``PrepareCustomConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`.
+        """
+
+        def _make_tuple(key: Any, e: StandaloneModuleConfigEntry):
+            qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None
+            prepare_custom_config_dict = (
+                e.prepare_custom_config.to_dict() if e.prepare_custom_config else None
+            )
+            return (
+                key,
+                qconfig_dict,
+                e.example_inputs,
+                prepare_custom_config_dict,
+                e.backend_config,
+            )
+
+        d: dict[str, Any] = {}
+        for module_name, sm_config_entry in self.standalone_module_names.items():
+            if STANDALONE_MODULE_NAME_DICT_KEY not in d:
+                d[STANDALONE_MODULE_NAME_DICT_KEY] = []
+            d[STANDALONE_MODULE_NAME_DICT_KEY].append(
+                _make_tuple(module_name, sm_config_entry)
+            )
+        for module_class, sm_config_entry in self.standalone_module_classes.items():
+            if STANDALONE_MODULE_CLASS_DICT_KEY not in d:
+                d[STANDALONE_MODULE_CLASS_DICT_KEY] = []
+            d[STANDALONE_MODULE_CLASS_DICT_KEY].append(
+                _make_tuple(module_class, sm_config_entry)
+            )
+        for (
+            quant_type,
+            float_to_observed_mapping,
+        ) in self.float_to_observed_mapping.items():
+            if FLOAT_TO_OBSERVED_DICT_KEY not in d:
+                d[FLOAT_TO_OBSERVED_DICT_KEY] = {}
+            d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = (
+                float_to_observed_mapping
+            )
+        if len(self.non_traceable_module_names) > 0:
+            d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names
+        if len(self.non_traceable_module_classes) > 0:
+            d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes
+        if len(self.input_quantized_indexes) > 0:
+            d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes
+        if len(self.output_quantized_indexes) > 0:
+            d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes
+        if len(self.preserved_attributes) > 0:
+            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
+        return d
+
+
+class ConvertCustomConfig:
+    """
+    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`.
+
+    Example usage::
+
+        convert_custom_config = ConvertCustomConfig() \
+            .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \
+            .set_preserved_attributes(["attr1", "attr2"])
+    """
+
+    def __init__(self) -> None:
+        self.observed_to_quantized_mapping: dict[QuantType, dict[type, type]] = {}
+        self.preserved_attributes: list[str] = []
+
+    def __repr__(self):
+        dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
+        return f"ConvertCustomConfig({dict_nonempty})"
+
+    def set_observed_to_quantized_mapping(
+        self,
+        observed_class: type,
+        quantized_class: type,
+        quant_type: QuantType = QuantType.STATIC,
+    ) -> ConvertCustomConfig:
+        """
+        Set the mapping from a custom observed module class to a custom quantized module class.
+
+        The quantized module class must have a ``from_observed`` class method that converts the observed module class
+        to the quantized module class.
+        """
+        if quant_type not in self.observed_to_quantized_mapping:
+            self.observed_to_quantized_mapping[quant_type] = {}
+        self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class
+        return self
+
+    def set_preserved_attributes(self, attributes: list[str]) -> ConvertCustomConfig:
+        """
+        Set the names of the attributes that will persist in the graph module even if they are not used in
+        the model's ``forward`` method.
+        """
+        self.preserved_attributes = attributes
+        return self
+
+    # TODO: remove this
+    @classmethod
+    def from_dict(
+        cls, convert_custom_config_dict: dict[str, Any]
+    ) -> ConvertCustomConfig:
+        """
+        Create a ``ConvertCustomConfig`` from a dictionary with the following items:
+
+            "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization
+            mode to an inner mapping from observed module classes to quantized module classes, e.g.::
+            {
+            "static": {FloatCustomModule: ObservedCustomModule},
+            "dynamic": {FloatCustomModule: ObservedCustomModule},
+            "weight_only": {FloatCustomModule: ObservedCustomModule}
+            }
+            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
+
+        This function is primarily for backward compatibility and may be removed in the future.
+        """
+        conf = cls()
+        for quant_type_name, custom_module_mapping in convert_custom_config_dict.get(
+            OBSERVED_TO_QUANTIZED_DICT_KEY, {}
+        ).items():
+            quant_type = _quant_type_from_str(quant_type_name)
+            for observed_class, quantized_class in custom_module_mapping.items():
+                conf.set_observed_to_quantized_mapping(
+                    observed_class, quantized_class, quant_type
+                )
+        conf.set_preserved_attributes(
+            convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
+        )
+        return conf
+
+    def to_dict(self) -> dict[str, Any]:
+        """
+        Convert this ``ConvertCustomConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
+        """
+        d: dict[str, Any] = {}
+        for (
+            quant_type,
+            observed_to_quantized_mapping,
+        ) in self.observed_to_quantized_mapping.items():
+            if OBSERVED_TO_QUANTIZED_DICT_KEY not in d:
+                d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {}
+            d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = (
+                observed_to_quantized_mapping
+            )
+        if len(self.preserved_attributes) > 0:
+            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
+        return d
+
+
+class FuseCustomConfig:
+    """
+    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`.
+
+    Example usage::
+
+        fuse_custom_config = FuseCustomConfig().set_preserved_attributes(
+            ["attr1", "attr2"]
+        )
+    """
+
+    def __init__(self) -> None:
+        self.preserved_attributes: list[str] = []
+
+    def __repr__(self):
+        dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
+        return f"FuseCustomConfig({dict_nonempty})"
+
+    def set_preserved_attributes(self, attributes: list[str]) -> FuseCustomConfig:
+        """
+        Set the names of the attributes that will persist in the graph module even if they are not used in
+        the model's ``forward`` method.
+        """
+        self.preserved_attributes = attributes
+        return self
+
+    # TODO: remove this
+    @classmethod
+    def from_dict(cls, fuse_custom_config_dict: dict[str, Any]) -> FuseCustomConfig:
+        """
+        Create a ``ConvertCustomConfig`` from a dictionary with the following items:
+
+            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
+
+        This function is primarily for backward compatibility and may be removed in the future.
+        """
+        conf = cls()
+        conf.set_preserved_attributes(
+            fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
+        )
+        return conf
+
+    def to_dict(self) -> dict[str, Any]:
+        """
+        Convert this ``FuseCustomConfig`` to a dictionary with the items described in
+        :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
+        """
+        d: dict[str, Any] = {}
+        if len(self.preserved_attributes) > 0:
+            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
+        return d
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py
new file mode 100644
index 0000000000000000000000000000000000000000..2078ddba9f404d747ea4c64bf0c553f852012d50
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse.py
@@ -0,0 +1,191 @@
+# mypy: allow-untyped-defs
+import warnings
+from typing import Any, Callable, Union
+
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    get_native_backend_config,
+)
+from torch.ao.quantization.backend_config.utils import (
+    get_fuser_method_mapping,
+    get_fusion_pattern_to_extra_inputs_getter,
+    get_fusion_pattern_to_root_node_getter,
+)
+from torch.ao.quantization.utils import NodePattern, Pattern
+from torch.fx import GraphModule, map_arg, Node
+from torch.fx.graph import Graph
+
+from .custom_config import FuseCustomConfig
+from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler
+from .match_utils import _is_match, MatchAllNode
+from .pattern_utils import _sorted_patterns_dict
+
+
+__all__ = [
+    "fuse",
+    # TODO: We should make this private in the future
+    # This is currently needed for test_public_bindings for some reason
+    "FuseHandler",
+]
+
+
+def fuse(
+    model: GraphModule,
+    is_qat: bool,
+    fuse_custom_config: Union[FuseCustomConfig, dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, dict[str, Any], None] = None,
+) -> GraphModule:
+    if fuse_custom_config is None:
+        fuse_custom_config = FuseCustomConfig()
+
+    if isinstance(fuse_custom_config, dict):
+        warnings.warn(
+            "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
+            "in a future version. Please pass in a FuseCustomConfig instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
+
+    if isinstance(backend_config, dict):
+        warnings.warn(
+            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a BackendConfig instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        backend_config = BackendConfig.from_dict(backend_config)
+
+    named_modules = dict(model.named_modules())
+
+    if backend_config is None:
+        backend_config = get_native_backend_config()
+
+    fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(
+        _get_fusion_pattern_to_fuse_handler_cls(backend_config)
+    )
+    fuser_method_mapping = get_fuser_method_mapping(backend_config)
+    fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(
+        backend_config
+    )
+    fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(
+        backend_config
+    )
+
+    # find fusion
+    fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls)
+    # TODO: change this to inplace changes to graph, since we no longer construct
+    # new GraphModule anymore
+    fused_graph = Graph()
+    env: dict[Any, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node.name])
+
+    def default_root_node_getter(node_pattern):
+        while not isinstance(node_pattern[-1], Node):
+            node_pattern = node_pattern[-1]
+        return node_pattern[-1]
+
+    for node in model.graph.nodes:
+        (
+            maybe_last_node,
+            pattern,
+            matched_node_pattern,
+            obj,
+            node_to_subpattern,
+        ) = fusion_pairs.get(node.name, (None, None, None, None, None))
+        # get the corresponding subpattern for the current node
+        if node_to_subpattern is not None:
+            node_subpattern = node_to_subpattern.get(node, None)
+        else:
+            node_subpattern = None
+        if maybe_last_node is node:
+            assert obj is not None
+            root_node_getter = fusion_pattern_to_root_node_getter.get(
+                pattern, default_root_node_getter
+            )
+            root_node = root_node_getter(matched_node_pattern)  # type: ignore[index]
+            extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(
+                pattern, None
+            )
+            extra_inputs = []
+            if extra_inputs_getter is not None:
+                extra_inputs = extra_inputs_getter(matched_node_pattern)
+            # TODO: add validation that root_node is a module and has the same type
+            # as the root_module in the configuration
+            env[node.name] = obj.fuse(
+                load_arg,
+                named_modules,
+                fused_graph,
+                root_node,
+                extra_inputs,
+                matched_node_pattern,  # type: ignore[arg-type]
+                fuse_custom_config,
+                fuser_method_mapping,
+                is_qat,
+            )
+        elif maybe_last_node is None or node_subpattern is MatchAllNode:
+            env[node.name] = fused_graph.node_copy(node, load_arg)
+        # node matched in patterns and is not root is removed here
+
+    model = GraphModule(model, fused_graph)
+    return model
+
+
+def _find_matches(
+    root: GraphModule,
+    graph: Graph,
+    pattern_to_fuse_handler_cls: dict[Pattern, Callable],
+) -> dict[str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]]:
+    modules = dict(root.named_modules())
+    # node name -> (root_node, match_value)
+    match_map: dict[
+        str, tuple[Node, Pattern, NodePattern, FuseHandler, dict[Node, Any]]
+    ] = {}
+    # a map from node to the matched subpattern
+    node_to_subpattern: dict[Node, Any] = {}
+
+    # TODO: dedup with quantization matching function in match_utils.py
+    def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
+        if isinstance(pattern, tuple):
+            s, *args = pattern
+            current_node_pattern: list[Node] = []
+            apply_match(s, node, match, current_node_pattern, node_to_subpattern)
+            for subpattern, arg in zip(args, node.args):
+                apply_match(
+                    subpattern, arg, match, current_node_pattern, node_to_subpattern
+                )
+            matched_node_pattern.append(tuple(current_node_pattern))
+        else:
+            # the first pattern matches will take precedence
+            if node.name not in match_map:
+                matched_node_pattern.append(node)
+                # MatchAllNode here is actually MatchAllInputNode which should not
+                # be added to match_map
+                if pattern is not MatchAllNode:
+                    node_to_subpattern[node] = pattern
+                    root_node, pattern, handler = match
+                    match_map[node.name] = (
+                        root_node,
+                        pattern,
+                        matched_node_pattern,
+                        handler,
+                        node_to_subpattern,
+                    )
+
+    for node in reversed(graph.nodes):
+        if node.name not in match_map:
+            for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
+                matched_node_pattern: list[Node] = []
+                if _is_match(modules, node, pattern):
+                    apply_match(
+                        pattern,
+                        node,
+                        (node, pattern, fuse_handler_cls(node)),
+                        matched_node_pattern,
+                        node_to_subpattern,
+                    )
+                    break
+
+    return match_map
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a5a440a51284628951e495874784a04a799c32
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/fuse_handler.py
@@ -0,0 +1,129 @@
+# mypy: allow-untyped-defs
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Union
+
+import torch
+from torch.ao.quantization.backend_config import BackendConfig
+from torch.ao.quantization.fuser_method_mappings import get_fuser_method_new
+from torch.ao.quantization.utils import _parent_name, NodePattern, Pattern
+from torch.fx.graph import Graph, Node
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+from .custom_config import FuseCustomConfig
+from .match_utils import MatchAllNode
+
+
+__all__ = [
+    "DefaultFuseHandler",
+    "FuseHandler",
+]
+
+
+# ----------------------------
+# Fusion Pattern Registrations
+# ----------------------------
+
+
+# Base Pattern Handler
+class FuseHandler(ABC):
+    """Base handler class for the fusion patterns"""
+
+    @abstractmethod
+    def __init__(self, node: Node):
+        pass
+
+    @abstractmethod
+    def fuse(
+        self,
+        load_arg: Callable,
+        named_modules: dict[str, torch.nn.Module],
+        fused_graph: Graph,
+        root_node: Node,
+        extra_inputs: list[Any],
+        matched_node_pattern: NodePattern,
+        fuse_custom_config: FuseCustomConfig,
+        fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]],
+        is_qat: bool,
+    ) -> Node:
+        pass
+
+
+class DefaultFuseHandler(FuseHandler):
+    def __init__(self, node: Node):
+        super().__init__(node)  # type:ignore[safe-super]
+
+    def fuse(
+        self,
+        load_arg: Callable,
+        named_modules: dict[str, torch.nn.Module],
+        fused_graph: Graph,
+        root_node: Node,
+        extra_inputs: list[Any],
+        matched_node_pattern: NodePattern,
+        fuse_custom_config: FuseCustomConfig,
+        fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]],
+        is_qat: bool,
+    ) -> Node:
+        assert root_node.op == "call_module", (
+            "Expecting module node to be a call_module Node"
+        )
+        root_module = named_modules[str(root_node.target)]
+
+        def get_modules(pattern):
+            """Given a node pattern, extract the corresponding modules
+            e.g. input: (relu_node, (bn_node, conv_node))
+                 output: (relu_module, (bn_module, conv_module))
+            """
+            if isinstance(pattern, (tuple, list)):
+                n, *args = pattern
+                modules: list[torch.nn.Module] = []
+                modules.append(get_modules(n))
+                modules.extend(get_modules(a) for a in args)
+                return tuple(modules)
+            else:
+                n = pattern
+                if n.op == "call_module":
+                    return named_modules[n.target]
+                elif n.op == "call_function" and n.target == torch.nn.functional.relu:
+                    relu = torch.nn.ReLU()
+                    relu.training = root_module.training
+                    return relu
+                elif n.op == "call_function" or n.op == "call_method":
+                    return n.target
+                else:
+                    return MatchAllNode
+
+        # since relu can be used multiple times, we'll need to create a relu module for each match
+        matched_modules = get_modules(matched_node_pattern)
+
+        def get_matched_types(m):
+            if isinstance(m, tuple):
+                return tuple(map(get_matched_types, m))
+            if isinstance(m, torch.nn.Module):
+                return type_before_parametrizations(m)
+            return m
+
+        matched_module_types = get_matched_types(matched_modules)
+        module_parent_name, module_name = _parent_name(root_node.target)
+        fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
+        # TODO: change the signature for fuser_method to take matched module patterns
+        # as input
+        fused_module = fuser_method(is_qat, *matched_modules)
+        setattr(named_modules[module_parent_name], module_name, fused_module)
+        extra_args = [load_arg(input) for input in extra_inputs]
+        node = fused_graph.node_copy(root_node, load_arg)
+        args = list(node.args)
+        args.extend(extra_args)
+        node.args = tuple(args)
+        return node
+
+
+def _get_fusion_pattern_to_fuse_handler_cls(
+    backend_config: BackendConfig,
+) -> dict[Pattern, Callable]:
+    fusion_pattern_to_fuse_handlers: dict[Pattern, Callable] = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        if config.fuser_method is not None:
+            # TODO: is this logic right?
+            fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
+    return fusion_pattern_to_fuse_handlers
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..15d8fc7852e0fb5f12ecb43180e7e8a1ef8efd9a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/graph_module.py
@@ -0,0 +1,205 @@
+# mypy: allow-untyped-defs
+import copy
+from typing import Any, Union
+
+import torch
+from torch.fx import GraphModule
+from torch.fx.graph import Graph
+
+
+__all__ = [
+    "FusedGraphModule",
+    "ObservedGraphModule",
+    "ObservedStandaloneGraphModule",
+    "QuantizedGraphModule",
+]
+
+
+class FusedGraphModule(GraphModule):
+    def __init__(
+        self,
+        root: Union[torch.nn.Module, dict[str, Any]],
+        graph: Graph,
+        preserved_attr_names: set[str],
+    ):
+        self.preserved_attr_names = preserved_attr_names
+        preserved_attrs = {
+            attr: getattr(root, attr)
+            for attr in self.preserved_attr_names
+            if hasattr(root, attr)
+        }
+        super().__init__(root, graph)
+        for attr in preserved_attrs:
+            setattr(self, attr, preserved_attrs[attr])
+
+    # GraphModule does not copy attributes which are not in the __dict__
+    # of vanilla nn.Module.  So, we override __deepcopy__ in order
+    # to copy the quantization specific attributes correctly.
+    def __deepcopy__(self, memo):
+        fake_mod = torch.nn.Module()
+        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
+        return FusedGraphModule(
+            fake_mod,
+            copy.deepcopy(self.graph),
+            copy.deepcopy(self.preserved_attr_names),
+        )
+
+
+class ObservedGraphModule(GraphModule):
+    def __init__(
+        self,
+        root: Union[torch.nn.Module, dict[str, Any]],
+        graph: Graph,
+        preserved_attr_names: set[str],
+    ):
+        self.preserved_attr_names = {
+            "_activation_post_process_map",
+            "_activation_post_process_indexes",
+            "_patterns",
+            "_node_name_to_qconfig",
+            "_prepare_custom_config",
+            "_equalization_node_name_to_qconfig",
+            "_node_name_to_scope",
+            "_qconfig_mapping",
+            "_is_qat",
+            "_observed_node_names",
+        }.union(preserved_attr_names)
+        preserved_attrs = {
+            attr: getattr(root, attr)
+            for attr in self.preserved_attr_names
+            if hasattr(root, attr)
+        }
+        super().__init__(root, graph)
+        for attr in preserved_attrs:
+            setattr(self, attr, preserved_attrs[attr])
+
+    # GraphModule does not copy attributes which are not in the __dict__
+    # of vanilla nn.Module.  So, we override __deepcopy__ in order
+    # to copy the quantization specific attributes correctly.
+    def __deepcopy__(self, memo):
+        fake_mod = torch.nn.Module()
+        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
+        return ObservedGraphModule(
+            fake_mod,
+            copy.deepcopy(self.graph),
+            copy.deepcopy(self.preserved_attr_names),
+        )
+
+
+def _is_observed_module(module: Any) -> bool:
+    return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta
+
+
+def _get_observed_graph_module_attr(
+    model: Union[torch.nn.Module, GraphModule], attr_name: str
+) -> Any:
+    if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta:  # type: ignore[operator, index]
+        return getattr(model.meta["_observed_graph_module_attrs"], attr_name)  # type: ignore[index]
+    return None
+
+
+class ObservedStandaloneGraphModule(ObservedGraphModule):
+    def __init__(
+        self,
+        root: Union[torch.nn.Module, dict[str, Any]],
+        graph: Graph,
+        preserved_attr_names: set[str],
+    ):
+        preserved_attr_names = preserved_attr_names.union(
+            {
+                "_standalone_module_input_quantized_idxs",
+                "_standalone_module_output_quantized_idxs",
+            }
+        )
+        super().__init__(root, graph, preserved_attr_names)
+
+    def __deepcopy__(self, memo):
+        fake_mod = torch.nn.Module()
+        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
+        return ObservedStandaloneGraphModule(
+            fake_mod,
+            copy.deepcopy(self.graph),
+            copy.deepcopy(self.preserved_attr_names),
+        )
+
+
+def _is_observed_standalone_module(module: Any) -> bool:
+    return (
+        _is_observed_module(module)
+        and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module
+    )
+
+
+def _save_packed_weight(self, destination, prefix, keep_vars):
+    for attr_name in dir(self):
+        if "_packed_weight" in attr_name and isinstance(
+            getattr(self, attr_name), torch._C.ScriptObject
+        ):  # type: ignore[attr-defined]
+            packed_weight = getattr(self, attr_name)
+            destination[prefix + attr_name] = packed_weight
+
+
+class QuantizedGraphModule(GraphModule):
+    """This class is created to make sure PackedParams
+    (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict
+    so that we can serialize and deserialize quantized graph module with
+    torch.save(m.state_dict()) and m.load_state_dict(state_dict)
+    """
+
+    def __init__(
+        self,
+        root: Union[torch.nn.Module, dict[str, Any]],
+        graph: Graph,
+        preserved_attr_names: set[str],
+    ):
+        self.preserved_attr_names = preserved_attr_names
+        preserved_attrs = {
+            attr: getattr(root, attr)
+            for attr in self.preserved_attr_names
+            if hasattr(root, attr)
+        }
+        super().__init__(root, graph)
+        for attr in preserved_attrs:
+            setattr(self, attr, preserved_attrs[attr])
+        self._register_state_dict_hook(_save_packed_weight)
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        attrs_to_pop = []
+        for attr_name in state_dict:
+            if attr_name.startswith("_packed_weight") and isinstance(
+                state_dict[attr_name], torch._C.ScriptObject
+            ):  # type: ignore[attr-defined] # noqa: B950
+                setattr(self, attr_name, state_dict[attr_name])
+                attrs_to_pop.append(attr_name)
+
+        # pop the packed param attributesn
+        for attr_name in attrs_to_pop:
+            state_dict.pop(attr_name)
+
+        super()._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
+
+    def __deepcopy__(self, memo):
+        fake_mod = torch.nn.Module()
+        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
+        return QuantizedGraphModule(
+            fake_mod,
+            copy.deepcopy(self.graph),
+            copy.deepcopy(self.preserved_attr_names),
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..73fd3e8741b6d6c26d5a352d25d4cf6986de4d9d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py
@@ -0,0 +1,21 @@
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.fx import GraphModule
+
+from ._lower_to_native_backend import _lower_to_native_backend
+
+
+__all__ = ["lower_to_fbgemm"]
+
+
+def lower_to_fbgemm(
+    model: GraphModule,
+    qconfig_map: dict[str, QConfigAny],
+    node_name_to_scope: dict[str, tuple[str, type]],
+    keep_original_weights: bool = False,
+) -> GraphModule:
+    """Lower a quantized reference model (with reference quantized operator patterns)
+    to fbgemm
+    """
+    return _lower_to_native_backend(
+        model, qconfig_map, node_name_to_scope, keep_original_weights
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1fa3ecf3f5a3b2b5dc67d769853f8424bae7efb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py
@@ -0,0 +1,18 @@
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.fx import GraphModule
+
+from ._lower_to_native_backend import _lower_to_native_backend
+
+
+__all__ = ["lower_to_qnnpack"]
+
+
+def lower_to_qnnpack(
+    model: GraphModule,
+    qconfig_map: dict[str, QConfigAny],
+    node_name_to_scope: dict[str, tuple[str, type]],
+) -> GraphModule:
+    """Lower a quantized reference model (with reference quantized operator patterns)
+    to qnnpack
+    """
+    return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe18ba465212f99d1c953e498f2fb391a2d53eb0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/lstm_utils.py
@@ -0,0 +1,221 @@
+import copy
+import operator
+from typing import Any, Callable, Optional
+
+import torch
+from torch.ao.quantization import (
+    default_weight_fake_quant,
+    default_weight_observer,
+    FakeQuantizeBase,
+    QConfig,
+    QConfigMapping,
+)
+from torch.ao.quantization.backend_config import BackendConfig
+from torch.ao.quantization.observer import _PartialWrapper
+from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
+
+
+# TODO: move all LSTM util functions from fx/utils.py to this file
+def _get_lstm_with_individually_observed_parts(
+    float_lstm: torch.nn.LSTM,
+    example_inputs: tuple[Any, ...],
+    backend_config: Optional[BackendConfig] = None,
+    linear_output_obs_ctr: Optional[_PartialWrapper] = None,
+    sigmoid_obs_ctr: Optional[_PartialWrapper] = None,
+    tanh_obs_ctr: Optional[_PartialWrapper] = None,
+    cell_state_obs_ctr: Optional[_PartialWrapper] = None,
+    hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
+    split_gates: bool = False,
+) -> torch.ao.nn.quantizable.LSTM:
+    """
+    Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
+    with specific observers or fake quantizes assigned to the inner ops or submodules.
+
+    In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is
+    used as an observed custom module, which is responsible for inserting its own
+    observers. By default, all inner ops inherit the parent custom module's QConfig.
+    Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM`
+    and use this helper function to customize the observer insertion logic.
+
+    This is meant to be used to convert a float module to an observed module in the
+    custom module flow.
+
+    Args:
+        `float_lstm`: The float LSTM module
+        `example_inputs`: example inputs for the forward function of the LSTM module
+        `backend_config`: BackendConfig to use to observe the LSTM module
+        `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b,
+            where W is the weight matrix, b is the bias, and x is either the inputs
+            or the hidden state from the previous layer (if any)
+        `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations
+        `tanh_obs_ctr`: observer or fake quantize for tanh activations
+        `cell_state_obs_ctr`: observer or fake quantize for the cell state
+        `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and
+            the output
+
+    Return:
+        A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes
+        assigned to the inner ops.
+    """
+
+    def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig:
+        """
+        Make a QConfig with fixed qparams observers or fake quantizes.
+        """
+        if isinstance(obs_ctr(), FakeQuantizeBase):
+            weight = default_weight_fake_quant
+        else:
+            weight = default_weight_observer
+        return QConfig(activation=obs_ctr, weight=weight)
+
+    quantizable_lstm = torch.ao.nn.quantizable.LSTM(
+        float_lstm.input_size,
+        float_lstm.hidden_size,
+        float_lstm.num_layers,
+        float_lstm.bias,
+        float_lstm.batch_first,
+        float_lstm.dropout,
+        float_lstm.bidirectional,
+        split_gates=split_gates,
+    )
+    quantizable_lstm.qconfig = float_lstm.qconfig
+
+    for idx in range(float_lstm.num_layers):
+        quantizable_lstm.layers[idx] = (
+            torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(
+                float_lstm,
+                idx,
+                float_lstm.qconfig,
+                batch_first=False,
+                split_gates=split_gates,
+            )
+        )
+
+    # Build QConfigMapping for the LSTM cell
+    # Note: FloatFunctional qconfigs will be configured separately below
+    cell_qm = QConfigMapping().set_global(float_lstm.qconfig)  # type: ignore[arg-type]
+    if sigmoid_obs_ctr is not None:
+        cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr))
+        cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr))
+        cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr))
+    if tanh_obs_ctr is not None:
+        cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr))
+
+    # Insert observers into each LSTM cell
+    # TODO: maybe make this work for layer_bw as well
+    for layer in quantizable_lstm.layers:
+        cell = layer.layer_fw.cell  # type: ignore[union-attr]
+        assert isinstance(cell, torch.nn.Module), "cell should be a nn.Module"
+        cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config)
+        # HACK: Manually replace the activation_post_process following these ops.
+        # This is needed for FloatFunctional ops because there is currently no way
+        # to configure these ops in FX graph mode quantization today. This is because
+        # the FloatFunctional modules simply disappear from the graph after tracing.
+        # In the future, we should rewrite quantizable LSTM without FloatFunctionals.
+        if not split_gates:
+            op_index_to_activation_post_process_ctr = {
+                (torch.add, 0): linear_output_obs_ctr,  # gates.add
+                (torch.mul, 0): cell_state_obs_ctr,  # fgate_cx.mul
+                (torch.mul, 1): cell_state_obs_ctr,  # igate_cgate.mul
+                (torch.add, 1): cell_state_obs_ctr,  # fgate_cx_igate_cgate.add
+                (torch.mul, 2): hidden_state_obs_ctr,  # ogate_cy.mul
+            }
+        else:
+            op_index_to_activation_post_process_ctr = {
+                (torch.add, 0): linear_output_obs_ctr,  # gates.add (input)
+                (torch.add, 1): linear_output_obs_ctr,  # gates.add (forget)
+                (torch.add, 2): linear_output_obs_ctr,  # gates.add (cell)
+                (torch.add, 3): linear_output_obs_ctr,  # gates.add (output)
+                (torch.mul, 0): cell_state_obs_ctr,  # fgate_cx.mul
+                (torch.mul, 1): cell_state_obs_ctr,  # igate_cgate.mul
+                (torch.add, 4): cell_state_obs_ctr,  # fgate_cx_igate_cgate.add
+                (torch.mul, 2): hidden_state_obs_ctr,  # ogate_cy.mul
+            }
+        add_count = 0
+        mul_count = 0
+        for node in cell.graph.nodes:
+            op_index: Optional[tuple[Callable, int]] = None  # e.g. (torch.add, 1)
+            if node.target == torch.add:
+                op_index = (torch.add, add_count)
+                add_count += 1
+            elif node.target == torch.mul:
+                op_index = (torch.mul, mul_count)
+                mul_count += 1
+            else:
+                # Neither torch.add nor torch.mul
+                continue
+            if op_index not in op_index_to_activation_post_process_ctr:
+                continue
+            assert len(node.users) == 1
+            activation_post_process_name = next(iter(node.users.keys())).name
+            activation_post_process_ctr = op_index_to_activation_post_process_ctr[
+                op_index
+            ]
+            if activation_post_process_ctr is not None:
+                setattr(
+                    cell, activation_post_process_name, activation_post_process_ctr()
+                )
+        layer.layer_fw.cell = cell  # type: ignore[union-attr]
+    return quantizable_lstm
+
+
+def _get_reference_quantized_lstm_module(
+    observed_lstm: torch.ao.nn.quantizable.LSTM,
+    backend_config: Optional[BackendConfig] = None,
+) -> torch.ao.nn.quantized.LSTM:
+    """
+    Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM`
+    with observers or fake quantizes inserted through `prepare_fx`, e.g. from
+    `_get_lstm_with_individually_observed_parts`.
+
+    This is meant to be used to convert an observed module to a quantized module in the
+    custom module flow.
+
+    Args:
+        `observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx`
+        `backend_config`: BackendConfig to use to produce the reference quantized model
+
+    Return:
+        A reference `torch.ao.nn.quantized.LSTM` module.
+    """
+    quantized_lstm = torch.ao.nn.quantized.LSTM(
+        observed_lstm.input_size,
+        observed_lstm.hidden_size,
+        observed_lstm.num_layers,
+        observed_lstm.bias,
+        observed_lstm.batch_first,
+        observed_lstm.dropout,
+        observed_lstm.bidirectional,
+    )
+
+    for i, layer in enumerate(quantized_lstm.layers):
+        cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell)  # type: ignore[union-attr]
+        cell = convert_to_reference_fx(cell, backend_config=backend_config)  # type: ignore[arg-type]
+        assert isinstance(cell, torch.fx.GraphModule)
+        # HACK: Manually remove input quantize nodes and output dequantize nodes,
+        # since custom modules expect quint8 inputs and outputs for now. Note that
+        # this functionality is supposedly handled through PrepareCustomConfig's
+        # `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that
+        # API doesn't currently handle tuple inputs and outputs, so we have to do
+        # this manually for now. In the future we should (1) relax the restriction
+        # on custom module input/output dtypes, and (2) expand support for complex
+        # input/output structures.
+        for node in cell.graph.nodes:
+            if node.target == torch.quantize_per_tensor:
+                arg = node.args[0]
+                # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1])
+                if arg.target == "x" or (
+                    arg.target == operator.getitem and arg.args[0].target == "hidden"
+                ):
+                    with cell.graph.inserting_before(node):
+                        node.replace_all_uses_with(arg)
+                        cell.graph.erase_node(node)
+            if node.target == "output":
+                # Remove all dequantize nodes in the output tuple
+                for arg in node.args[0]:
+                    with cell.graph.inserting_before(node):
+                        node.replace_input_with(arg, arg.args[0])
+        cell.graph.eliminate_dead_code()
+        cell.recompile()
+        layer.layer_fw.cell = cell  # type: ignore[union-attr]
+    return quantized_lstm
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..137461d233ce5c45b4e59ae9846a155f35955df4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/match_utils.py
@@ -0,0 +1,228 @@
+# mypy: allow-untyped-defs
+import sys
+from collections.abc import Iterable
+from typing import Any, Callable, Optional
+
+import torch
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.ao.quantization.utils import MatchAllNode, Pattern
+from torch.fx.graph import Graph, Node
+from torch.nn.utils.parametrize import type_before_parametrizations
+
+from .graph_module import _is_observed_standalone_module
+from .quantize_handler import QuantizeHandler
+
+
+__all__: list[str] = []
+
+# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
+# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
+_MatchResult = tuple[Node, list[Node], Optional[Pattern], QuantizeHandler]
+
+_MatchResultWithQConfig = tuple[
+    Node, list[Node], Optional[Pattern], QuantizeHandler, QConfigAny
+]
+
+
+# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
+# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
+# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
+# we'll start from the last node of the graph and traverse back.
+def _is_match(modules, node, pattern, max_uses=sys.maxsize):
+    """Matches a node in fx against a pattern"""
+    if isinstance(pattern, tuple):
+        self_match, *arg_matches = pattern
+        if self_match is getattr:
+            assert len(pattern) == 2, "Expecting getattr pattern to have two elements"
+            arg_matches = []
+    else:
+        self_match = pattern
+        arg_matches = []
+
+    if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
+        return True
+
+    if node == pattern:
+        return True
+
+    if not isinstance(node, Node) or len(node.users) > max_uses:
+        return False
+
+    if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
+        if node.op != "call_module":
+            return False
+        if not type_before_parametrizations(modules[node.target]) == self_match:
+            return False
+    elif callable(self_match):
+        if node.op != "call_function" or node.target is not self_match:
+            return False
+        elif node.target is getattr:
+            if node.args[1] != pattern[1]:
+                return False
+    elif isinstance(self_match, str):
+        if node.op != "call_method" or node.target != self_match:
+            return False
+    elif node.target != self_match:
+        return False
+
+    if not arg_matches:
+        return True
+
+    if len(arg_matches) != len(node.args):
+        return False
+
+    return all(
+        _is_match(modules, node, arg_match, max_uses=1)
+        for node, arg_match in zip(node.args, arg_matches)
+    )
+
+
+def _find_matches(
+    graph: Graph,
+    modules: dict[str, torch.nn.Module],
+    patterns: dict[Pattern, QuantizeHandler],
+    root_node_getter_mapping: dict[Pattern, Callable],
+    standalone_module_names: Optional[list[str]] = None,
+    standalone_module_classes: Optional[list[type]] = None,
+    custom_module_classes: Optional[list[Any]] = None,
+) -> dict[str, _MatchResult]:
+    """
+    Matches the nodes in the input graph to quantization patterns, and
+    outputs the information needed to quantize them in future steps.
+
+    Inputs:
+      - graph: an fx.Graph object
+      - modules: a mapping of fully qualified module name to instance,
+          for example, {'foo': ModuleFoo, ...}
+      - patterns: a mapping from a tuple of nodes in reverse order to
+          uninitialized QuantizeHandler subclass.
+
+    Outputs a map of
+      node_name ->
+        (node, matched_values, matched_pattern, QuantizeHandler instance,
+         qconfig)
+
+    For example, {
+      'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
+                 , QConfig(...)),
+      ...
+    }
+    """
+    if custom_module_classes is None:
+        custom_module_classes = []
+
+    if standalone_module_classes is None:
+        standalone_module_classes = []
+
+    if standalone_module_names is None:
+        standalone_module_names = []
+
+    match_map: dict[str, _MatchResult] = {}
+    all_matched: set[str] = set()
+
+    def _recursive_record_node_in_match_map(
+        last_node, match_map, node_pattern, matched_node_pattern, pattern, match_value
+    ):
+        if isinstance(node_pattern, Node):
+            match_map[node_pattern.name] = (
+                last_node,
+                matched_node_pattern,
+                pattern,
+                match_value,
+            )
+        elif not isinstance(node_pattern, Iterable):
+            return
+        else:
+            for n in node_pattern:
+                _recursive_record_node_in_match_map(
+                    last_node, match_map, n, matched_node_pattern, pattern, match_value
+                )
+
+    # TODO: 1. merge with fuse matcher 2. document the code
+    def record_match(pattern, node, last_node, matched_node_pattern, match_map):
+        if isinstance(pattern, tuple):
+            s, *args = pattern
+            is_single_arg = len(args) == 1
+            current_node_pattern: list[Node] = []
+            record_match(s, node, last_node, matched_node_pattern, match_map)
+            if pattern[0] is not getattr:
+                for subpattern, arg in zip(args, node.args):
+                    record_match(subpattern, arg, node, current_node_pattern, match_map)
+            if len(current_node_pattern) > 1:
+                # current_node_pattern is  the node pattern we get from matching
+                # the subpattern with arguments of the node
+                # we use is_single_arg to recover the original structure of the pattern
+                # if the original pattern has a single argument, we will have
+                # (original_op, (original_arg, ...))
+                # otherwise, we'll have a list of arguments
+                # (original_op, arg0, arg1, arg2, ...)
+                if is_single_arg:
+                    matched_node_pattern.append(tuple(current_node_pattern))
+                else:
+                    matched_node_pattern.extend(list(current_node_pattern))
+            else:
+                matched_node_pattern.append(current_node_pattern[0])
+        else:
+            matched_node_pattern.append(node)
+
+    for node in reversed(graph.nodes):
+        if node.name not in match_map and node.name not in all_matched:
+            for pattern, quantize_handler_cls in patterns.items():
+                root_node_getter = root_node_getter_mapping.get(pattern, None)
+                if _is_match(modules, node, pattern) and node.name not in match_map:
+                    matched_node_pattern: list[Node] = []
+                    record_match(pattern, node, node, matched_node_pattern, match_map)
+                    quantize_handler = quantize_handler_cls(  # type: ignore[operator]
+                        matched_node_pattern, modules, root_node_getter
+                    )
+                    last_node = node
+                    # record the match for all nodes in the pattern
+                    _recursive_record_node_in_match_map(
+                        last_node,
+                        match_map,
+                        # we need to record all nodes in the matched pattern in the match_map
+                        matched_node_pattern,
+                        # this is a part of the value corresponding to the node
+                        matched_node_pattern,
+                        pattern,
+                        quantize_handler,
+                    )
+                    break
+
+    # add custom module instances to the match result
+    assert modules is not None
+    for node in graph.nodes:
+        if (
+            node.op == "call_module"
+            and type(modules[node.target]) in custom_module_classes
+        ):
+            match_map[node.name] = (
+                node,
+                node,
+                None,
+                QuantizeHandler(node, modules, is_custom_module=True),
+            )
+
+    def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]):
+        assert modules is not None
+        return (
+            node_target in standalone_module_names
+            or type(modules[node_target])  # type: ignore[operator]
+            in standalone_module_classes  # type: ignore[operator]
+        )
+
+    # add standalone modules to the match
+    for node in graph.nodes:
+        if node.op == "call_module" and (
+            is_standalone_module(node.target, modules)
+            or _is_observed_standalone_module(modules[node.target])
+        ):
+            # add node to matched nodes
+            match_map[node.name] = (
+                node,
+                node,
+                None,
+                QuantizeHandler(node, modules, is_standalone_module=True),
+            )
+
+    return match_map
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e86f95d67aba092daff6a3a14a14767f29d249a2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/pattern_utils.py
@@ -0,0 +1,112 @@
+# mypy: allow-untyped-defs
+import copy
+from collections import OrderedDict
+from typing import Any
+
+from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
+from torch.ao.quantization.observer import ObserverBase
+from torch.ao.quantization.utils import Pattern
+
+
+__all__ = [
+    "get_default_fusion_patterns",
+    "get_default_quant_patterns",
+    "get_default_output_activation_post_process_map",
+]
+
+# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
+QuantizeHandler = Any
+
+# pattern for conv bn fusion
+_DEFAULT_FUSION_PATTERNS: dict[Pattern, QuantizeHandler] = OrderedDict()
+
+
+def _register_fusion_pattern(pattern):
+    def insert(fn):
+        _DEFAULT_FUSION_PATTERNS[pattern] = fn
+        return fn
+
+    return insert
+
+
+def get_default_fusion_patterns() -> dict[Pattern, QuantizeHandler]:
+    return copy.copy(_DEFAULT_FUSION_PATTERNS)
+
+
+_DEFAULT_QUANTIZATION_PATTERNS: dict[Pattern, QuantizeHandler] = OrderedDict()
+
+# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
+# e.g. pattern: torch.sigmoid,
+#      output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
+_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: dict[Pattern, QuantizeHandler] = {}
+_DEFAULT_OUTPUT_OBSERVER_MAP: dict[Pattern, QuantizeHandler] = {}
+
+
+# Register pattern for both static quantization and qat
+def _register_quant_pattern(pattern, fixed_qparams_observer=None):
+    def insert(fn):
+        _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
+        if fixed_qparams_observer is not None:
+            _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = (
+                FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer)
+            )
+            _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer
+        return fn
+
+    return insert
+
+
+# Get patterns for both static quantization and qat
+def get_default_quant_patterns() -> dict[Pattern, QuantizeHandler]:
+    return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS)
+
+
+# a map from pattern to output activation post process constructor
+# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
+def get_default_output_activation_post_process_map(
+    is_training,
+) -> dict[Pattern, ObserverBase]:
+    if is_training:
+        return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
+    else:
+        return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP)
+
+
+# Example use of register pattern function:
+# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
+# class ConvOrLinearBNReLUFusion():
+#     def __init__(...):
+#         ...
+#
+
+
+def _sorted_patterns_dict(
+    patterns_dict: dict[Pattern, QuantizeHandler],
+) -> dict[Pattern, QuantizeHandler]:
+    """
+    Return a sorted version of the patterns dictionary such that longer patterns are matched first,
+    e.g. match (F.relu, F.linear) before F.relu.
+    This works for current use cases, but we may need to have a more clever way to sort
+    things to address more complex patterns
+    """
+
+    def get_len(pattern):
+        """this will calculate the length of the pattern by counting all the entries
+        in the pattern.
+        this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before
+        (nn.BatchNorm, nn.Conv2d) so that we can match the former first
+        """
+        len = 0
+        if isinstance(pattern, tuple):
+            for item in pattern:
+                len += get_len(item)
+        else:
+            len += 1
+        return len
+
+    return OrderedDict(
+        sorted(
+            patterns_dict.items(),
+            key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1,
+        )
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f77cb4992d1b3da412e193dd1341912011f5320
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/prepare.py
@@ -0,0 +1,2186 @@
+# mypy: allow-untyped-defs
+import copy
+import warnings
+from dataclasses import asdict
+from typing import Any, Optional, Union
+
+import torch
+from torch._subclasses import FakeTensor
+from torch.ao.quantization import (
+    _DerivedObserverOrFakeQuantize,
+    FixedQParamsFakeQuantize,
+    FixedQParamsObserver,
+    ObserverBase,
+    ObserverOrFakeQuantize,
+    PlaceholderObserver,
+)
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    DTypeConfig,
+    get_native_backend_config,
+)
+from torch.ao.quantization.backend_config.utils import (
+    get_fusion_pattern_to_root_node_getter,
+    get_module_to_qat_module,
+    get_pattern_to_dtype_configs,
+)
+from torch.ao.quantization.observer import _is_activation_post_process, _PartialWrapper
+from torch.ao.quantization.qconfig import _is_reuse_input_qconfig, QConfigAny
+from torch.ao.quantization.qconfig_mapping import QConfigMapping
+from torch.ao.quantization.quantize import convert, propagate_qconfig_
+from torch.ao.quantization.quantizer import (
+    DerivedQuantizationSpec,
+    EdgeOrNode,
+    FixedQParamsQuantizationSpec,
+    QuantizationSpec,
+    QuantizationSpecBase,
+    SharedQuantizationSpec,
+)
+from torch.ao.quantization.utils import (
+    _parent_name,
+    get_qconfig_dtypes,
+    get_swapped_custom_module_class,
+    NodePattern,
+    Pattern,
+)
+from torch.fx import GraphModule
+from torch.fx.graph import Graph, Node
+from torch.fx.node import Argument
+
+from ._equalize import is_equalization_observer, node_supports_equalization
+from .custom_config import PrepareCustomConfig, StandaloneModuleConfigEntry
+from .match_utils import _find_matches, _MatchResultWithQConfig
+from .pattern_utils import _sorted_patterns_dict
+from .qconfig_mapping_utils import (
+    _generate_node_name_to_qconfig,
+    _get_flattened_qconfig_dict,
+    _update_qconfig_for_fusion,
+    _update_qconfig_for_qat,
+)
+from .quantize_handler import (
+    _default_root_node_getter,
+    _get_pattern_to_quantize_handlers,
+    QuantizeHandler,
+)
+from .utils import (
+    _insert_dequant_stubs_for_custom_module_lstm_output,
+    _is_custom_module_lstm,
+    _maybe_get_custom_module_lstm_from_node_arg,
+    _qconfig_satisfies_dtype_config_constraints,
+    all_node_args_have_no_tensors,
+    assert_and_get_unique_device,
+    get_custom_module_class_keys,
+    get_new_attr_name_with_prefix,
+    get_non_observable_arg_indexes_and_types,
+    node_arg_is_bias,
+    node_arg_is_weight,
+    NON_QUANTIZABLE_WEIGHT_OPS,
+    ObservedGraphModuleAttrs,
+)
+
+
+__all__ = [
+    "insert_observers_for_model",
+    "prepare",
+    "propagate_dtypes_for_known_nodes",
+]
+
+
+# list of dtypes to not add observers to
+_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
+_OBS_DTYPE_LIST = [
+    torch.quint8,
+    torch.qint8,
+    torch.qint32,
+    torch.float16,
+    torch.uint8,
+    torch.int8,
+    torch.int16,
+    torch.int32,
+    torch.float8_e5m2,
+    torch.float8_e4m3fn,
+]
+
+_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
+
+# note: the following default target dtype info dicts are temporary,
+# should be moved to the new programmable API class soon
+_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
+    "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
+    "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
+}
+
+_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
+    "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
+    "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
+}
+
+
+def _get_observer_kwargs(
+    quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec],
+):
+    kwargs_dict = asdict(quant_spec)
+    return copy.deepcopy(kwargs_dict)
+
+
+def _get_qspec_for_arg(
+    arg: Node,
+    input_qspec_map: dict[Node, QuantizationSpecBase],
+    named_modules: dict[str, torch.nn.Module],
+) -> Optional[QuantizationSpecBase]:
+    while _is_activation_post_process_node(arg, named_modules):
+        arg = arg.args[0]  # type: ignore[assignment]
+    return input_qspec_map.get(arg, None)
+
+
+def _create_obs_or_fq_from_qspec(
+    quantization_spec: Optional[QuantizationSpecBase],
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+):
+    """Create observer or fake quantize objects based on quantization spec
+
+    Args:
+       quantization_spec: used to store parameters to create the observer or fake quantizer
+       obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant
+       instance, it may be reused for different edge/output depending on configuration
+    """
+    if quantization_spec is None:
+        return None
+    if isinstance(quantization_spec, SharedQuantizationSpec):
+        edge_or_node = quantization_spec.edge_or_node
+        assert edge_or_node in obs_or_fq_map, (
+            "please make sure only refer to edge or node that has "
+            f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
+        )
+        return obs_or_fq_map[edge_or_node]
+    elif isinstance(quantization_spec, DerivedQuantizationSpec):
+        # can't use asdict, so not calling get_observer_kwargs here
+        kwargs = {
+            "dtype": quantization_spec.dtype,
+            "derive_qparams_fn": quantization_spec.derive_qparams_fn,
+            "quant_min": quantization_spec.quant_min,
+            "quant_max": quantization_spec.quant_max,
+            "qscheme": quantization_spec.qscheme,
+            "ch_axis": quantization_spec.ch_axis,
+        }
+        edge_or_nodes = quantization_spec.derived_from
+        obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
+        kwargs["obs_or_fqs"] = obs_or_fqs
+        return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
+    elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
+        kwargs = _get_observer_kwargs(quantization_spec)
+        observer_ctr = FixedQParamsObserver.with_args(**kwargs)
+        if is_qat:
+            return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)()
+        else:
+            return observer_ctr()
+
+    assert isinstance(quantization_spec, QuantizationSpec)
+    observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
+    kwargs = _get_observer_kwargs(quantization_spec)
+    kwargs.pop("observer_or_fake_quant_ctr")
+    # we will remove is_dynamic from QuantizationSpec because
+    # it seems that dynamic range quantization
+    obs_or_fq_class = observer_or_fake_quant_ctr
+    if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
+        obs_or_fq_class = observer_or_fake_quant_ctr.p.func  # type: ignore[union-attr, assignment]
+    if "PerChannel" not in obs_or_fq_class.__name__:  # type: ignore[operator, union-attr]
+        kwargs.pop("ch_axis")
+    return observer_or_fake_quant_ctr.with_args(**kwargs)()
+
+
+def _needs_obs_or_fq(
+    prev_output_dtype: Any,
+    prev_output_is_dynamic: bool,
+    cur_target_dtype: Any,
+    cur_target_is_dynamic: bool,
+    reuse_input_obs_or_fq: bool,
+    is_zeroth_arg: bool = False,
+) -> bool:
+    """
+    note: we will treat "not specified" as torch.float for now
+    utility function that checks if we should insert an observer or fake quant node
+    base on the requested dtype for the nodes from user
+
+    is_zeroth_arg: we only dynamically quantize the first arg of the node right now
+      this should be removed when we enable configuring dynamic quantization
+      for a specific argument, this can be removed if we deprecate fx graph mode
+      quantization
+
+    """
+
+    # need to insert placeholder observer for dynamic quantization so that it can
+    # be converted to choose_qparams -> q -> dq in convert step
+    if cur_target_is_dynamic:
+        assert cur_target_dtype in _OBS_DTYPE_LIST, (
+            f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
+        )
+        assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST
+        return is_zeroth_arg
+    if reuse_input_obs_or_fq:
+        return False
+    # non dynamic quantization
+    if cur_target_dtype in _OBS_DTYPE_LIST:
+        return (
+            prev_output_dtype in _OBS_DTYPE_LIST + [torch.float]
+            and cur_target_dtype != prev_output_dtype
+        )
+
+    # lots of error checking are skipped here for now
+    return False
+
+
+def _is_activation_post_process_node(
+    node: Node, named_modules: dict[str, torch.nn.Module]
+) -> bool:
+    return (
+        isinstance(node, torch.fx.Node)
+        and node.op == "call_module"
+        and _is_activation_post_process(named_modules[str(node.target)])
+    )
+
+
+def _get_dtype_and_is_dynamic(
+    obs_or_fq: Optional[ObserverOrFakeQuantize],
+) -> tuple[Optional[torch.dtype], bool]:
+    """Given a constructor for observer or fake quant module, returns
+    a Tuple of dtype and is_dynamic
+    """
+    # TODO: instead of instantiating the instance, we can use inspect to get the default args
+    if obs_or_fq is None:
+        return None, False
+    else:
+        return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False)  # type: ignore[return-value]
+
+
+def _is_input_arg_dtype_supported_by_backend(
+    arg: Argument,
+    node: Node,
+    qconfig: QConfigAny,
+    dtype_config: DTypeConfig,
+    backend_config: BackendConfig,
+) -> bool:
+    """Check if the configured qconfig for the argument
+    is supported by the backend or not
+    """
+    if isinstance(arg, (list, tuple)):
+        return all(
+            _is_input_arg_dtype_supported_by_backend(
+                a, node, qconfig, dtype_config, backend_config
+            )
+            for a in arg
+        )
+    if not isinstance(arg, Node):
+        return True
+    # TODO: support check for standalone module
+    is_weight = node_arg_is_weight(node, arg)
+    is_bias = node_arg_is_bias(node, arg)
+    is_activation = not is_weight and not is_bias
+    if is_activation:
+        input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
+            "input_act_obs_or_fq_ctr"
+        )
+        input_act_obs_or_fq = (
+            input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None
+        )
+        qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(
+            input_act_obs_or_fq
+        )
+        # TODO(future PR): remove the cast to bool below after figuring
+        # out why backend_config has is_dynamic set to None in some cases.
+        return (dtype_config.input_dtype is None) or (
+            dtype_config.input_dtype == qconfig_dtype
+            and bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic)
+            and _qconfig_satisfies_dtype_config_constraints(
+                qconfig, dtype_config.input_dtype_with_constraints
+            )
+        )
+    elif is_weight:
+        # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
+        weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
+            "weight_obs_or_fq_ctr", None
+        )
+        weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None
+        qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq)
+        backend_config_weight_dtype = dtype_config.weight_dtype
+        dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
+        qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
+            qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False
+        )
+        return backend_config_weight_dtype is None or (
+            dtype_matches and qconfig_satisfies_constraints
+        )
+    else:  # bias
+        # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
+        bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
+            "bias_obs_or_fq_ctr", None
+        )
+        bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None
+        qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq)
+        backend_config_bias_dtype = dtype_config.bias_dtype
+        return (
+            backend_config_bias_dtype is None
+            or qconfig_bias_dtype == backend_config_bias_dtype
+        )
+
+
+def _is_output_dtype_supported_by_backend(
+    node: Node,
+    qconfig: QConfigAny,
+    dtype_config: DTypeConfig,
+) -> bool:
+    """Check if the configured qconfig for the output
+    is supported by the backend or not
+    """
+    # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
+    backend_config_output_dtype = dtype_config.output_dtype
+    # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
+    # from input activation check can be reused here
+    qconfig_output_dtype = None
+    output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
+        "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
+    )
+    output_act_obs_or_fq = (
+        output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+    )
+    qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(
+        output_act_obs_or_fq
+    )
+    # TODO: this is a hack because we can only specify one activation_obs_or_fq for
+    # qconfig (qconfig.activation), and we are only supporting dynamically quantized
+    # linear op which has fp32 output dtype, this should be removed if we generalize
+    # the structure of qconfig in the future
+    if qconfig_output_is_dynamic:
+        qconfig_output_dtype = torch.float32
+    dtype_matches = qconfig_output_dtype == backend_config_output_dtype
+    qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
+        qconfig, dtype_config.output_dtype_with_constraints
+    )
+    return backend_config_output_dtype is None or (
+        dtype_matches and qconfig_satisfies_constraints
+    )
+
+
+def _is_observer_in_same_graph(
+    node: Node,
+    named_modules: dict[str, torch.nn.Module],
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat,
+):
+    """Check if observer in same graph
+    when the node output is not fp32 and input is 'placeholder'
+    the input is assumed to be quantized, so it is observed
+    in a different place rather than not observed.
+    """
+    node_output_dtype = _get_arg_target_dtype_as_output(
+        node, named_modules, obs_or_fq_map, is_qat
+    )
+    if len(node.args) > 0 and isinstance(node.args[0], Node):
+        if (
+            node_output_dtype in [torch.quint8, torch.uint8]
+            and node.args[0].op == "placeholder"
+        ):
+            return False
+    return True
+
+
+def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
+    pattern: Optional[Pattern],
+    matched_node_pattern: Optional[list[Node]],
+    qconfig: QConfigAny,
+    backend_config: BackendConfig,
+) -> bool:
+    """Check if the dtype configuration of a pattern is supported by
+    the backend or not, and whether the qconfig satisfies constraints
+    specified in the corresponding dtype config.
+    """
+    if backend_config is None or pattern is None:
+        return True
+    assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
+    pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
+    dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
+    pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
+
+    root_node_getter = pattern_to_root_node_getter.get(
+        pattern, _default_root_node_getter
+    )
+    root_node = root_node_getter(matched_node_pattern)
+    input_node = root_node
+    output_node = matched_node_pattern[0]
+    for dtype_config in dtype_configs:
+        # check if arg dtype are supported
+        supported = True
+        for arg in list(input_node.args) + list(input_node.kwargs.values()):
+            supported = supported and _is_input_arg_dtype_supported_by_backend(
+                arg, input_node, qconfig, dtype_config, backend_config
+            )
+        # check if output dtype is supported
+        supported = supported and _is_output_dtype_supported_by_backend(
+            output_node, qconfig, dtype_config
+        )
+        if supported:
+            return True
+    return False
+
+
+def _get_standalone_module_configs(
+    node: Node,
+    named_modules: dict[str, torch.nn.Module],
+    prepare_custom_config: PrepareCustomConfig,
+    parent_qconfig: QConfigAny,
+    parent_backend_config: Optional[BackendConfig],
+) -> tuple[
+    QConfigMapping, tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]
+]:
+    """
+    Returns the standalone module QConfigMapping and PrepareCustomConfig
+    for `node`, assuming that the module pointed to by `node` is
+    a standalone modules.
+    """
+    module_name = str(node.target)
+    module_type = type(named_modules[module_name])  # type: ignore[index]
+    # name config has precedence over type config
+    config_entry = StandaloneModuleConfigEntry(None, (), None, None)
+    config_entry = prepare_custom_config.standalone_module_classes.get(
+        module_type, config_entry
+    )
+    config_entry = prepare_custom_config.standalone_module_names.get(
+        module_name, config_entry
+    )
+    # fallback to use parent module's qconfig if user didn't specify qconfig dict
+    qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(
+        parent_qconfig
+    )
+    example_inputs = config_entry.example_inputs
+    prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
+    backend_config = config_entry.backend_config or parent_backend_config
+    return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
+
+
+def _qat_swap_modules(
+    root: torch.nn.Module, module_to_qat_module: dict[Pattern, type[torch.nn.Module]]
+) -> None:
+    convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
+
+
+def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: set[str]):
+    if isinstance(matched_node_pattern, Node):
+        s.add(matched_node_pattern.name)
+    elif isinstance(matched_node_pattern, (list, tuple)):
+        for maybe_node in matched_node_pattern:
+            _add_matched_node_name_to_set(maybe_node, s)
+
+
+def _insert_obs_or_fq(
+    node: Node,
+    obs_or_fq: ObserverOrFakeQuantize,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+) -> Node:
+    """
+    Attaches `obs_or_fq` to `model`, and creates a node which calls
+    `obs_or_fq` on the output of `node`.
+
+    obs_or_fq: an instance of Observer or FakeQuantize module
+    """
+    model_device = assert_and_get_unique_device(model)
+    if model_device:
+        obs_or_fq.to(model_device)
+    # add obs_or_fq module as attribute
+    if is_equalization_observer(obs_or_fq):
+        prefix = node.name + "_equalization_process_"
+    else:
+        prefix = "activation_post_process_"
+    get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix)
+    obs_or_fq_name = get_new_obs_or_fq_name(model)
+    setattr(model, obs_or_fq_name, obs_or_fq)
+    named_modules[obs_or_fq_name] = obs_or_fq
+    with graph.inserting_after(node):
+        new_obs = graph.create_node("call_module", obs_or_fq_name, (node,), {})
+    return new_obs
+
+
+def _set_target_dtype_info_for_matched_node_pattern(
+    matched_node_pattern: NodePattern,
+    last_node: Node,
+    qconfig: QConfigAny,
+    qhandler: Optional[QuantizeHandler],
+    backend_config: BackendConfig,
+    named_modules: dict[str, torch.nn.Module],
+    cache_for_no_tensor_check: dict[Node, bool],
+    processed_nodes: set[Node],
+) -> None:
+    """Sets the target_dtype_info for each node in matched_node_pattern
+    Note: processed_nodes is used to ensure we only process each node once
+    """
+    if isinstance(matched_node_pattern, (list, tuple)):
+        for node_pattern in matched_node_pattern:
+            _set_target_dtype_info_for_matched_node_pattern(
+                node_pattern,
+                last_node,
+                qconfig,
+                qhandler,
+                backend_config,
+                named_modules,
+                cache_for_no_tensor_check,
+                processed_nodes,
+            )
+
+    # set target_dtype_info if matched_node_pattern is a Node
+    # other types of matched object, e.g. int, float literals, are ignored
+    elif isinstance(matched_node_pattern, Node):
+        # for pyre
+        assert isinstance(matched_node_pattern, Node)
+        node = matched_node_pattern
+        if node in processed_nodes:
+            return
+        processed_nodes.add(node)
+
+        if qconfig is None:
+            return
+        # TODO: refactor the following code in terms of apply a qconfig to a pattern
+        # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
+        # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
+        # and set output_obs_or_fq_ctr based on qconfig.output_act
+        # this also requires we extend the structure of QConfig to support more fine
+        # grained configurations
+        target_dtype_info: dict[str, Any] = _get_target_activation_dtype_for_node(
+            node,
+            qconfig,
+            qhandler,
+            named_modules,
+            backend_config,
+            cache_for_no_tensor_check,
+        )
+        node.meta["target_dtype_info"] = target_dtype_info
+
+
+def _get_target_activation_dtype_for_node(
+    node: Node,
+    qconfig: QConfigAny,
+    qhandler: Optional[QuantizeHandler],
+    named_modules: dict[str, torch.nn.Module],
+    backend_config: BackendConfig,
+    cache_for_no_tensor_check: dict[Node, bool],
+) -> dict[str, Any]:
+    """
+    For each op attribute in the op's input activation, output activation,
+    weight, bias - returns the settings of dtype and is_dynamic we expect
+    for the `quantize` call in the reference model representation, or None
+    if there is no `quantize` call needed.
+
+    For example, if we have a node corresponding to `op0` in
+
+      x0 -> op0 -> x1
+
+    And we want a reference quantized representation to be
+
+      x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
+
+    Then this function will return
+
+      {
+        "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
+        "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
+      }
+
+    TODO(future PR, if needed): explicitly spell out the non-Tensor
+    dtypes.
+    """
+    args_have_no_tensors = all_node_args_have_no_tensors(
+        node, named_modules, cache_for_no_tensor_check
+    )
+    if args_have_no_tensors:
+        return {
+            "input_act_obs_or_fq_ctr": None,
+            "output_act_obs_or_fq_ctr": None,
+        }
+    # get qconfig to determine the eventual dtype of this node
+    if qconfig is not None:
+        act_dtype, weight_dtype, input_act_is_dynamic = get_qconfig_dtypes(qconfig)
+
+        # Currently `QConfig` only has one `activation` field.
+        # For static quantization, it is reused for both input
+        # and output activation. For dynamic quantization, this
+        # field is currently only used for the input activation,
+        # with the output activation being in fp32.
+        # In the future this may change as we add more fields
+        # to the `QConfig` object.
+        bias_dtype = (
+            torch.float16
+            if (
+                act_dtype == torch.float16
+                and weight_dtype == torch.float16
+                and (not input_act_is_dynamic)
+            )
+            else torch.float
+        )
+
+        is_general_tensor_value_op = (
+            qhandler is not None and qhandler.is_general_tensor_value_op()
+        )
+
+        _is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
+
+        weight_index = None
+        if (
+            isinstance(node, Node)
+            and node.op == "call_function"
+            and node.target in backend_config._pattern_complex_format_to_config
+        ):
+            weight_index = backend_config._pattern_complex_format_to_config[
+                node.target
+            ]._input_type_to_index.get("weight")
+
+        bias_index = None
+        if (
+            isinstance(node, Node)
+            and node.op == "call_function"
+            and node.target in backend_config._pattern_complex_format_to_config
+        ):
+            bias_index = backend_config._pattern_complex_format_to_config[
+                node.target
+            ]._input_type_to_index.get("bias")
+
+        return {
+            "input_act_obs_or_fq_ctr": qconfig.activation,
+            "weight_obs_or_fq_ctr": qconfig.weight,
+            "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
+            "weight_index": weight_index,
+            "bias_index": bias_index,
+            "output_act_obs_or_fq_ctr": qconfig.activation,
+            "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig),
+            "input_output_share_observers": is_general_tensor_value_op,
+            "_is_standalone_module": _is_standalone_module,
+        }
+    return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
+
+
+def _get_output_act_obs_or_fq(
+    arg: Node,
+    named_modules: dict[str, torch.nn.Module],
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[ObserverOrFakeQuantize]:
+    """Get the constructor for observer or fake quant object for
+    the argument in the original graph as the output of previous node,
+    skipping inserted observers
+
+    We are assuming that the observers are inserted correctly, and the dtype for
+    argument in quantized graph will match what is specified by the qconfig
+    """
+    assert isinstance(arg, Node)
+    if "quantization_annotation" in arg.meta:
+        return _create_obs_or_fq_from_qspec(
+            arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
+        )
+
+    # Custom module LSTM output is a tuple that we broke down into the internal nodes in order
+    # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
+    # Since we modified the graph in this case, we must trace back from the args through
+    # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
+    # not be able to accurately detect whether this node is a consumer of custom module LSTM.
+    custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(
+        arg, named_modules
+    )
+    output_act_obs_or_fq_ctr = None
+    if custom_module_lstm_node is not None:
+        output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"][
+            "output_act_obs_or_fq_ctr"
+        ]
+        output_act_obs_or_fq = (
+            output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+        )
+    elif _is_activation_post_process_node(arg, named_modules):
+        observed_arg = arg.args[0]
+        assert isinstance(observed_arg, Node), (
+            "Currently we only support observing Node"
+        )
+        if "quantization_annotation" in observed_arg.meta:
+            output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
+                observed_arg.meta["quantization_annotation"].output_qspec,
+                obs_or_fq_map,
+                is_qat,
+            )
+        else:
+            assert "target_dtype_info" in observed_arg.meta
+            output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][
+                "output_act_obs_or_fq_ctr"
+            ]
+            output_act_obs_or_fq = (
+                output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+            )
+    else:
+        if "target_dtype_info" in arg.meta:
+            output_act_obs_or_fq_ctr = arg.meta["target_dtype_info"].get(
+                "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
+            )
+        else:
+            output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
+        output_act_obs_or_fq = (
+            output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+        )
+
+    return output_act_obs_or_fq
+
+
+def _get_arg_target_dtype_as_output(
+    arg: Node,
+    named_modules: dict[str, torch.nn.Module],
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[torch.dtype]:
+    arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(
+        arg, named_modules, obs_or_fq_map, is_qat
+    )
+    arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(
+        arg_as_output_act_obs_or_fq
+    )
+    return arg_as_output_target_dtype
+
+
+def _get_arg_as_input_act_obs_or_fq(
+    arg: Node,
+    node: Node,
+    named_modules: dict[str, torch.nn.Module],
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[ObserverOrFakeQuantize]:
+    """Get the observer or fake quant constructor for the Argument `arg`, as input
+    to Node `node`
+    """
+    assert isinstance(arg, Node)
+    # "input_qspec_map" is the more general design we'll use for pt2e path
+    # it is a map from input argument node to observer or fake quant constructor, for example
+    # for the following graph:
+    # x -> conv -> output
+    #
+    # we may annotate conv node like the following:
+    # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...)
+    #
+    if "quantization_annotation" in node.meta:
+        input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
+        input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
+        if input_arg_qspec is None:
+            input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR()
+        else:
+            input_arg_obs_or_fq = _create_obs_or_fq_from_qspec(
+                input_arg_qspec, obs_or_fq_map, is_qat
+            )
+        return input_arg_obs_or_fq
+
+    # we can remove the following path in the future if fx graph mode quantization is
+    # no longer used
+    is_weight = node_arg_is_weight(node, arg)
+    is_bias = node_arg_is_bias(node, arg)
+    is_activation = not is_weight and not is_bias
+    obs_or_fq_ctr = None
+    if is_activation:
+        obs_or_fq_ctr = node.meta["target_dtype_info"].get(
+            "input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
+        )
+    elif is_weight:
+        if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
+            obs_or_fq_ctr = node.meta["target_dtype_info"].get(
+                "weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
+            )
+    else:
+        obs_or_fq_ctr = node.meta["target_dtype_info"].get(
+            "bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
+        )
+    return obs_or_fq_ctr() if obs_or_fq_ctr else None
+
+
+def _maybe_insert_input_observer_for_arg_or_kwarg(
+    node: Union[Node, Any],
+    arg: Argument,
+    qconfig: QConfigAny,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+    qhandler: Optional[QuantizeHandler],
+    prepare_custom_config: PrepareCustomConfig,
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+    backend_config: Optional[BackendConfig] = None,
+) -> Argument:
+    """
+    Given a `node` and an `arg`, inserts an input observer between
+    `node` and `arg` if necessary.
+    """
+    # for ops such as torch.cat([x0, x1]),
+    # traverse through the list
+    if isinstance(arg, (list, tuple)):
+        new_arg_to_return = []
+        for inner_arg in arg:
+            new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
+                node,
+                inner_arg,
+                qconfig,
+                model,
+                named_modules,
+                graph,
+                qhandler,
+                prepare_custom_config,
+                obs_or_fq_map,
+                is_qat,
+                backend_config,
+            )
+            new_arg_to_return.append(new_inner_arg)
+        return type(arg)(new_arg_to_return)
+
+    if not isinstance(arg, Node):
+        return arg
+    assert isinstance(arg, Node)
+    # default (no observer)
+    new_arg = arg
+
+    is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
+    # TODO: move this to a separate function
+    if not is_standalone_module:
+        # Note: qconfig can be None in this branch this we are getting act/fq from
+        # node.meta now
+        # regular flow for most nodes, except standalone modules
+
+        if "quantization_annotation" in node.meta:
+            reuse_input_obs_or_fq = node.meta[
+                "quantization_annotation"
+            ]._reuse_input_obs_or_fq
+        else:
+            assert "target_dtype_info" in node.meta
+            # TODO: we are assuming "target_dtype_info" exists here, maybe
+            # a default value also need to be provided here
+            target_dtype_info = node.meta["target_dtype_info"]
+            # for nodes that doesn't have `reuse_input_obs_or_fq` configured,
+            # we'll default to False, this makes configuring this field optional for users
+            reuse_input_obs_or_fq = target_dtype_info.get(
+                "reuse_input_obs_or_fq", False
+            )
+        arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(
+            arg, node, named_modules, obs_or_fq_map, is_qat
+        )
+        (
+            arg_as_input_target_dtype,
+            arg_as_input_target_is_dynamic,
+        ) = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
+
+        arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(
+            arg, named_modules, obs_or_fq_map, is_qat
+        )
+        (
+            arg_as_output_target_dtype,
+            arg_as_output_target_is_dynamic,
+        ) = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
+
+        needs_obs_or_fq = _needs_obs_or_fq(
+            arg_as_output_target_dtype,
+            arg_as_output_target_is_dynamic,
+            arg_as_input_target_dtype,
+            arg_as_input_target_is_dynamic,
+            reuse_input_obs_or_fq,
+            is_zeroth_arg=len(node.args) > 0 and arg is node.args[0],
+        )
+
+    else:
+        assert qconfig is not None
+        # custom flow for standalone modules
+        _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs(
+            node, named_modules, prepare_custom_config, qconfig, backend_config
+        )
+        sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
+
+        # for args, this is set to the index of the current arg
+        # for kwargs, this is left at None
+        cur_input_idx = None
+        for arg_idx, arg_to_check in enumerate(node.args):
+            if arg_to_check is arg:
+                cur_input_idx = arg_idx
+                break
+
+        if cur_input_idx is None:
+            needs_obs_or_fq = False
+        else:
+            arg_as_output_target_dtype = _get_arg_target_dtype_as_output(
+                arg, named_modules, obs_or_fq_map, is_qat
+            )
+            arg_as_input_target_dtype = (
+                torch.quint8
+                if cur_input_idx in sm_input_quantized_idxs
+                else torch.float
+            )
+            needs_obs_or_fq = (
+                arg_as_output_target_dtype != arg_as_input_target_dtype
+            ) and (arg_as_input_target_dtype != torch.float)
+
+        act_post_process_ctr = qconfig.activation
+        arg_as_input_act_obs_or_fq = (
+            act_post_process_ctr() if act_post_process_ctr else None
+        )
+
+    if needs_obs_or_fq:
+        existing_obs_node = None
+
+        # Before using the new observer, check if an observer
+        # of the correct type already exists. If it does, use it.
+        # This prevents duplicate observer insertions if a node is
+        # used by multiple nodes.
+        # TODO: this is looking into how the value is used in the future
+        # we should remove this
+        # removing this means we insert one observer for each use, even if they
+        # have the same dtype, we can have an extra pass that removes the extra observers
+        for maybe_obs_node in arg.users.keys():
+            if maybe_obs_node.op == "call_module":
+                maybe_obs_mod = named_modules[maybe_obs_node.target]  # type: ignore[index]
+                if (
+                    type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq)
+                    and maybe_obs_mod.dtype == arg_as_input_target_dtype  # type: ignore[possibly-undefined]
+                ):
+                    arg_as_input_act_obs_or_fq = maybe_obs_mod  # type: ignore[assignment]
+                    existing_obs_node = maybe_obs_node
+                    break
+
+        assert arg_as_input_act_obs_or_fq is not None
+        obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
+        if existing_obs_node is None:
+            new_obs_node = _insert_obs_or_fq(
+                arg, arg_as_input_act_obs_or_fq, model, named_modules, graph
+            )
+            # override this arg to be the observed arg
+            new_arg = new_obs_node
+        else:
+            new_arg = existing_obs_node
+
+    return new_arg
+
+
+def _maybe_insert_input_observers_for_node(
+    node: Node,
+    qconfig: QConfigAny,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+    qhandler: Optional[QuantizeHandler],
+    prepare_custom_config: PrepareCustomConfig,
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+    backend_config: Optional[BackendConfig] = None,
+) -> None:
+    """
+    If needed, inserts observers to the input args and kwargs of `node`.
+    Note: modifies `node` inplace.
+
+    For example, if cur_node needs an observer after prev_node, we change from
+
+      prev_node -> cur_node
+
+    To
+
+      prev_node -> obs -> cur_node
+
+    Note: backend_config only needed for standalone_module node
+    """
+    # Look through every input arg.  If that arg's target dtype does not
+    # match the current node's target dtype, insert an observer.
+    new_args = []
+    for arg in node.args:
+        new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
+            node,
+            arg,
+            qconfig,
+            model,
+            named_modules,
+            graph,
+            qhandler,
+            prepare_custom_config,
+            obs_or_fq_map,
+            is_qat,
+            backend_config,
+        )
+        new_args.append(new_arg)
+
+    new_kwargs = {}
+    for k, kwarg in node.kwargs.items():
+        new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
+            node,
+            kwarg,
+            qconfig,
+            model,
+            named_modules,
+            graph,
+            qhandler,
+            prepare_custom_config,
+            obs_or_fq_map,
+            is_qat,
+            backend_config,
+        )
+        new_kwargs[k] = new_kwarg
+
+    # assign the new args and kwargs to the node, inplace
+    node.args = tuple(new_args)
+    node.kwargs = new_kwargs
+
+
+def _maybe_insert_input_equalization_observers_for_node(
+    node: Node,
+    equalization_qconfig: Any,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+    is_branch: bool,
+) -> None:
+    """
+    If `node` needs to be equalized, find the input/weight observers it needs in
+    `equalization_qconfig`, creates them, and inserts it into `graph`.
+
+    If `node` does not need an equalization observer, returns None.
+    """
+    if equalization_qconfig is None or not node_supports_equalization(
+        node, named_modules
+    ):
+        return
+
+    if is_branch:
+        warnings.warn(f"Cannot equalize {node} because it is part of a branch.")
+        return
+
+    new_args = []
+    for arg in node.args:
+        if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
+            new_args.append(arg)
+            continue
+
+        is_weight = node_arg_is_weight(node, arg)
+
+        act_eq_process_ctr = (
+            equalization_qconfig.weight
+            if is_weight
+            else equalization_qconfig.input_activation
+        )
+
+        new_eq_obs_mod = act_eq_process_ctr()
+        new_eq_obs_node = _insert_obs_or_fq(
+            arg, new_eq_obs_mod, model, named_modules, graph
+        )
+
+        new_args.append(new_eq_obs_node)
+
+    # assign the new args and kwargs to the node, inplace
+    node.args = tuple(new_args)
+
+
+def _maybe_insert_output_observer_for_node(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[Node]:
+    """
+    If `node` needs an output observer, creates it, inserts it into `graph`
+    and returns it.
+
+    If `node` does not need an output observer, returns None.
+
+    Note: inserting dynamic quantization ops for output is not supported in fx graph mode
+    quantization code path right now
+    """
+    assert node.op != "output", "observer insertion for outputs is handled elsewhere"
+
+    is_standalone_module = False
+    if "quantization_annotation" in node.meta:
+        output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
+            node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
+        )
+    else:
+        assert "target_dtype_info" in node.meta
+        is_standalone_module = node.meta["target_dtype_info"].get(
+            "_is_standalone_module", False
+        )
+        output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
+            "output_act_obs_or_fq_ctr"
+        )
+        output_act_obs_or_fq = (
+            output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
+        )
+    target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
+    # uncomment after we support reuse_input_obs_or_fq properly by having separate
+    # implemntations for this key instead of reusing the input_output_share_observers
+    # code
+    # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
+    # for now we set this to False since reuse_input_obs_or_fq for
+    # the output of a node is implementation in the same code path as observer sharing,
+    # we should refactor this part to make it clearer in the future
+    # and we would be able to read this from config directly
+    reuse_input_obs_or_fq = False
+
+    # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False
+    # because the prev_output is the output of an fp32 op, althought technically
+    # we should get the dtype of the output from node.meta["val"] in the future
+    # if we deprecate fx graph mode quantization
+    needs_obs_or_fq = _needs_obs_or_fq(
+        torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq
+    )
+    # currently the activation in QConfig(activation=...,) is for both input
+    # and output, and when the activation is configured to be dynamic quantization
+    # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means
+    # the input should by dynamically quantized, but output should not be quantized
+    #
+    # there is no way we can specify different observer/fq for input and output
+    # activation through QConfig today, this limitation is lifted in the
+    # quantizer/annotation API in pytorch 2.0 export quantization code path,
+    # but since this code is reused, annotating output to be dynamically quantized
+    # would not work either for that.
+    # we can change QConfig to support input/output activation if we want
+    # to remove the following check, or if we can deprecate fx graph mode quantization
+    if target_is_dynamic:
+        needs_obs_or_fq = False
+
+    # we never insert observers to output of standalone module, we assume
+    # if needed, they are inserted inside the standalone module
+    needs_obs_or_fq = needs_obs_or_fq and (not is_standalone_module)
+
+    if needs_obs_or_fq:
+        obs_or_fq_map[node] = output_act_obs_or_fq
+        return _insert_obs_or_fq(
+            node, output_act_obs_or_fq, model, named_modules, graph
+        )
+    else:
+        return None
+
+
+def _maybe_insert_observers_before_graph_output(
+    graph_output_node: Node,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> None:
+    """
+    If the output needs to be quantized and there are any nodes
+    in the output which are not already observed, inserts observers
+    for those nodes.
+    """
+
+    def _recursive_maybe_replace_node_with_obs(
+        maybe_node: Argument,
+        model: torch.nn.Module,
+        named_modules: dict[str, torch.nn.Module],
+        graph: Graph,
+    ) -> Argument:
+        """
+        Navigate an arbitrary data structure of lists, tuples, dicts.
+        For each container type, recurse on all inputs. Once any Node
+        is found, insert an observer if needed and do not recurse further.
+
+        For example, given a structure of
+
+          {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
+
+        we recurse down to bar1 and bar3, observe them if necessary,
+        and if we inserted an observer then replace the original node
+        with its observer.
+
+        Returns the data structure with all nodes needing observation being
+        replaced by their observers.
+        """
+        if isinstance(maybe_node, Node):
+            # check dtype of this node
+            arg_as_output_target_dtype = _get_arg_target_dtype_as_output(
+                maybe_node, named_modules, obs_or_fq_map, is_qat
+            )
+            observer_mod = None
+            arg_as_input_target_dtype = torch.float
+            if "target_dtype_info" in maybe_node.meta:
+                observer_cls = maybe_node.meta["target_dtype_info"].get(
+                    "input_act_obs_or_fq_ctr", None
+                )
+                if observer_cls is not None:
+                    observer_mod = observer_cls()
+                    arg_as_input_target_dtype = observer_mod.dtype
+            # TODO: this does not handle dynamic quantization yet
+            need_obs = (
+                arg_as_output_target_dtype != arg_as_input_target_dtype
+                and arg_as_input_target_dtype != torch.float
+            )
+            if need_obs:
+                assert observer_mod is not None
+                # insert observer
+                observer_node = _insert_obs_or_fq(
+                    maybe_node, observer_mod, model, named_modules, graph
+                )
+                return observer_node
+            else:
+                return maybe_node
+        elif isinstance(maybe_node, (list, tuple)):
+            results = [
+                _recursive_maybe_replace_node_with_obs(
+                    inner_node, model, named_modules, graph
+                )
+                for inner_node in maybe_node
+            ]
+            if isinstance(maybe_node, list):
+                return results
+            else:
+                return tuple(results)
+        elif isinstance(maybe_node, dict):
+            results_dict = {}
+            for k, inner_v in maybe_node.items():
+                results_dict[k] = _recursive_maybe_replace_node_with_obs(
+                    inner_v, model, named_modules, graph
+                )
+            return results_dict
+        elif maybe_node is None:
+            return None
+        else:
+            raise Exception(  # noqa: TRY002
+                "Unhandled type for returned node:", maybe_node
+            )
+
+    new_args = [
+        _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
+        for old_arg in graph_output_node.args
+    ]
+
+    graph_output_node.args = tuple(new_args)  # type: ignore[assignment]
+
+
+def _maybe_propagate_dtype_for_node(
+    node: Node,
+    target_dtype: Union[torch.dtype, type],
+    node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
+) -> None:
+    """
+    Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
+    is a general tensor shape op, also call this function recursively on
+    the first argument, to propagate the dtype to the caller.
+    """
+    node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
+    node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
+    # if this is a copy node, propagate to first arg
+    (
+        _root_node,
+        _,
+        _pattern,
+        qhandler,
+        _qconfig,
+    ) = node_name_to_match_result_with_qconfig.get(
+        node.name, (None, None, None, None, None)
+    )
+    # TODO: probably need to remove `is_general_tensor_value_op`
+    if qhandler is not None and qhandler.is_general_tensor_value_op():
+        prev_node = node.args[0]
+        if isinstance(prev_node, Node):
+            _maybe_propagate_dtype_for_node(
+                prev_node, target_dtype, node_name_to_match_result_with_qconfig
+            )
+
+
+def propagate_dtypes_for_known_nodes(
+    graph: Graph,
+    node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
+) -> None:
+    """
+    Currently we assume that inputs to the graph are either `torch.float` or
+    `torch.quint8`, which is not always correct. For ops such as
+    `x.masked_fill(mask, value)`, we know that the dtype of  `mask` is a
+    `BoolTensor`. Propagate this information throughout the graph.
+
+    Note: not all dtypes in the graph will be correct after this pass, but a
+    higher percentage of them will be correct. Hopefully in the future we can
+    replace this with a better way to reason about dtypes of tensors.
+    """
+    for node in graph.nodes:
+        non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
+
+        for arg_type in non_observable_arg_dict:
+            non_observable_indices = non_observable_arg_dict[arg_type](node)
+
+            for index in non_observable_indices:
+                arg = node.args[index]
+
+                # when an argument is a tuple, it does not show up as another node so we need to go through
+                # all elements of the tuple manually
+                if isinstance(arg, (tuple, list)):
+                    arg_list = list(arg)
+                else:
+                    arg_list = [arg]
+
+                for cur_arg in arg_list:
+                    # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated
+                    if isinstance(cur_arg, torch.fx.node.Node):
+                        _maybe_propagate_dtype_for_node(
+                            cur_arg, arg_type, node_name_to_match_result_with_qconfig
+                        )
+
+
+def _maybe_make_input_output_share_observers(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+) -> bool:
+    """
+    Ensures that we share an observer
+    for all input arguments as well as the output argument. In detail, given
+    a graph of
+
+      x0 -> obs0 -> op -> x2
+                  /
+      x1 -> obs1 /
+
+    where node obs0 points to observer instance observer0,
+    obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
+    and ob2 point to observer0.
+    Returns: whether the operation succeeded or not
+    """
+    first_arg = None
+    # find the first non-Tensor arg
+    for i in range(len(node.args)):
+        if isinstance(node.args[i], (Node, list, tuple)):
+            first_arg = node.args[i]
+            break
+
+    # if there is no non-Tensor arg, return directly
+    if first_arg is None:
+        return False
+
+    if isinstance(first_arg, (list, tuple)):
+        first_arg_arg = first_arg[0]
+    elif isinstance(first_arg, Node):
+        first_arg_arg = first_arg
+    else:
+        return False
+
+    # if we have a graph such as
+    #   observed_node -> non_observed_node -> cat
+    # we need to navigate up to the first observer
+    iteration_guard = 0
+    while not _is_activation_post_process_node(first_arg_arg, named_modules):
+        if not isinstance(first_arg_arg, Node):
+            return False
+        # did not find an activation_post_process for the op
+        if first_arg_arg.op == "placeholder":
+            return False
+        # trace back the args until we found the first Tensor/Node
+        trace_back_node = None
+        for i in range(len(first_arg_arg.args)):
+            trace_back_node = first_arg_arg.args[i]
+            if isinstance(trace_back_node, Node):
+                break
+        if trace_back_node is None:
+            return False
+        first_arg_arg = trace_back_node
+
+        iteration_guard += 1
+        if iteration_guard > 10000:
+            raise AssertionError("Unable to find observer of previous node")
+
+    assert isinstance(first_arg_arg, Node)
+    target_to_use = first_arg_arg.target
+    assert isinstance(target_to_use, str)
+    obs_mod_to_use = named_modules[target_to_use]
+
+    if isinstance(first_arg, (list, tuple)):
+        # set all other input observer nodes to use that module
+        for input_idx, input_arg in enumerate(first_arg):
+            if input_idx == 0:
+                continue
+            iteration_guard = 0
+            while not _is_activation_post_process_node(input_arg, named_modules):
+                # failed to trace back since no input arg for the current node
+                if len(input_arg.args) < 1:
+                    return False
+                input_arg = input_arg.args[0]
+                iteration_guard += 1
+                if iteration_guard > 10000:
+                    raise AssertionError("Unable to find observer of previous node")
+
+            parent_name, name = _parent_name(input_arg.target)
+            setattr(named_modules[parent_name], name, obs_mod_to_use)
+
+    # set the output observer node to use that module
+    for output_obs_node in node.users.keys():
+        assert _is_activation_post_process_node(output_obs_node, named_modules)
+        parent_name, name = _parent_name(output_obs_node.target)
+        setattr(named_modules[parent_name], name, obs_mod_to_use)
+
+    # TODO(future PR): delete the orphaned observer modules
+    return True
+
+
+def _remove_output_observer(
+    node: Node, model: torch.nn.Module, named_modules: dict[str, torch.nn.Module]
+):
+    items = list(node.users.items())
+    for output_obs_node, _ in items:
+        assert _is_activation_post_process_node(output_obs_node, named_modules)
+        output_obs_node.replace_all_uses_with(node)
+        model.graph.erase_node(output_obs_node)  # type: ignore[union-attr, operator]
+
+
+def _swap_custom_module_to_observed(
+    node: Node,
+    qconfig: QConfigAny,
+    named_modules: dict[str, torch.nn.Module],
+    prepare_custom_config: PrepareCustomConfig,
+):
+    custom_module = named_modules[node.target]  # type: ignore[index]
+    custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
+    observed_custom_module_class = get_swapped_custom_module_class(
+        custom_module, custom_module_class_mapping, qconfig
+    )
+    observed_custom_module = observed_custom_module_class.from_float(custom_module)
+    parent_name, name = _parent_name(node.target)
+    setattr(named_modules[parent_name], name, observed_custom_module)
+
+
+def insert_observers_for_model(
+    model: GraphModule,
+    node_name_to_match_result_with_qconfig: dict[str, _MatchResultWithQConfig],
+    node_name_to_qconfig: dict[str, QConfigAny],
+    prepare_custom_config: PrepareCustomConfig,
+    equalization_config_map: dict[str, Any],
+    backend_config: BackendConfig,
+    observed_node_names: set[str],
+    is_qat: bool,
+) -> Optional[Node]:
+    """
+    Inserts observers, using the following high level algorithm:
+
+    For each node in the graph:
+      1. determine the target dtype of this node in the quantized graph, and save
+           it for future steps
+      2. determine the target dtype or all args and kwargs of this node
+      3. if any arg or kwarg's target dtype does not match the current node's
+           dtype, insert an observer
+      4. if the current node needs an output observer, insert it
+
+    For example:
+
+    - starting graph:
+        x0 -> linear -> x1
+
+    - observed graph after processing x0:
+        x0(fp32)
+
+    - observed graph after processing linear:
+        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
+
+    - observed graph after processing x1:
+        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
+
+    After a node is processed, the naive observer placement is guaranteed to be
+    complete for that node and all of its predecessors. There can be future
+    passes which optimize the graph by deduplicating observers, etc.
+    """
+
+    # node.meta["target_dtype_info"] stores the target dtype information
+    # that's derived from qconfig for the Node, for example, if we have
+    # a conv2d node that has a qconfig
+    # qconfig = QConfig(activation=..., weight=...)
+    # # information for input and bias node omitted
+    # # for getattr node
+    # # weight = getattr(self, 'weight')
+    # weight.meta["target_dtype_info"] = {
+    #    'output_act_obs_or_fq_ctr': qconfig.weight,
+    # }
+    # # for conv2d node
+    # # conv2d = call_function[target=torch.nn.functional.conv2d](
+    # #            args=(input, weight, bias))
+    # conv2d.meta["target_dtype_info"] = {
+    #   'input_act_obs_or_fq_ctr': qconfig.activation
+    #   'weight_obs_or_fq_ctr': qconfig.weight,
+    #   'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
+    #   'output_act_obs_or_fq_ctr': qconfig.activation,
+    # }
+    #
+    cache_for_no_tensor_check: dict[Node, bool] = {}
+
+    # first, populate the dtype map based only on qconfig and qhandler
+    # this assumes:
+    # graph inputs are fp32 by default, and int8 where overridden
+    # other nodes output dtype is specified by the qconfig
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+
+    input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
+    output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes
+    processed_nodes: set[Node] = set()
+    # initialize target_dtype_info
+    for node in model.graph.nodes:
+        node.meta["target_dtype_info"] = copy.copy(
+            _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO
+        )
+
+    inputs_seen_counter = 0
+    outputs_seen_counter = 0
+    placeholder_node_to_input_index: dict[Node, int] = {}
+    # TODO: we probably don't need this counter since each graph will only have
+    # one output node?
+    output_node_to_output_index: dict[Node, int] = {}
+    for node in model.graph.nodes:
+        if node.op == "placeholder":
+            placeholder_node_to_input_index[node] = inputs_seen_counter
+            inputs_seen_counter += 1
+        if node.op == "output":
+            output_node_to_output_index[node] = outputs_seen_counter
+            outputs_seen_counter += 1
+
+    # Step 1, set the observer or fake quantize module constructor for each node in the
+    # matched_node_pattern
+
+    for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
+        (
+            last_node,
+            matched_node_pattern,
+            pattern,
+            qhandler,
+            qconfig,
+        ) = match_res_with_qconfig
+        assert qhandler is not None
+        _set_target_dtype_info_for_matched_node_pattern(
+            matched_node_pattern,
+            last_node,
+            qconfig,
+            qhandler,
+            backend_config,
+            named_modules,
+            cache_for_no_tensor_check,
+            processed_nodes,
+        )
+
+    # Step 2. Special cases for some operators, we might be able to remove them
+    # in the future if we know dtype information of each node better
+
+    # Step 2.1. some settings are not based on patterns, we need to process each node
+    # instead
+    for node in model.graph.nodes:
+        if (
+            node.op == "placeholder"
+            and placeholder_node_to_input_index[node] in input_quantized_idxs
+        ):
+            # users are not supposed to call calculate_qparams on PlaceholderObserver, and
+            # this is OK because we are using this as a way to encode the dtypes of input
+            # tensor, we won't actually insert these observers in the graph and won't
+            # actually call calculate_qparams
+            node.meta["target_dtype_info"] = copy.copy(
+                _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO
+            )
+        elif node.op in ("call_module", "call_method", "call_function"):
+            args_have_no_tensors = all_node_args_have_no_tensors(
+                node, named_modules, cache_for_no_tensor_check
+            )
+            if args_have_no_tensors:
+                node.meta["target_dtype_info"] = {
+                    "input_act_obs_or_fq_ctr": None,
+                    "output_act_obs_or_fq_ctr": None,
+                }
+        elif (
+            node.op == "output"
+            and output_node_to_output_index[node] in output_quantized_idxs
+        ):
+            # TODO(future PR): update the output_quantized_idxs API to match
+            # arbitrary data structures. There is always a single output, and
+            # that output can have arbitrary nesting of values. List[int] is
+            # not the right data type for this.
+
+            # TODO(future PR): support more dtypes in model outputs, if necessary
+            node.meta["target_dtype_info"] = copy.copy(
+                _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO
+            )
+
+    # Step 2.2, for nodes with known input dtypes, propagate them throughout the
+    # graph. For example, if there is a call such as
+    #   x1 = x0.masked_fill(mask, 1)
+    # we propagate the type of mask to be torch.bool
+    propagate_dtypes_for_known_nodes(
+        model.graph, node_name_to_match_result_with_qconfig
+    )
+
+    # Step 3, check if the requested target_dtype_info is supported by backend or not
+    # if not, we'll reset the target_dtye_info to use the default (float Tensor)
+
+    # reset the counters and set of processed_nodes
+    processed_nodes: set[Node] = set()
+    for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
+        (
+            last_node,
+            matched_node_pattern,
+            pattern,
+            qhandler,
+            qconfig,
+        ) = match_res_with_qconfig
+        is_supported_by_backend = (
+            _is_pattern_dtype_config_and_qconfig_supported_by_backend(
+                pattern, matched_node_pattern, qconfig, backend_config
+            )
+        )
+        assert qhandler is not None
+
+        # get output_act_dtype so that we don't also reset the special typed nodes
+        # TODO: we might want to handle these more uniformly with the default path
+        # this can be improved if we can use node.meta["val"]
+        output_act_or_fq_ctr = node.meta["target_dtype_info"][
+            "output_act_obs_or_fq_ctr"
+        ]
+        output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None
+        output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq)
+        if not is_supported_by_backend and output_act_dtype not in [
+            None,
+            int,
+            float,
+            torch.bool,
+        ]:
+            # restore target_dtype_info to default if it is not supported by backend
+            _set_target_dtype_info_for_matched_node_pattern(
+                matched_node_pattern,
+                last_node,
+                torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
+                None,
+                backend_config,
+                named_modules,
+                cache_for_no_tensor_check,
+                processed_nodes,
+            )
+
+    # After this point, the current node and all of its arguments
+    # have a target_dtype_info assigned. Now, we insert observers for inputs
+    # of this node (if needed for this node), and the output of this node
+    # (if needed for this node).
+
+    # Since we are mutating the graph as we go, we iterate over the original
+    # nodes before observer insertion, instead of model.graph.nodes.
+    nodes_before_observation = list(model.graph.nodes)
+
+    # Avoid duplicates custom module swaps for multiple nodes with same target.
+    custom_module_names_already_swapped: set[str] = set()
+
+    # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
+    # reset inputs/outputs counters
+    inputs_seen_counter = 0
+    outputs_seen_counter = 0
+    results_node = None
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
+
+    # TODO: change this to insert obs/fq by pattern instead of by node
+    for node in nodes_before_observation:
+        if node.op == "placeholder":
+            # if a graph input is in fp32, it does not need observation
+            # if a graph input is in int8, we assume the observation happens
+            #   outside of the graph, and no additional observation is needed
+            pass
+
+        elif node.op in ("call_module", "call_method", "call_function", "output"):
+            # check for matches
+            (
+                last_node,
+                matched_node_pattern,
+                pattern,
+                qhandler,
+                qconfig,
+            ) = node_name_to_match_result_with_qconfig.get(  # type: ignore[assignment]
+                node.name, (None, None, None, None, None)
+            )
+            equalization_qconfig = equalization_config_map.get(node.name, None)
+
+            this_node_dtype_info = node.meta["target_dtype_info"]
+            if "val" in node.meta:
+                output_is_a_tensor = this_node_dtype_info is not None and isinstance(
+                    node.meta["val"], FakeTensor
+                )
+            else:
+                output_is_a_tensor = this_node_dtype_info is not None
+
+            skip_inserting_observers = (
+                (qconfig is None) or not output_is_a_tensor
+            ) and (not node.op == "output")
+
+            # TODO: take a closer look to see if we can remove this check
+            # right now it is here because of `observed_node_names`, we are using
+            # it as an indicator for swapping the modules to reference modules in
+            # convert
+            is_supported_by_backend = (
+                _is_pattern_dtype_config_and_qconfig_supported_by_backend(
+                    pattern, matched_node_pattern, qconfig, backend_config
+                )
+            )
+
+            if not skip_inserting_observers and is_supported_by_backend:
+                named_modules = dict(model.named_modules(remove_duplicate=False))
+                if node.op != "output":
+                    assert matched_node_pattern is not None
+                    # add matched nodes to the observed node name set
+                    _add_matched_node_name_to_set(
+                        matched_node_pattern, observed_node_names
+                    )
+
+                    # This is currently only used for equalization.
+                    # Checks if the current node is in a branch in which the two
+                    # first layers are both being quantized.
+                    #
+                    # ex.       conv2
+                    #         /
+                    #      x -> conv1
+                    #
+                    # If this is the case, we will not apply equalization to the
+                    # initial two layers.
+                    is_quantized_branch = False
+                    if (
+                        len(node.args) > 0
+                        and isinstance(node.args[0], Node)
+                        and len(node.args[0].users) > 1
+                    ):
+                        for user in node.args[0].users:
+                            # Checks if there exists another user being quantized
+                            is_user_quantized = node_name_to_qconfig.get(
+                                user.name, None
+                            ) is not None or (
+                                user.op == "call_module"
+                                and isinstance(
+                                    named_modules[str(user.target)], ObserverBase
+                                )
+                            )
+                            if user != node and is_user_quantized:
+                                is_quantized_branch = True
+
+                    pattern_to_root_node_getter = (
+                        get_fusion_pattern_to_root_node_getter(backend_config)
+                    )
+                    root_node_getter = pattern_to_root_node_getter.get(
+                        pattern, _default_root_node_getter
+                    )
+                    root_node = root_node_getter(matched_node_pattern)
+                    is_input_node_of_the_pattern = node is root_node
+                    if is_input_node_of_the_pattern:
+                        # this modifies node inplace
+                        _maybe_insert_input_observers_for_node(
+                            node,
+                            qconfig,
+                            model,
+                            named_modules,
+                            model.graph,
+                            qhandler,
+                            prepare_custom_config,
+                            obs_or_fq_map,
+                            is_qat,
+                            backend_config,
+                        )
+
+                        # insert equalization input observers if needed
+                        _maybe_insert_input_equalization_observers_for_node(
+                            node,
+                            equalization_qconfig,
+                            model,
+                            named_modules,
+                            model.graph,
+                            is_quantized_branch,
+                        )
+
+                    is_last_node_of_pattern = node is last_node
+                    input_output_share_observers = node.meta["target_dtype_info"].get(
+                        "input_output_share_observers", False
+                    )
+                    reuse_input_obs_or_fq = node.meta["target_dtype_info"].get(
+                        "reuse_input_obs_or_fq", False
+                    )
+
+                    if is_last_node_of_pattern:
+                        if _is_custom_module_lstm(
+                            node, named_modules, qconfig, qhandler
+                        ):
+                            # Currently custom module outputs are assumed to be already quantized,
+                            # so we need to insert a DeQuantStub after the output. For custom module
+                            # LSTM specifically, the outputs are also a nested tuple, so we must first
+                            # break down the tuple to insert DeQuantStubs after the internal nodes.
+
+                            # TODO: This currently diverges from how custom modules are handled today,
+                            # where we insert observers after the output instead of DeQuantStubs, and
+                            # replace these observers with "dequantize" nodes during convert. Conceptually,
+                            # these output observers are the same as DeQuantStubs. In the future, we
+                            # should resolve this inconsistency by inserting DeQuantStubs for all custom
+                            # modules, not just for LSTM.
+                            _insert_dequant_stubs_for_custom_module_lstm_output(
+                                node, model, named_modules, model.graph
+                            )
+                            if node.target not in custom_module_names_already_swapped:
+                                custom_module_names_already_swapped.add(node.target)
+                                _swap_custom_module_to_observed(
+                                    node, qconfig, named_modules, prepare_custom_config
+                                )
+                        else:
+                            # this returns the new observer node if it was needed
+                            maybe_output_obs_node = (
+                                _maybe_insert_output_observer_for_node(
+                                    node,
+                                    model,
+                                    named_modules,
+                                    model.graph,
+                                    obs_or_fq_map,
+                                    is_qat,
+                                )
+                            )
+
+                            if maybe_output_obs_node is not None:
+                                # Update users of original node to use the output observer
+                                # instead. For example, change
+                                #
+                                #           next_node
+                                #          /
+                                #   cur_node -> obs
+                                #
+                                # to
+                                #
+                                #                 next_node
+                                #                 /
+                                #   cur_node -> obs
+                                #
+                                # We need to save orig users before updating uses because
+                                # the list of users will change as we update uses
+                                orig_users = list(node.users.keys())
+                                for user_node in orig_users:
+                                    if user_node is maybe_output_obs_node:
+                                        continue
+                                    user_node.replace_input_with(
+                                        node, maybe_output_obs_node
+                                    )
+
+                                _is_observer_in_same_graph_ = (
+                                    _is_observer_in_same_graph(
+                                        node, named_modules, obs_or_fq_map, is_qat
+                                    )
+                                )
+
+                                # for ops whose inputs and outputs share observer/fqs, we modify the graph
+                                # to make all inputs and outputs use the first input's
+                                # observer/fq
+                                if (
+                                    input_output_share_observers
+                                    and _is_observer_in_same_graph_
+                                ) or reuse_input_obs_or_fq:
+                                    if not _maybe_make_input_output_share_observers(
+                                        node, model, named_modules
+                                    ):
+                                        _remove_output_observer(
+                                            node, model, named_modules
+                                        )
+
+                                if qhandler is not None and qhandler.is_custom_module():
+                                    if (
+                                        node.target
+                                        not in custom_module_names_already_swapped
+                                    ):
+                                        custom_module_names_already_swapped.add(
+                                            node.target
+                                        )
+                                        _swap_custom_module_to_observed(
+                                            node,
+                                            qconfig,
+                                            named_modules,
+                                            prepare_custom_config,
+                                        )
+
+                else:  # output
+                    _maybe_insert_observers_before_graph_output(
+                        node, model, named_modules, model.graph, obs_or_fq_map, is_qat
+                    )
+
+        #
+        # After this point, the current node has input and output observers
+        # that it needs for itself inserted.
+        #
+
+        # increment the counters, so future inputs and outputs are assigned
+        # correct dtypes
+        if node.op == "placeholder":
+            inputs_seen_counter += 1
+        elif node.op == "output":
+            outputs_seen_counter += 1
+            results_node = node
+
+    return results_node
+
+
+def _run_prepare_fx_on_standalone_modules(
+    model: torch.nn.Module,
+    is_qat: bool,
+    named_modules: dict[str, torch.nn.Module],
+    node_name_to_match_result_with_qconfig: Any,
+    prepare_custom_config: PrepareCustomConfig,
+    backend_config: BackendConfig,
+) -> None:
+    """
+    Runs prepare_fx on each standalone module. Note: this does
+    not modify the graph, it just replaces the unobserved modules with
+    their observed versions.
+    """
+    for (
+        root_node,
+        _,
+        _pattern,
+        qhandler,
+        qconfig,
+    ) in node_name_to_match_result_with_qconfig.values():
+        if qhandler is None:
+            continue
+        elif not qhandler.is_standalone_module():
+            continue
+
+        (
+            sm_qconfig_mapping,
+            sm_example_inputs,
+            sm_prepare_custom_config,
+            sm_backend_config,
+        ) = _get_standalone_module_configs(
+            root_node, named_modules, prepare_custom_config, qconfig, backend_config
+        )
+
+        standalone_module = named_modules[root_node.target]
+        prepare = torch.ao.quantization.quantize_fx._prepare_standalone_module_fx  # type: ignore[attr-defined]
+        observed_standalone_module = prepare(
+            standalone_module,
+            sm_qconfig_mapping,
+            is_qat,
+            example_inputs=sm_example_inputs,
+            prepare_custom_config=sm_prepare_custom_config,
+            backend_config=sm_backend_config,
+        )
+        parent_name, name = _parent_name(root_node.target)
+        setattr(named_modules[parent_name], name, observed_standalone_module)
+        named_modules[root_node.target] = observed_standalone_module
+
+
+def _save_state(
+    observed: GraphModule,
+    node_name_to_qconfig: dict[str, QConfigAny],
+    node_name_to_scope: dict[str, tuple[str, type]],
+    prepare_custom_config: PrepareCustomConfig,
+    equalization_node_name_to_qconfig: dict[str, Any],
+    qconfig_mapping: QConfigMapping,
+    is_qat: bool,
+    observed_node_names: set[str],
+) -> None:
+    observed.meta["_observed_graph_module_attrs"] = ObservedGraphModuleAttrs(
+        node_name_to_qconfig=node_name_to_qconfig,
+        node_name_to_scope=node_name_to_scope,
+        prepare_custom_config=prepare_custom_config,
+        equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
+        qconfig_mapping=qconfig_mapping,
+        is_qat=is_qat,
+        observed_node_names=observed_node_names,
+    )
+
+
+def prepare(
+    model: GraphModule,
+    qconfig_mapping: Union[QConfigMapping, dict[str, Any]],
+    is_qat: bool,
+    node_name_to_scope: dict[str, tuple[str, type]],
+    example_inputs: tuple[Any, ...],
+    prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None,
+    _equalization_config: Union[QConfigMapping, dict[str, Any], None] = None,
+    backend_config: Union[BackendConfig, dict[str, Any], None] = None,
+    is_standalone_module: bool = False,
+) -> GraphModule:
+    """standalone_module means it a submodule that is not inlined in
+    parent module, and will be quantized separately as one unit.
+
+    How the standalone module is observed is specified by `input_quantized_idxs` and
+    `output_quantized_idxs` in the prepare_custom_config for the standalone module
+    Args:
+        node_name_to_scope: mapping from node name to the scope of the module which contains the node.
+        The scope is a tuple of fully qualified path of the module and the type of the module
+    Returns:
+        model(GraphModule): prepared standalone module
+        attributes related to standalone module
+        in model.meta["_observed_graph_module_attrs"]:
+            is_observed_standalone_module (bool): boolean value that shows whether the
+            current model is a observed standalone module or not
+            standalone_module_input_quantized_idxs(List[Int]): a list of
+                indexes for the graph input that is expected to be quantized,
+                same as input_quantized_idxs configuration provided
+                for the standalone module
+            standalone_module_output_quantized_idxs(List[Int]): a list of
+                indexs for the graph output that is quantized
+                same as input_quantized_idxs configuration provided
+                for the standalone module
+    """
+    if prepare_custom_config is None:
+        prepare_custom_config = PrepareCustomConfig()
+    if _equalization_config is None:
+        _equalization_config = QConfigMapping()
+
+    if isinstance(qconfig_mapping, dict):
+        warnings.warn(
+            "Passing a QConfig dictionary to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a QConfigMapping instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
+
+    if isinstance(_equalization_config, dict):
+        warnings.warn(
+            "Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
+            "be supported in a future version. Please pass in a QConfigMapping instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        _equalization_config = QConfigMapping.from_dict(_equalization_config)
+
+    if isinstance(prepare_custom_config, dict):
+        warnings.warn(
+            "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a PrepareCustomConfig instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
+
+    if isinstance(backend_config, dict):
+        warnings.warn(
+            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
+            "in a future version. Please pass in a BackendConfig instead.",
+            FutureWarning,
+            stacklevel=2,
+        )
+        backend_config = BackendConfig.from_dict(backend_config)
+
+    assert isinstance(qconfig_mapping, QConfigMapping)
+    assert isinstance(_equalization_config, QConfigMapping)
+    qconfig_mapping = copy.deepcopy(qconfig_mapping)
+    _equalization_config = copy.deepcopy(_equalization_config)
+
+    # mapping from a tuple of nodes in reverse order to uninitialized
+    #   QuantizeHandler subclass. For example,
+    # {
+    #   # match a single node
+    #   (:
+    #     ),
+    #   # match multiple nodes in reverse order
+    #   ((, ):
+    #     ),
+    # }
+
+    pattern_to_quantize_handler: dict[Pattern, QuantizeHandler] = {}
+    if backend_config is None:
+        backend_config = get_native_backend_config()
+    pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
+    pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
+
+    root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
+
+    _update_qconfig_for_fusion(model, qconfig_mapping)
+    _update_qconfig_for_fusion(model, _equalization_config)
+    flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
+    # TODO: support regex as well
+    propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
+
+    if is_qat:
+        module_to_qat_module = get_module_to_qat_module(backend_config)
+        _qat_swap_modules(model, module_to_qat_module)
+        _update_qconfig_for_qat(qconfig_mapping, backend_config)
+
+    # mapping from fully qualified module name to module instance
+    # for example,
+    # {
+    #   '': Model(...),
+    #   'linear': Linear(...),
+    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
+    # }
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+
+    # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
+    equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
+        model, named_modules, model.graph, _equalization_config, node_name_to_scope
+    )
+    node_name_to_qconfig = _generate_node_name_to_qconfig(
+        model, named_modules, model.graph, qconfig_mapping, node_name_to_scope
+    )
+
+    # match the patterns that will get quantized
+    standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
+    standalone_module_classes = list(
+        prepare_custom_config.standalone_module_classes.keys()
+    )
+
+    custom_module_classes = get_custom_module_class_keys(
+        prepare_custom_config.float_to_observed_mapping
+    )
+    matches_without_qconfig = _find_matches(
+        model.graph,
+        named_modules,
+        pattern_to_quantize_handler,
+        root_node_getter_mapping,
+        standalone_module_names,
+        standalone_module_classes,
+        custom_module_classes,
+    )
+
+    # map qconfig instances to matches
+    node_name_to_match_result_with_qconfig = {}
+    for node_name, match_without_qconfig in matches_without_qconfig.items():
+        match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
+        node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
+
+    _run_prepare_fx_on_standalone_modules(
+        model,
+        is_qat,
+        named_modules,
+        node_name_to_match_result_with_qconfig,
+        prepare_custom_config,
+        backend_config,
+    )
+
+    # record names for the set of observed node, so that in convert step
+    # we know whether we need to convert a floating point module to reference
+    # quantized module or not
+    observed_node_names: set[str] = set()
+
+    result_node = insert_observers_for_model(
+        model,
+        node_name_to_match_result_with_qconfig,
+        node_name_to_qconfig,
+        prepare_custom_config,
+        equalization_node_name_to_qconfig,
+        backend_config,
+        observed_node_names,
+        is_qat,
+    )
+    model = GraphModule(model, model.graph)
+
+    _save_state(
+        model,
+        node_name_to_qconfig,
+        node_name_to_scope,
+        prepare_custom_config,
+        equalization_node_name_to_qconfig,
+        qconfig_mapping,
+        is_qat,
+        observed_node_names,
+    )
+
+    if is_standalone_module:
+        assert result_node is not None
+        assert isinstance(result_node.args[0], Node), (
+            "standalone module only supports returning simple value currently"
+            "(not tuple, dict etc.)"
+        )
+        # these inputs are observed in parent
+        # converting List[int] to Tensor since module attribute is
+        # Union[Tensor, Module]
+        input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
+        output_quantized_idxs: list[int] = (
+            prepare_custom_config.output_quantized_indexes
+        )
+        observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
+        # inplace modification
+        observed_graph_module_attrs.is_observed_standalone_module = True
+        observed_graph_module_attrs.standalone_module_input_quantized_idxs = (
+            input_quantized_idxs
+        )
+        observed_graph_module_attrs.standalone_module_output_quantized_idxs = (
+            output_quantized_idxs
+        )
+    return model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..421e6d4b8eba6b893eb393d4960e85d625d3cdec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py
@@ -0,0 +1,400 @@
+# mypy: allow-untyped-defs
+import re
+from collections import defaultdict, OrderedDict
+from typing import Any, Callable, Union
+
+import torch
+from torch.ao.nn.intrinsic import _FusedModule
+from torch.ao.quantization import QConfig
+from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig
+from torch.ao.quantization.backend_config.utils import get_module_to_qat_module
+from torch.ao.quantization.observer import _is_activation_post_process
+from torch.ao.quantization.qconfig import (
+    _add_module_to_qconfig_obs_ctr,
+    qconfig_equals,
+    QConfigAny,
+)
+from torch.ao.quantization.qconfig_mapping import (
+    _MODULE_NAME_DICT_KEY,
+    _MODULE_NAME_REGEX_DICT_KEY,
+    _OBJECT_TYPE_DICT_KEY,
+    QConfigMapping,
+)
+from torch.ao.quantization.utils import _parent_name, get_qconfig_dtypes
+from torch.fx import GraphModule
+from torch.fx.graph import Graph
+
+
+__all__: list[str] = []
+
+
+def _maybe_adjust_qconfig_for_module_name_object_type_order(
+    qconfig_mapping: QConfigMapping,
+    cur_module_path: str,
+    cur_object_type: Callable,
+    cur_object_type_idx: int,
+    fallback_qconfig: QConfigAny,
+) -> QConfigAny:
+    for (
+        module_name,
+        object_type,
+        index,
+    ), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items():
+        if (
+            (module_name == cur_module_path)
+            and (object_type == cur_object_type)
+            and (index == cur_object_type_idx)
+        ):
+            return qconfig
+    return fallback_qconfig
+
+
+def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
+    """
+    Update the QConfigMapping to account for fused modules such as LinearReLU.
+    This assumes the QConfigMapping's attributes have already been converted to OrderedDicts.
+    """
+    object_type_dict = qconfig_mapping.object_type_qconfigs
+    if len(object_type_dict) == 0:
+        return qconfig_mapping
+
+    modules = dict(model.named_modules())
+
+    for node in model.graph.nodes:
+        if node.op == "call_module" and node.target in modules:
+            maybe_fused_module = modules[str(node.target)]
+            if not isinstance(maybe_fused_module, _FusedModule):
+                continue
+
+            ops = list(maybe_fused_module._modules.values())
+            fused_qconfig = object_type_dict.get(type(ops[0]), None)
+
+            # Raise an error if the modules in the fused module have
+            # different qconfigs specified in the qconfig_dict
+            # TODO: currently it only works for modules,
+            # need to make this work for torch.nn.functional.relu
+            # TODO: currently it only works for object_type configurations,
+            # ideally it should work for different types of configurations,
+            # maybe we want to redesign this part
+            for op in ops[1:]:
+                if not qconfig_equals(
+                    object_type_dict.get(type(op), None), fused_qconfig
+                ):
+                    raise LookupError(
+                        "During fusion, we need to specify the same "
+                        + f"qconfigs for all module types in {type(maybe_fused_module)} "
+                        + f"offending type: {type(op)}"
+                    )
+
+            if fused_qconfig is not None:
+                object_type_dict[type(maybe_fused_module)] = fused_qconfig
+
+
+def _generate_node_name_to_qconfig(
+    root: torch.nn.Module,
+    modules: dict[str, torch.nn.Module],
+    input_graph: Graph,
+    qconfig_mapping: QConfigMapping,
+    node_name_to_scope: dict[str, tuple[str, type]],
+) -> dict[str, QConfigAny]:
+    global_qconfig = qconfig_mapping.global_qconfig
+    node_name_to_qconfig = {}
+
+    # example:
+    #
+    #   {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...}
+    #
+    # meaning in submodule 'foo.bar', we have seen 0 F.linear and
+    # 1 F.conv2d invocations so far.
+    submodule_to_object_type_to_cur_idx: dict[str, dict[Callable, int]] = defaultdict(
+        lambda: defaultdict(int)
+    )
+    for node in input_graph.nodes:
+        qconfig = None
+        if node.op == "get_attr":
+            module_name, _ = _parent_name(node.target)
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, type(modules[module_name]), module_name, global_qconfig
+            )
+            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
+                qconfig, modules.get(node.target, None)
+            )
+        elif node.op == "call_function":
+            # precedence: module_name_qconfig
+            # > function_qconfig > global_qconfig
+            # module_name takes precedence over function qconfig
+            function_qconfig = _get_object_type_qconfig(
+                qconfig_mapping, node.target, global_qconfig
+            )
+            module_path, module_type = node_name_to_scope[node.name]
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, module_type, module_path, function_qconfig
+            )
+
+            cur_object_type_idx = submodule_to_object_type_to_cur_idx[module_path][
+                node.target
+            ]
+            submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
+            qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
+                qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig
+            )
+            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
+                qconfig, modules.get(node.target, None)
+            )
+
+        elif node.op == "call_method":
+            module_path, module_type = node_name_to_scope[node.name]
+            # first use node.target (string) to get the qconfig
+            # this is to support configs like
+            # "object_type": [("reshape", qconfig)]
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, node.target, module_path, global_qconfig
+            )
+            # if there is no special config for the method, we'll fall back to the
+            # config for the module that contains the call_method node
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, module_type, module_path, qconfig
+            )
+            # currently call_method does not support modifying qconfig
+            # by order, we can add this later if it is needed.
+            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
+                qconfig, modules.get(node.target, None)
+            )
+
+        elif node.op == "call_module":
+            # if the node is an observer, just continue - don't add it to the qconfig_map
+            if _is_activation_post_process(modules[node.target]):
+                continue
+            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
+                qconfig_mapping, type(modules[node.target]), node.target, global_qconfig
+            )
+
+            module_path, module_type = node_name_to_scope[node.name]
+            # Note: for call_module, the module_path is the current module's name.
+            # to meaningfully count invocations, we need to count them in the parent
+            # module.
+            parent_name, _ = _parent_name(module_path)
+            cur_object_type_idx = submodule_to_object_type_to_cur_idx[parent_name][
+                module_type
+            ]
+            submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
+            qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
+                qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig
+            )
+            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
+                qconfig, modules.get(node.target, None)
+            )
+
+            # regex is not supported eager mode propagate_qconfig_, we'll
+            # need to set the qconfig explicitly here in case regex
+            # is used
+            modules[node.target].qconfig = qconfig_with_device_check
+        else:
+            qconfig_with_device_check = None
+
+        node_name_to_qconfig[node.name] = qconfig_with_device_check
+    return node_name_to_qconfig
+
+
+def _check_is_valid_config_dict(
+    config_dict: Any, allowed_keys: set[str], dict_name: str
+) -> None:
+    r"""Checks if the given config_dict has the correct keys
+
+    Args:
+      `config_dict`: dictionary whose keys we want to check
+    """
+
+    for k in config_dict.keys():
+        if k not in allowed_keys:
+            raise ValueError(
+                "Expected "
+                + dict_name
+                + " to have the following keys: "
+                + str(allowed_keys)
+                + ". But found '"
+                + k
+                + "' instead."
+            )
+
+
+def _compare_prepare_convert_qconfig_mappings(
+    prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping
+):
+    r"""Compare the qconfig_mapping passed in convert to the one from prepare and check the values
+
+    Args:
+      `prepare_qconfig_mapping`: configuration for prepare quantization step
+      `convert_qconfig_mapping`: configuration for convert quantization step
+    """
+    assert qconfig_equals(
+        prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig
+    ), (
+        "Expected global qconfigs to be the same in the prepare and convert quantization configs"
+    )
+    prepare_dicts: list[OrderedDict] = [
+        prepare_qconfig_mapping.object_type_qconfigs,
+        prepare_qconfig_mapping.module_name_qconfigs,
+        prepare_qconfig_mapping.module_name_regex_qconfigs,
+    ]
+    convert_dicts: list[OrderedDict] = [
+        convert_qconfig_mapping.object_type_qconfigs,
+        convert_qconfig_mapping.module_name_qconfigs,
+        convert_qconfig_mapping.module_name_regex_qconfigs,
+    ]
+    dict_names = [
+        _OBJECT_TYPE_DICT_KEY,
+        _MODULE_NAME_DICT_KEY,
+        _MODULE_NAME_REGEX_DICT_KEY,
+    ]
+    for i in range(len(prepare_dicts)):
+        for name in prepare_dicts[i].keys():
+            assert name in convert_dicts[i], (
+                f"Missing key {dict_names[i]} {name} in convert QConfigMapping \
+                when it was present in prepare"
+            )
+            assert convert_dicts[i][name] is None or qconfig_equals(
+                prepare_dicts[i][name], convert_dicts[i][name]
+            ), (
+                f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \
+                prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
+            )
+
+
+def _is_qconfig_supported_by_dtype_configs(
+    qconfig: QConfig, dtype_configs: list[DTypeConfig]
+):
+    for dtype_config in dtype_configs:
+        is_dynamic = dtype_config.is_dynamic
+        if is_dynamic is None:
+            is_dynamic = False
+        input_dtype = dtype_config.input_dtype or torch.float
+        weight_dtype = dtype_config.weight_dtype or torch.float
+        bias_dtype = dtype_config.bias_dtype or torch.float
+        output_dtype = dtype_config.output_dtype or torch.float
+        (
+            qconfig_activation_dtype,
+            qconfig_weight_dtype,
+            qconfig_input_act_is_dynamic,
+        ) = get_qconfig_dtypes(qconfig)
+        qconfig_bias_dtype = (
+            torch.float16
+            if (
+                qconfig_activation_dtype == torch.float16
+                and qconfig_weight_dtype == torch.float16
+                and not is_dynamic
+            )
+            else torch.float
+        )
+
+        if is_dynamic:
+            is_match = (
+                qconfig_input_act_is_dynamic
+                and input_dtype == qconfig_activation_dtype
+                and output_dtype == torch.float
+                and weight_dtype == qconfig_weight_dtype
+            )
+        else:
+            is_match = (
+                input_dtype == qconfig_activation_dtype
+                and output_dtype == qconfig_activation_dtype
+                and weight_dtype == qconfig_weight_dtype
+                and bias_dtype == qconfig_bias_dtype
+            )
+        if is_match:
+            return True
+    return False
+
+
+def _get_object_type_qconfig(
+    qconfig_mapping: QConfigMapping,
+    object_type: Union[Callable, str],
+    fallback_qconfig: QConfigAny,
+) -> QConfigAny:
+    return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
+
+
+def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
+    for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
+        if re.match(regex_pattern, module_name):
+            # first match wins
+            return qconfig
+    return fallback_qconfig
+
+
+def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
+    if module_name == "":
+        # module name qconfig not found
+        return fallback_qconfig
+    if module_name in qconfig_mapping.module_name_qconfigs:
+        return qconfig_mapping.module_name_qconfigs[module_name]
+    else:
+        parent, _ = _parent_name(module_name)
+        return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
+
+
+def _maybe_adjust_qconfig_for_module_type_or_name(
+    qconfig_mapping, module_type, module_name, global_qconfig
+):
+    # get qconfig for module_name,
+    # fallback to module_name_regex_qconfig, module_type_qconfig,
+    # global_qconfig if necessary
+    module_type_qconfig = _get_object_type_qconfig(
+        qconfig_mapping, module_type, global_qconfig
+    )
+    module_name_regex_qconfig = _get_module_name_regex_qconfig(
+        qconfig_mapping, module_name, module_type_qconfig
+    )
+    module_name_qconfig = _get_module_name_qconfig(
+        qconfig_mapping, module_name, module_name_regex_qconfig
+    )
+    return module_name_qconfig
+
+
+def _get_flattened_qconfig_dict(
+    qconfig_mapping: QConfigMapping,
+) -> dict[Union[Callable, str], QConfigAny]:
+    """flatten the global, object_type and module_name qconfig
+    to the same qconfig_dict so that it can be used by
+    propagate_qconfig_ function.
+    "module_name_regex" is ignored for now since it's not supported
+    in propagate_qconfig_, but it can be fixed later.
+
+    For example:
+    Input: {
+      "": qconfig,
+      "object_type": [
+        (torch.add, qconfig)
+      ],
+      "module_name": [
+        ("conv", qconfig)
+      ]
+    }
+
+    Output: {
+      "": qconfig,
+      torch.add: qconfig,
+      "conv": qconfig
+    }
+    """
+    flattened: dict[Union[Callable, str], QConfigAny] = {
+        "": qconfig_mapping.global_qconfig
+    }
+    flattened.update(qconfig_mapping.object_type_qconfigs)
+    flattened.update(qconfig_mapping.module_name_qconfigs)  # type: ignore[arg-type]
+    return flattened
+
+
+def _update_qconfig_for_qat(
+    qconfig_mapping: QConfigMapping, backend_config: BackendConfig
+):
+    """
+    Update the qconfig_mapping to account for module swaps during QAT.
+    During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
+    """
+    module_to_qat_module_class = get_module_to_qat_module(backend_config)
+    object_type_dict = qconfig_mapping.object_type_qconfigs
+    new_object_type_dict = object_type_dict.copy()
+    for k, v in new_object_type_dict.items():
+        if k in module_to_qat_module_class:
+            object_type_dict[module_to_qat_module_class[k]] = v
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a285a58814babcbc9b6b69a052bac15d2709924c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/quantize_handler.py
@@ -0,0 +1,225 @@
+# mypy: allow-untyped-defs
+from abc import ABC
+from typing import Callable, Optional
+
+import torch
+from torch.ao.quantization.backend_config import (
+    BackendConfig,
+    DTypeConfig,
+    ObservationType,
+)
+from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls
+from torch.fx.graph import Node
+
+from .utils import all_node_args_have_no_tensors
+
+
+__all__ = [
+    "QuantizeHandler",
+    "BinaryOpQuantizeHandler",
+    "CatQuantizeHandler",
+    "ConvReluQuantizeHandler",
+    "LinearReLUQuantizeHandler",
+    "BatchNormQuantizeHandler",
+    "EmbeddingQuantizeHandler",
+    "RNNDynamicQuantizeHandler",
+    "DefaultNodeQuantizeHandler",
+    "FixedQParamsOpQuantizeHandler",
+    "CopyNodeQuantizeHandler",
+    "GeneralTensorShapeOpQuantizeHandler",
+    "CustomModuleQuantizeHandler",
+    "StandaloneModuleQuantizeHandler",
+]
+
+
+def _default_root_node_getter(node_pattern):
+    if node_pattern is None:
+        return node_pattern
+    while not isinstance(node_pattern, Node):
+        node_pattern = node_pattern[-1]
+    return node_pattern
+
+
+# Base Pattern Handler
+class QuantizeHandler(ABC):  # noqa: B024
+    """Base handler class for the quantizer patterns"""
+
+    def __init__(
+        self,
+        node_pattern: NodePattern,
+        modules: dict[str, torch.nn.Module],
+        root_node_getter: Optional[Callable] = None,
+        is_custom_module=False,
+        is_standalone_module=False,
+    ):
+        """Records pattern information in __init__, which will be used
+        in convert
+        """
+        self.node_pattern = node_pattern
+        self.modules = modules
+        if root_node_getter is None:
+            root_node_getter = _default_root_node_getter
+        self.root_node = root_node_getter(node_pattern)
+        self.is_custom_module_ = is_custom_module
+        self.is_standalone_module_ = is_standalone_module
+        self.num_tensor_args = 0
+        # determine how many of the first two args are Tensors (versus scalars)
+        # this distinguishes things like "x + y" from "x + 2" or "2 + x"
+        if isinstance(self.root_node, Node):
+            cache_for_no_tensor_check: dict[Node, bool] = {}
+            for arg_idx in range(len(self.root_node.args)):
+                arg = self.root_node.args[arg_idx]
+                if isinstance(arg, Node) and (
+                    not all_node_args_have_no_tensors(
+                        arg, self.modules, cache_for_no_tensor_check
+                    )
+                ):
+                    self.num_tensor_args += 1
+
+    def is_general_tensor_value_op(self) -> bool:
+        """
+        Returns True if the operator works for both floating point and
+        quantized input, and does some computation based on the input Tensor,
+        or the ops that only re-arranges the Tensor values or query some metadata
+        about the Tensor
+        so we need to insert observer/fake_quant for the output of the
+        operator (same observer instance as input)
+        since the distribution of values is different for input and output
+        Tensors (for HistogramObserver) while they share the same quantization
+        parameters
+        Example operator: avgpool2d, reshape, transpose, maxpool2d
+        Example observed operator:
+        observer_0 - avgpool2d - observer_0 (same observer instance as input)
+        """
+        return False
+
+    def is_custom_module(self):
+        return self.is_custom_module_
+
+    def is_standalone_module(self):
+        return self.is_standalone_module_
+
+
+def _get_quantize_handler_cls(
+    observation_type: ObservationType,
+    dtype_configs: list[DTypeConfig],
+    num_tensor_args_to_observation_type: dict[int, ObservationType],
+) -> type[QuantizeHandler]:
+    """
+    Return a configurable QuantizeHandler that matches the given specifications from the backend.
+    """
+
+    class ConfigurableQuantizeHandler(QuantizeHandler):
+        def __init__(
+            self,
+            node_pattern: NodePattern,
+            modules: dict[str, torch.nn.Module],
+            root_node_getter: Optional[Callable] = None,
+        ):
+            super().__init__(node_pattern, modules, root_node_getter)
+            if num_tensor_args_to_observation_type:
+                assert self.num_tensor_args in num_tensor_args_to_observation_type, (
+                    f"Must provide observation_type config for tensor number {self.num_tensor_args}"
+                    f" in num_tensor_args_to_observation_type for {node_pattern}"
+                )
+                self.observation_type = num_tensor_args_to_observation_type[
+                    self.num_tensor_args
+                ]
+            else:
+                self.observation_type = observation_type
+            self.dtype_configs = dtype_configs
+
+        def is_general_tensor_value_op(self) -> bool:
+            return (
+                self.observation_type
+                == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
+            )
+
+    return ConfigurableQuantizeHandler
+
+
+def _get_pattern_to_quantize_handlers(
+    backend_config: BackendConfig,
+) -> dict[Pattern, QuantizerCls]:
+    """
+    Note: Quantize handler is just a holder for some check methods like
+    (should_insert_observer_for_output), maybe this can be a enum as well,
+    we can refactor this after we convert the path for fbgemm/qnnpack fully to the
+    new path, this is not exposed to backend developers
+    """
+    pattern_to_quantize_handlers = {}
+    for pattern, config in backend_config._pattern_complex_format_to_config.items():
+        observation_type = config.observation_type
+        dtype_configs = config.dtype_configs
+        num_tensor_args_to_observation_type = (
+            config._num_tensor_args_to_observation_type
+        )
+        pattern_to_quantize_handlers[pattern] = _get_quantize_handler_cls(
+            observation_type, dtype_configs, num_tensor_args_to_observation_type
+        )
+    return pattern_to_quantize_handlers
+
+
+# TODO: remove this class, this is still exposed in torch.ao.quantization
+# but we should be able to break bc
+class BinaryOpQuantizeHandler(QuantizeHandler):
+    pass
+
+
+class CatQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: remove this class
+class ConvReluQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: remove this class
+class LinearReLUQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: remove this class
+class BatchNormQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: remove this class
+class EmbeddingQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: remove this class
+class RNNDynamicQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: remove this class
+class DefaultNodeQuantizeHandler(QuantizeHandler):
+    """Common quantized op, first input and first output will be quantized"""
+
+
+# TODO: remove this class
+class FixedQParamsOpQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: remove
+class CopyNodeQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: remove
+class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
+class CustomModuleQuantizeHandler(QuantizeHandler):
+    pass
+
+
+# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
+class StandaloneModuleQuantizeHandler(QuantizeHandler):
+    pass
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py
new file mode 100644
index 0000000000000000000000000000000000000000..915b396e9f3368ed4fe2c2d13b86636db101e90c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/tracer.py
@@ -0,0 +1,48 @@
+from typing import Callable
+
+import torch
+from torch.ao.nn.intrinsic import _FusedModule
+from torch.fx._symbolic_trace import Tracer
+from torch.fx.proxy import Scope
+
+
+__all__ = [
+    "QuantizationTracer",
+]
+
+
+class ScopeContextManager(torch.fx.proxy.ScopeContextManager):
+    def __init__(
+        self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
+    ):
+        super().__init__(scope, Scope(current_module_path, type(current_module)))
+
+
+class QuantizationTracer(Tracer):
+    def __init__(
+        self, skipped_module_names: list[str], skipped_module_classes: list[Callable]
+    ):
+        super().__init__()
+        self.skipped_module_names = skipped_module_names
+        self.skipped_module_classes = skipped_module_classes
+        # NB: initialized the module_type of top level module to None
+        # we are assuming people won't configure the model with the type of top level
+        # module here, since people can use "" for global config
+        # We can change this if there is a use case that configures
+        # qconfig using top level module type
+        self.scope = Scope("", None)
+        self.record_stack_traces = True
+
+    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
+        return (
+            (
+                (
+                    m.__module__.startswith("torch.nn")
+                    or m.__module__.startswith("torch.ao.nn")
+                )
+                and not isinstance(m, torch.nn.Sequential)
+            )
+            or module_qualified_name in self.skipped_module_names
+            or type(m) in self.skipped_module_classes
+            or isinstance(m, _FusedModule)
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5166f28571686dde7f26ae3a4a40c49a25c698cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fx/utils.py
@@ -0,0 +1,963 @@
+# mypy: allow-untyped-defs
+import copy
+import operator
+import warnings
+from collections import namedtuple
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.ao.quantization import QConfigAny, QuantType
+from torch.ao.quantization.backend_config import DTypeWithConstraints
+from torch.ao.quantization.fake_quantize import (
+    FakeQuantizeBase,
+    FixedQParamsFakeQuantize,
+)
+from torch.ao.quantization.observer import (
+    _is_activation_post_process,
+    FixedQParamsObserver,
+    ObserverBase,
+)
+from torch.ao.quantization.qconfig import (
+    float16_dynamic_qconfig,
+    float16_static_qconfig,
+    qconfig_equals,
+)
+from torch.ao.quantization.qconfig_mapping import QConfigMapping
+from torch.ao.quantization.stubs import DeQuantStub
+from torch.ao.quantization.utils import (
+    _assert_and_get_unique_device,
+    activation_is_statically_quantized,
+)
+from torch.fx import GraphModule, map_arg
+from torch.fx.graph import Graph, Node
+
+# importing the lib so that the quantized_decomposed ops are registered
+from ._decomposed import quantized_decomposed_lib  # noqa: F401
+from .custom_config import PrepareCustomConfig
+
+
+# TODO: revisit this list. Many helper methods shouldn't be public
+__all__ = [
+    "all_node_args_except_first",
+    "all_node_args_have_no_tensors",
+    "assert_and_get_unique_device",
+    "collect_producer_nodes",
+    "create_getattr_from_value",
+    "create_node_from_old_node_preserve_meta",
+    "EMPTY_ARG_DICT",
+    "get_custom_module_class_keys",
+    "get_linear_prepack_op_for_dtype",
+    "get_new_attr_name_with_prefix",
+    "get_non_observable_arg_indexes_and_types",
+    "get_qconv_prepack_op",
+    "get_skipped_module_name_and_classes",
+    "graph_module_from_producer_nodes",
+    "maybe_get_next_module",
+    "NodeInfo",
+    "node_arg_is_bias",
+    "node_arg_is_weight",
+    "NON_OBSERVABLE_ARG_DICT",
+    "NON_QUANTIZABLE_WEIGHT_OPS",
+    "return_arg_list",
+    "ObservedGraphModuleAttrs",
+]
+
+NON_QUANTIZABLE_WEIGHT_OPS = {
+    torch.nn.functional.layer_norm,
+    torch.nn.functional.group_norm,
+    torch.nn.functional.instance_norm,
+}
+
+
+@dataclass
+class ObservedGraphModuleAttrs:
+    node_name_to_qconfig: dict[str, QConfigAny]
+    node_name_to_scope: dict[str, tuple[str, type]]
+    prepare_custom_config: PrepareCustomConfig
+    equalization_node_name_to_qconfig: dict[str, Any]
+    qconfig_mapping: QConfigMapping
+    is_qat: bool
+    observed_node_names: set[str]
+    is_observed_standalone_module: bool = False
+    standalone_module_input_quantized_idxs: Optional[list[int]] = None
+    standalone_module_output_quantized_idxs: Optional[list[int]] = None
+
+
+def node_arg_is_weight(node: Node, arg: Any) -> bool:
+    """Returns if node arg is weight"""
+    weight_index = None
+    if "target_dtype_info" in node.meta:
+        weight_index = node.meta["target_dtype_info"].get("weight_index", None)
+    if (
+        weight_index is not None
+        and weight_index < len(node.args)
+        and node.args[weight_index] is arg
+    ):
+        return True
+    return node.kwargs.get("weight") is arg
+
+
+def node_arg_is_bias(node: Node, arg: Any) -> bool:
+    """Returns if node arg is bias"""
+    bias_index = None
+    if "target_dtype_info" in node.meta:
+        bias_index = node.meta["target_dtype_info"].get("bias_index", None)
+    if (
+        bias_index is not None
+        and bias_index < len(node.args)
+        and node.args[bias_index] is arg
+    ):
+        return True
+    return node.kwargs.get("bias") is arg
+
+
+def get_custom_module_class_keys(
+    custom_module_mapping: dict[QuantType, dict[type, type]],
+) -> list[Any]:
+    r"""Get all the unique custom module keys in the custom config dict
+    e.g.
+    Input:
+    {
+        QuantType.STATIC: {
+            CustomModule1: ObservedCustomModule
+        },
+        QuantType.DYNAMIC: {
+            CustomModule2: DynamicObservedCustomModule
+        },
+        QuantType.WEIGHT_ONLY: {
+            CustomModule3: WeightOnlyObservedCustomModule
+        },
+    }
+
+    Output:
+    # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
+    [CustomModule1, CustomModule2, CustomModule3]
+    """
+    # using set to dedup
+    float_custom_module_classes: set[Any] = set()
+    for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
+        quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
+        quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
+        float_custom_module_classes |= quant_mode_custom_module_classes
+    return list(float_custom_module_classes)
+
+
+def get_linear_prepack_op_for_dtype(dtype):
+    if dtype == torch.float16:
+        return torch.ops.quantized.linear_prepack_fp16
+    elif dtype == torch.qint8:
+        return torch.ops.quantized.linear_prepack
+    else:
+        raise Exception("can't get linear prepack op for dtype:", dtype)  # noqa: TRY002
+
+
+def get_qconv_prepack_op(conv_op: Callable) -> Callable:
+    prepack_ops = {
+        torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
+        torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
+        torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack,
+        torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack,
+        torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
+        torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
+    }
+    prepack_op = prepack_ops.get(conv_op, None)
+    assert prepack_op, f"Didn't find prepack op for {conv_op}"
+    return prepack_op
+
+
+# Returns a function that can get a new attribute name for module with given
+# prefix, for example,
+# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
+# >> new_name = get_new_observer_name(module)
+# new_name will be an unused attribute name on module, e.g. `_observer_1`
+def get_new_attr_name_with_prefix(prefix: str) -> Callable:
+    prefix = prefix.replace(".", "_")
+
+    def get_new_attr_name(module: torch.nn.Module):
+        def get_attr_name(i: int):
+            return prefix + str(i)
+
+        i = 0
+        attr_name = get_attr_name(i)
+        while hasattr(module, attr_name):
+            i += 1
+            attr_name = get_attr_name(i)
+        return attr_name
+
+    return get_new_attr_name
+
+
+def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
+    r"""Starting from a target node, trace back until we hit inpu or
+    getattr node. This is used to extract the chain of operators
+    starting from getattr to the target node, for example
+    def forward(self, x):
+      observed = self.observer(self.weight)
+      return F.linear(x, observed)
+    collect_producer_nodes(observed) will either return a list of nodes that
+    produces the observed node or None if we can't extract a self contained
+    graph without free variables(inputs of the forward function).
+    """
+    nodes = [node]
+    frontier = [node]
+    while frontier:
+        node = frontier.pop()
+        all_args = list(node.args) + list(node.kwargs.values())
+        for arg in all_args:
+            if not isinstance(arg, Node):
+                continue
+            if arg.op == "placeholder":
+                # hit input, can't fold in this case
+                return None
+            nodes.append(arg)
+            if not (arg.op == "call_function" and arg.target == getattr):
+                frontier.append(arg)
+    return nodes
+
+
+def graph_module_from_producer_nodes(
+    root: GraphModule, producer_nodes: list[Node]
+) -> GraphModule:
+    r"""Construct a graph module from extracted producer nodes
+    from `collect_producer_nodes` function
+    Args:
+      root: the root module for the original graph
+      producer_nodes: a list of nodes we use to construct the graph
+    Return:
+      A graph module constructed from the producer nodes
+    """
+    assert len(producer_nodes) > 0, "list of producer nodes can not be empty"
+    # since we traced back from node to getattr
+    producer_nodes.reverse()
+    graph = Graph()
+    env: dict[Any, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node])
+
+    for producer_node in producer_nodes:
+        env[producer_node] = graph.node_copy(producer_node, load_arg)
+    graph.output(load_arg(producer_nodes[-1]))
+    graph_module = GraphModule(root, graph)
+    return graph_module
+
+
+# TODO: delete
+def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
+    """
+    Returns the unique device for a module, or None if no device is found.
+    Throws an error if multiple devices are detected.
+    """
+    return _assert_and_get_unique_device(module)
+
+
+def create_getattr_from_value(
+    module: torch.nn.Module, graph: Graph, prefix: str, value: Any
+) -> Node:
+    """
+    Given a value of any type, creates a getattr node corresponding to the value and
+    registers the value as a buffer to the module.
+    """
+    get_new_attr_name = get_new_attr_name_with_prefix(prefix)
+    attr_name = get_new_attr_name(module)
+    device = assert_and_get_unique_device(module)
+    new_value = (
+        value.detach().clone()
+        if isinstance(value, torch.Tensor)
+        else torch.tensor(value, device=device)
+    )
+    module.register_buffer(attr_name, new_value)
+    # Create get_attr with value
+    attr_node = graph.create_node("get_attr", attr_name)
+    return attr_node
+
+
+def all_node_args_have_no_tensors(
+    node: Node, modules: dict[str, torch.nn.Module], cache: dict[Node, bool]
+) -> bool:
+    """
+    If we know for sure that all of this node's args have no
+    tensors (are primitives), return True.  If we either
+    find a tensor or are not sure, return False. Note: this
+    function is not exact.
+    """
+    if cache and node in cache:
+        return cache[node]
+
+    result = False  # will be overwritten
+    if not isinstance(node, Node):
+        result = True
+    elif node.op == "placeholder":
+        result = False
+    elif node.op == "call_module":
+        assert isinstance(node.target, str)
+        if _is_activation_post_process(modules[node.target]):
+            result = all_node_args_have_no_tensors(node.args[0], modules, cache)  # type: ignore[arg-type]
+    elif node.op == "call_module":
+        result = False
+    elif node.op == "call_function" and node.target is operator.getitem:
+        result = all_node_args_have_no_tensors(node.args[0], modules, cache)  # type: ignore[arg-type]
+    elif node.op == "get_attr":
+        result = False
+    elif node.target is getattr and node.args[1] in ["ndim", "shape"]:
+        # x1 = x0.ndim
+        result = True
+    elif node.op == "call_method" and node.target == "size":
+        # x1 = x0.size(0)
+        result = True
+    else:
+        found_one_tensor = False
+        for arg in node.args:
+            if isinstance(arg, list):
+                for list_el in arg:
+                    if isinstance(list_el, Node):
+                        this_list_el_args_have_no_tensors = (
+                            all_node_args_have_no_tensors(list_el, modules, cache)
+                        )
+                        found_one_tensor = found_one_tensor or (
+                            not this_list_el_args_have_no_tensors
+                        )
+                        # If found_one_tensor is True, there is no point in
+                        # recursing further as the end result will always
+                        # be True.
+                        # TODO(future PR): remove this entire function  and
+                        # change to dtype inference without recursion.
+                        if found_one_tensor:
+                            result = not found_one_tensor
+                            if cache:
+                                cache[node] = result
+                            return result
+            elif isinstance(arg, int):
+                pass
+            else:
+                if isinstance(arg, Node):
+                    this_arg_args_have_no_tensors = all_node_args_have_no_tensors(
+                        arg, modules, cache
+                    )
+                    found_one_tensor = found_one_tensor or (
+                        not this_arg_args_have_no_tensors
+                    )
+                    # If found_one_tensor is True, there is no point in
+                    # recursing further as the end result will always
+                    # be True.
+                    # TODO(future PR): remove this entire function  and
+                    # change to dtype inference without recursion.
+                    if found_one_tensor:
+                        result = not found_one_tensor
+                        if cache:
+                            cache[node] = result
+                        return result
+                else:
+                    found_one_tensor = True
+            result = not found_one_tensor
+    if cache:
+        cache[node] = result
+    return result
+
+
+def all_node_args_except_first(node: Node) -> list[int]:
+    """
+    Returns all node arg indices after first
+    """
+    return list(range(1, len(node.args)))
+
+
+def return_arg_list(arg_indices: list[int]) -> Callable[[Node], list[int]]:
+    """
+    Constructs a function that takes a node as arg and returns the arg_indices
+    that are valid for node.args
+    """
+
+    def arg_indices_func(node: Node) -> list[int]:
+        return [i for i in arg_indices if i < len(node.args)]
+
+    return arg_indices_func
+
+
+NodeInfo = namedtuple("NodeInfo", "op target")
+
+# this dict identifies which indices of a node are non tensors
+# so that they can be propagated correctly since inserting observers
+# for them would cause errors
+
+NON_OBSERVABLE_ARG_DICT: dict[
+    NodeInfo, dict[Union[type, torch.dtype], Callable[[Node], list[int]]]
+] = {
+    NodeInfo("call_method", "masked_fill"): {
+        torch.bool: return_arg_list([1]),
+        float: return_arg_list([2]),
+    },
+    NodeInfo("call_method", "permute"): {int: all_node_args_except_first},
+    NodeInfo("call_method", "repeat"): {int: all_node_args_except_first},
+    NodeInfo("call_method", "reshape"): {int: all_node_args_except_first},
+    NodeInfo("call_method", "size"): {int: return_arg_list([1])},
+    NodeInfo("call_method", "transpose"): {int: all_node_args_except_first},
+    NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first},
+    NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])},
+    NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])},
+    NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])},
+    NodeInfo("call_method", "view"): {int: all_node_args_except_first},
+}
+
+EMPTY_ARG_DICT: dict[Union[type, torch.dtype], Callable[[Node], list[int]]] = {}
+
+
+def get_non_observable_arg_indexes_and_types(
+    node: Node,
+) -> dict[Union[type, torch.dtype], Callable[[Node], list[int]]]:
+    """
+    Returns a dict with of non float tensor types as keys and values which correspond to a
+    function to retrieve the list (which takes the node as an argument)
+    """
+    info = NodeInfo(node.op, node.target)
+
+    return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
+
+
+def maybe_get_next_module(
+    node: Node,
+    modules: dict[str, nn.Module],
+    target_module_type: Optional[type[nn.Module]] = None,
+    target_functional_type: Any = None,
+) -> Optional[Node]:
+    """Gets the next module that matches what is needed in
+    is_target_module_type if it exists
+
+    Args:
+        node: The node whose users we want to look at
+        target_module_type: Module type that we want to check
+        target_functional_type: Functional type that we want to check
+    """
+
+    for user in node.users.keys():
+        if (
+            user.op == "call_module"
+            and target_module_type is not None
+            and isinstance(modules[str(user.target)], target_module_type)
+        ):
+            return user
+        elif (
+            user.op == "call_function"
+            and target_functional_type is not None
+            and user.target == target_functional_type
+        ):
+            return user
+
+    return None
+
+
+def create_node_from_old_node_preserve_meta(
+    quantized_graph: Graph,
+    create_node_args: tuple[Any, ...],
+    old_node: Node,
+) -> Node:
+    """
+    Creates `new_node` and copies the necessary metadata to it from `old_node`.
+    """
+    new_node = quantized_graph.create_node(*create_node_args)
+    new_node.stack_trace = old_node.stack_trace
+    return new_node
+
+
+def get_skipped_module_name_and_classes(
+    prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool
+) -> tuple[list[str], list[type[Any]]]:
+    skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
+    skipped_module_classes = copy.copy(
+        prepare_custom_config.non_traceable_module_classes
+    )
+    if not is_standalone_module:
+        # standalone module and custom module config are applied in top level module
+        skipped_module_names += list(
+            prepare_custom_config.standalone_module_names.keys()
+        )
+        skipped_module_classes += list(
+            prepare_custom_config.standalone_module_classes.keys()
+        )
+        skipped_module_classes += get_custom_module_class_keys(
+            prepare_custom_config.float_to_observed_mapping
+        )
+
+    return skipped_module_names, skipped_module_classes
+
+
+def _is_custom_module_lstm(
+    node: Node,
+    named_modules: dict[str, torch.nn.Module],
+    qconfig: QConfigAny = None,
+    # QuantizeHandler, but we cannot include the type here due to circular imports
+    qhandler: Optional[Any] = None,
+) -> bool:
+    """
+    Return whether this refers to the custom module LSTM flow.
+    """
+    mod = _get_module(node, named_modules)
+    if qconfig is not None and qhandler is not None:
+        assert isinstance(
+            qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
+        )  # type: ignore[attr-defined]
+        return (
+            isinstance(mod, torch.nn.LSTM)
+            and activation_is_statically_quantized(qconfig)
+            and qhandler.is_custom_module()
+        )
+    else:
+        return isinstance(mod, torch.ao.nn.quantizable.LSTM)
+
+
+def _is_custom_module_mha(
+    node: Node,
+    named_modules: dict[str, torch.nn.Module],
+    qconfig: QConfigAny = None,
+    # QuantizeHandler, but we cannot include the type here due to circular imports
+    qhandler: Optional[Any] = None,
+) -> bool:
+    """
+    Return whether this refers to the custom module MultiheadAttention flow.
+    """
+    mod = _get_module(node, named_modules)
+    if qconfig is not None and qhandler is not None:
+        assert isinstance(
+            qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
+        )  # type: ignore[attr-defined]
+        return (
+            isinstance(mod, torch.nn.MultiheadAttention)
+            and activation_is_statically_quantized(qconfig)
+            and qhandler.is_custom_module()
+        )
+    else:
+        return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)
+
+
+def _get_module(
+    node: Node, named_modules: dict[str, torch.nn.Module]
+) -> Optional[torch.nn.Module]:
+    """
+    If `node` refers to a call_module node, return the module, else None.
+    """
+    if node.op == "call_module" and str(node.target) in named_modules:
+        return named_modules[str(node.target)]
+    else:
+        return None
+
+
+def _insert_dequant_stub(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+) -> Node:
+    """
+    Attach a `DeQuantStub` to the model and create a node that calls this
+    `DeQuantStub` on the output of `node`, similar to how observers are inserted.
+    """
+    prefix = "dequant_stub_"
+    get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix)
+    dequant_stub_name = get_new_dequant_stub_name(model)
+    dequant_stub = DeQuantStub()
+    setattr(model, dequant_stub_name, dequant_stub)
+    named_modules[dequant_stub_name] = dequant_stub
+    with graph.inserting_after(node):
+        return graph.call_module(dequant_stub_name, (node,))
+
+
+def _insert_dequant_stubs_for_custom_module_lstm_output(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+) -> Node:
+    """
+    Insert DeQuantStubs after each internal output node of custom module LSTM.
+
+    Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)),
+    Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its
+    components through `getitem`. This function transforms the graph as follows:
+
+      (1) Split the LSTM node into (output, (hidden0, hidden1))
+      (2) Insert a DeQuantStub after each internal node
+      (3) Recombine the DeQuantStubs into the same structure as before
+      (4) Reroute all consumers of the original LSTM node and its sub-nodes
+          (e.g. lstm[0])
+
+    Before:
+                   lstm_output
+                        |
+                        v
+                  original_user(s)
+    After:
+                   lstm_output
+                  /           \\
+                 /  (getitem)  \\
+                /               \\
+               v                 v
+             output            hidden
+               |               /   \\
+         (DeQuantStub)        (getitem)
+               |             /       \\
+               v            v         v
+           output_dq     hidden0    hidden1
+               |            |         |
+               |    (DeQuantStub) (DeQuantStub)
+               |            |         |
+               |            v         v
+               |      hidden0_dq  hidden1_dq
+               |            \\       /
+               |              (tuple)
+               |              \\   /
+               |               v  v
+               |             hidden_dq
+               \\               /
+                \\   (tuple)   /
+                 v            v
+                 lstm_output_dq
+                       |
+                       v
+                original_user(s)
+
+    For step (4), reroute all users of the original LSTM node(s) as follows:
+      lstm_output -> lstm_output_dq
+      lstm_output[0] -> output_dq
+      lstm_output[1] -> hidden_dq
+      lstm_output[1][0] -> hidden0_dq
+      lstm_output[1][1] -> hidden1_dq
+
+    Return the node `lstm_output_dq`.
+    """
+    # (1) Split the LSTM node into (output, (hidden0, hidden1))
+    # (2) Insert a DeQuantStub after each internal node
+    with graph.inserting_after(node):
+        output = graph.call_function(operator.getitem, (node, 0))
+        output_dq = _insert_dequant_stub(output, model, named_modules, graph)
+    with graph.inserting_after(output_dq):
+        hidden = graph.call_function(operator.getitem, (node, 1))
+    with graph.inserting_after(hidden):
+        hidden0 = graph.call_function(operator.getitem, (hidden, 0))
+        hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph)
+    with graph.inserting_after(hidden0_dq):
+        hidden1 = graph.call_function(operator.getitem, (hidden, 1))
+        hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph)
+
+    # (3) Recombine the DeQuantStubs into the same structure as before
+    with graph.inserting_after(hidden1_dq):
+        hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],))
+    with graph.inserting_after(hidden_dq):
+        lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],))
+
+    # (4) Reroute all consumers of the original LSTM node and its sub-nodes
+    for user in list(node.users.keys()):
+        if user != output and user != hidden:
+            user.replace_input_with(node, lstm_output_dq)
+    # The getitem and tuple nodes we added here may interfere with reference quantized
+    # pattern matching, so we need to redirect the consumers of internal nodes to the
+    # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached,
+    # in order to preserve reference patterns like "dequantize - consumer - quantize".
+    _reroute_tuple_getitem_pattern(graph)
+    return lstm_output_dq
+
+
+def _maybe_get_custom_module_lstm_from_node_arg(
+    arg: Node,
+    named_modules: dict[str, torch.nn.Module],
+) -> Optional[Node]:
+    """
+    Given an argument of a node, if the argument refers to the path through which the node
+    is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise.
+
+    This is used to determine whether a node is a consumer of custom module LSTM, and, if so,
+    skip inserting input observers for this node. This is because custom module LSTM produces
+    quantized outputs, so inserting an input observer for the consumer of custom module LSTM
+    would unnecessarily quantize the outputs again.
+
+      lstm -> consumer
+
+    In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with
+    DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
+    This tuple can be consumed in one of four ways:
+
+      lstm -> getitem -> DeQuantStub -> consumer                       # consume lstm[0]
+      lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer   # consume lstm[1]
+      lstm -> getitem -> getitem -> DeQuantStub -> consumer            # consume lstm[1][0] or lstm[1][1]
+      lstm -> getitem -> DeQuantStub -> tuple -> consumer              # consume lstm
+
+    Thus, we must match against the above patterns instead of simply checking the parent node
+    to determine whether this node is a consumer of a custom module LSTM.
+    """
+
+    def match_dq(a):
+        return isinstance(_get_module(a, named_modules), DeQuantStub)
+
+    def match_lstm(a):
+        return _is_custom_module_lstm(a, named_modules)
+
+    def match_getitem(a):
+        return a.op == "call_function" and a.target == operator.getitem
+
+    def match_tuple(a):
+        return a.op == "call_function" and a.target == tuple
+
+    def _match_pattern(match_pattern: list[Callable]) -> Optional[Node]:
+        """
+        Traverse up the graph and match the args one by one.
+        If there is a match, return the last matched node, or None otherwise.
+        """
+        a = arg
+        for i, match in enumerate(match_pattern):
+            if not match(a):
+                return None
+            # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],)
+            if i < len(match_pattern) - 1:
+                if match == match_tuple:
+                    a = a.args[0][0]  # type: ignore[assignment,index]
+                else:
+                    a = a.args[0]  # type: ignore[assignment]
+        return a
+
+    all_match_patterns = [
+        [match_dq, match_getitem, match_lstm],
+        [match_tuple, match_dq, match_getitem, match_getitem, match_lstm],
+        [match_dq, match_getitem, match_getitem, match_lstm],
+        [match_tuple, match_dq, match_getitem, match_lstm],
+    ]
+
+    for p in all_match_patterns:
+        matched_node = _match_pattern(p)
+        if matched_node is not None:
+            return matched_node
+    return None
+
+
+def _reroute_tuple_getitem_pattern(graph: Graph):
+    """
+    Search for patterns where N consecutive `tuple` call_function nodes are followed by
+    N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes.
+    If we find this pattern, reroute the consumers of the last `getitem` to skip these
+    N `tuple` and `getitem` nodes.
+
+    Before:
+
+        a   b     c
+        |   \\   /
+        \\   tuple
+         \\   /
+          tuple
+            |
+        getitem(1)
+            |
+        getitem(0)
+            |
+            d
+
+    After:
+
+        b
+        |
+        d
+    """
+
+    def find_patterns(
+        node: Node,
+        index_stack: list[int],
+        current_pattern: list[Node],
+        matched_patterns: list[list[Node]],
+        seen: set[tuple[Node, tuple[int, ...]]],
+    ):
+        """
+        Traverse the graph recursively to match for the N-tuple - N-getitem patterns,
+        starting at the given node.
+
+        We use a stack to keep track of the expected `getitem` indices, since these are
+        reversed from the `tuple` indices. In the above example, the stack after
+        (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first
+        and then by getitem(0).
+
+        TODO: traverse upwards from the output and handle the case when tuple is not a
+        separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c)))
+        """
+        if len(index_stack) == 0 and len(current_pattern) > 0:
+            matched_patterns.append(copy.copy(current_pattern))
+            current_pattern.clear()
+
+        # Avoid duplicating work
+        state = (node, tuple(index_stack))
+        if state in seen:
+            return
+        seen.add(state)
+
+        # Iterate through users of this node to find tuple/getitem nodes to match
+        for user in node.users:
+            if user.op == "call_function" and user.target == tuple:
+                for i, user_arg in enumerate(user.args[0]):  # type: ignore[arg-type]
+                    if user_arg == node:
+                        index_stack.append(i)
+                        current_pattern.append(user)
+                        find_patterns(
+                            user, index_stack, current_pattern, matched_patterns, seen
+                        )
+            elif user.op == "call_function" and user.target == operator.getitem:
+                if len(index_stack) > 0:
+                    if user.args[1] == index_stack[-1]:
+                        index_stack.pop()
+                        current_pattern.append(user)
+                        find_patterns(
+                            user, index_stack, current_pattern, matched_patterns, seen
+                        )
+        return matched_patterns
+
+    # Collect all matched patterns
+    matched_patterns: list[list[Node]] = []
+    seen: set[tuple[Node, tuple[int, ...]]] = set()  # (node, index_stack)
+    for node in graph.nodes:
+        find_patterns(node, [], [], matched_patterns, seen)
+
+    # For each pattern, redirect all consumers of the last getitem node to the correct input
+    # of the first tuple node
+    for pattern in matched_patterns:
+        first_tuple = pattern[0]
+        last_getitem = pattern[-1]
+        assert first_tuple.op == "call_function" and first_tuple.target == tuple
+        assert (
+            last_getitem.op == "call_function"
+            and last_getitem.target == operator.getitem
+        )
+        last_getitem_index = last_getitem.args[1]
+        new_input = first_tuple.args[0][last_getitem_index]  # type: ignore[index]
+        for user in list(last_getitem.users.keys()):
+            user.replace_input_with(last_getitem, new_input)  # type: ignore[arg-type]
+
+
+def _get_observer_from_activation_post_process(
+    activation_post_process: Union[ObserverBase, FakeQuantizeBase],
+) -> ObserverBase:
+    """
+    If `activation_post_process` is an observer, return the observer.
+    If `activation_post_process` is a fake quantize, return the internal observer.
+    """
+    if isinstance(activation_post_process, ObserverBase):
+        return activation_post_process
+    else:
+        assert isinstance(activation_post_process, FakeQuantizeBase)
+        return activation_post_process.activation_post_process  # type: ignore[return-value]
+
+
+def _qconfig_satisfies_dtype_config_constraints(
+    qconfig: QConfigAny,
+    dtype_with_constraints: DTypeWithConstraints,
+    is_activation: bool = True,
+) -> bool:
+    """
+    Return whether `qconfig` satisfies the following constraints from the backend,
+    specified through the activation and weight DTypeWithConstraints.
+
+        1. QConfig specified a quantization range that falls within the backend's, if any
+        2. QConfig specified a min scale value that is >= the backend's, if any
+        3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has
+           scale and zero point that match the backend's, if any
+
+    If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`.
+    If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True.
+    """
+
+    # TODO: log warnings only when the user enabled a debug flag
+    def _activation_post_process_satisfies_dtype_config_constraints(
+        activation_post_process: Union[ObserverBase, FakeQuantizeBase],
+        dtype_with_constraints: DTypeWithConstraints,
+        debug_string: str,
+    ) -> bool:
+        observer = _get_observer_from_activation_post_process(activation_post_process)
+        app_quant_min = getattr(observer, "quant_min", None)
+        app_quant_max = getattr(observer, "quant_max", None)
+        # TODO: for now, just use the existing eps value as scale_min. In the future, we should
+        # resolve the differences between the two, either by renaming eps or some other way
+        app_scale_min = getattr(observer, "eps", None)
+        backend_quant_min = dtype_with_constraints.quant_min_lower_bound
+        backend_quant_max = dtype_with_constraints.quant_max_upper_bound
+        backend_scale_min = dtype_with_constraints.scale_min_lower_bound
+        backend_scale_exact_match = dtype_with_constraints.scale_exact_match
+        backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match
+        # check quantization ranges
+        if backend_quant_min is not None and backend_quant_max is not None:
+            if app_quant_min is None or app_quant_max is None:
+                warnings.warn(
+                    f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}"
+                )
+                return False
+            elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
+                warnings.warn(
+                    f"QConfig {debug_string} quantization range must fall within the backend's:\n"
+                    f"QConfig range = ({app_quant_min}, {app_quant_max}), "
+                    f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
+                    f"ignoring {qconfig}"
+                )
+                return False
+        # check scale min
+        if backend_scale_min is not None:
+            if app_scale_min is None:
+                warnings.warn(
+                    f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}"
+                )
+                return False
+            if app_scale_min < backend_scale_min:
+                warnings.warn(
+                    f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
+                    f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}"
+                )
+                return False
+        # check fixed scale and zero point
+        if (
+            backend_scale_exact_match is not None
+            and backend_zero_point_exact_match is not None
+        ):
+            # For tests only, accept the following qconfigs for now
+            # TODO: handle fp16 qconfigs properly
+            for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]:
+                if qconfig_equals(qconfig, accepted_qconfig):
+                    return True
+            suggestion_str = (
+                "Please use torch.ao.quantization.get_default_qconfig_mapping or "
+                "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n"
+                '    qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n'
+                "    model = prepare_fx(model, qconfig_mapping, example_inputs)"
+            )
+            if not isinstance(
+                activation_post_process, FixedQParamsObserver
+            ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize):
+                warnings.warn(
+                    f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
+                    f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}"
+                )
+                return False
+            if (
+                observer.scale != backend_scale_exact_match
+                or observer.zero_point != backend_zero_point_exact_match
+            ):
+                warnings.warn(
+                    f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
+                    f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
+                    f"ignoring {qconfig}.\n{suggestion_str}"
+                )
+                return False
+        return True
+
+    if qconfig is None or dtype_with_constraints.dtype is None:
+        return True
+
+    activation_post_process_ctr = (
+        qconfig.activation if is_activation else qconfig.weight
+    )
+    debug_string = "activation" if is_activation else "weight"
+    satisfies_constraints = True
+    if activation_post_process_ctr is not None:
+        activation_post_process = activation_post_process_ctr()
+        assert _is_activation_post_process(activation_post_process)
+        # If dtypes don't match, don't check the activation_post_process and return True early
+        if activation_post_process.dtype != dtype_with_constraints.dtype:
+            return True
+        satisfies_constraints = (
+            _activation_post_process_satisfies_dtype_config_constraints(
+                activation_post_process, dtype_with_constraints, debug_string
+            )
+        )
+    return satisfies_constraints
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..871db9ba1c45685c039ed97f5b4c79c0c4ca8428
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1d1cd5aeba58e93065073670cf2f2972cd8e6ff
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ac733e419701368c3a455ca2235bff7fb0ccf73
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66ecf54854fe6ccb453ccdaa4c731a88bbcc70f7
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..d32a7267b45247961851496c1e9e2836ee7de10d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_affine_quantization.py
@@ -0,0 +1,866 @@
+# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
+# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
+# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC
+import logging
+from abc import ABCMeta
+from typing import Any, Optional, Union
+
+import torch
+from torch.ao.quantization.observer import (
+    AffineQuantizedObserverBase,
+    get_block_size,
+    Granularity,
+    MappingType,
+    TorchAODType,
+    ZeroPointDomain,
+)
+
+
+ABC: Any = ABCMeta("ABC", (object,), {})  # compatible with Python 2 *and* 3:
+
+logger = logging.getLogger(__name__)
+
+FP8_TYPES = {
+    torch.float8_e4m3fn,
+    torch.float8_e5m2,
+    torch.float8_e4m3fnuz,
+    torch.float8_e5m2fnuz,
+}
+_SUB_BYTE_UINT_BOUNDS = {
+    torch.uint1: (0, 2**1 - 1),
+    torch.uint2: (0, 2**2 - 1),
+    torch.uint3: (0, 2**3 - 1),
+    torch.uint4: (0, 2**4 - 1),
+    torch.uint5: (0, 2**5 - 1),
+    torch.uint6: (0, 2**6 - 1),
+    torch.uint7: (0, 2**7 - 1),
+}
+
+"""
+Map from dtype to the bound value of integers
+TODO: maybe can replace this with call to torch.iinfo
+"""
+_DTYPE_TO_QVALUE_BOUNDS: dict[Union[torch.dtype, TorchAODType], tuple[int, int]] = {
+    torch.uint8: (0, 255),
+    torch.int8: (-128, 127),
+    torch.int16: (-(2**15), 2**15 - 1),
+    torch.int32: (-(2**31), 2**31 - 1),
+}
+_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
+
+
+def _is_float8_type(dtype: torch.dtype) -> bool:
+    fp8_types = {
+        torch.float8_e4m3fn,
+        torch.float8_e4m3fnuz,
+        torch.float8_e5m2,
+        torch.float8_e5m2fnuz,
+    }
+    return dtype in fp8_types
+
+
+# TODO: decide on if we want to allow custom quant_min/quant_max here
+def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
+    """Get quant_min and quant_max args based on dtype and also
+    verify that they are within the range of possible quant_min/quant_max
+    for dtype
+    """
+    if dtype in FP8_TYPES:
+        quant_min_lower_bound, quant_max_upper_bound = (
+            torch.finfo(dtype).min,
+            torch.finfo(dtype).max,
+        )
+    elif dtype not in _DTYPE_TO_QVALUE_BOUNDS:
+        raise ValueError(f"Unsupported dtype: {dtype}")
+    else:
+        quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
+    if quant_min is None:
+        quant_min = quant_min_lower_bound
+    if quant_max is None:
+        quant_max = quant_max_upper_bound
+
+    assert quant_min >= quant_min_lower_bound, (
+        "quant_min out of bound for dtype, "
+        f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
+    )
+
+    assert quant_max <= quant_max_upper_bound, (
+        "quant_max out of bound for dtype, "
+        f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
+    )
+    return quant_min, quant_max
+
+
+def _get_reduction_params(block_size, input_size):
+    """Given block_size and input size find the parameters for reduction:
+
+    Output:
+        shape_for_reduction: the shape we use to `view` input to prepare it for reduction
+        reduction_dims: the dims we'll do reduction over
+
+    Example::
+        Input:
+          block_size: (3, 3, 2, 10)
+          input_size: (3, 3, 10, 10)
+
+        Output:
+          shape_for_reduction: (3, 3, 5, 2, 10)
+          reduction_dim: [0, 1, 3, 4]
+    """
+    assert len(block_size) == len(input_size)
+    shape_for_reduction = []
+    reduction_dims = []
+    cur_dim = 0
+    for i in range(len(block_size)):
+        if block_size[i] != input_size[i] and block_size[i] > 1:
+            assert input_size[i] % block_size[i] == 0, (
+                f"Expecting input size at {i} dimension: "
+                f"{input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}"
+            )
+            shape_for_reduction.append(input_size[i] // block_size[i])
+            shape_for_reduction.append(block_size[i])
+            # reduce over the block_size[i] dim
+            reduction_dims.append(cur_dim + 1)
+            cur_dim += 2
+        else:
+            # block_size[i] == input_size[i] or block_size[i] == 1
+            shape_for_reduction.append(input_size[i])
+            # we only need to reduce over the dimension if block_size is greater than 1
+            # otherwise it's already the same as reduced dimension
+            if block_size[i] != 1:
+                reduction_dims.append(cur_dim)
+            cur_dim += 1
+    return shape_for_reduction, reduction_dims
+
+
+def _register_custom_op(lib):
+    """This decorator is used to preserve some high level operators for torch.export.export
+    while still allow them to be decomposed for inductor path
+
+    requirement: make sure `fn.__name__[1:]` is the operator name you want to register
+
+    NOTE: This should be applied at the top, after all other decorators have been applied
+    NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
+    e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
+    sense for downstream system (like executorch) to accept as well
+
+    Example:
+        lib = torch.library.Library("my_namespace', "FRAGMENT")
+
+        register_custom_op = _register_custom_op(lib)
+
+        @register_custom_op
+        def _the_op_that_needs_to_be_preserved(...)
+            ...
+
+        # after this, `_the_op_that_needs_to_be_preserved` will be preserved as
+        # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
+        # torch.export.export / torch._export.export_for_training
+
+    """
+    from torch._inductor.decomposition import register_decomposition
+
+    def decorator(fn):
+        from torch._library.infer_schema import infer_schema
+
+        # expecting fn.__name__ starts with `_` and we want to take the rest
+        # to be the name of the custom op
+        assert fn.__name__[0] == "_", (
+            f"Expecting function name starts with `_`, got {fn.__name__}"
+        )
+        assert not any(c in fn.__name__ for c in ".<>"), (
+            f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
+        )
+        op_name = fn.__name__[1:]
+        schema = op_name + infer_schema(fn, mutates_args={})
+        lib.define(schema)
+        lib.impl(op_name, fn, "CompositeImplicitAutograd")
+
+        lib_namespace = lib.ns
+        op = getattr(getattr(torch.ops, lib_namespace), op_name)
+        register_decomposition([op])(fn)
+        return op
+
+    return decorator
+
+
+quant_lib = torch.library.Library("pt2e_quant", "FRAGMENT")  # noqa: TOR901
+
+register_custom_op = _register_custom_op(quant_lib)
+
+
+def choose_qparams_affine_with_min_max(
+    min_val: torch.Tensor,
+    max_val: torch.Tensor,
+    mapping_type: MappingType,
+    block_size: tuple[int, ...],
+    target_dtype: torch.dtype,
+    quant_min: Optional[int] = None,
+    quant_max: Optional[int] = None,
+    eps: Optional[float] = None,
+    scale_dtype: Optional[torch.dtype] = None,
+    zero_point_dtype: Optional[torch.dtype] = None,
+    preserve_zero: bool = True,
+    zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`
+    operator that pass in min_val and max_val directly instead of deriving these from a single input.
+    This is used for observers in static quantization where min_val and max_val may be obtained through
+    tracking all the data in calibration data set.
+
+    Args:
+      Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one
+      difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
+      and then scale/zero_point, we pass in min_val/max_val directly
+    """
+    return _choose_qparams_affine(
+        None,
+        mapping_type.name,
+        block_size,
+        target_dtype,
+        quant_min,
+        quant_max,
+        eps,
+        scale_dtype,
+        zero_point_dtype,
+        preserve_zero,
+        zero_point_domain.name if zero_point_domain is not None else None,
+        min_val,
+        max_val,
+    )
+
+
+@register_custom_op
+def _choose_qparams_affine(
+    input: Optional[torch.Tensor],
+    mapping_type: str,
+    block_size: list[int],
+    target_dtype: torch.dtype,
+    quant_min: Optional[Union[int, float, bool]] = None,
+    quant_max: Optional[Union[int, float, bool]] = None,
+    eps: Optional[float] = None,
+    scale_dtype: Optional[torch.dtype] = None,
+    zero_point_dtype: Optional[torch.dtype] = None,
+    preserve_zero: bool = True,
+    zero_point_domain: Optional[str] = "INT",
+    min_val: Optional[torch.Tensor] = None,
+    max_val: Optional[torch.Tensor] = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    """op definition that has compatible signatures with custom op library
+
+    The op does the following:
+    1. figure out the dimension for reduction based on block_size
+    2. find min_val/max_val based on the dimension for reduction
+    3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
+       and `zero_point_domain`
+    """
+    quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
+    assert mapping_type in [
+        MappingType.SYMMETRIC.name,
+        MappingType.SYMMETRIC_NO_CLIPPING_ERR.name,
+        MappingType.ASYMMETRIC.name,
+    ], f"Unsupported mapping type: {mapping_type}"
+    if target_dtype in FP8_TYPES:
+        assert mapping_type == MappingType.SYMMETRIC.name, (
+            f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"
+        )
+
+    if input is not None:
+        if scale_dtype is None:
+            scale_dtype = input.dtype
+        if zero_point_dtype is None:
+            zero_point_dtype = input.dtype
+        if eps is None:
+            eps = torch.finfo(input.dtype).eps
+
+        assert len(block_size) == input.dim(), (
+            f"Got input dim:{input.dim()}, block_size: {block_size}"
+        )
+        shape_for_reduction, reduction_dims = _get_reduction_params(
+            block_size, input.size()
+        )
+        input = input.view(shape_for_reduction)
+
+        min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
+        max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
+    else:
+        assert min_val is not None and max_val is not None, (
+            "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}"
+        )
+        assert min_val.dtype == max_val.dtype, (
+            "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}"
+        )
+
+        if scale_dtype is None:
+            scale_dtype = min_val.dtype
+        if zero_point_dtype is None:
+            zero_point_dtype = min_val.dtype
+        if eps is None:
+            eps = torch.finfo(min_val.dtype).eps
+
+    if preserve_zero:
+        min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+        max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+    else:
+        min_val_neg = min_val
+        max_val_pos = max_val
+
+    if (
+        mapping_type == MappingType.SYMMETRIC.name
+        or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
+    ):
+        # scales
+        if mapping_type == MappingType.SYMMETRIC.name:
+            max_val_pos = torch.max(-min_val_neg, max_val_pos)
+            scale = max_val_pos / (float(quant_max - quant_min) / 2)
+        else:
+            assert mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
+            # calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and
+            # quant_max = 7.
+            # - If smin is bigger: There would be coverage on negative values down to -8, and less rounding
+            # error than the existing SYMMETRIC case.
+            # - If smax is bigger: it covers the positive values up to 7. The round
+            # error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after
+            # quantization.
+            smin = min_val_neg / float(quant_min)
+            smax = max_val_pos / float(quant_max)
+            mask = smin > smax
+            scale = torch.where(mask, smin, smax)
+        # zeros
+        if not preserve_zero:
+            raise ValueError(
+                "preserve_zero == False is not supported for symmetric quantization"
+            )
+        if (
+            zero_point_domain is not None
+            and zero_point_domain != ZeroPointDomain.INT.name
+        ):
+            raise ValueError(
+                "zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization"
+            )
+        scale = torch.clamp(scale, min=eps)
+        zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
+    else:
+        assert mapping_type == MappingType.ASYMMETRIC.name
+        scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
+        scale = torch.clamp(scale, min=eps)
+        if zero_point_domain == ZeroPointDomain.NONE.name:
+            zero_point = None
+        else:
+            if preserve_zero:
+                zero_point = quant_min - torch.round(min_val_neg / scale)
+                zero_point = torch.clamp(zero_point, quant_min, quant_max)
+            else:
+                assert zero_point_domain == ZeroPointDomain.FLOAT.name, (
+                    "if not preserve_zero, zero_point must be in FLOAT domain"
+                )
+                mid_point = (quant_max + quant_min + 1) / 2
+                zero_point = min_val_neg + scale * mid_point
+
+    if zero_point is not None:
+        zero_point = zero_point.to(dtype=zero_point_dtype)
+    return scale.to(dtype=scale_dtype), zero_point
+
+
+@torch.no_grad()
+def quantize_affine(
+    input: torch.Tensor,
+    block_size: tuple[int, ...],
+    scale: torch.Tensor,
+    zero_point: Optional[torch.Tensor],
+    output_dtype: torch.dtype,
+    quant_min: Optional[Union[int, float]] = None,
+    quant_max: Optional[Union[int, float]] = None,
+    zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
+) -> torch.Tensor:
+    """
+    Args:
+      input (torch.Tensor): original float32, float16 or bfloat16 Tensor
+      block_size: (Tuple[int, ...]): granularity of quantization,
+           this means the size of the tensor elements that's sharing the same qparam
+           e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
+      scale (float): quantization parameter for affine quantization
+      zero_point (int): quantization parameter for affine quantization
+      output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+      quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
+      quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
+      zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
+        if zero_point is in integer domain, zero point is added to the quantized integer value during
+        quantization
+        if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
+        value during quantization
+        default is ZeroPointDomain.INT
+
+    Note:
+      How can block_size represent different granularities?
+      let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different
+      granularities:
+
+       granularity type       |     block_size
+         per_tensor           |    (3, 3, 10, 10)
+         per_axis (axis=0)    |    (1, 3, 10, 10)
+         per_axis (axis=1)    |    (3, 1, 10, 10)
+     per_group (groupsize=2)  |    (3, 3, 10, 2)
+     per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10)
+
+
+    Output:
+      quantized tensor with requested dtype
+    """
+    return _quantize_affine(
+        input,
+        block_size,
+        scale,
+        zero_point,
+        output_dtype,
+        quant_min,
+        quant_max,
+        zero_point_domain.name if zero_point_domain is not None else None,
+    )
+
+
+@register_custom_op
+def _quantize_affine(
+    input: torch.Tensor,
+    block_size: list[int],
+    scale: torch.Tensor,
+    zero_point: Optional[torch.Tensor],
+    output_dtype: torch.dtype,
+    quant_min: Optional[Union[int, float, bool]] = None,
+    quant_max: Optional[Union[int, float, bool]] = None,
+    zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
+) -> torch.Tensor:
+    """op definition that has compatible signatures with custom op library
+
+    Note:
+        zero_point_domain is optional specifies how we quantize the floating point to quantized data:
+        INT: quantized_val = (float_val / scale) (integer) + zero_point (integer)
+        FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
+        None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization
+            Where we do not want to round values to nearest integer and instead scale and cast.
+    """
+    quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
+    # workaround for uintx dtypes, since we don't have native Uintx dtype connected with
+    # torch.uintx dtypes yet
+    if output_dtype in _SUB_BYTE_UINT_BOUNDS:
+        output_dtype = torch.uint8
+    return _quantize_affine_no_dtype_cast(
+        input,
+        block_size,
+        scale,
+        zero_point,
+        quant_min,
+        quant_max,
+        zero_point_domain,
+    ).to(output_dtype)
+
+
+def _quantize_affine_no_dtype_cast(
+    input: torch.Tensor,
+    block_size: list[int],
+    scale: torch.Tensor,
+    zero_point: Optional[torch.Tensor],
+    quant_min: Union[int, float],
+    quant_max: Union[int, float],
+    zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
+) -> torch.Tensor:
+    """
+    The op does the following:
+    1. figure out the dimension for reduction based on block_size, also reshape the input to align with
+       the shape after reduction
+    2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
+    3. reshape the quantized result to origianl shape
+    """
+    # TODO: validations
+    # TODO: validate scale/zero_point dimensions are compatible with block_size
+    assert input.dtype in [
+        torch.float32,
+        torch.float16,
+        torch.bfloat16,
+    ], f"Unsupported input dtype: {input.dtype}"
+    assert len(block_size) == input.dim(), (
+        f"Got input dim:{input.dim()}, block_size: {block_size}"
+    )
+    shape_for_reduction, reduction_dims = _get_reduction_params(
+        block_size, input.size()
+    )
+    original_shape = input.shape
+    input = input.view(shape_for_reduction)
+    shape_after_reduction = shape_for_reduction
+    for i in reduction_dims:
+        shape_after_reduction[i] = 1
+    scale = scale.view(shape_after_reduction)
+    if zero_point is not None:
+        zero_point = zero_point.view(shape_after_reduction)
+
+    if zero_point_domain == ZeroPointDomain.INT.name:
+        quant = torch.clamp(
+            torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
+        )
+    elif zero_point_domain == ZeroPointDomain.NONE.name:
+        assert zero_point is None, (
+            "zero_point should be None when zero_point_domain is NONE"
+        )
+        quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
+    elif zero_point_domain is None:
+        # This case handles quantization for float8 we expect no zero point and no zero point domain
+        assert zero_point is None, (
+            "zero_point should be None when zero_point_domain is None"
+        )
+        quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
+    else:
+        assert zero_point_domain == ZeroPointDomain.FLOAT.name
+        mid_point = (quant_max + quant_min + 1) / 2
+        min_val = zero_point - scale * mid_point
+        quant = torch.clamp(
+            torch.round((input - min_val) / scale), quant_min, quant_max
+        )
+    quant = quant.view(original_shape)
+
+    return quant
+
+
+def dequantize_affine(
+    input: torch.Tensor,
+    block_size: tuple[int, ...],
+    scale: torch.Tensor,
+    zero_point: Optional[torch.Tensor],
+    input_dtype: torch.dtype,
+    quant_min: Optional[Union[int, float]] = None,
+    quant_max: Optional[Union[int, float]] = None,
+    zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
+    *,
+    output_dtype: torch.dtype = torch.float32,
+) -> torch.Tensor:
+    """
+    Args:
+      input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
+      block_size: (List[int]): granularity of quantization,
+        this means the size of the tensor elements that's sharing the same qparam
+        e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
+      scale (Tensor): quantization parameter for affine quantization
+      zero_point (Tensor): quantization parameter for affine quantization
+      input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
+      quant_min (Optional[int]): minimum quantized value for input Tensor
+      quant_max (Optional[int]): maximum quantized value for input Tensor
+      output_dtype (torch.dtype): dtype for output Tensor, default is fp32
+      zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
+        if zero_point is in integer domain, zero point is added to the quantized integer value during
+        quantization
+        if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
+        value during quantization
+        default is ZeroPointDomain.INT
+
+    Output:
+      dequantized Tensor, with requested dtype or fp32
+    """
+    return _dequantize_affine(
+        input,
+        block_size,
+        scale,
+        zero_point,
+        input_dtype,
+        quant_min,
+        quant_max,
+        zero_point_domain.name if zero_point_domain is not None else None,
+        output_dtype=output_dtype,
+    )
+
+
+@register_custom_op
+def _dequantize_affine(
+    input: torch.Tensor,
+    block_size: list[int],
+    scale: torch.Tensor,
+    zero_point: Optional[torch.Tensor],
+    input_dtype: torch.dtype,
+    quant_min: Optional[Union[int, float, bool]] = None,
+    quant_max: Optional[Union[int, float, bool]] = None,
+    zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
+    output_dtype: torch.dtype = torch.float32,
+) -> torch.Tensor:
+    """op definition that has compatible signatures with custom op library"""
+    # TODO: validate scale/zero_point dimensions are compatible with block_size
+    if input_dtype not in _SUB_BYTE_UINT_BOUNDS:
+        assert input.dtype == input_dtype, (
+            f"Expected: {input_dtype}, got: {input.dtype}"
+        )
+    assert output_dtype in [
+        torch.float32,
+        torch.float16,
+        torch.bfloat16,
+    ], f"Unsupported output dtype: {output_dtype}"
+    quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
+    return _dequantize_affine_no_dtype_check(
+        input,
+        block_size,
+        scale,
+        zero_point,
+        quant_min,
+        quant_max,
+        zero_point_domain,
+        output_dtype,
+    )
+
+
+def _dequantize_affine_no_dtype_check(
+    input: torch.Tensor,
+    block_size: list[int],
+    scale: torch.Tensor,
+    zero_point: Optional[torch.Tensor],
+    quant_min: Union[int, float],
+    quant_max: Union[int, float],
+    zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
+    output_dtype: torch.dtype = torch.float32,
+) -> torch.Tensor:
+    """This function converts AQT tensors to their high precision floating point representation
+
+    The op does the following:
+    1. figure out the dimension for reduction based on block_size, also reshape the input to align with
+       the shape after reduction
+    2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
+    3. reshape the quantized result to origianl shape and change dtype to the output_dtype
+    """
+    assert len(block_size) == input.dim(), (
+        f"Got input dim:{input.dim()}, block_size: {block_size}"
+    )
+    shape_for_reduction, reduction_dims = _get_reduction_params(
+        block_size, input.size()
+    )
+    original_shape = input.shape
+    input = input.view(shape_for_reduction)
+    shape_after_reduction = shape_for_reduction
+    for i in reduction_dims:
+        shape_after_reduction[i] = 1
+    scale = scale.view(shape_after_reduction)
+
+    if zero_point is not None:
+        zero_point = zero_point.view(shape_after_reduction)
+
+    if zero_point_domain == ZeroPointDomain.INT.name:
+        # Force a copy to avoid input modification due
+        # to upcoming in-place operations.
+        dequant = input.to(torch.int32, copy=True)
+        if zero_point is not None:
+            dequant = dequant - zero_point.to(torch.int32)
+        dequant = dequant.to(output_dtype)
+        dequant = dequant * scale
+    elif zero_point_domain == ZeroPointDomain.NONE.name:
+        assert zero_point is None, (
+            "zero_point should be None when zero_point_domain is NONE"
+        )
+        dequant = input.to(output_dtype)
+        dequant = dequant * scale
+    elif zero_point_domain is None:
+        # This case handles dequantization for float8 we expect no zero point and no zero point domain
+        assert zero_point is None, (
+            "zero_point should be None when zero_point_domain is None"
+        )
+        assert _is_float8_type(input.dtype), (
+            f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}"
+        )
+        dequant = input.to(output_dtype)
+        dequant = dequant * scale
+    else:
+        assert zero_point_domain == ZeroPointDomain.FLOAT.name, (
+            f"Unexpected zero point domain: {zero_point_domain}"
+        )
+        # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this)
+        mid_point = (quant_max + quant_min + 1) / 2
+        # This should allocate new memory and avoid input modification
+        dequant = input - mid_point
+        dequant = dequant.to(output_dtype)
+        dequant *= scale
+        if zero_point is not None:
+            dequant += zero_point
+
+    return dequant.view(original_shape).to(output_dtype)
+
+
+class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
+    def forward(self, input: torch.Tensor):
+        if input.numel() == 0:
+            return input
+
+        input_detached = input.detach()
+        self.original_dtype = input_detached.dtype
+        assert self.granularity is not None, "granularity is None"
+        self.block_size = get_block_size(input_detached.shape, self.granularity)
+
+        shape_for_reduction, reduction_dims = _get_reduction_params(
+            self.block_size, input_detached.size()
+        )
+        input_detached = input_detached.view(shape_for_reduction)
+        min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
+        max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
+        if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
+            self.min_val = min_val
+            self.max_val = max_val
+        else:
+            assert self.min_val.shape == min_val.shape, (
+                f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}"
+            )
+            assert self.max_val.shape == max_val.shape, (
+                f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}"
+            )
+            min_val = torch.min(self.min_val, min_val)
+            max_val = torch.max(self.max_val, max_val)
+            self.min_val.copy_(min_val)
+            self.max_val.copy_(max_val)
+        # returning original input
+        return input
+
+    def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]:
+        assert hasattr(self, "min_val") and hasattr(self, "max_val"), (
+            "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
+        )
+        return choose_qparams_affine_with_min_max(
+            self.min_val,
+            self.max_val,
+            self.mapping_type,
+            [],  # BlockSize is not needed because the min/max are already reduced
+            self.target_dtype,
+            self.quant_min,
+            self.quant_max,
+            self.eps,
+            self.scale_dtype,
+            self.zero_point_dtype,
+            self.preserve_zero,
+            self.zero_point_domain,
+        )
+
+
+class AffineQuantizedMovingAverageMinMaxObserver(AffineQuantizedObserverBase):
+    def __init__(
+        self,
+        mapping_type: MappingType,
+        target_dtype: torch.dtype,
+        granularity: Granularity,
+        averaging_constant=0.01,
+        quant_min: Optional[int] = None,
+        quant_max: Optional[int] = None,
+        eps: Optional[float] = None,
+        is_dynamic=False,
+        scale_dtype: Optional[torch.dtype] = None,
+        zero_point_dtype: Optional[torch.dtype] = None,
+        preserve_zero: bool = True,
+        zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
+        # there could be some extra args that's ignored
+        **kwargs,
+    ):
+        self.is_dynamic = is_dynamic
+        self.averaging_constant = averaging_constant
+        if is_dynamic and self.averaging_constant != 1:
+            raise NotImplementedError(
+                "MovingAverageMinMaxObserver doesn't support dynamic quantization for "
+                f"averaging constant of {self.averaging_constant}"
+            )
+
+        super().__init__(
+            mapping_type=mapping_type,
+            target_dtype=target_dtype,
+            granularity=granularity,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            eps=eps,
+            scale_dtype=scale_dtype,
+            zero_point_dtype=zero_point_dtype,
+            preserve_zero=preserve_zero,
+            zero_point_domain=zero_point_domain,
+        )
+
+    def forward(self, input: torch.Tensor):
+        if input.numel() == 0:
+            return input
+
+        input_detached = input.detach()
+        self.original_dtype = input_detached.dtype
+        assert self.granularity is not None, "granularity is None"
+        self.block_size = get_block_size(input_detached.shape, self.granularity)
+
+        shape_for_reduction, reduction_dims = _get_reduction_params(
+            self.block_size, input_detached.size()
+        )
+        input_detached = input_detached.view(shape_for_reduction)
+        min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
+        max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
+        if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
+            self.min_val = min_val
+            self.max_val = max_val
+        else:
+            assert self.min_val.shape == min_val.shape, (
+                f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}"
+            )
+            assert self.max_val.shape == max_val.shape, (
+                f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}"
+            )
+            min_val = self.min_val + self.averaging_constant * (min_val - self.min_val)
+            max_val = self.max_val + self.averaging_constant * (max_val - self.max_val)
+            self.min_val.copy_(min_val)
+            self.max_val.copy_(max_val)
+
+        # returning original input
+        return input
+
+    def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]:
+        assert hasattr(self, "min_val") and hasattr(self, "max_val"), (
+            "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
+        )
+
+        return choose_qparams_affine_with_min_max(
+            self.min_val,
+            self.max_val,
+            self.mapping_type,
+            [],  # BlockSize is not needed because the min/max are already reduced
+            self.target_dtype,
+            self.quant_min,
+            self.quant_max,
+            self.eps,
+            self.scale_dtype,
+            self.zero_point_dtype,
+            self.preserve_zero,
+            self.zero_point_domain,
+        )
+
+
+class AffineQuantizedPlaceholderObserver(AffineQuantizedObserverBase):
+    def __init__(
+        self,
+        mapping_type: MappingType,
+        target_dtype: torch.dtype,
+        granularity: Granularity,
+        quant_min: Optional[int] = None,
+        quant_max: Optional[int] = None,
+        eps: Optional[float] = None,
+        is_dynamic=False,
+        scale_dtype: Optional[torch.dtype] = None,
+        zero_point_dtype: Optional[torch.dtype] = None,
+        preserve_zero: bool = True,
+        zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
+        # there could be some extra args that's ignored
+        **kwargs,
+    ):
+        self.is_dynamic = is_dynamic
+
+        super().__init__(
+            mapping_type=mapping_type,
+            target_dtype=target_dtype,
+            granularity=granularity,
+            quant_min=quant_min,
+            quant_max=quant_max,
+            eps=eps,
+            scale_dtype=scale_dtype,
+            zero_point_dtype=zero_point_dtype,
+            preserve_zero=preserve_zero,
+            zero_point_domain=zero_point_domain,
+        )
+
+    def forward(self, input):
+        self.block_size = get_block_size(input.shape, self.granularity)
+        self.original_dtype = input.dtype
+        return input
+
+    def calculate_qparams(self):
+        raise Exception(  # noqa: TRY002
+            "calculate_qparams should not be called for PlaceholderObserver"
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py
new file mode 100644
index 0000000000000000000000000000000000000000..81c6a2060e76b42b42c2d8b771e7122da5485432
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/_numeric_debugger.py
@@ -0,0 +1,342 @@
+import copy
+import logging
+from collections.abc import Sequence
+from dataclasses import dataclass
+from typing import Callable, Optional
+
+import torch
+from torch.ao.ns.fx.utils import compute_sqnr
+from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
+from torch.export import ExportedProgram
+from torch.fx import GraphModule, Node
+from torch.nn import functional as F
+
+
+NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
+CUSTOM_KEY = "custom"
+
+log = logging.getLogger(__name__)
+
+
+def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
+    """
+    Attach numeric_debug_handle_id for all nodes in the graph module of the given
+    ExportedProgram, like conv2d, squeeze, conv1d, etc, except for placeholder.
+    Notice that nodes like getattr are out of scope since they are not in the graph.
+
+    The graph nodes of input exported program are modified inplace.
+
+    Here's an example of using debug handle quantize flow::
+
+        ep = export_for_training(eager_model, example_inputs)
+        generate_numeric_debug_handle(ep)
+
+        m = ep.module()
+        quantizer = XNNPACKQuantizer()
+        m = prepare_pt2e(m, quantizer)
+        m = convert_pt2e(m)
+    """
+
+    # Sanity check the input data type
+    if not isinstance(ep, ExportedProgram):
+        raise ValueError(
+            f"Expected ep to be ExportedProgram, got {type(ExportedProgram)}"
+        )
+
+    unique_id = 0
+
+    def _find_max_id(node: torch.fx.Node) -> None:
+        nonlocal unique_id
+        unique_id = max(
+            unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
+        )
+
+    def _assign_debug_handle(node: torch.fx.Node) -> None:
+        nonlocal unique_id
+        if CUSTOM_KEY not in node.meta:
+            node.meta[CUSTOM_KEY] = {}
+
+        if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]:
+            node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
+            unique_id += 1
+
+    # Find the max ID that exists in the graph first, in case part of the graph
+    # has already been annotated. This way we guarantee there are no duplicate
+    # handle IDs.
+    bfs_trace_with_node_process(ep, _find_max_id)
+
+    unique_id += 1
+
+    # Assign debug handles to all nodes in the graph that don't have one based on the
+    # max ID found in the previous step.
+    bfs_trace_with_node_process(ep, _assign_debug_handle)
+
+
+def _detach(x: object) -> object:
+    detached: object = None
+    if isinstance(x, torch.Tensor):
+        detached = x.detach()
+    elif isinstance(x, (list, tuple)):
+        detached = type(x)([_detach(e) for e in x])
+    elif isinstance(x, dict):
+        detached = {k: _detach(e) for k, e in x.items()}
+    else:
+        detached = x
+    return detached
+
+
+def _tensor_shape_equals(x: object, y: object) -> bool:
+    if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
+        return x.shape == y.shape
+    elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
+        return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y))
+    elif isinstance(x, dict) and isinstance(y, dict):
+        all_equal = True
+        for k in x:
+            all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k]))
+        return all_equal
+    else:
+        log.debug("Comparing non Tensors: %s and %s, they must be equal", x, y)
+        return type(x) == type(y) and x == y
+
+
+def _loss_fn(
+    loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object
+) -> object:
+    """The returned loss will have the same structure as `x` and `y`, e.g.
+    if both are Tensor, we'll return a Tensor
+    if both are list, we'll return a list of Tensors
+    if both are dict, we'll return a dict with the same key, and value being the loss between the
+    two Tensors
+    """
+    if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
+        return loss(x.to(torch.float32), y.to(torch.float32))
+    elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
+        return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)])
+    elif isinstance(x, dict) and isinstance(y, dict):
+        return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()}
+    else:
+        return None
+
+
+class OutputLogger(torch.nn.Module):
+    """
+    Base class for capturing output values for nodes in a GraphModule, it only captures
+    Tensor output currently, but we can extend it to work for other types of inputs later if needed
+    """
+
+    # Mark as impure so that calls to it will not be removed during DCE.
+    _is_impure = True
+
+    def __init__(
+        self,
+        debug_handle: int,
+        node_name: Optional[str] = None,
+        nn_module_stack: Optional[object] = None,
+    ) -> None:
+        super().__init__()
+        self.node_name = node_name
+        self.nn_module_stack = nn_module_stack
+        self.debug_handle = debug_handle
+        self.stats: list[object] = []
+
+    def forward(self, x: object) -> object:
+        self.stats.append(_detach(x))
+        return x
+
+    def __extra_repr__(self) -> str:
+        return (
+            f"debug_handle={self.debug_handle}, node_name={self.node_name}, "
+            "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})"
+        )
+
+
+def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node:
+    """For a given node, adds an OutputLogger that observes the output of that node,
+    and all its users use the OutputLogger output instead.
+    The OutputLogger will contain the debug_handle which can be used to compare
+    graphs after transforms"""
+
+    # to avoid circular dep
+    from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
+
+    # add a logger after the node
+    with model.graph.inserting_after(node):
+        get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger")
+        logger_name = get_new_attr_name(model)
+        setattr(
+            model,
+            logger_name,
+            OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")),
+        )
+        logger_node = model.graph.call_module(logger_name, (node,), {})
+
+    orig_users = list(node.users.keys())
+    for user_node in orig_users:
+        if user_node is logger_node:
+            continue
+        user_node.replace_input_with(node, logger_node)
+
+    return logger_node
+
+
+def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule:
+    """Add output loggers to node that has numeric_debug_handle
+
+    Args:
+        model (GraphModule): original model
+    Returns:
+        a model with output loggers for all nodes that has numeric_debug_handle_id
+    """
+    # don't change the original model
+    model = copy.deepcopy(model)
+    for n in model.graph.nodes:
+        if (
+            CUSTOM_KEY not in n.meta
+            or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY]
+        ):
+            continue
+        numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
+        _insert_logger(model, n, numeric_debug_handle)
+
+    model.recompile()
+    return model
+
+
+@dataclass(frozen=True)
+class QuantizationComparisonResult:
+    actual: torch.Tensor
+    ref: torch.Tensor
+
+    @property
+    def mse_loss(self) -> object:
+        return self.loss(F.mse_loss)
+
+    @property
+    def sqnr(self) -> object:
+        return self.loss(compute_sqnr)
+
+    def loss(
+        self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
+    ) -> object:
+        return _loss_fn(loss_function, self.actual, self.ref)
+
+    def __repr__(self) -> str:
+        # Don't include the tensors themselves as they are quite large to print
+        # out.
+        return (
+            f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})"
+        )
+
+    def __post_init__(self) -> None:
+        if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)):
+            raise ValueError(
+                f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}"
+            )
+
+        if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)):
+            raise ValueError(
+                f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}"
+            )
+
+        if not _tensor_shape_equals(self.ref, self.actual):
+            raise ValueError(
+                f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}"
+            )
+
+
+@dataclass(frozen=True)
+class NodeAccuracySummary:
+    handle: int
+    actual_node_name: str
+    actual_module_stack: str
+    ref_node_name: str
+    ref_module_stack: str
+    results: Sequence[QuantizationComparisonResult]
+
+
+def _module_stack_to_str(module_stack: object) -> str:
+    """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear")
+    to "mod.foo.0.linear"
+    """
+    if not isinstance(module_stack, dict):
+        return str(module_stack)
+    module_values_list = list(module_stack.values())
+    if len(module_values_list) > 0:
+        owning_module = module_values_list[-1][0]
+        return str(owning_module)
+    else:
+        return str(module_stack)
+
+
+def extract_results_from_loggers(
+    model: GraphModule,
+) -> dict[int, tuple[Optional[str], object, list[object]]]:
+    """For a given model, extract the tensors stats and related information for each debug handle.
+    The reason we have a list of object, instead of Tensor is because the output of node may not be
+    a Tensor, it could be (nested) list, tuple or dict as well.
+
+    Returns:
+        A dict is keyed by the debug_handle id and the values are a list of object recorded
+        in loggers
+
+    """
+    # Results maps debug handle to a tensor list for each model being compared.
+    handles: dict[int, tuple[Optional[str], object, list[object]]] = {}
+    for _name, module in model.named_children():
+        if isinstance(module, OutputLogger) and len(module.stats) > 0:
+            handles[module.debug_handle] = (
+                module.node_name,
+                module.nn_module_stack,
+                module.stats,
+            )
+
+    return handles
+
+
+def compare_results(
+    ref_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]],
+    actual_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]],
+) -> dict[int, NodeAccuracySummary]:
+    """Given two dict mapping from `debug_handle_id` (int) to list of tensors
+    return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
+    comparison information like SQNR, MSE etc.
+
+    Args:
+        ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id
+        actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id
+
+    Returns:
+        Dict[int, NodeAccuracySummary]
+    """
+    comparisons = {}
+    for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items():
+        if debug_handle not in actual_results:
+            log.debug(
+                "Cannot compare for handle %s because it wasn't found in the transformed model",
+                debug_handle,
+            )
+            continue
+        actual_name, actual_stack, actual_stats = actual_results[debug_handle]
+        try:
+            results = [
+                QuantizationComparisonResult(actual=a, ref=b)
+                for a, b in zip(actual_stats, ref_stats)
+            ]
+        except Exception as e:
+            # Add extra information for an exception from QuantizationComparisonResult
+            # if the shapes didn't match, to include the handle and the node names.
+            raise ValueError(
+                f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
+            ) from e
+
+        comparisons[debug_handle] = NodeAccuracySummary(
+            handle=debug_handle,
+            actual_node_name=actual_name or "",
+            actual_module_stack=_module_stack_to_str(actual_stack),
+            ref_node_name=ref_name or "",
+            ref_module_stack=_module_stack_to_str(ref_stack),
+            results=results,
+        )
+
+    return comparisons
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py
new file mode 100644
index 0000000000000000000000000000000000000000..163184c00f1d1ddfba371ceb5a0948ff863e9183
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py
@@ -0,0 +1,82 @@
+# mypy: allow-untyped-defs
+import logging
+import operator
+
+import torch
+from torch.ao.quantization.pt2e.utils import (
+    _filter_sym_size_users,
+    _is_valid_annotation,
+)
+from torch.fx.node import map_arg
+from torch.fx.passes.infra.pass_base import PassBase, PassResult
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.WARNING)
+
+__all__ = ["DuplicateDQPass"]
+
+_QUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.quantize_per_tensor.default,
+    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.quantize_per_channel.default,
+]
+
+_DEQUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.dequantize_per_channel.default,
+]
+
+
+def _maybe_duplicate_dq(
+    gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
+):
+    annotation = user.meta.get("quantization_annotation", None)
+    if not _is_valid_annotation(annotation):  # type: ignore[arg-type]
+        return
+    with gm.graph.inserting_after(dq_node):
+        new_node = gm.graph.node_copy(dq_node)
+
+        def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
+            if n == dq_node:
+                return new_node
+            else:
+                return n
+
+        new_args = map_arg(user.args, maybe_replace_node)
+        new_kwargs = map_arg(user.kwargs, maybe_replace_node)
+        user.args = new_args  # type: ignore[assignment]
+        user.kwargs = new_kwargs  # type: ignore[assignment]
+
+
+class DuplicateDQPass(PassBase):
+    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
+        for node in graph_module.graph.nodes:
+            if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
+                dq_users = _filter_sym_size_users(node)
+                if len(dq_users) <= 1:
+                    continue
+                # Do not duplicate dq for dynamic quantization
+                # Pattern: choose_qparam - getitem - q - dq
+                q_node = node.args[0]
+                if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
+                    getitem_node = q_node.args[1]
+                    if (
+                        isinstance(getitem_node, torch.fx.node.Node)
+                        and getitem_node.op == "call_function"
+                        and getitem_node.target == operator.getitem
+                    ):
+                        choose_qparam_node = getitem_node.args[0]
+                        if (
+                            isinstance(choose_qparam_node, torch.fx.node.Node)
+                            and choose_qparam_node.op == "call_function"
+                            and choose_qparam_node.target
+                            == torch.ops.quantized_decomposed.choose_qparams.tensor
+                        ):
+                            continue
+                for user in dq_users:
+                    _maybe_duplicate_dq(graph_module, node, user)
+        graph_module.graph.eliminate_dead_code()
+        graph_module.recompile()
+        return PassResult(graph_module, True)
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..70cca73dd00dcb4bd865dda4f2718a610496323e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/export_utils.py
@@ -0,0 +1,240 @@
+# mypy: allow-untyped-defs
+import types
+
+import torch
+import torch.nn.functional as F
+from torch.ao.quantization.utils import _assert_and_get_unique_device
+
+
+__all__ = [
+    "model_is_exported",
+]
+
+_EXPORTED_TRAINING_ATTR = "_exported_training"
+
+
+class _WrapperModule(torch.nn.Module):
+    """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
+    are trying to export a callable.
+    """
+
+    def __init__(self, fn):
+        super().__init__()
+        self.fn = fn
+
+    def forward(self, *args, **kwargs):
+        """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
+        return self.fn(*args, **kwargs)
+
+
+def model_is_exported(m: torch.nn.Module) -> bool:
+    """
+    Return True if the `torch.nn.Module` was exported, False otherwise
+    (e.g. if the model was FX symbolically traced or not traced at all).
+    """
+    return isinstance(m, torch.fx.GraphModule) and any(
+        "val" in n.meta for n in m.graph.nodes
+    )
+
+
+def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
+    """
+    Switch dropout patterns in the model between train and eval modes.
+
+    Dropout has different behavior in train vs eval mode. For exported models,
+    however, calling `model.train()` or `model.eval()` does not automatically switch
+    the dropout behavior between the two modes, so here we need to rewrite the aten
+    dropout patterns manually to achieve the same effect.
+
+    See https://github.com/pytorch/pytorch/issues/103681.
+    """
+    # Avoid circular dependencies
+    from .utils import _get_aten_graph_module_for_pattern
+
+    # Needed to ensure subgraph matches are self-contained
+    m.graph.eliminate_dead_code()
+    m.recompile()
+
+    for inplace in [False, True]:
+
+        def dropout_train(x):
+            return F.dropout(x, p=0.5, training=True, inplace=inplace)
+
+        def dropout_eval(x):
+            return F.dropout(x, p=0.5, training=False, inplace=inplace)
+
+        example_inputs = (torch.randn(1),)
+        if train_to_eval:
+            match_pattern = _get_aten_graph_module_for_pattern(
+                _WrapperModule(dropout_train),
+                example_inputs,
+            )
+            replacement_pattern = _get_aten_graph_module_for_pattern(
+                _WrapperModule(dropout_eval),
+                example_inputs,
+            )
+        else:
+            match_pattern = _get_aten_graph_module_for_pattern(
+                _WrapperModule(dropout_eval),
+                example_inputs,
+            )
+            replacement_pattern = _get_aten_graph_module_for_pattern(
+                _WrapperModule(dropout_train),
+                example_inputs,
+            )
+
+        from torch.fx.subgraph_rewriter import replace_pattern_with_filters
+
+        replace_pattern_with_filters(
+            m,
+            match_pattern,
+            replacement_pattern,
+            match_filters=[],
+            ignore_literals=True,
+        )
+        m.recompile()
+
+
+def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
+    """
+    Switch batchnorm patterns in the model between train and eval modes.
+
+    Batchnorm has different behavior in train vs eval mode. For exported models,
+    however, calling `model.train()` or `model.eval()` does not automatically switch
+    the batchnorm behavior between the two modes, so here we need to rewrite the aten
+    batchnorm patterns manually to achieve the same effect.
+    """
+    # TODO(Leslie): This function still fails to support custom momentum and eps value.
+    # Enable this support in future updates.
+
+    # Avoid circular dependencies
+    from .utils import _get_aten_graph_module_for_pattern
+
+    # Needed to ensure subgraph matches are self-contained
+    m.graph.eliminate_dead_code()
+    m.recompile()
+
+    def bn_train(
+        x: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ):
+        return F.batch_norm(
+            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
+        )
+
+    def bn_eval(
+        x: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ):
+        return F.batch_norm(
+            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
+        )
+
+    example_inputs = (
+        torch.randn(1, 1, 3, 3),  # x
+        torch.randn(1),  # bn_weight
+        torch.randn(1),  # bn_bias
+        torch.randn(1),  # bn_running_mean
+        torch.randn(1),  # bn_running_var
+    )
+
+    device = _assert_and_get_unique_device(m)
+    is_cuda = device is not None and device.type == "cuda"
+    bn_train_aten = _get_aten_graph_module_for_pattern(
+        _WrapperModule(bn_train),
+        example_inputs,
+        is_cuda,
+    )
+    bn_eval_aten = _get_aten_graph_module_for_pattern(
+        _WrapperModule(bn_eval),
+        example_inputs,
+        is_cuda,
+    )
+
+    if train_to_eval:
+        match_pattern = bn_train_aten
+        replacement_pattern = bn_eval_aten
+    else:
+        match_pattern = bn_eval_aten
+        replacement_pattern = bn_train_aten
+
+    from torch.fx.subgraph_rewriter import replace_pattern_with_filters
+
+    replace_pattern_with_filters(
+        m,
+        match_pattern,
+        replacement_pattern,
+        match_filters=[],
+        ignore_literals=True,
+    )
+    m.recompile()
+
+
+# TODO: expose these under this namespace?
+def _move_exported_model_to_eval(model: torch.fx.GraphModule):
+    """
+    Move an exported GraphModule to eval mode.
+
+    This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
+    QAT users should call this before performing inference on the model.
+
+    This call is idempotent; if the model is already in eval mode, nothing will happen.
+    """
+    is_training = getattr(model, _EXPORTED_TRAINING_ATTR, True)
+    if not is_training:
+        return model
+    setattr(model, _EXPORTED_TRAINING_ATTR, False)
+    _replace_dropout(model, train_to_eval=True)
+    _replace_batchnorm(model, train_to_eval=True)
+    return model
+
+
+def _move_exported_model_to_train(model: torch.fx.GraphModule):
+    """
+    Move an exported GraphModule to train mode.
+
+    This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
+    QAT users should call this before performing training on the model.
+
+    This call is idempotent; if the model is already in train mode, nothing will happen.
+    """
+    is_training = getattr(model, _EXPORTED_TRAINING_ATTR, False)
+    if is_training:
+        return model
+    setattr(model, _EXPORTED_TRAINING_ATTR, True)
+    _replace_dropout(model, train_to_eval=False)
+    _replace_batchnorm(model, train_to_eval=False)
+    return model
+
+
+def _allow_exported_model_train_eval(model: torch.fx.GraphModule):
+    """
+    Allow users to call `model.train()` and `model.eval()` on an exported model,
+    but with the effect of changing behavior between the two modes limited to special
+    ops only, which are currently dropout and batchnorm.
+
+    Note: This does not achieve the same effect as what `model.train()` and `model.eval()`
+    does in eager models, but only provides an approximation. In particular, user code
+    branching on `training` flag will not function correctly in general because the branch
+    is already specialized at export time. Additionally, other ops beyond dropout and batchnorm
+    that have different train/eval behavior will also not be converted properly.
+    """
+
+    def _train(self, mode: bool = True):
+        if mode:
+            _move_exported_model_to_train(self)
+        else:
+            _move_exported_model_to_eval(self)
+
+    def _eval(self):
+        _move_exported_model_to_eval(self)
+
+    model.train = types.MethodType(_train, model)  # type: ignore[method-assign]
+    model.eval = types.MethodType(_eval, model)  # type: ignore[method-assign]
+    return model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..80520d1ef0d0db30a35decd6752e7c2fd8e70bf9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/graph_utils.py
@@ -0,0 +1,181 @@
+# mypy: allow-untyped-defs
+import itertools
+import operator
+from collections import OrderedDict
+from collections.abc import Sequence
+from typing import Any, Callable, Optional, Union
+
+import torch
+from torch.export import ExportedProgram
+from torch.fx import Node
+from torch.fx.passes.utils.source_matcher_utils import (
+    check_subgraphs_connected,
+    get_source_partitions,
+    SourcePartition,
+)
+
+
+__all__ = [
+    "find_sequential_partitions",
+    "get_equivalent_types",
+    "update_equivalent_types_dict",
+    "bfs_trace_with_node_process",
+]
+
+_EQUIVALENT_TYPES: list[set] = [
+    {torch.nn.Conv1d, torch.nn.functional.conv1d},
+    {torch.nn.Conv2d, torch.nn.functional.conv2d},
+    {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
+    {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
+    {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
+    {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
+    {torch.add, operator.add, operator.iadd, "add", "add_"},
+    {torch.mul, operator.mul, operator.imul, "mul", "mul_"},
+]
+
+
+def _create_equivalent_types_dict():
+    _DICT = {}
+    for values in _EQUIVALENT_TYPES:
+        for v in values:
+            _DICT[v] = list(values)
+    return _DICT
+
+
+_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
+
+
+def get_equivalent_types() -> list[set]:
+    return _EQUIVALENT_TYPES
+
+
+def update_equivalent_types_dict(customized_equivalent_types=None):
+    """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
+    When customized_equivalent_types passes in,
+    re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
+    """
+    if customized_equivalent_types is None:
+        raise ValueError("customized_equivalent_types should not be None")
+    global _EQUIVALENT_TYPES
+    global _EQUIVALENT_TYPES_DICT
+    _EQUIVALENT_TYPES = customized_equivalent_types
+    _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
+
+
+def _partitions_sequential(partitions: Sequence[SourcePartition]):
+    prev_partition = None
+    for partition in partitions:
+        if prev_partition is not None and not check_subgraphs_connected(
+            prev_partition, partition
+        ):
+            return False
+        prev_partition = partition
+    return True
+
+
+def _get_matching_types(partition_type):
+    matching_types = [partition_type]
+    if partition_type in _EQUIVALENT_TYPES_DICT:
+        matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type])
+    return matching_types
+
+
+def _valid_type_sequence(partition_types: list[Any]):
+    partition_types_set = set()  # type: ignore[var-annotated]
+    for partition_type in partition_types:
+        matching_types = _get_matching_types(partition_type)
+        matching_types_set = set(matching_types)
+        if len(partition_types_set & matching_types_set) > 0:
+            return False
+        partition_types_set |= matching_types_set
+    return True
+
+
+def find_sequential_partitions(
+    gm: torch.fx.GraphModule,
+    partition_types: list[Any],
+    include_functional_equivalent=True,
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+):
+    if not _valid_type_sequence(partition_types):
+        raise ValueError(
+            f"Invalid partition types: {partition_types}. Each type in the sequence must be unique"
+        )
+
+    typed_partitions: OrderedDict[Any, list[SourcePartition]] = OrderedDict()
+    for partition_type in partition_types:
+        types_to_match = _get_matching_types(partition_type)
+        partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
+        typed_partitions[partition_type] = list(
+            itertools.chain.from_iterable(partitions.values())
+        )
+
+    typed_partitions_list = list(typed_partitions.values())
+    fusion_candidates = itertools.product(*typed_partitions_list)
+    fused_partitions = [
+        candidate
+        for candidate in fusion_candidates
+        if _partitions_sequential(candidate)
+    ]
+    return fused_partitions
+
+
+def _get_submodule(
+    graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
+) -> tuple[str, torch.nn.Module, torch.fx.Node]:
+    submod_node = node.args[arg_index]
+    assert isinstance(submod_node, torch.fx.Node)
+    assert submod_node.op == "get_attr"
+    assert isinstance(submod_node.target, str)
+    submodule = graph_module.get_submodule(submod_node.target)
+    # pyre-ignore
+    return submod_node.target, submodule, node
+
+
+def _get_control_flow_submodules(
+    graph_module: torch.fx.GraphModule,
+) -> list[tuple[str, torch.nn.Module, torch.fx.Node]]:
+    """
+    Returns a list of submodules used for control flow operations
+    (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
+    into submodules). Specifically, the returned value is a list containing a
+    tuple of (name of the submodule that's stored in the graph module, the
+    submodule itself, and the fx node that uses this submodule).
+    """
+    control_flow_submodules = []
+    for node in graph_module.graph.nodes:
+        if node.op != "call_function":
+            continue
+
+        if node.target is torch.ops.higher_order.cond:
+            control_flow_submodules.append(_get_submodule(graph_module, node, 1))
+            control_flow_submodules.append(_get_submodule(graph_module, node, 2))
+        if node.target is torch.ops.higher_order.map_impl:
+            control_flow_submodules.append(_get_submodule(graph_module, node, 0))
+
+    return control_flow_submodules
+
+
+def bfs_trace_with_node_process(
+    model: Union[ExportedProgram, torch.fx.GraphModule], node_op: Callable
+) -> None:
+    """Traverse the graph module and apply node_op to each node."""
+
+    assert isinstance(model, (ExportedProgram, torch.fx.GraphModule)), (
+        f"Expected GraphModule or ExportedProgram, got {type(model)}"
+    )
+    gm = model.graph_module if isinstance(model, ExportedProgram) else model
+    queue = [gm]
+    while queue:
+        current_graph_module = queue.pop(0)
+        for node in current_graph_module.graph.nodes:
+            if node.op in ["output", "placeholder"]:
+                continue
+
+            node_op(node)
+
+        control_flow_submodules = [
+            submodule
+            for _, submodule, _ in _get_control_flow_submodules(current_graph_module)
+        ]
+        queue.extend(control_flow_submodules)
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py
new file mode 100644
index 0000000000000000000000000000000000000000..587cee22560df9b0f0f3a38ee7eb6f1d04ffba8e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/lowering.py
@@ -0,0 +1,60 @@
+import torch
+from torch._inductor.constant_folding import constant_fold
+from torch._inductor.fx_passes.freezing_patterns import freezing_passes
+
+
+__all__ = [
+    "lower_pt2e_quantized_to_x86",
+]
+
+
+def lower_pt2e_quantized_to_x86(
+    model: torch.fx.GraphModule,
+    example_inputs: tuple[torch.Tensor, ...],
+) -> torch.fx.GraphModule:
+    """Lower a PT2E-qantized model to x86 backend.
+
+    Args:
+    * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow.
+    * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model.
+
+    Return:
+    A GraphModule lowered to x86 backend.
+    """
+
+    def _post_autograd_decomp_table():  # type: ignore[no-untyped-def]
+        decomp_table = torch.export.default_decompositions()
+
+        # if we are post-autograd, we shouldn't
+        # decomp prim ops.
+        for k in list(decomp_table.keys()):
+            if not torch._export.utils._is_cia_op(k):
+                del decomp_table[k]
+
+        return decomp_table
+
+    def _node_replace(m):  # type: ignore[no-untyped-def]
+        # Replace aten.t(x) with aten.permute(x, [1, 0])
+        aten = torch.ops.aten
+        g = m.graph
+        for node in g.nodes:
+            if node.target == aten.t.default:
+                with g.inserting_before(node):
+                    x = node.args[0]
+                    dims = [1, 0]
+                    perm_node = g.call_function(aten.permute.default, args=(x, dims))
+                    node.replace_all_uses_with(perm_node)
+                    g.erase_node(node)
+
+        g.lint()
+        m.recompile()
+
+    lowered_model = (
+        torch.export.export_for_training(model, example_inputs, strict=True)
+        .run_decompositions(_post_autograd_decomp_table())
+        .module()
+    )
+    _node_replace(lowered_model)
+    freezing_passes(lowered_model, example_inputs)
+    constant_fold(lowered_model)
+    return lowered_model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c96f915306d3e316a4bb8a9c8be41dcec00900b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py
@@ -0,0 +1,218 @@
+# mypy: allow-untyped-defs
+import logging
+from typing import Optional
+
+import torch
+from torch._export.error import InternalError
+from torch.ao.quantization.pt2e.utils import (
+    _filter_sym_size_users,
+    _find_q_dq_node_for_user,
+    _is_valid_annotation,
+)
+from torch.ao.quantization.quantizer import QuantizationSpecBase
+from torch.fx.passes.infra.pass_base import PassBase, PassResult
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.ERROR)
+
+__all__ = ["PortNodeMetaForQDQ"]
+
+_METADATA_TO_PORT = [
+    "stack_trace",
+    "quantization_tag",
+]
+
+_QUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.quantize_per_tensor.default,
+    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.quantize_per_channel.default,
+    torch.ops.pt2e_quant.quantize_affine,
+]
+
+_DEQUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.dequantize_per_channel.default,
+    torch.ops.pt2e_quant.dequantize_affine,
+]
+
+_CHOOSE_QPARAMS_OPS = [
+    torch.ops.quantized_decomposed.choose_qparams.tensor,
+    torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
+    torch.ops.pt2e_quant.choose_qparams_affine,
+]
+
+
+def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None:
+    from_meta = from_node.meta
+    for meta_name in _METADATA_TO_PORT:
+        if meta_name in from_meta:
+            to_node.meta[meta_name] = from_meta[meta_name]
+
+
+def _has_quant_annotation(node: torch.fx.Node) -> bool:
+    return "quantization_annotation" in node.meta
+
+
+def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
+    # BFS to look for choose qparams
+    from collections import deque
+
+    queue = deque(list(node.users.keys()))
+    while len(queue):
+        n = queue.popleft()
+        if n.op == "output":
+            continue
+        if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS:
+            return n
+        for k in n.users.keys():
+            queue.append(k)
+    return None
+
+
+def _port_metadata_for_input_quant_nodes(
+    input_node: torch.fx.Node,
+    node: torch.fx.Node,
+    qspec: Optional[QuantizationSpecBase],
+):
+    if qspec is None:
+        return
+
+    is_dynamic_quant = getattr(qspec, "is_dynamic", None)
+    if is_dynamic_quant is not None and is_dynamic_quant is True:
+        choose_qparams_node = _find_choose_qparams_node(input_node)
+        if choose_qparams_node is None:
+            raise ValueError(f"No chose qparams node found for {node}")
+        choose_qparam_users = _filter_sym_size_users(choose_qparams_node)
+        if len(choose_qparam_users) != 2:
+            raise InternalError(f"Expecting exactly two user for {choose_qparams_node}")
+        scale_node = choose_qparam_users.pop()
+        dynamic_q_node = next(iter(scale_node.users.keys()))
+        dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node)
+        if len(dynamic_q_node_users) > 1:
+            raise InternalError(f"Expecting single user for {dynamic_q_node}")
+        dynamic_dq_node = dynamic_q_node_users.pop()
+        _add_metadata(choose_qparams_node, node)
+        _add_metadata(dynamic_q_node, node)
+        _add_metadata(dynamic_dq_node, node)
+    else:
+        q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
+        if q_node is None or dq_node is None:
+            return
+        # add metadata for all the node between q_node and get_attr node
+        # if the q_node can be traced back to get_attr node
+        q_to_get_attr_nodes = [q_node]
+        q_node_input = q_node.args[0]
+        while (
+            isinstance(q_node_input, torch.fx.Node)
+            and q_node_input.op == "call_function"
+            and q_node_input.target
+            in [
+                torch.ops.aten.flatten.using_ints,
+                torch.ops.aten.permute.default,
+                torch.ops.aten.permute_copy.default,
+                torch.ops.aten.slice_copy.Tensor,
+                torch.ops.aten.squeeze.dim,
+                torch.ops.aten.squeeze_copy.dim,
+                torch.ops.aten.transpose.Dimname,
+                torch.ops.aten.transpose.int,
+                torch.ops.aten.transpose_,
+                torch.ops.aten.view_copy.default,
+                torch.ops.aten.view.default,
+                torch.ops.aten._mkldnn_transpose,
+            ]
+        ):
+            q_to_get_attr_nodes.append(q_node_input)
+            q_node_input = q_node_input.args[0]
+        if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr":
+            for n in q_to_get_attr_nodes:
+                _add_metadata(n, q_node_input)
+        _add_metadata(dq_node, node)
+
+
+def _port_metadata_for_output_quant_nodes(
+    node: torch.fx.Node, qspec: Optional[QuantizationSpecBase]
+):
+    if qspec is None:
+        return
+
+    node_users = _filter_sym_size_users(node)
+    if len(node.users) == 0:
+        return
+    if len(node_users) != 1:
+        logger.warning(f"Expecting {node} to have single user")  # noqa: G004
+    q_node = node_users.pop()
+    if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS:
+        logger.warning(
+            f"Expecting {node} user to be a quantized op but got {q_node}"  # noqa: G004
+        )  # noqa: G004
+        return
+
+    _add_metadata(q_node, node)
+
+
+class PortNodeMetaForQDQ(PassBase):
+    """
+    Port metadata for nodes added by quantization flow.
+    For static quant these are:
+    - quantizer_per_tensor.default, dequantize_per_tensor.default
+    - quantizer_per_channel.default, dequantize_per_channel.default
+    For dynamic quant these are:
+    - choose_qparams.tensor
+    - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor
+    - quantizer_per_channel.default, dequantize_per_channel.default
+
+    Rules of porting metadata:
+    - Metadata to be ported:
+      - nn_module_stack
+      - stack_trace
+      - quantization_tag
+    - Metadata to NOT be ported:
+      - Everything else
+    - Rules:
+      - Statically quantized patterns:
+        - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node.
+        - Quantize nodes on the outputs inherit metadata of the producer node.
+        - Example 1:
+          - Original: [Conv -> AvgPool -> Linear]
+          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
+          - Inner brackets specify which nodes Q/DQ inherit metdata from
+          - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ]
+          - Note first Q and last DQ do not inherit metadata from any nodes
+        - Example 2:
+          - Original: [Conv -> AvgPool -> Linear]
+          - AvgPool is not quantized
+          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
+          - Inner brackets specify which nodes Q/DQ inherit metdata from
+          - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ]
+          - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because
+            AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation
+            on the nodes (in this case AvgPool node) to conclude if the node or patter was
+            supposed to be quantized. And subsequntly decide if the preceding Q, if any, should
+            inherit metadata from AvgPool.
+      - Dynamically quantized patterns:
+        - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes
+        - For example, below linear is dynamically quantized while rest statically:
+          - Original: [Conv -> AvgPool -> Linear]
+          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
+          - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
+          - Note first Q does not inherit metadata from any nodes
+    NB:
+    - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely
+      knows which quantization spec is converted to q/dq and thus from where the metadata should be ported.
+      However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit.
+      Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant
+      code, this pass should like to be integrated in the refactored variant of "convert" step.
+    """
+
+    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
+        for node in graph_module.graph.nodes:
+            annotation = node.meta.get("quantization_annotation", None)
+            if _is_valid_annotation(annotation):
+                input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
+                output_qspec = node.meta["quantization_annotation"].output_qspec
+                for input_node, qspec in input_qspec_map.items():
+                    _port_metadata_for_input_quant_nodes(input_node, node, qspec)
+                _port_metadata_for_output_quant_nodes(node, output_qspec)
+        return PassResult(graph_module, True)
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..15aa1136805365de39c75f2f1239fc38cd9535c1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/prepare.py
@@ -0,0 +1,574 @@
+# mypy: allow-untyped-defs
+from typing import Any, Optional, Union
+
+import torch
+from torch._subclasses import FakeTensor
+from torch.ao.quantization import (
+    CUSTOM_KEY,
+    NUMERIC_DEBUG_HANDLE_KEY,
+    ObserverOrFakeQuantize,
+    QConfigMapping,
+)
+from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
+from torch.ao.quantization.fx.prepare import (
+    _create_obs_or_fq_from_qspec,
+    _insert_obs_or_fq,
+    _is_activation_post_process_node,
+    _save_state,
+)
+from torch.ao.quantization.qconfig import QConfigAny
+from torch.ao.quantization.quantizer import (
+    EdgeOrNode,
+    QuantizationSpecBase,
+    SharedQuantizationSpec,
+)
+from torch.fx import Graph, GraphModule, Node
+from torch.fx.node import Argument
+
+
+# TODO: make pt2e folder private?
+__all__ = [
+    "prepare",
+]
+
+
+def _find_root_edge_or_node(
+    edge_or_node: EdgeOrNode, shared_with_map: dict[EdgeOrNode, EdgeOrNode]
+) -> EdgeOrNode:
+    """Find the root node for the sharing tree
+    Args:
+        edge_or_node: edge/node that we want to find the root
+        shared_with_map: each edge/node points to the parent, the root node will points to itself
+
+    Returns:
+        root edge/node
+    """
+    parent = shared_with_map[edge_or_node]
+    if parent == edge_or_node:
+        return edge_or_node
+    root = _find_root_edge_or_node(parent, shared_with_map)
+    # path compression
+    shared_with_map[edge_or_node] = root
+    return root
+
+
+def _union(
+    parent: EdgeOrNode,
+    child: EdgeOrNode,
+    shared_with_map: dict[EdgeOrNode, EdgeOrNode],
+) -> None:
+    """Merge the subtree for `child` with `parent`, the order is important here"""
+    root_parent = _find_root_edge_or_node(parent, shared_with_map)
+    root_child = _find_root_edge_or_node(child, shared_with_map)
+    # union the two trees by pointing the root of child to root of parent
+    shared_with_map[root_child] = root_parent
+
+
+def _update_shared_with(
+    child: EdgeOrNode,
+    qspec: QuantizationSpecBase,
+    shared_with_map: dict[EdgeOrNode, EdgeOrNode],
+):
+    """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
+    configuration and established the relationship between `edge_or_node` with the edge/node that it
+    is pointing to, we'll use this information in the end to get the group id
+    """
+    if isinstance(qspec, SharedQuantizationSpec):
+        parent = qspec.edge_or_node
+        # we point from edge_or_node to the node that it is sharing_with, e.g.
+        # qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
+        _union(parent, child, shared_with_map)
+
+
+def _unwrap_shared_qspec(
+    qspec: QuantizationSpecBase,
+    edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
+    shared_with_map: dict[EdgeOrNode, EdgeOrNode],
+) -> QuantizationSpecBase:
+    """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
+    if qspec is SharedQuantizationSpec
+       (1). tries to find the root edge or node for the node that the qspec points to
+       (2). recursively find the root qspec based on the qspec for the root node
+    """
+    if isinstance(qspec, SharedQuantizationSpec):
+        sharing_with = qspec.edge_or_node
+        root = _find_root_edge_or_node(sharing_with, shared_with_map)
+        qspec = edge_or_node_to_qspec[root]
+        return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
+    return qspec
+
+
+def _has_same_attr(
+    qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str
+):
+    return (
+        hasattr(qspec_a, attr_name)
+        and hasattr(qspec_b, attr_name)
+        and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name)
+    ) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name))
+
+
+def _get_edge_or_node_to_qspec(
+    model: torch.fx.GraphModule,
+) -> dict[EdgeOrNode, QuantizationSpecBase]:
+    """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes"""
+    edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {}
+    for n in model.graph.nodes:
+        if hasattr(n, "meta") and "quantization_annotation" in n.meta:
+            qa = n.meta["quantization_annotation"]
+            for input_to_n, qspec in qa.input_qspec_map.items():
+                input_edge = (input_to_n, n)
+                edge_or_node_to_qspec[input_edge] = qspec
+            if qa.output_qspec is not None:
+                output_node = n
+                qspec = qa.output_qspec
+                edge_or_node_to_qspec[output_node] = qspec
+    return edge_or_node_to_qspec
+
+
+def _union_input_edge_with(
+    input_edge,
+    input_edge_root_qspec,
+    edge_or_node,
+    edge_or_node_to_qspec,
+    shared_with_map,
+):
+    """Union input edge with another edge or node, used in implicit sharing to point the current input
+    edge to other user edges of the producer node, or the output of producer node since these are
+    referring to the same Tensor
+    """
+    root_qspec = None
+    if edge_or_node in edge_or_node_to_qspec:
+        qspec = edge_or_node_to_qspec[edge_or_node]
+        root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
+    # TODO: add assertions for types of root qspecs
+    if root_qspec is not None and all(
+        _has_same_attr(root_qspec, input_edge_root_qspec, attr)
+        for attr in [
+            "dtype",
+            "is_dynamic",
+            "quant_min",
+            "quant_max",
+            "qscheme",
+            "ch_axis",
+            "scale",
+            "zero_point",
+        ]
+    ):
+        # the input arg to the node should reuse the existing output observer for arg
+        # since dtype is the same (we may want to extend this to be a more strict check
+        # in the future)
+        # so we point from `input_edge` to `arg` (output of the argument)
+        _union(edge_or_node, input_edge, shared_with_map)
+
+
+def _get_edge_or_node_to_group_id(
+    edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
+) -> dict[EdgeOrNode, int]:
+    """Map from edge/node to the group ID, generated from quantization annotations,
+    edge/node with the same group ID should use the same observer/fake_quant instance
+
+    This is applying SharedQuantizationSpec configuration and map each edge/node to a group
+    There is another implicit sharing that's built in the quantization, when we have the following:
+       * op1 -> op2
+       * output of op1: int8_qspec
+       * (op1 -> op2) input edge: int8_qspec
+    we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
+
+    Figuring out the correct group ID for all edge/node is a standard union find problem:
+    https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
+
+    Args:
+        edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
+    Returns:
+        edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
+        belongs to the same group should have the same id
+
+    Example:
+        op2 -> cat1 -> cat2
+           op1 /        /
+                     op3
+        edge_or_node_to_qspec: {
+            op1: int8_qspec,
+            op2: int8_qspec,
+            (op1, cat1): int8_qspc,
+            (op2, cat1): SharedQuantizationSpec((op1, cat1)),
+            cat1: SharedQuantizationSpec((op1, cat1)),
+            (op3, cat2): int8_qspec,
+            (cat1, cat2): SharedQuantizationSpec((op3, cat2)),
+            cat2: SharedQuantizationSpec((op3, cat2)),
+        }
+
+        edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
+        edge_or_node_to_group_id: {
+            op1: 1,
+            op2: 1,
+            (op1, cat1): 1,
+            (op2, cat1): 1,
+            cat1: 1,
+            (op3, cat2): 1,
+            (cat1, cat2): 1,
+            cat2: 1,
+        }
+        # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
+        # connects the two sharing group around cat1 and cat2 op due to transitive sharing
+    """
+    # means the observer of key should be shared with observer with value, by default it will
+    # be shared with itself
+    shared_with_map: dict[EdgeOrNode, EdgeOrNode] = {
+        k: k for k in edge_or_node_to_qspec.keys()
+    }
+    for edge_or_node, qspec in edge_or_node_to_qspec.items():
+        if isinstance(edge_or_node, torch.fx.Node):
+            output_node = edge_or_node
+            _update_shared_with(output_node, qspec, shared_with_map)
+        else:
+            input_edge = edge_or_node
+            input_edge_root_qspec = _unwrap_shared_qspec(
+                qspec, edge_or_node_to_qspec, shared_with_map
+            )
+
+            assert isinstance(input_edge, tuple)
+            arg, n = input_edge
+            if n.meta["quantization_annotation"].allow_implicit_sharing:
+                # NOTE: the order is important here, we first share with other users and then share with previous
+                # output because the reverse order could cause circular dependency
+                # e.g node1 -> node2
+                #          \ -> node3
+                # when processing (node1, node2), if we first point (node1, node2) to node1
+                # Step 1. shared_map = {(node1, node2): node1}
+                # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
+                # which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
+                # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
+                # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
+                # have a circular dependency
+                # the following order works around this issue, but this does not allow arbitrary configuration
+                # of sharing so it might break in a different case in the future, when it breaks
+                # quantizer writer can check the notes here to debug the issue
+
+                # sharing with other users of the producer node
+                # (arg, user)
+                if not isinstance(arg, Node) or not isinstance(n, Node):
+                    raise Exception(  # noqa: TRY002
+                        f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}"
+                    )
+                for user in arg.users:
+                    if user is n:
+                        continue
+                    arg_to_user_edge = (arg, user)
+                    _union_input_edge_with(
+                        input_edge,
+                        input_edge_root_qspec,
+                        arg_to_user_edge,
+                        edge_or_node_to_qspec,
+                        shared_with_map,
+                    )
+
+                # sharing with output of producer node
+                _union_input_edge_with(
+                    input_edge,
+                    input_edge_root_qspec,
+                    arg,
+                    edge_or_node_to_qspec,
+                    shared_with_map,
+                )
+
+            _update_shared_with(input_edge, qspec, shared_with_map)
+
+    # now that we get the sharing relations between all edges and nodes, we can assingn group ids
+    cur_group_id = 0
+    edge_or_node_to_group_id: dict[EdgeOrNode, int] = {}
+    for edge_or_node in shared_with_map.keys():
+        root = _find_root_edge_or_node(edge_or_node, shared_with_map)
+        if root not in edge_or_node_to_group_id:
+            edge_or_node_to_group_id[root] = cur_group_id
+            cur_group_id += 1
+        edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
+
+    return edge_or_node_to_group_id
+
+
+def _get_obs_or_fq_map(
+    edge_or_node_to_group_id: dict[EdgeOrNode, int],
+    edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
+    is_qat: bool,
+) -> dict[EdgeOrNode, ObserverOrFakeQuantize]:
+    """Generates the EdgeOrNode to observer/fake_quant instances
+    Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
+    instances
+    """
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
+    group_id_to_obs_or_fq: dict[int, ObserverOrFakeQuantize] = {}
+    for edge_or_node, qspec in edge_or_node_to_qspec.items():
+        group_id = edge_or_node_to_group_id[edge_or_node]
+        if group_id not in group_id_to_obs_or_fq:
+            # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
+            # the implementation for _create_obs_or_fq_from_qspec
+            group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(
+                qspec, obs_or_fq_map, is_qat
+            )
+        obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
+    return obs_or_fq_map
+
+
+def _maybe_insert_input_observer_for_arg_or_kwarg(
+    node: Union[Node, Any],
+    arg: Argument,
+    qconfig: QConfigAny,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Argument:
+    """
+    Given a `node` and an `arg`, inserts an input observer between
+    `node` and `arg` if necessary.
+    """
+    # for ops such as torch.cat([x0, x1]),
+    # traverse through the list
+    if isinstance(arg, (list, tuple)):
+        new_arg_to_return = []
+        for inner_arg in arg:
+            new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
+                node,
+                inner_arg,
+                qconfig,
+                model,
+                named_modules,
+                obs_or_fq_map,
+                is_qat,
+            )
+            new_arg_to_return.append(new_inner_arg)
+        return type(arg)(new_arg_to_return)
+
+    if not isinstance(arg, Node):
+        return arg
+    assert isinstance(arg, Node)
+    # default (no observer)
+    new_arg = arg
+
+    # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
+    original_arg = arg
+    while _is_activation_post_process_node(original_arg, named_modules):
+        original_arg = original_arg.args[0]  # type: ignore[assignment]
+    assert isinstance(original_arg, Node), (
+        f"expect original argument to be a Node, but got: {type(original_arg)}"
+    )
+
+    input_edge = (original_arg, node)
+    if input_edge not in obs_or_fq_map:
+        return new_arg
+    # input_edge needs to be observed
+    input_edge_obs_or_fq = obs_or_fq_map[input_edge]
+    if input_edge_obs_or_fq is None:
+        return new_arg
+
+    arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
+    # the arg is observed as the output and is using the same instance as the input_edge
+    # we'll reuse the inserted observer/fake_quant
+    if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
+        input_edge_obs_or_fq
+    ):
+        return new_arg
+
+    # otherwise, we'll insert a new observer/fake_quant node
+
+    # skip inserting new observers if the same observer instance is inserted before for another user
+    # Example:
+    # conv1 -> obs1 -> existing_obs -> conv2
+    #             \ -> conv3
+    #
+    # instead of inserting new observers we will have:
+    # conv1 -> obs1 -> existing_obs -> conv2
+    #                            \ -> conv3
+    for maybe_obs_node in arg.users.keys():
+        if not _is_activation_post_process_node(maybe_obs_node, named_modules):
+            continue
+        maybe_obs_mod = named_modules[maybe_obs_node.target]  # type: ignore[index]
+        if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
+            return maybe_obs_node
+
+    assert isinstance(model.graph, Graph)
+    new_arg = _insert_obs_or_fq(
+        arg, input_edge_obs_or_fq, model, named_modules, model.graph
+    )
+    return new_arg
+
+
+def _maybe_insert_input_observers_for_node(
+    node: Node,
+    qconfig: QConfigAny,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> None:
+    """
+    If needed, inserts observers to the input args and kwargs of `node`.
+    Note: modifies `node` inplace.
+
+    For example, if cur_node needs an observer after prev_node, we change from
+
+      prev_node -> cur_node
+
+    To
+
+      prev_node -> obs -> cur_node
+
+    """
+    # Look through every input arg.  If that arg's target dtype does not
+    # match the current node's target dtype, insert an observer.
+    new_args = []
+    for arg in node.args:
+        new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
+            node,
+            arg,
+            qconfig,
+            model,
+            named_modules,
+            obs_or_fq_map,
+            is_qat,
+        )
+        new_args.append(new_arg)
+
+    # Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and
+    # gelu has a has an approximate kwarg that persist in exported graph.
+    # This is just a work around for these.
+    assert (
+        node.target == torch.ops.aten.clone.default
+        or node.target == torch.ops.aten.zeros_like.default
+        or node.target == torch.ops.aten.gelu.default
+        or len(node.kwargs) == 0
+    ), " expecting kwargs for aten op IR to be empty"
+
+    # assign the new args to the node, inplace
+    node.args = tuple(new_args)
+
+
+def _maybe_insert_output_observer_for_node(
+    node: Node,
+    model: torch.nn.Module,
+    named_modules: dict[str, torch.nn.Module],
+    graph: Graph,
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+) -> Optional[Node]:
+    if node in obs_or_fq_map:
+        output_act_obs_or_fq = obs_or_fq_map[node]
+        new_output = _insert_obs_or_fq(
+            node, output_act_obs_or_fq, model, named_modules, graph
+        )
+        # propagate numeric debug handle from original node to observer/fake_quant node
+        if (
+            isinstance(node, Node)
+            and isinstance(new_output, Node)
+            and CUSTOM_KEY in node.meta
+            and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
+        ):
+            if CUSTOM_KEY not in new_output.meta:
+                new_output.meta[CUSTOM_KEY] = {}
+            new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
+                CUSTOM_KEY
+            ][NUMERIC_DEBUG_HANDLE_KEY]
+        return new_output
+    return None
+
+
+def _maybe_insert_input_and_output_observers_for_node(
+    node: Node,
+    model: torch.fx.GraphModule,
+    obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    is_qat: bool,
+):
+    this_node_quantization_annotation = (
+        node.meta["quantization_annotation"]
+        if "quantization_annotation" in node.meta
+        else None
+    )
+    if this_node_quantization_annotation is None:
+        return
+
+    named_modules = dict(model.named_modules(remove_duplicate=False))
+    _maybe_insert_input_observers_for_node(
+        node,
+        None,  # qconfig
+        model,
+        named_modules,
+        obs_or_fq_map,
+        is_qat,
+    )
+
+    output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
+    if not output_is_a_tensor:
+        return
+
+    # this returns the new observer node if it was needed
+    maybe_output_obs_node = _maybe_insert_output_observer_for_node(
+        node, model, named_modules, model.graph, obs_or_fq_map, is_qat
+    )
+
+    if maybe_output_obs_node is None:
+        return
+    # Update users of original node to use the output observer
+    # instead. For example, change
+    #
+    #           next_node
+    #          /
+    #   cur_node -> obs
+    #
+    # to
+    #
+    #                 next_node
+    #                 /
+    #   cur_node -> obs
+    #
+    # We need to save orig users before updating uses because
+    # the list of users will change as we update uses
+    orig_users = list(node.users.keys())
+    for user_node in orig_users:
+        if user_node is maybe_output_obs_node:
+            continue
+        user_node.replace_input_with(node, maybe_output_obs_node)
+
+
+def prepare(
+    model: GraphModule,
+    node_name_to_scope: dict[str, tuple[str, type]],
+    is_qat: bool,
+    obs_or_fq_callback=None,
+) -> GraphModule:
+    # Since we are mutating the graph as we go, we iterate over the original
+    # nodes before observer insertion, instead of model.graph.nodes.
+    nodes_before_observation = list(model.graph.nodes)
+
+    # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
+    # all edge/nodes that belongs to the same group will use the same instance
+    # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
+    # instance
+    edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
+    edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
+    obs_or_fq_map = _get_obs_or_fq_map(
+        edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat
+    )
+    if obs_or_fq_callback:
+        obs_or_fq_callback(model, obs_or_fq_map)
+
+    for node in nodes_before_observation:
+        # TODO: simplify logic for inserting observers
+        _maybe_insert_input_and_output_observers_for_node(
+            node, model, obs_or_fq_map, is_qat
+        )
+
+    model = GraphModule(model, model.graph)
+
+    _save_state(
+        model,
+        {},  # node_name_to_qconfig
+        node_name_to_scope,
+        PrepareCustomConfig(),
+        {},  # equalization_node_name_to_qconfig
+        QConfigMapping(),
+        is_qat,
+        set(),  # observed_node_names
+    )
+    return model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7423a5320d97d3209faa60fc966e17d4f64e1e3c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/qat_utils.py
@@ -0,0 +1,995 @@
+# mypy: allow-untyped-defs
+import copy
+import dataclasses
+import itertools
+import operator
+from typing import Any, Callable, Optional, TYPE_CHECKING
+
+import torch
+import torch.nn.functional as F
+from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
+from torch.ao.quantization.pt2e.export_utils import _WrapperModule
+from torch.ao.quantization.quantizer import (
+    DerivedQuantizationSpec,
+    EdgeOrNode,
+    QuantizationSpecBase,
+    SharedQuantizationSpec,
+)
+from torch.fx import Graph, GraphModule, Node
+from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
+
+from .utils import (
+    _get_aten_graph_module_for_pattern,
+    _is_bn_node,
+    _is_conv_or_conv_transpose_node,
+    _is_conv_transpose_fn,
+    fold_bn_weights_into_conv_node,
+)
+
+
+if TYPE_CHECKING:
+    from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _get_quantized_conv_bn_example_inputs_kwargs(
+    is_per_channel: bool,
+    has_bias: bool,
+    bias_is_quantized: bool,
+    is_cuda: bool,
+) -> dict[str, Any]:
+    """
+    Optional example inputs for quantized and folded conv-bn patterns
+    used in convert, expressed as kwargs.
+    """
+    kwargs = {}
+    # Per tensor quantization uses literals to represent scale and zero
+    # point, so there is no need to include them here as kwargs
+    if is_per_channel:
+        kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
+        kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
+        if has_bias and bias_is_quantized:
+            kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float)
+            kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int)
+    if has_bias:
+        kwargs["conv_bias"] = torch.randn(1)
+    if is_cuda:
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                kwargs[k] = v.cuda()
+    return kwargs
+
+
+def _get_conv_bn_pattern(conv_fn: Callable) -> Callable:
+    def _conv_bn_pattern(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        conv_bias: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ) -> torch.Tensor:
+        x = conv_fn(x, conv_weight, conv_bias)
+        x = F.batch_norm(
+            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
+        )
+        return x
+
+    return _WrapperModule(_conv_bn_pattern)
+
+
+# TODO: merge this with the `no_conv_bias` case
+def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable:
+    def _qat_conv_bn_pattern(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        conv_bias: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Approximated method to fuse conv and bn. It requires only one forward pass.
+        conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std.
+        This is based on `nniqat.ConvBn2d._forward_approximate`.
+        """
+        # TODO: allow setting eps
+        bn_eps = 1e-5
+        running_std = torch.sqrt(bn_running_var + bn_eps)
+        scale_factor = bn_weight / running_std
+        weight_shape = [1] * len(conv_weight.shape)
+        weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
+        weight_shape[weight_in_channel_axis] = -1
+        bias_shape = [1] * len(conv_weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
+        zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
+        x = conv_fn(x, scaled_weight, zero_bias)
+        x = x / scale_factor.reshape(bias_shape)
+        x = x + conv_bias.reshape(bias_shape)
+        x = F.batch_norm(
+            x,
+            bn_running_mean,
+            bn_running_var,
+            bn_weight,
+            bn_bias,
+            training=True,
+            eps=bn_eps,
+        )
+        return x
+
+    return _WrapperModule(_qat_conv_bn_pattern)
+
+
+def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable:
+    def _qat_conv_bn_pattern_no_conv_bias(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        # Not used, only for matching convenience
+        conv_bias: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias.
+        """
+        # TODO: allow setting eps
+        bn_eps = 1e-5
+        running_std = torch.sqrt(bn_running_var + bn_eps)
+        scale_factor = bn_weight / running_std
+        weight_shape = [1] * len(conv_weight.shape)
+        weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
+        weight_shape[weight_in_channel_axis] = -1
+        bias_shape = [1] * len(conv_weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
+        x = conv_fn(x, scaled_weight, None)
+        x = x / scale_factor.reshape(bias_shape)
+        x = F.batch_norm(
+            x,
+            bn_running_mean,
+            bn_running_var,
+            bn_weight,
+            bn_bias,
+            training=True,
+            eps=bn_eps,
+        )
+        return x
+
+    return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias)
+
+
+def _append_qdq(x, is_per_channel, is_bias, kwargs):
+    """
+    Helper function to append q-dq ops after `x`, using dummy values for the qparams
+    and qmin/qmax. We use dummy values here because we match with `ignore_literals=True`
+    and will manually replace these values after subgraph rewriting.
+
+    Return the dq node.
+    """
+    # Dummy args to be passed into q-dq ops
+    per_channel_axis = 0
+    scale_key = "bias_scale" if is_bias else "weight_scale"
+    zp_key = "bias_zero_point" if is_bias else "weight_zero_point"
+    scale = kwargs[scale_key] if is_per_channel else 1.0
+    zp = kwargs[zp_key] if is_per_channel else 0
+    qmin = -127
+    qmax = 127
+    dtype = torch.int8
+
+    qd = torch.ops.quantized_decomposed
+    if is_per_channel:
+        x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
+        x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
+    else:
+        x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
+        x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
+    return x
+
+
+def _get_quantized_qat_conv_bn_pattern(
+    is_per_channel: bool,
+    has_bias: bool,
+    bias_is_quantized: bool,
+    conv_fn: Callable,
+    bn_is_training: bool,
+) -> Callable:
+    """
+    Return the quantized version of QAT conv + BN pattern.
+    This is based on `nniqat.ConvBn2d._forward_approximate`,
+    used in QAT convert. We first match this pattern and replace
+    it with the normal [conv - bn] pattern, then fold the BN
+    weights into conv.
+    """
+    # TODO: allow setting eps
+    bn_eps = 1e-5
+
+    def _quantized_qat_conv_bn_pattern(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+        **kwargs,
+    ) -> torch.Tensor:
+        running_std = torch.sqrt(bn_running_var + bn_eps)
+        scale_factor = bn_weight / running_std
+        weight_shape = [1] * len(conv_weight.shape)
+        weight_shape[0] = -1
+        bias_shape = [1] * len(conv_weight.shape)
+        bias_shape[1] = -1
+        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
+        scaled_weight = _append_qdq(
+            scaled_weight,
+            is_per_channel,
+            is_bias=False,
+            kwargs=kwargs,
+        )
+        if has_bias:
+            zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
+            if bias_is_quantized:
+                zero_bias = _append_qdq(
+                    zero_bias,
+                    is_per_channel,
+                    is_bias=True,
+                    kwargs=kwargs,
+                )
+            x = conv_fn(x, scaled_weight, zero_bias)
+        else:
+            x = conv_fn(x, scaled_weight, None)
+        x = x / scale_factor.reshape(bias_shape)
+        if has_bias:
+            x = x + kwargs["conv_bias"].reshape(bias_shape)
+        x = F.batch_norm(
+            x,
+            bn_running_mean,
+            bn_running_var,
+            bn_weight,
+            bn_bias,
+            training=bn_is_training,
+            eps=bn_eps,
+        )
+        return x
+
+    return _WrapperModule(_quantized_qat_conv_bn_pattern)
+
+
+def _get_folded_quantized_qat_conv_bn_pattern(
+    is_per_channel: bool,
+    has_bias: bool,
+    bias_is_quantized: bool,
+    conv_fn: Callable,
+    bn_is_training: bool,
+) -> Callable:
+    """
+    Quantized QAT conv - bn pattern with bn weights being folded into conv.
+    """
+    # TODO: allow setting eps
+    bn_eps = 1e-5
+
+    def _folded_quantized_qat_conv_bn_pattern(
+        x: torch.Tensor,
+        conv_weight: torch.Tensor,
+        bn_weight: torch.Tensor,
+        bn_bias: torch.Tensor,
+        bn_running_mean: torch.Tensor,
+        bn_running_var: torch.Tensor,
+        **kwargs,
+    ) -> torch.Tensor:
+        conv_weight = _append_qdq(
+            conv_weight,
+            is_per_channel,
+            is_bias=False,
+            kwargs=kwargs,
+        )
+        if has_bias:
+            bias = kwargs["conv_bias"]
+            if bias_is_quantized:
+                bias = _append_qdq(
+                    bias,
+                    is_per_channel,
+                    is_bias=True,
+                    kwargs=kwargs,
+                )
+        else:
+            bias = None
+        x = conv_fn(x, conv_weight, bias)
+        x = F.batch_norm(
+            x,
+            bn_running_mean,
+            bn_running_var,
+            bn_weight,
+            bn_bias,
+            training=bn_is_training,
+            eps=bn_eps,
+        )
+        return x
+
+    return _WrapperModule(_folded_quantized_qat_conv_bn_pattern)
+
+
+def _has_conv_bias_filter(
+    match: "InternalMatch",
+    original_graph: Graph,
+    pattern_graph: Graph,
+) -> bool:
+    """
+    Match filter for the subgraph rewriter that returns True if the conv node in
+    the original graph has bias.
+    """
+    for n in match.nodes_map.values():
+        if _is_conv_or_conv_transpose_node(n):
+            return len(n.args) > 2 and n.args[2] is not None
+    raise ValueError("Could not find conv node in matched conv + bn pattern")
+
+
+def _no_conv_bias_filter(
+    match: "InternalMatch",
+    original_graph: Graph,
+    pattern_graph: Graph,
+) -> bool:
+    """
+    Match filter for the subgraph rewriter that returns True if the conv node in
+    the original graph does NOT have bias.
+    """
+    return not _has_conv_bias_filter(match, original_graph, pattern_graph)
+
+
+def _is_quantize(n: Node) -> bool:
+    return n.target in [
+        torch.ops.quantized_decomposed.quantize_per_tensor.default,
+        torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+        torch.ops.quantized_decomposed.quantize_per_channel.default,
+    ]
+
+
+def _is_dequantize(n: Node) -> bool:
+    return n.target in [
+        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+        torch.ops.quantized_decomposed.dequantize_per_channel.default,
+    ]
+
+
+def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> dict[str, tuple[Node, Node]]:
+    """
+    Helper function to extract the nodes in the conv-bn fusion pattern after
+    subgraph rewriting, in the form of a map:
+
+        {name: (original_node, replacement_node)}
+
+    The following names must exist in the map:
+
+        "conv", "conv_weight", "conv_input", "bn", "getitem"
+
+    The following names may exist in the map:
+
+        "conv_weight_q", "conv_weight_dq", "conv_bias",
+        "conv_bias_q", "conv_bias_dq"
+    """
+
+    def _get_nodes(nodes: list[Node]) -> tuple[Node, Node, Optional[Node]]:
+        """
+        Return a 3-tuple of (conv_node, bn_node, getitem_node).
+        This asserts that the match contains exactly one of each node.
+        """
+        conv_node, bn_node, getitem_node = None, None, None
+        for n in nodes:
+            if n.op != "call_function":
+                continue
+            if _is_conv_or_conv_transpose_node(n):
+                assert conv_node is None
+                conv_node = n
+            if _is_bn_node(n):
+                assert bn_node is None
+                bn_node = n
+            if n.target == operator.getitem:
+                assert getitem_node is None
+                getitem_node = n
+        assert conv_node is not None
+        assert bn_node is not None
+        return (conv_node, bn_node, getitem_node)
+
+    def _get_q_dq_nodes(n: Node) -> tuple[Node, Node, Node]:
+        """
+        Return a 3-tuple of (orig_node, q_node, dq_node).
+        """
+        assert _is_dequantize(n)
+        q_node = n.args[0]
+        assert isinstance(q_node, Node)
+        assert _is_quantize(q_node)
+        orig_node = q_node.args[0]
+        assert isinstance(orig_node, Node)
+        return (orig_node, q_node, n)
+
+    original_nodes = list(_filter_nodes_map(r.nodes_map).values())
+    o_conv, o_bn, o_getitem = _get_nodes(original_nodes)
+    r_conv, r_bn, r_getitem = _get_nodes(r.replacements)
+
+    # Create the mapping from original node to replacement node
+    assert o_getitem is None
+    assert r_getitem is None
+    mapping = {
+        "conv": (o_conv, r_conv),
+        "bn": (o_bn, r_bn),
+    }
+
+    # Extract conv input and weight
+    # Note: here we extract the original nodes indirectly through the pattern nodes
+    # because the args of the original nodes are no longer available after replacement
+    (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys()))
+    (p_conv_input, p_conv_weight, *_) = p_conv.args
+    (r_conv_input, r_conv_weight, *_) = r_conv.args
+    assert isinstance(p_conv_input, Node)
+    assert isinstance(p_conv_weight, Node)
+    assert isinstance(r_conv_input, Node)
+    assert isinstance(r_conv_weight, Node)
+    o_conv_input = r.nodes_map[p_conv_input]
+    o_conv_weight = r.nodes_map[p_conv_weight]
+
+    # If conv weight is quantized, extract the q - dq nodes
+    if _is_dequantize(p_conv_weight):
+        p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes(
+            p_conv_weight
+        )
+        r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes(
+            r_conv_weight
+        )
+        o_conv_weight = r.nodes_map[p_conv_weight]
+        o_conv_weight_q = r.nodes_map[p_conv_weight_q]
+        o_conv_weight_dq = r.nodes_map[p_conv_weight_dq]
+        mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q)
+        mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq)
+    mapping["conv_input"] = (o_conv_input, r_conv_input)
+    mapping["conv_weight"] = (o_conv_weight, r_conv_weight)
+
+    # Extract conv bias
+    if len(p_conv.args) > 2 and len(r_conv.args) > 2:
+        p_conv_bias = p_conv.args[2]
+        r_conv_bias = r_conv.args[2]
+        assert isinstance(p_conv_bias, Node)
+        assert isinstance(r_conv_bias, Node)
+        o_conv_bias = r.nodes_map[p_conv_bias]
+
+        # If conv bias is quantized, extract the q - dq nodes
+        if _is_dequantize(p_conv_bias):
+            p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias)
+            r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias)
+            o_conv_bias = r.nodes_map[p_conv_bias]
+            o_conv_bias_q = r.nodes_map[p_conv_bias_q]
+            o_conv_bias_dq = r.nodes_map[p_conv_bias_dq]
+            mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q)
+            mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq)
+        mapping["conv_bias"] = (o_conv_bias, r_conv_bias)
+    return mapping
+
+
+def _filter_nodes_map(nodes_map: dict[Node, Node]) -> dict[Node, Node]:
+    """
+    Return a filtered `nodes_map` returned from the subgraph rewriter.
+    The filtered `nodes_map` will contain only nodes that are actually
+    matched in the pattern, excluding None or placeholder nodes.
+    """
+    new_nodes_map: dict[Node, Node] = {}
+    for pattern_node, graph_node in nodes_map.items():
+        # bias can be None
+        if graph_node is None:
+            continue
+        # skip pattern placeholder nodes
+        if pattern_node.op == "placeholder":
+            continue
+        new_nodes_map[pattern_node] = graph_node
+    return new_nodes_map
+
+
+# TODO: this is error prone, use the replace_literals_with_placeholders hack instead
+def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
+    """
+    Copy over literal args in conv, such as stride and padding, from the matched node
+    in the original graph to its replacement in the new graph.
+
+    This is needed due to the following limitation in the subgraph rewriter when used
+    with dynamo export: literal (non-tensor) args are not supported in the match and
+    replacement patterns. This is because dynamo export automatically inlines these
+    literal args, making them dead placeholder nodes. In the future, we should check
+    if dynamo export can optionally disable this inlining, or if subgraph rewriter
+    can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419.
+
+    Note: Unlike other tensor args like conv weights and biases, literal args are
+    preserved in the original nodes after replacement, so we can access them here.
+    """
+    assert _is_conv_or_conv_transpose_node(original_node)
+    assert _is_conv_or_conv_transpose_node(new_node)
+    # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
+    new_args = list(new_node.args)
+    if len(new_args) < 3:
+        # bias is optional, when it is not present, it means it is None
+        new_args.append(None)
+    new_node.args = tuple(new_args[:3]) + original_node.args[3:]
+
+
+def _update_conv_input_qspec_map_after_replacement(
+    original_node: Node, replacement_node: Node
+):
+    """
+    Update the `input_qspec_map` in the annotation after subgraph rewriting.
+
+    The original annotation referred to the nodes in the original graph,
+    so the keys in the `input_qspec_map` will need to be updated to reflect
+    the corresponding nodes in the replacement graph.
+    """
+    assert _is_conv_or_conv_transpose_node(original_node)
+    assert _is_conv_or_conv_transpose_node(replacement_node)
+    if "quantization_annotation" not in original_node.meta:
+        return
+    original_input_qspec_map = original_node.meta[
+        "quantization_annotation"
+    ].input_qspec_map
+    input_qspec_map = {}
+    # get the list of configs, it should be ordered as input, weight, bias
+    # note: this is really hacky, we need a better solution, hopefully
+    # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820
+    all_configs = list(original_input_qspec_map.items())
+    # input activation
+    input_qspec_map[replacement_node.args[0]] = all_configs[0][1]
+    # weight
+    input_qspec_map[replacement_node.args[1]] = all_configs[1][1]
+    # bias
+    if len(replacement_node.args) > 2 and len(all_configs) > 2:
+        input_qspec_map[replacement_node.args[2]] = all_configs[2][1]
+    replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map
+
+
+def _update_special_qspecs_after_replacement(
+    node: Node,
+    original_to_replacement_node: dict[Node, Node],
+):
+    """
+    Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s
+    used in `node`'s quantization annotation after subgraph rewriting.
+
+    The original annotation referred to the nodes in the original graph,
+    so the nodes used in these special quantization specs will need to
+    be updated to the corresponding nodes in the replacement graph.
+    """
+
+    def _get_new_edge_or_node(edge_or_node: EdgeOrNode):
+        if isinstance(edge_or_node, Node):
+            _node = edge_or_node
+            return original_to_replacement_node.get(_node, _node)
+        elif (
+            isinstance(edge_or_node, tuple)
+            and len(edge_or_node) == 2
+            and all(isinstance(x, Node) for x in edge_or_node)
+        ):
+            src, dest = edge_or_node
+            return (
+                original_to_replacement_node.get(src, src),
+                original_to_replacement_node.get(dest, dest),
+            )
+        else:
+            raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node))
+
+    def _get_new_qspec(qspec: QuantizationSpecBase):
+        if isinstance(qspec, SharedQuantizationSpec):
+            new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node)
+            return SharedQuantizationSpec(new_edge_or_node)
+        elif isinstance(qspec, DerivedQuantizationSpec):
+            new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from]
+            return dataclasses.replace(qspec, derived_from=new_derived_from)
+        else:
+            return qspec
+
+    if "quantization_annotation" not in node.meta:
+        return
+    annotation = node.meta["quantization_annotation"]
+    for input_node, qspec in annotation.input_qspec_map.items():
+        annotation.input_qspec_map[input_node] = _get_new_qspec(qspec)
+    annotation.output_qspec = _get_new_qspec(annotation.output_qspec)
+
+
+def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
+    # Example inputs for conv-bn1d patterns
+    _conv1d_bn_example_inputs = (
+        torch.randn(1, 1, 3),  # x
+        torch.randn(1, 1, 1),  # conv_weight
+        torch.randn(1),  # conv_bias
+        torch.randn(1),  # bn_weight
+        torch.randn(1),  # bn_bias
+        torch.randn(1),  # bn_running_mean
+        torch.randn(1),  # bn_running_var
+    )
+
+    # Example inputs for conv-bn2d patterns
+    _conv2d_bn_example_inputs = (
+        torch.randn(1, 1, 3, 3),  # x
+        torch.randn(1, 1, 1, 1),  # conv_weight
+        torch.randn(1),  # conv_bias
+        torch.randn(1),  # bn_weight
+        torch.randn(1),  # bn_bias
+        torch.randn(1),  # bn_running_mean
+        torch.randn(1),  # bn_running_var
+    )
+
+    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
+    if not has_bn:
+        return m
+    is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
+    for is_cuda in is_cuda_options:
+        m = _fuse_conv_bn_qat_helper(
+            m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda
+        )
+        m = _fuse_conv_bn_qat_helper(
+            m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda
+        )
+        m = _fuse_conv_bn_qat_helper(
+            m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda
+        )
+        m = _fuse_conv_bn_qat_helper(
+            m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda
+        )
+    return m
+
+
+def _fuse_conv_bn_qat_helper(
+    m: GraphModule,
+    conv_fn: Callable,
+    example_inputs: tuple[Any, ...],
+    is_cuda: bool,
+) -> GraphModule:
+    """
+    Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
+    the fused QAT subgraph equivalent. The input graph should already be annotated.
+    The annotations in the original nodes will be preserved in the corresponding
+    nodes in the new subgraph.
+
+    Note: This also handles the (conv + bn + relu) pattern.
+    """
+    m.graph.eliminate_dead_code()
+    m.recompile()
+
+    conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
+    match_pattern = _get_aten_graph_module_for_pattern(
+        conv_bn_pattern,
+        example_inputs,
+        is_cuda,
+    )
+
+    # Step (1): Replace patterns with conv bias
+    #
+    # Here we do replacement separately for cases with and without conv bias, since
+    # the replacement patterns for these two cases are substantially different.
+    # TODO: use the public replace_pattern API once it also returns replacement nodes
+
+    qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn)
+    replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern(
+        qat_conv_bn_pattern,
+        example_inputs,
+        is_cuda,
+    )
+    replacements_with_conv_bias = replace_pattern_with_filters(
+        m,
+        match_pattern,
+        replacement_pattern_with_conv_bias,
+        match_filters=[_has_conv_bias_filter],
+        ignore_literals=True,
+    )
+    m.recompile()
+
+    # Step (2): Replace patterns without conv bias
+
+    qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn)
+    replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern(
+        qat_conv_bn_pattern_no_conv_bias,
+        example_inputs,
+        is_cuda,
+    )
+    replacements_no_conv_bias = replace_pattern_with_filters(
+        m,
+        match_pattern,
+        replacement_pattern_no_conv_bias,
+        match_filters=[_no_conv_bias_filter],
+        ignore_literals=True,
+    )
+    m.recompile()
+
+    # Step (3): Post processing
+    #
+    # Due to limited functionality in the subgraph rewriter, here we manually
+    # update the replacement graph as follows:
+    #
+    #   (a) Copy over metadata from original subgraph. This ensures the stack traces
+    #       and annotations are preserved in the new subgraph
+    #
+    #   (b) Copy over literal args for conv from the original subgraph
+    #       TODO: do this for literal args for batchnorm as well
+    #
+    #   (c) Update all references of the old nodes in the original subgraph to refer
+    #       to the corresponding nodes in the new subgraph in the annotations
+    #
+    # In the future, we should try to push as much of this functionality into the
+    # subgraph rewriter as possible, so we don't have to manually copy anything over.
+    # For more detail, see https://github.com/pytorch/pytorch/issues/100419.
+
+    all_original_to_replacement_nodes = {}
+    for r in replacements_with_conv_bias + replacements_no_conv_bias:
+        replacement_dict = _get_conv_bn_pattern_nodes(r)
+        # The original conv node's "nn_module_stack"
+        conv_nn_module = replacement_dict["conv"][0].meta.get("nn_module_stack", None)
+        for k, node_tuple in replacement_dict.items():
+            original_node, replacement_node = node_tuple
+            # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
+            replacement_node.meta = original_node.meta
+            # If original_node is a get_attr node, it doesn't have nn_module_stack.
+            # In this case, we copy nn_module_stack from the original conv node.
+            if (
+                k in ["conv_input", "conv_weight"]
+                and conv_nn_module
+                and "nn_module_stack" not in replacement_node.meta
+            ):
+                replacement_node.meta["nn_module_stack"] = copy.deepcopy(conv_nn_module)
+            if _is_conv_or_conv_transpose_node(original_node):
+                # Step (3b): Copy over conv literal args
+                _copy_over_literal_conv_args(original_node, replacement_node)
+                # Step (3c): Update old references in the conv node's input_qspec_map
+                _update_conv_input_qspec_map_after_replacement(
+                    original_node, replacement_node
+                )
+            all_original_to_replacement_nodes[original_node] = replacement_node
+
+    # Step (3c): Update old references in the special qspecs for all nodes in the graph
+    for n in m.graph.nodes:
+        _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes)
+
+    return m
+
+
+def _duplicate_dequantize_node(m: GraphModule):
+    """
+    Helper function to duplicate all dequantize nodes in the graph if the
+    node has more than one user. For example:
+
+    Before:
+      quantize -> dequantize -> a
+                          \\--> b
+                          \\--> c
+
+    After:
+      quantize -> dequantize_1 -> a
+            \\--> dequantize_2 -> b
+            \\--> dequantize_3 -> c
+
+    This is useful for subgraph rewriting. E.g. if we wish to match the
+    pattern [dequantize - a] above, subgraph matching would fail because
+    the dequantize node has users outside the matched portion of the graph.
+    Instead, we match [dequantize_1 - a], which is safe.
+    """
+    dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
+    for n in m.graph.nodes:
+        if n.op != "call_function" or n.target != dq_op or len(n.users) == 1:
+            continue
+        for user in list(n.users):
+            with m.graph.inserting_before(n):
+                new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs)
+            user.replace_input_with(n, new_node)
+        m.graph.erase_node(n)
+    m.recompile()
+
+
+def _remove_extra_dequantize(m: GraphModule):
+    """
+    Removes duplicate dequant nodes in the graph, for an operator that has
+    multiple dequant nodes as a user, replace them with a single dequant node
+    that can be shared across all the uses. This should be seen as the "reverse"
+    of `_duplicate_dequantize_node`.
+    """
+    dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
+    for n in m.graph.nodes:
+        dq_users = [
+            user
+            for user in n.users
+            if user.op == "call_function" and user.target == dq_op
+        ]
+        if len(dq_users) > 1:
+            with m.graph.inserting_after(dq_users[0]):
+                new_node = m.graph.create_node(
+                    "call_function", dq_op, dq_users[0].args, {}
+                )
+            for dq_user in dq_users:
+                dq_user.replace_all_uses_with(new_node)
+                m.graph.erase_node(dq_user)
+    m.recompile()
+
+
+def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
+    """
+    Given a pair of quantize or dequantize nodes, copy over all literal args
+    from the original node to the replacement node.
+    """
+    # For quantize_per_tensor, scale and zp are literals and need to be copied
+    # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped
+    assert original_node.target == replacement_node.target
+    if original_node.target in (
+        torch.ops.quantized_decomposed.quantize_per_tensor.default,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+    ):
+        # Args: input, [scale, zp, qmin, qmax, dtype]
+        start_copy_arg_index = 1
+    elif original_node.target in (
+        torch.ops.quantized_decomposed.quantize_per_channel.default,
+        torch.ops.quantized_decomposed.dequantize_per_channel.default,
+    ):
+        # Args: input, scale, zp, [axis, qmin, qmax, dtype]
+        start_copy_arg_index = 3
+    else:
+        raise ValueError(
+            f"Expected quantize/dequantize nodes, got '{original_node.target}'"
+        )
+    replacement_node.args = (
+        replacement_node.args[:start_copy_arg_index]
+        + original_node.args[start_copy_arg_index:]
+    )
+
+
+def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
+    # Example inputs for quantized and folded conv-bn1d patterns used in convert
+    _quantized_conv1d_bn_example_inputs = (
+        torch.randn(1, 1, 3),  # x
+        torch.randn(1, 1, 1),  # conv_weight
+        torch.randn(1),  # bn_weight
+        torch.randn(1),  # bn_bias
+        torch.randn(1),  # bn_running_mean
+        torch.randn(1),  # bn_running_var
+    )
+
+    # Example inputs for quantized and folded conv-bn2d patterns used in convert
+    _quantized_conv2d_bn_example_inputs = (
+        torch.randn(1, 1, 3, 3),  # x
+        torch.randn(1, 1, 1, 1),  # conv_weight
+        torch.randn(1),  # bn_weight
+        torch.randn(1),  # bn_bias
+        torch.randn(1),  # bn_running_mean
+        torch.randn(1),  # bn_running_var
+    )
+
+    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
+    if not has_bn:
+        return m
+    is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
+    for is_cuda in is_cuda_options:
+        m = _fold_conv_bn_qat_helper(
+            m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda
+        )
+        m = _fold_conv_bn_qat_helper(
+            m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda
+        )
+        m = _fold_conv_bn_qat_helper(
+            m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda
+        )
+        m = _fold_conv_bn_qat_helper(
+            m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda
+        )
+
+    # remove in place add from batchnorm tracking traning stats
+    for node in m.graph.nodes:
+        if (
+            node.target == torch.ops.aten.add_.Tensor
+            and node.args[0].op == "get_attr"
+            and node.args[1] == 1
+            and (
+                torch.nn.modules.batchnorm.BatchNorm2d
+                in [val[1] for val in node.meta["source_fn_stack"]]
+                or torch.nn.modules.batchnorm.BatchNorm1d
+                in [val[1] for val in node.meta["source_fn_stack"]]
+            )
+        ):
+            m.graph.erase_node(node)
+
+    m.graph.eliminate_dead_code()
+    m.recompile()
+
+    return m
+
+
+def _fold_conv_bn_qat_helper(
+    m: GraphModule,
+    conv_fn: Callable,
+    example_inputs: tuple[Any, ...],
+    is_cuda: bool,
+) -> GraphModule:
+    """
+    Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
+    """
+
+    m.graph.eliminate_dead_code()
+    m.recompile()
+    _duplicate_dequantize_node(m)
+
+    # Step (1): Replace QAT pattern with simple [conv - bn] pattern
+    replacements = []
+    replacement_options = itertools.product(
+        [True, False],  # is_per_channel
+        [True, False],  # has_bias
+        [True, False],  # bias_is_quantized
+        [True, False],  # bn_is_training
+    )
+    for (
+        is_per_channel,
+        has_bias,
+        bias_is_quantized,
+        bn_is_training,
+    ) in replacement_options:
+        # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily
+        # filter out one of the values for this flag to avoid having duplicate patterns
+        if not has_bias and bias_is_quantized:
+            continue
+        kwargs = _get_quantized_conv_bn_example_inputs_kwargs(
+            is_per_channel, has_bias, bias_is_quantized, is_cuda
+        )
+        match_pattern = _get_quantized_qat_conv_bn_pattern(
+            is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
+        )
+        match_pattern = _get_aten_graph_module_for_pattern(
+            match_pattern,
+            example_inputs,
+            is_cuda,
+            **kwargs,
+        )
+        replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
+            is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
+        )
+        replacement_pattern = _get_aten_graph_module_for_pattern(
+            replacement_pattern,
+            example_inputs,
+            is_cuda,
+            **kwargs,
+        )
+        replacements.extend(
+            replace_pattern_with_filters(
+                m,
+                match_pattern,
+                replacement_pattern,
+                ignore_literals=True,
+            )
+        )
+    m.recompile()
+    _remove_extra_dequantize(m)
+
+    for r in replacements:
+        node_map = _get_conv_bn_pattern_nodes(r)
+
+        # Step (2): Copy over metadata from original subgraph
+        for original_node, replacement_node in node_map.values():
+            replacement_node.meta = original_node.meta
+
+        # Step (3): Copy over args for weight (and optionally bias) q - dq nodes
+        _copy_over_q_dq_args(*node_map["conv_weight_q"])
+        _copy_over_q_dq_args(*node_map["conv_weight_dq"])
+        if "conv_bias_q" in node_map:
+            assert "conv_bias_dq" in node_map
+            _copy_over_q_dq_args(*node_map["conv_bias_q"])
+            _copy_over_q_dq_args(*node_map["conv_bias_dq"])
+
+        # Step (4): Fold BN weights into conv
+        conv_bias = None
+        (_, conv_node) = node_map["conv"]
+        (_, bn_node) = node_map["bn"]
+        (_, conv_weight) = node_map["conv_weight"]
+        if "conv_bias" in node_map:
+            (_, conv_bias) = node_map["conv_bias"]
+        fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
+
+        # Copy over literal args for conv
+        for original_node in _filter_nodes_map(r.nodes_map).values():
+            if _is_conv_or_conv_transpose_node(original_node):
+                _copy_over_literal_conv_args(original_node, conv_node)
+
+    m.graph.eliminate_dead_code()
+    m.recompile()
+    return m
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8876d439feb41929ca9b64f3f023db499eac007b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/__init__.py
@@ -0,0 +1,6 @@
+from .rewrite import reference_representation_rewrite
+
+
+__all__ = [
+    "reference_representation_rewrite",
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc439a7c405366b04c2556cc3d11ce549bd9595f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py
@@ -0,0 +1,824 @@
+# mypy: allow-untyped-defs
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Callable, Optional
+
+import torch
+from torch._export.utils import _disable_aten_to_metadata_assertions
+from torch._higher_order_ops.out_dtype import out_dtype
+from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
+from torch.ao.quantization.pt2e.export_utils import _WrapperModule
+from torch.ao.quantization.pt2e.utils import (
+    _get_aten_graph_module_for_pattern,
+    _replace_literals_with_existing_placeholders,
+    _replace_literals_with_new_placeholders,
+    remove_tensor_overload_for_qdq_ops,
+)
+from torch.fx import GraphModule
+from torch.fx.subgraph_rewriter import replace_pattern
+
+
+__all__ = [
+    "reference_representation_rewrite",
+]
+
+
+def _qdq_quantized_linear(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    x_quant_min,
+    x_quant_max,
+    weight_i8,
+    weight_scale,
+    weight_zero_point,
+    weight_quant_min,
+    weight_quant_max,
+    bias_fp32,
+    out_scale,
+    out_zero_point,
+    out_quant_min,
+    out_quant_max,
+):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
+    )
+    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        weight_i8,
+        weight_scale,
+        weight_zero_point,
+        weight_quant_min,
+        weight_quant_max,
+        torch.int8,
+    )
+    out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
+    )
+    return out_i8
+
+
+def _reference_quantized_linear(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    x_quant_min,
+    x_quant_max,
+    weight_i8,
+    weight_scale,
+    weight_zero_point,
+    weight_quant_min,
+    weight_quant_max,
+    bias_fp32,
+    out_scale,
+    out_zero_point,
+    out_quant_min,
+    out_quant_max,
+):
+    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
+    # This results in failure to match the pattern.
+    # Therefore, we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
+    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
+
+    x_i16 = x_i8.to(torch.int16)
+    weight_i16 = weight_i8.to(torch.int16)
+    # always set bias to None so that the same representation can work for the case
+    # no matter if bias_scale == x_scale * weight_scale or not
+    acc_i32 = out_dtype(
+        torch.ops.aten.linear.default,
+        torch.int32,
+        x_i16 - x_zero_point,
+        weight_i16 - weight_zero_point,
+        None,
+    )
+    # TODO: change to mul.Scalar
+    # Note: we are quantizing bias with these scales without signal from user, but it might be OK
+    bias_scale = x_scale * weight_scale
+    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
+    acc_i32 = acc_i32 + bias_i32
+    # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
+    acc_i32 = (
+        out_dtype(
+            torch.ops.aten.mul.Tensor,
+            torch.int32,
+            acc_i32,
+            x_scale * weight_scale / out_scale,
+        )
+        + out_zero_point
+    )
+    out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
+    return out_i8
+
+
+def _qdq_dynamic_quantized_linear(
+    x_fp32,
+    x_quant_min,
+    x_quant_max,
+    x_eps,
+    weight_i8,
+    weight_scale,
+    weight_zero_point,
+    weight_quant_min,
+    weight_quant_max,
+    bias_fp32,
+):
+    x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(
+        x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8
+    )
+    x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
+    )
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
+    )
+    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        weight_i8,
+        weight_scale,
+        weight_zero_point,
+        weight_quant_min,
+        weight_quant_max,
+        torch.int8,
+    )
+    out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
+    return out_fp32
+
+
+def _reference_dynamic_quantized_linear(
+    x_fp32,
+    x_quant_min,
+    x_quant_max,
+    x_eps,
+    weight_i8,
+    weight_scale,
+    weight_zero_point,
+    weight_quant_min,
+    weight_quant_max,
+    bias_fp32,
+):
+    x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(
+        x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8
+    )
+    # decomposed representation for quantize_per_tensor
+    # TODO: use out_dtype(mul, ...) here when the op is ready
+    x_fp32 = x_fp32 / x_scale  # fp32
+    # round modes might be different here
+    # pytorch is rounding to even, which is also common for most of the backends
+    x_fp32 = torch.round(x_fp32)  # fp32
+    x_i32 = x_fp32.to(dtype=torch.int32)  # int32
+    x_i32 = x_i32 + x_zero_point  # int32
+    # clamp works for fp32, int32 and int8 dtypes
+    x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max)  # int32
+    x_i8 = x_i32.to(dtype=torch.int8)
+
+    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
+
+    x_i16 = x_i8.to(torch.int16)
+    weight_i16 = weight_i8.to(torch.int16)
+    # always set bias to None so that the same representation can work for the case
+    # no matter if bias_scale == x_scale * weight_scale or not
+    acc_i32 = out_dtype(
+        torch.ops.aten.linear.default,
+        torch.int32,
+        x_i16 - x_zero_point,
+        weight_i16 - weight_zero_point,
+        None,
+    )
+    bias_scale = x_scale * weight_scale
+    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
+    acc_i32 = acc_i32 + bias_i32
+    out_fp32 = acc_i32 * (x_scale * weight_scale)
+    return out_fp32
+
+
+def _qdq_quantized_conv2d(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    x_quant_min,
+    x_quant_max,
+    weight_i8,
+    weight_scale,
+    weight_zero_point,
+    weight_quant_min,
+    weight_quant_max,
+    bias_fp32,
+    out_scale,
+    out_zero_point,
+    out_quant_min,
+    out_quant_max,
+):
+    stride = [1, 1]
+    padding = [0, 0]
+    dilation = [1, 1]
+    transposed = False
+    output_padding = [0, 0]
+    groups = 1
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
+    )
+    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        weight_i8,
+        weight_scale,
+        weight_zero_point,
+        weight_quant_min,
+        weight_quant_max,
+        torch.int8,
+    )
+    out_fp32 = torch.ops.aten.convolution.default(
+        x_fp32,
+        weight_fp32,
+        bias_fp32,
+        stride,
+        padding,
+        dilation,
+        transposed,
+        output_padding,
+        groups,
+    )
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
+    )
+    return out_i8
+
+
+def _reference_quantized_conv2d(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    x_quant_min,
+    x_quant_max,
+    weight_i8,
+    weight_scale,
+    weight_zero_point,
+    weight_quant_min,
+    weight_quant_max,
+    bias_fp32,
+    out_scale,
+    out_zero_point,
+    out_quant_min,
+    out_quant_max,
+):
+    stride = [1, 1]
+    padding = [0, 0]
+    dilation = [1, 1]
+    transposed = False
+    output_padding = [0, 0]
+    groups = 1
+    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
+    # This results in failure to match the pattern.
+    # Therefore, we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
+    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
+
+    x_i16 = x_i8.to(torch.int16)
+    weight_i16 = weight_i8.to(torch.int16)
+    # always set bias to None so that the same representation can work for the case
+    # no matter if bias_scale == x_scale * weight_scale or not
+    acc_i32 = out_dtype(
+        torch.ops.aten.convolution.default,
+        torch.int32,
+        x_i16 - x_zero_point,
+        weight_i16 - weight_zero_point,
+        None,
+        stride,
+        padding,
+        dilation,
+        transposed,
+        output_padding,
+        groups,
+    )
+    # Note: we are quantizing bias with these scales without signal from user, but it might be OK
+    bias_scale = x_scale * weight_scale
+    # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to:
+    # Take linear calculation for example
+    # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32
+    # Represent X, W fp32 as their dequant transforms
+    # A_fp32 = (A_q - A_zero_point)/A_scale
+    # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32
+    # Factor out X_scale and W_scale
+    # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32
+    # In order to addition of bias_(i)_fp32 inside, we must do
+    # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale  # noqa: B950
+    # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale
+    # Thus bias quantization to int32 must be with X_scale * W_scale
+
+    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
+    # Unsqueeze to match broadcast dims
+    # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare
+    # in graph pattern replacement
+    bias_i32 = bias_i32.unsqueeze(-1)
+    bias_i32 = bias_i32.unsqueeze(-1)
+    acc_i32 = acc_i32 + bias_i32
+    # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
+    acc_i32 = (
+        out_dtype(
+            torch.ops.aten.mul.Tensor,
+            torch.int32,
+            acc_i32,
+            x_scale * weight_scale / out_scale,
+        )
+        + out_zero_point
+    )
+    out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
+    return out_i8
+
+
+def _qdq_quantized_add_relu(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    y_i8,
+    y_scale,
+    y_zero_point,
+    out_scale,
+    out_zero_point,
+    quant_min,
+    quant_max,
+):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8
+    )
+    y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8
+    )
+    out_fp32 = x_fp32 + y_fp32
+    out_fp32 = torch.ops.aten.relu(out_fp32)
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
+    )
+    return out_i8
+
+
+def _reference_quantized_add_relu(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    y_i8,
+    y_scale,
+    y_zero_point,
+    out_scale,
+    out_zero_point,
+    quant_min,
+    quant_max,
+):
+    """
+    See comments for `_reference_quantized_add` for more information on
+    how to derive the formula for out_i8 based on x_i8 and y_i8
+    """
+    x_i32 = x_i8.to(torch.int32)
+    y_i32 = y_i8.to(torch.int32)
+    # TODO: change this to mul.Scalar?
+    x_i32 = out_dtype(
+        torch.ops.aten.mul.Tensor,
+        torch.int32,
+        (x_i32 - x_zero_point),
+        (x_scale / out_scale),
+    )
+    y_i32 = out_dtype(
+        torch.ops.aten.mul.Tensor,
+        torch.int32,
+        (y_i32 - y_zero_point),
+        (y_scale / out_scale),
+    )
+    out_i32 = x_i32 + y_i32 + out_zero_point
+    # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point)
+    out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8)
+    return out_i8
+
+
+def _qdq_quantized_add(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    y_i8,
+    y_scale,
+    y_zero_point,
+    out_scale,
+    out_zero_point,
+    quant_min,
+    quant_max,
+):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8
+    )
+    y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8
+    )
+    out_fp32 = x_fp32 + y_fp32
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
+    )
+    return out_i8
+
+
+def _reference_quantized_add(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    y_i8,
+    y_scale,
+    y_zero_point,
+    out_scale,
+    out_zero_point,
+    quant_min,
+    quant_max,
+):
+    """
+        # How to Derive the formula for out_i8 based on x_i8 and y_i8
+        # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
+
+        # out_i8 is quantized output, we can write down the formula for it first:
+    out_i8 = out_f32 / out_scale + out_zero_point           (1)
+
+        # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
+        out_f32 = x_f32 + y_f32           (2)
+        x_fp32 = (x_i8 - x_zero_point) * x_scale         (3)
+        y_fp32 = (y_i8 - y_zero_point) * y_scale         (4)
+
+        # applying the above fomula to the out_i8 equation we can get the following:
+        out_i8 = out_fp32 / out_scale + out_zero_point             # (1)
+           = (x_f32 + y_f32) / out_scale + out_zero_point      # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
+           = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point  # apply (3) and (4)
+    """
+    x_i32 = x_i8.to(torch.int32)
+    y_i32 = y_i8.to(torch.int32)
+    # TODO: use out_dtype op
+    x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
+    y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
+    out_i32 = x_i32 + y_i32 + out_zero_point
+    quant_min = -128
+    quant_max = 127
+    out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
+    return out_i8
+
+
+def _qdq_quantized_max_pool2d(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    x_quant_min,
+    x_quant_max,
+    out_scale,
+    out_zero_point,
+    out_quant_min,
+    out_quant_max,
+):
+    kernel_size = 1
+    stride = 1
+    padding = 0
+    dilation = 1
+    ceil_mode = False
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
+    )
+    out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default(
+        x_fp32, kernel_size, stride, padding, dilation, ceil_mode
+    )
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
+        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
+    )
+    return out_i8
+
+
+def _reference_quantized_max_pool2d(
+    x_i8,
+    x_scale,
+    x_zero_point,
+    x_quant_min,
+    x_quant_max,
+    out_scale,
+    out_zero_point,
+    out_quant_min,
+    out_quant_max,
+):
+    kernel_size = 1
+    stride = 1
+    padding = 0
+    dilation = 1
+    ceil_mode = False
+    # to preserve x_quant_min, x_quant_max in the graph for pattern matching
+    x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
+    x_i32 = x_i8.to(torch.int32)
+    out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default(
+        x_i32 - x_zero_point, kernel_size, stride, padding, dilation, ceil_mode
+    )
+    out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
+    out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
+    out_i8 = out_fp32.to(torch.int8)
+    return out_i8
+
+
+def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
+    x = torch.ops.quantized_decomposed.quantize_per_tensor(
+        x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
+    )
+    return x
+
+
+def _reference_quantize_per_tensor_int8(
+    x_fp32, scale, zero_point, quant_min, quant_max
+):
+    # TODO: use out_dtype(mul, ...) here when the op is ready
+    x = x_fp32 / scale  # fp32
+    # round modes might be different here
+    # pytorch is rounding to even, which is also common for most of the backends
+    x = torch.round(x)  # fp32
+    x = x.to(dtype=torch.int32)  # int32
+    x = x + zero_point  # int32
+    # clamp works for fp32, int32 and int8 dtypes
+    x = torch.clamp(x, quant_min, quant_max)  # int32
+    x = x.to(dtype=torch.int8)
+    return x
+
+
+def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
+    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
+        x_i8, scale, zero_point, quant_min, quant_max, torch.int8
+    )
+    return x_fp32
+
+
+def _reference_dequantize_per_tensor_int8(
+    x_i8, scale, zero_point, quant_min, quant_max
+):
+    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
+    # This results in failure to match the pattern.
+    # Therefore, we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
+    # TODO: use out_dtype op
+    # note: x_i8.to(torch.int32) does not work here
+    # TODO: debug the implementation later when torchdynamo time out issue is resolved
+    return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
+
+
+def _quantize_per_channel_int8(
+    x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
+):
+    out_i8 = torch.ops.quantized_decomposed.quantize_per_channel(
+        x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
+    )
+    return out_i8
+
+
+def _reference_quantize_per_channel_int8(
+    x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
+):
+    x_fp32 = torch.transpose(x_fp32, ch_axis, -1)
+    out_i32 = torch.ops.aten.clamp(
+        torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max
+    )
+    out_i32 = torch.transpose(out_i32, ch_axis, -1)
+    return out_i32.to(torch.int8)
+
+
+def _dequantize_per_channel_int8(
+    x_i8, scales, zero_points, ch_axis, quant_min, quant_max
+):
+    # the following will be replaced as placeholders
+    out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel(
+        x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
+    )
+    return out_fp32
+
+
+def _reference_dequantize_per_channel_int8(
+    x_i8, scales, zero_points, ch_axis, quant_min, quant_max
+):
+    # the following will be replaced as placeholders
+    # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops)
+    # we call a torch.ops.aten.clamp here
+    x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
+    x_i8 = torch.transpose(x_i8, ch_axis, -1)
+    x_i32 = x_i8.to(torch.int32)
+    out_fp32 = (x_i32 - zero_points).to(torch.float) * scales
+    out_fp32 = torch.transpose(out_fp32, ch_axis, -1)
+    return out_fp32
+
+
+def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule):
+    return _replace_literals_with_existing_placeholders(
+        gm, exclude_literals=[-1], literal_to_ph_idx={1: 3, -128: 4, 127: 5}
+    )
+
+
+@dataclass
+class _RewriteInfo:
+    """Data needed for rewrite, this includes example inputs, pattern and replacement functions
+    and post transformation functions for the exported pattern and replacement GraphModule
+    """
+
+    # example inputs used for exporting the pattern into GraphModule
+    example_inputs: tuple[Any, ...]
+    pattern: Callable
+    replacement: Callable
+    # post transformation on the exported pattern and replacement GraphModule
+    pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
+    replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
+
+
+def reference_representation_rewrite(model: GraphModule) -> GraphModule:
+    _QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
+        torch.randint(-128, 127, (2, 5), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+        torch.randint(-128, 127, (5, 5), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-127], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+        torch.randn(1, dtype=torch.float),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+    )
+
+    _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
+        torch.randn((2, 5), dtype=torch.float),
+        -128,
+        127,
+        torch.finfo(torch.float32).eps,
+        torch.randint(-128, 127, (5, 5), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-127], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+        torch.randn(1, dtype=torch.float),
+    )
+
+    _QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
+        torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+        torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-127], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+        torch.randn(1, dtype=torch.float),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+    )
+
+    _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
+        torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+    )
+
+    _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
+        torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+    )
+
+    _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
+        torch.randn(1, 3, 3, 3, dtype=torch.float),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+    )
+
+    _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
+        torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+        torch.randn(1, dtype=torch.float),
+        torch.zeros(1, dtype=torch.int),
+        torch.tensor([-128], dtype=torch.int),
+        torch.tensor([127], dtype=torch.int),
+    )
+
+    _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
+        torch.randn(1, 3, 3, 3, dtype=torch.float),
+        torch.randn(3, dtype=torch.float),
+        torch.zeros(3, dtype=torch.int),
+        1,
+        -128,
+        127,
+    )
+
+    _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
+        torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
+        torch.randn(3, dtype=torch.float),
+        torch.zeros(3, dtype=torch.int),
+        1,
+        -128,
+        127,
+    )
+
+    _REWRITE_INFO_LIST = [
+        _RewriteInfo(
+            _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
+            _WrapperModule(_qdq_dynamic_quantized_linear),
+            _WrapperModule(_reference_dynamic_quantized_linear),
+            partial(
+                _replace_literals_with_existing_placeholders,
+                literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
+            ),
+            partial(
+                _replace_literals_with_existing_placeholders,
+                literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
+            ),
+        ),
+        _RewriteInfo(
+            _QUANTIZED_LINEAR_EXAMPLE_INPUTS,
+            _WrapperModule(_qdq_quantized_linear),
+            _WrapperModule(_reference_quantized_linear),
+            _replace_literals_with_new_placeholders,
+            _replace_literals_with_new_placeholders,
+        ),
+        _RewriteInfo(
+            _QUANTIZED_CONV2d_EXAMPLE_INPUTS,
+            _WrapperModule(_qdq_quantized_conv2d),
+            _WrapperModule(_reference_quantized_conv2d),
+            partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
+            partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
+        ),
+        _RewriteInfo(
+            _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
+            _WrapperModule(_qdq_quantized_add_relu),
+            _WrapperModule(_reference_quantized_add_relu),
+        ),
+        _RewriteInfo(
+            _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
+            _WrapperModule(_qdq_quantized_add),
+            _WrapperModule(_reference_quantized_add),
+        ),
+        _RewriteInfo(
+            _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
+            _WrapperModule(_qdq_quantized_max_pool2d),
+            _WrapperModule(_reference_quantized_max_pool2d),
+            _replace_literals_with_new_placeholders,
+            _replace_literals_with_new_placeholders,
+        ),
+        _RewriteInfo(
+            _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
+            _WrapperModule(_quantize_per_tensor_int8),
+            _WrapperModule(_reference_quantize_per_tensor_int8),
+        ),
+        _RewriteInfo(
+            _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
+            _WrapperModule(_dequantize_per_tensor_int8),
+            _WrapperModule(_reference_dequantize_per_tensor_int8),
+        ),
+        _RewriteInfo(
+            _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
+            _WrapperModule(_quantize_per_channel_int8),
+            _WrapperModule(_reference_quantize_per_channel_int8),
+            _replace_ph_qdq_per_channel_replacement,
+            _replace_ph_qdq_per_channel_replacement,
+        ),
+        _RewriteInfo(
+            _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
+            _WrapperModule(_dequantize_per_channel_int8),
+            _WrapperModule(_reference_dequantize_per_channel_int8),
+            _replace_ph_qdq_per_channel_replacement,
+            _replace_ph_qdq_per_channel_replacement,
+        ),
+    ]
+
+    remove_tensor_overload_for_qdq_ops(model)
+
+    with _disable_aten_to_metadata_assertions():
+        for rewrite_info in _REWRITE_INFO_LIST:
+            example_inputs = rewrite_info.example_inputs
+            pattern = rewrite_info.pattern
+            replacement = rewrite_info.replacement
+            pattern_post_trans = rewrite_info.pattern_post_trans
+            replacement_post_trans = rewrite_info.replacement_post_trans
+            pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs)  # type: ignore[arg-type, assignment]
+            remove_tensor_overload_for_qdq_ops(pattern)  # type: ignore[arg-type]
+            replacement = _get_aten_graph_module_for_pattern(  # type: ignore[assignment]
+                replacement,
+                example_inputs,  # type: ignore[arg-type]
+            )
+            remove_tensor_overload_for_qdq_ops(replacement)  # type: ignore[arg-type]
+            if pattern_post_trans:
+                pattern = pattern_post_trans(pattern)
+            if replacement_post_trans:
+                replacement = replacement_post_trans(replacement)
+            pattern.recompile()  # type: ignore[attr-defined]
+            replacement.recompile()  # type: ignore[attr-defined]
+            replace_pattern(model, pattern, replacement)
+
+    return model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f919c3d9dff05e12e01d40e67a94a8917afb5e40
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/pt2e/utils.py
@@ -0,0 +1,616 @@
+# mypy: allow-untyped-defs
+import operator
+import types
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.ao.quantization.pt2e._affine_quantization  # noqa: F401
+import torch.nn.functional as F
+
+# Makes sure that quantized_decomposed ops are registered
+from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
+from torch.ao.quantization.quantizer import QuantizationAnnotation
+from torch.export.unflatten import _assign_attr, _AttrKind
+from torch.fx import GraphModule, Node
+from torch.nn.utils.fusion import fuse_conv_bn_weights
+from torch.utils._pytree import LeafSpec
+
+
+__all__ = [
+    "fold_bn_weights_into_conv_node",
+    "remove_tensor_overload_for_qdq_ops",
+]
+
+_QUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.quantize_per_tensor.default,
+    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.quantize_per_channel.default,
+]
+
+
+_DEQUANTIZE_OPS = [
+    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+    torch.ops.quantized_decomposed.dequantize_per_channel.default,
+]
+
+
+def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
+    """
+    Assuming dest is one of the ops inserted by quant workflow, this function
+    finds if source and dest are connected. Assumption is that only quant workflow
+    inserted ops exist between source and dest
+    """
+    quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
+    quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
+    while dest.target in quant_workflow_ops:
+        if not isinstance(dest.args[0], torch.fx.Node):
+            raise ValueError(
+                f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}"
+            )
+        dest = dest.args[0]
+    return dest == source
+
+
+def _find_q_dq_node_for_user(
+    produer: torch.fx.Node, user: torch.fx.Node
+) -> tuple[Any, Any]:
+    """
+    Find q, dq pair corresponding to [producer -> q -> dq -> user]
+    Utils works by finding dq arg of user and ensuring it is connected to
+    producer
+    """
+    dq_node = None
+    for n in user.args:
+        if (
+            isinstance(n, torch.fx.Node)
+            and n.op == "call_function"
+            and n.target in _DEQUANTIZE_OPS
+        ):
+            if _is_connected(produer, n):
+                dq_node = n
+                break
+    if dq_node is None:
+        for n in user.kwargs:
+            if (
+                isinstance(n, torch.fx.Node)
+                and n.op == "call_function"
+                and n.target in _DEQUANTIZE_OPS
+            ):
+                if _is_connected(produer, n):
+                    dq_node = n
+                    break
+    if dq_node is None:
+        return (None, None)
+
+    q_node = None
+    if (
+        isinstance(arg := dq_node.args[0], torch.fx.Node)
+        and arg.op == "call_function"
+        and arg.target in _QUANTIZE_OPS
+    ):
+        q_node = arg
+    return (q_node, dq_node)
+
+
+def _is_sym_size_node(node: Node):
+    return (
+        node.op == "call_function"
+        and node.target == torch.ops.aten.sym_size.default
+        or node.target == torch.ops.aten.sym_numel.default
+        or node.target == torch.ops.aten.sym_numel
+        or node.target == torch.ops.aten.sym_size
+    )
+
+
+def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]:
+    node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
+    return node_users
+
+
+def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
+    if annotation is None:
+        return False
+    input_qspec_map = annotation.input_qspec_map
+    output_qspec = annotation.output_qspec
+    if len(input_qspec_map) == 0 and output_qspec is None:
+        return False
+    return True
+
+
+def _get_tensor_constant_from_node(node, m):
+    if node is None:
+        return None
+    assert node.op == "get_attr"
+    target_atoms = node.target.split(".")
+    attr_itr = m
+    for i, atom in enumerate(target_atoms):
+        if not hasattr(attr_itr, atom):
+            raise RuntimeError(
+                f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
+            )
+        attr_itr = getattr(attr_itr, atom)
+    return attr_itr
+
+
+def _get_all_arguments(orig_args, orig_kwargs, args_schema):
+    all_args = []
+    for i, schema in enumerate(args_schema):
+        if schema.name in orig_kwargs:
+            all_args.append(orig_kwargs[schema.name])
+        elif not schema.kwarg_only and i < len(orig_args):
+            all_args.append(orig_args[i])
+        else:
+            all_args.append(schema.default_value)
+    return all_args
+
+
+def _is_supported_batch_norm_for_training(node: Node):
+    """
+    Return True if the given node refers to an aten batch norm op QAT supports.
+    """
+    supported_ops = [
+        torch.ops.aten.batch_norm.default,
+        torch.ops.aten._native_batch_norm_legit.default,
+        # Note: we won't need this op anymore after batch norm consolidation
+        # For now, we need to continue to support it because it gives better
+        # training numerics than `_native_batch_norm_legit`
+        torch.ops.aten.cudnn_batch_norm.default,
+        torch.ops.aten.miopen_batch_norm.default,
+    ]
+    return node.target in supported_ops
+
+
+# TODO: move this to torch/ao/quantization/utils.py
+def _is_conv_node(n: Node):
+    """
+    Return whether the node refers to an aten conv op.
+    """
+    return n.op == "call_function" and n.target in [
+        torch.ops.aten.conv1d.default,
+        torch.ops.aten.conv1d.padding,
+        torch.ops.aten.conv2d.default,
+        torch.ops.aten.conv2d.padding,
+        torch.ops.aten.conv3d.default,
+        torch.ops.aten.conv3d.padding,
+    ]
+
+
+def _is_conv_transpose_node(n: Node):
+    """
+    Return whether the node refers to an aten conv_transpose op.
+    """
+    return n.op == "call_function" and n.target in [
+        torch.ops.aten.conv_transpose1d,
+        torch.ops.aten.conv_transpose1d.default,
+        torch.ops.aten.conv_transpose2d,
+        torch.ops.aten.conv_transpose2d.input,
+    ]
+
+
+def _is_conv_or_conv_transpose_node(n: Node):
+    """
+    Return whether the node refers to an aten conv or conv transpose op.
+    """
+    return _is_conv_node(n) or _is_conv_transpose_node(n)
+
+
+def _is_conv_transpose_fn(conv_fn: Callable):
+    return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
+
+
+def _is_bn_node(n: Node):
+    return (
+        _is_supported_batch_norm_for_training(n)
+        or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
+    )
+
+
+def fold_bn_weights_into_conv_node(
+    conv_node: Node,
+    conv_weight_node: Node,
+    conv_bias_node: Optional[Node],
+    bn_node: Node,
+    m: GraphModule,
+) -> None:
+    # conv args: input, weight, bias, stride, padding, dilation, ...
+    conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
+    conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
+    transpose = _is_conv_transpose_node(conv_node)
+
+    # eval bn args: input, weight, bias, running mean, running var, momentum, eps
+    # train bn args: input, weight, bias, running mean, running var, training, momentum, eps
+    bn_args_schema = bn_node.target._schema.arguments  # type: ignore[union-attr]
+    bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
+    bn_w = _get_tensor_constant_from_node(bn_args[1], m)
+    bn_b = _get_tensor_constant_from_node(bn_args[2], m)
+    bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
+    bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
+    if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
+        eps_arg_index = 6
+    elif _is_supported_batch_norm_for_training(bn_node):
+        eps_arg_index = 7
+    else:
+        raise ValueError("BN node target is unexpected ", bn_node.target)
+    bn_eps = bn_args[eps_arg_index]
+
+    fused_weight, fused_bias = fuse_conv_bn_weights(
+        conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose
+    )
+
+    # update the weight and bias for conv
+    conv_args = list(conv_node.args)
+    # filling in the default bias argument
+    if len(conv_args) == 2:
+        conv_args.append(None)
+
+    # calling data since the fused_weight and fused_bias are nn.Parameter
+    weight_attr_name = conv_weight_node.target
+    assert isinstance(weight_attr_name, str)
+    _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER)
+    if conv_bias_node is not None:
+        bias_attr_name = conv_bias_node.target
+        _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER)
+    else:
+        bias_attr_name = weight_attr_name + "_bias"
+        _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER)
+        with m.graph.inserting_before(conv_node):
+            get_bias_node = m.graph.get_attr(bias_attr_name)
+        # NOTE: here we assume the bias of conv is not quantized!
+        conv_args[2] = get_bias_node
+    conv_node.args = tuple(conv_args)
+
+    # native_batch_norm has 3 outputs, we expect getitem calls on the output
+    # and we want to replace the uses of getitem 0 with the output of conv
+    #
+    if bn_node.target == torch.ops.aten.batch_norm.default:
+        # With the new training ir, instead of batch_norm + getitem,
+        # we only have the batch_norm node.
+        #
+        # Before:
+        # conv -> bn -> users
+        # After:
+        # conv -> users
+        #       bn has no users now
+        bn_node.replace_all_uses_with(conv_node)
+    else:
+        # Before:
+        # conv -> bn - (first output) -> users1
+        #          \ - (second output) -> users2
+        #          \ - (third output) -> users3
+        # After:
+        # conv -> (first output) -> users1
+        #       bn -
+        #          \ - (second output) -> users2
+        #          \ - (third output) -> users3
+        # if users2 and users3 are empty then bn will be removed through dead code elimination
+        for user in bn_node.users:
+            if (
+                user.op != "call_function"
+                or user.target != operator.getitem
+                or user.args[1] != 0
+            ):
+                continue
+            user.replace_all_uses_with(conv_node)
+
+    # If the BN node does not have users, erase it from the graph
+    # Note: we need to do this manually because the model can still be in train
+    # mode at this point, in which case DCE won't erase the BN node automatically
+    # since the node refers to a mutating op. Here we still need to call DCE first
+    # to get rid of the unused getitem nodes that consume the BN node.
+    m.graph.eliminate_dead_code()
+    if len(bn_node.users) == 0:
+        m.graph.erase_node(bn_node)
+
+
+# fuse conv bn weights, inplace modification of the graph_module and graph
+def _fuse_conv_bn_(m: GraphModule) -> None:
+    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
+    if not has_bn:
+        return
+    for n in m.graph.nodes:
+        if n.op != "call_function" or n.target not in (
+            torch.ops.aten._native_batch_norm_legit_no_training.default,
+            torch.ops.aten.batch_norm.default,
+        ):
+            continue
+        bn_node = n
+        n = bn_node.args[0]
+        if not _is_conv_or_conv_transpose_node(n):
+            continue
+        conv_node = n
+        conv_weight_node = conv_node.args[1]
+        conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
+        fold_bn_weights_into_conv_node(
+            conv_node, conv_weight_node, conv_bias_node, bn_node, m
+        )
+
+    m.graph.eliminate_dead_code()
+    m.recompile()
+
+
+def _get_node_name_to_scope(model: GraphModule) -> dict[str, tuple[str, type]]:
+    # TODO: move this information to fx node itself
+    node_name_to_scope: dict[str, tuple[str, type]] = {}
+    for n in model.graph.nodes:
+        nn_module_stack = n.meta.get("nn_module_stack", None)
+        current_scope = ("", type(None))
+        if nn_module_stack:
+            bt = list(nn_module_stack.values())[-1]
+            current_scope = (bt[0].split(".")[-1], bt[1])
+        node_name_to_scope[n.name] = current_scope
+    return node_name_to_scope
+
+
+def _get_aten_graph_module_for_pattern(
+    pattern: Callable,
+    example_inputs: tuple[Any, ...],
+    is_cuda: bool = False,
+    **kwargs,
+) -> GraphModule:
+    """
+    Convert the pattern to an FX graph with decomposed aten ops.
+    """
+    if is_cuda:
+        example_inputs = tuple(
+            [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
+        )
+
+    aten_pattern = torch.export.export_for_training(
+        pattern,  # type: ignore[arg-type]
+        example_inputs,
+        kwargs,
+        strict=True,
+    ).module()
+
+    aten_pattern.graph.eliminate_dead_code()  # type: ignore[operator, union-attr]
+    aten_pattern.recompile()  # type: ignore[operator]
+
+    # ep.module() adds copy_ nodes for the mutated inputs.
+    # For patterns, it doesn't matter
+    for node in aten_pattern.graph.nodes:  # type: ignore[union-attr]
+        if (
+            node.op == "call_function"
+            and node.target == torch.ops.aten.copy_.default
+            and len(node.users) == 0
+        ):
+            aten_pattern.graph.erase_node(node)  # type: ignore[operator, union-attr]
+
+    aten_pattern.graph.eliminate_dead_code()  # type: ignore[operator, union-attr]
+    aten_pattern.recompile()  # type: ignore[operator]
+
+    return aten_pattern  # type: ignore[return-value]
+
+
+def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
+    """Remove .tensor overload for quantize/dequantize ops so that we can
+    use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
+    """
+    _MAP = {
+        torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
+        torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
+        torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
+        torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
+        torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
+        torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
+    }
+    for n in match_pattern.graph.nodes:
+        if n.op != "call_function":
+            continue
+        if n.target in _MAP:
+            n.target = _MAP[n.target]
+
+
+def _is_literal(arg):
+    if isinstance(arg, (int, float)):
+        return True
+    if isinstance(arg, (tuple, list)):
+        return all(map(_is_literal, arg))
+    return False
+
+
+def _replace_literals_with_new_placeholders(
+    gm: torch.fx.GraphModule,
+    merge_dup: bool = False,
+    exclude_literals: Optional[list[Any]] = None,
+):
+    """Replace the literals in the graph with placeholder nodes that's created on the fly while we
+    traverse the graph, so that the literal arguments in the graph can be matched and replaced
+
+    To use this, the pattern and replacement graph should have the exact same number of literal args
+    and they should be used in the exact same order in the pattern and replacement graph.
+
+    If the literal arguments are not used in the same order in pattern and replacement graph, please
+    use `_replace_literals_with_existing_placeholders` instead
+
+    Args:
+        `gm`: input GraphModule that we'll transform
+        `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in
+         the graph, whether they should correspond to the same placeholder or not
+        `exclude_literals`: a list of literals that will not be replaced with placeholders
+
+    Example:
+
+    # 1. Original Graph
+    def pattern(self, x):
+        return x + 3
+
+    def replacement(self, x):
+        return x - 3
+
+    example_inputs = (torch.randn(1, 3, 3, 3),)
+    pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
+    replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
+
+    # 2. Before calling replace literals we'll see the following graph:
+    def pattern(self, x):
+        return x + 3
+
+    def replacement(self, x):
+        return x - 3
+
+    pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)
+    replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)
+
+    # 3. After replacing literals with new placeholder nodes
+
+    def pattern(self, x, new_ph):
+        return x + new_ph
+
+    def pattern(self, x, new_ph):
+        return x - new_ph
+
+    """
+    last_ph = None
+    cnt = 0
+    literal_to_ph: dict[Union[float, bool, int, torch.dtype], Node] = {}
+    if exclude_literals is None:
+        exclude_literals = []
+
+    in_spec = gm._in_spec
+    args_spec = in_spec.children_specs[0]
+    for node in gm.graph.nodes:
+        if node.op == "placeholder":
+            last_ph = node
+            cnt += 1
+            continue
+        with gm.graph.inserting_after(last_ph):
+            new_args = []
+            for arg in node.args:
+                if _is_literal(arg) and arg not in exclude_literals:
+                    if merge_dup and arg in literal_to_ph:
+                        new_args.append(literal_to_ph[arg])
+                    else:
+                        ph_node = gm.graph.placeholder("arg" + str(cnt))
+                        new_args.append(ph_node)
+                        args_spec.children_specs.append(LeafSpec())
+                        cnt += 1
+                        if merge_dup:
+                            literal_to_ph[arg] = ph_node
+                else:
+                    new_args.append(arg)
+            new_args = tuple(new_args)
+
+        node.args = new_args
+
+    # Update `num_nodes`, `num_leaves`, `num_children`.
+    args_spec.__post_init__()
+    in_spec.__post_init__()
+    return gm
+
+
+def _replace_literals_with_existing_placeholders(
+    gm: torch.fx.GraphModule,
+    exclude_literals: Optional[list[Any]] = None,
+    literal_to_ph_idx: Optional[dict[Union[float, int, bool, torch.dtype], int]] = None,
+):
+    """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments
+    in the graph can be matched and replaced
+
+    To use this, all literal args in the graph should be unique and each of them should correspond
+    to exactly one placeholder node
+
+    # 1. Original Graph
+    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
+        return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
+
+    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
+        x_i8 = torch.clamp(x_i8, quant_min, quant_max)
+        return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
+
+    example_inputs = (
+        torch.randn(1, 3, 3, 3),
+        1.0,
+        0,
+        -128,
+        127,
+    )
+    pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
+    replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
+
+    # 2. Before calling replace literals we'll see the following graph:
+    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
+        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
+        return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)
+
+    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
+        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
+        x_i8 = torch.clamp(x_i8, -128, 127)
+        return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)
+
+    # Note that literal args appear in different order in pattern and replacement graph, so
+    # we can't use _replace_literals_with_new_placeholders
+
+    literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}
+    pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)
+    replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)
+
+    # 3. After replacing literals with existing placeholder nodes
+
+    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
+        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
+        return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
+
+    def replacement(x_i8, scale, zero_point, quant_min, quant_max):
+        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
+        x_i8 = torch.clamp(x_i8, quant_min, quant_max)
+        return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
+    """
+    if exclude_literals is None:
+        exclude_literals = []
+
+    if literal_to_ph_idx is None:
+        literal_to_ph_idx = {}
+
+    phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
+
+    for node in gm.graph.nodes:
+        if node.op != "call_function":
+            continue
+        new_args = []
+        for arg in node.args:
+            if (
+                _is_literal(arg)
+                and arg not in exclude_literals
+                and arg in literal_to_ph_idx
+            ):
+                ph_idx = literal_to_ph_idx[arg]
+                ph_node = phs[ph_idx]
+                new_args.append(ph_node)
+            else:
+                new_args.append(arg)
+        new_args = tuple(new_args)
+        node.args = new_args
+    return gm
+
+
+# TODO: Handle this in export itself and don't wrap the model in another GraphModule
+# in prepare and convert
+def _disallow_eval_train(model: GraphModule):
+    """
+    Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
+    This is useful for exported models, where these methods don't actually behave as expected.
+    """
+    error_message = """
+        Calling train() or eval() is not supported for exported models.
+        Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead.
+
+        If you cannot replace the calls to `model.train()` and `model.eval()`, you may override
+        the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`,
+        which does the above automatically for you. Note that this has limited effect on switching
+        behavior between train and eval modes, and should be used only for special ops such as dropout
+        and batchnorm.
+        """
+
+    def _train(self, mode: bool = True):
+        raise NotImplementedError(error_message)
+
+    def _eval(self, mode: bool = True):
+        raise NotImplementedError(error_message)
+
+    model.train = types.MethodType(_train, model)  # type: ignore[method-assign]
+    model.eval = types.MethodType(_eval, model)  # type: ignore[method-assign]
+    return model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5cd5e8696d39781004960f47e6f44d3b1987ff4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/__init__.py
@@ -0,0 +1,22 @@
+from .quantizer import (
+    DerivedQuantizationSpec,
+    EdgeOrNode,
+    FixedQParamsQuantizationSpec,
+    QuantizationAnnotation,
+    QuantizationSpec,
+    QuantizationSpecBase,
+    Quantizer,
+    SharedQuantizationSpec,
+)
+
+
+__all__ = [
+    "EdgeOrNode",
+    "Quantizer",
+    "QuantizationSpecBase",
+    "QuantizationSpec",
+    "FixedQParamsQuantizationSpec",
+    "SharedQuantizationSpec",
+    "DerivedQuantizationSpec",
+    "QuantizationAnnotation",
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..15404cc560117713bf8c952f594c051b1c13e3a4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/composable_quantizer.py
@@ -0,0 +1,83 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .quantizer import QuantizationAnnotation, Quantizer
+
+
+if TYPE_CHECKING:
+    import torch
+    from torch.fx import Node
+
+__all__ = [
+    "ComposableQuantizer",
+]
+
+
+class ComposableQuantizer(Quantizer):
+    """
+    ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
+    This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
+    maybe supported by one quantizer while linear layers and other ops might be supported by another
+    quantizer.
+
+    ComposableQuantizer is initialized with a list of `Quantizer` instances.
+    The order of the composition matters since that is the order in which the quantizers will be
+    applies.
+    Example:
+    ```
+    embedding_quantizer = EmbeddingQuantizer()
+    linear_quantizer = MyLinearQuantizer()
+    xnnpack_quantizer = (
+        XNNPackQuantizer()
+    )  # to handle ops not quantized by previous two quantizers
+    composed_quantizer = ComposableQuantizer(
+        [embedding_quantizer, linear_quantizer, xnnpack_quantizer]
+    )
+    prepared_m = prepare_pt2e(model, composed_quantizer)
+    ```
+    """
+
+    def __init__(self, quantizers: list[Quantizer]):
+        super().__init__()
+        self.quantizers = quantizers
+        self._graph_annotations: dict[Node, QuantizationAnnotation] = {}
+
+    def _record_and_validate_annotations(
+        self, gm: torch.fx.GraphModule, quantizer: Quantizer
+    ) -> None:
+        for n in gm.graph.nodes:
+            if "quantization_annotation" in n.meta:
+                # check if the annotation has been changed by
+                # comparing QuantizationAnnotation object id
+                if n in self._graph_annotations and (
+                    id(self._graph_annotations[n])
+                    != id(n.meta["quantization_annotation"])
+                ):
+                    raise RuntimeError(
+                        f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
+                    )
+                else:
+                    self._graph_annotations[n] = n.meta["quantization_annotation"]
+            else:
+                if n in self._graph_annotations:
+                    raise RuntimeError(
+                        f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
+                    )
+
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """just handling global spec for now"""
+        for quantizer in self.quantizers:
+            quantizer.annotate(model)
+            self._record_and_validate_annotations(model, quantizer)
+        return model
+
+    def transform_for_annotation(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        for quantizer in self.quantizers:
+            model = quantizer.transform_for_annotation(model)
+        return model
+
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..88bc6f3c8c9ff38b532846bcf279ce7b222898f2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py
@@ -0,0 +1,97 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import copy
+
+import torch
+import torch.nn.functional as F
+from torch.ao.quantization.observer import PerChannelMinMaxObserver
+from torch.ao.quantization.quantizer.quantizer import (
+    QuantizationAnnotation,
+    QuantizationSpec,
+    Quantizer,
+)
+from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
+    OperatorConfig,
+    OperatorPatternType,
+    QuantizationConfig,
+)
+
+
+__all__ = [
+    "get_embedding_operators_config",
+    "EmbeddingQuantizer",
+]
+
+
+def get_embedding_operators_config() -> OperatorConfig:
+    weight_quantization_spec = QuantizationSpec(
+        dtype=torch.uint8,
+        qscheme=torch.per_channel_affine_float_qparams,
+        ch_axis=0,
+        observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12),
+    )
+    quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None)
+    ops: list[OperatorPatternType] = [[torch.nn.Embedding]]
+    ops.append([F.embedding])
+    supported_config_and_operators = OperatorConfig(
+        config=quantization_config, operators=ops
+    )
+    return copy.deepcopy(supported_config_and_operators)
+
+
+class EmbeddingQuantizer(Quantizer):
+    def __init__(self) -> None:
+        super().__init__()
+
+    @classmethod
+    def get_supported_quantization_configs(cls) -> list[QuantizationConfig]:
+        op_configs: set[QuantizationConfig] = {
+            spec for spec, _ in cls.get_supported_operators()
+        }
+        return list(op_configs)
+
+    @classmethod
+    def get_supported_operator_for_quantization_config(
+        cls, quantization_config: QuantizationConfig
+    ) -> list[OperatorPatternType]:
+        for config, ops in cls.get_supported_operators():
+            # note: this assumes each entry in cls.supported_spec_and_operators
+            # corresponds to one spec, e.g. we don't have
+            # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
+            # where the first and second entry have the same spec but did not
+            # merge the op list
+            if config == quantization_config:
+                return ops
+        return []
+
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """just handling global spec for now"""
+        self._annotate_embedding_ops(model.graph)
+        return model
+
+    def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None:
+        embedding_config: OperatorConfig = get_embedding_operators_config()
+        for node in graph.nodes:
+            # Keep node parsing based annotations instead of module partitioners
+            # just as an example of alternate ways of annotating
+            if (
+                node.op == "call_function"
+                and node.target == torch.ops.aten.embedding.default
+            ):
+                if embedding_config.config.weight is None:
+                    raise ValueError(
+                        "Embedding config must have a valid weight quantization spec."
+                    )
+                node.meta["quantization_annotation"] = QuantizationAnnotation(
+                    input_qspec_map={
+                        node.args[0]: embedding_config.config.weight,
+                    }
+                )
+
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
+
+    @classmethod
+    def get_supported_operators(cls) -> list[OperatorConfig]:
+        return [get_embedding_operators_config()]
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/quantizer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7da601052a9c06ad93fb1a9162be7e63f9294fe7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/quantizer.py
@@ -0,0 +1,181 @@
+# mypy: allow-untyped-defs
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import Callable, Optional, Union
+
+import torch
+from torch import Tensor
+from torch.ao.quantization import ObserverOrFakeQuantize
+from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
+from torch.fx import Node
+
+
+__all__ = [
+    "Quantizer",
+    "QuantizationSpecBase",
+    "QuantizationSpec",
+    "FixedQParamsQuantizationSpec",
+    "EdgeOrNode",
+    "SharedQuantizationSpec",
+    "DerivedQuantizationSpec",
+    "QuantizationAnnotation",
+]
+
+
+class QuantizationSpecBase(ABC):  # noqa: B024
+    """Base class for different types of quantization specs that allows users to
+    specify how to quantize a Tensor (input/output of a Node) in the model
+    """
+
+
+@dataclass(eq=True, frozen=True)
+class QuantizationSpec(QuantizationSpecBase):
+    """Quantization spec for common operators that allows user to specify how to
+    quantize a Tensor, this includes dtype, quant_min, quant_max etc.
+    """
+
+    dtype: torch.dtype
+    # observer or fake_quantize constructor such as
+    # MinMaxObserver, PerChannelHistogramObserver etc.
+    # or we can attach some custom args to them
+    # e.g. MinMaxObserver.with_args(eps=eps)
+    observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor
+    quant_min: Optional[int] = None
+    quant_max: Optional[int] = None
+    qscheme: Optional[torch.qscheme] = None
+    ch_axis: Optional[int] = None
+    is_dynamic: bool = False
+
+    def __post_init__(self):
+        # TODO: add init for quant_min/quant_max
+        # quant_min must be less than quant_max
+        if (
+            self.quant_min is not None
+            and self.quant_max is not None
+            and self.quant_min > self.quant_max
+        ):
+            raise ValueError(
+                f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
+            )
+
+        # ch_axis must be less than the number of channels
+        # but no way to check here. Just check that it is not < 0.
+        if self.ch_axis is not None and self.ch_axis < 0:
+            raise ValueError("Ch_axis is < 0.")
+
+
+@dataclass(eq=True, frozen=True)
+class FixedQParamsQuantizationSpec(QuantizationSpecBase):
+    dtype: torch.dtype
+    scale: float
+    zero_point: int
+    quant_min: Optional[int] = None
+    quant_max: Optional[int] = None
+    qscheme: Optional[torch.qscheme] = None
+    is_dynamic: bool = False
+
+
+"""
+The way we refer to other points of quantization in the graph will be either
+an input edge or an output value
+input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
+output value is an fx Node
+"""
+EdgeOrNode = Union[tuple[Node, Node], Node]
+EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"
+
+
+@dataclass(eq=True, frozen=True)
+class SharedQuantizationSpec(QuantizationSpecBase):
+    """
+    Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
+    """
+
+    # the edge or node to share observer or fake quant instances with
+    edge_or_node: EdgeOrNode
+
+
+@dataclass(eq=True, frozen=True)
+class DerivedQuantizationSpec(QuantizationSpecBase):
+    """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""
+
+    derived_from: list[EdgeOrNode]
+    derive_qparams_fn: Callable[[list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor]]
+    dtype: torch.dtype
+    quant_min: Optional[int] = None
+    quant_max: Optional[int] = None
+    qscheme: Optional[torch.qscheme] = None
+    ch_axis: Optional[int] = None
+    is_dynamic: bool = False
+
+
+@dataclass
+class QuantizationAnnotation:
+    """How are input arguemnt or output should be quantized,
+    expressed as QuantizationSpec, this corresponds to how a Tensor in the
+    operator Graph is observed (PTQ) or fake quantized (QAT)
+    """
+
+    # a map from torch.fx.Node to a type of QuantizationSpecBase
+    input_qspec_map: dict[Node, Optional[QuantizationSpecBase]] = field(
+        default_factory=dict
+    )
+
+    # How the output of this node is quantized, expressed as QuantizationSpec
+    # TODO: change the value to QuantizationSpec in a separate PR
+    output_qspec: Optional[QuantizationSpecBase] = None
+
+    # For a Node: node1 and edge: (node1, node2), since they are observing the same
+    # Tensor, we may want to implicitly share observers, this flag allows people to
+    # turn off this behavior for the output of the node
+    allow_implicit_sharing: bool = True
+
+    # whether the node is annotated or not
+    _annotated: bool = False
+
+
+class Quantizer(ABC):
+    def transform_for_annotation(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        """Allows for user defined transforms to run before annotating the graph.
+        This allows quantizer to allow quantizing part of the model that are otherwise not quantizable.
+        For example quantizer can
+        a) decompose a compound operator like scaled dot product attention,
+        into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa
+        or b) transform scalars to tensor to allow quantizing scalares.
+
+        Note: this is an optional method
+        """
+        return model
+
+    # annotate nodes in the graph with observer or fake quant constructors
+    # to convey the desired way of quantization
+    @abstractmethod
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        pass
+
+    # validate the annotated graph is supported by the backend
+    @abstractmethod
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
+
+    def prepare_obs_or_fq_callback(
+        self,
+        model: torch.fx.GraphModule,
+        edge_or_node_to_obs_or_fq: dict[EdgeOrNode, ObserverOrFakeQuantize],
+    ) -> None:
+        """A callback that will be called after the observers or fake quants are created
+        for each sharing group, but before they are inserted into the graph. The
+        callback can be used to make final quantization adjustments, such as enforcing
+        specific scale and zero point on model input or output.
+
+        Args:
+          * `model`: the graph module being prepared.
+          * `edge_or_node_to_obs_or_fq`: a dictionary mapping each annotated edge and
+            node to the corresponding observer or fake quant object. Note that multiple
+            edges and/or nodes can map to the same observer / fake quant instance if
+            they were annotated with SharedQuantizationSpec. This dictionary can be
+            modified by the callback.
+        """
+        return
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cae2ec30d1e337a0451097277c1c1049c0e879c7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/utils.py
@@ -0,0 +1,82 @@
+# mypy: allow-untyped-defs
+
+from torch.ao.quantization.pt2e.utils import _is_sym_size_node
+from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
+from torch.fx import Node
+
+
+def _annotate_input_qspec_map(node: Node, input_node: Node, qspec):
+    quantization_annotation = node.meta.get(
+        "quantization_annotation", QuantizationAnnotation()
+    )
+    if quantization_annotation.input_qspec_map is None:
+        quantization_annotation.input_qspec_map = {}
+    quantization_annotation.input_qspec_map[input_node] = qspec
+    node.meta["quantization_annotation"] = quantization_annotation
+
+
+def _annotate_output_qspec(node: Node, qspec):
+    quantization_annotation = node.meta.get(
+        "quantization_annotation", QuantizationAnnotation()
+    )
+    quantization_annotation.output_qspec = qspec
+    node.meta["quantization_annotation"] = quantization_annotation
+
+
+def _node_only_used_for_sym_size(node: Node, partition_nodes: list[Node]):
+    """
+    This utility is used to handle cases when dynami_shape=True tracing leads
+    to symint nodes in the pattern of linear module. In those cases, we need to
+    distinguish between the nodes that are in input for just extracting value of
+    some dimentions (and symint nodes) vs. the one that is activation.
+    For example:
+    graph(x, y, weight):
+       size_0 = torch.ops.aten.sym_size([x], [0])
+       size_1 = torch.ops.aten.sym_size([y], [1])
+       view_size = size_0 * size_1
+       size_3 = torch.ops.aten.sym_size([x], [2])
+       vie_out = torch.ops.aten.view(x, [view_size, size_3])
+       return mm(view_out, weight)
+    In the example above y node is not actual input. It exist only to extract size_1
+    """
+    if _is_sym_size_node(node):
+        return True
+
+    return all(
+        ((user not in partition_nodes) or _is_sym_size_node(user))
+        for user in node.users
+    )
+
+
+def _get_module_name_filter(module_name: str):
+    """Get the module_name_filter function for a given module name, the filter accepts
+    a node and checks if the node comes from a module that has certain module name
+
+    For example:
+        node: linear_op = call_function[...](...)  # comes from a module with name blocks.sub.linear1
+
+
+    >> module_name_filter = _get_module_name_filter("blocks.sub")
+    >> print(module_name_filter(node))
+    True  # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
+    """
+
+    def module_name_filter(n: Node) -> bool:
+        # example: {
+        #    'L__self___sub': ("L['self'].sub", ),
+        #    'L__self___sub_linear': ("L['self'].sub.linear", )
+        # }
+        # get_attr nodes doesn't have nn_module_stack?
+        nn_module_stack = n.meta.get("nn_module_stack", {})
+
+        def _normalize_path(n):
+            prefix = 0
+            # TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph.
+            if n.startswith("L['self']."):
+                prefix = len("L['self'].")
+            return n[prefix:]
+
+        names = [_normalize_path(n) for n, _ in nn_module_stack.values()]
+        return module_name in names
+
+    return module_name_filter
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4777645a9e90cb7a94e684a5c3bbeae906568fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py
@@ -0,0 +1,1575 @@
+# mypy: allow-untyped-defs
+import functools
+import itertools
+import operator
+import warnings
+from collections.abc import Sequence
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, TYPE_CHECKING, Union
+from typing_extensions import TypeAlias
+
+import torch
+import torch.nn.functional as F
+from torch.ao.quantization.fake_quantize import (
+    FakeQuantize,
+    FusedMovingAvgObsFakeQuantize,
+)
+from torch.ao.quantization.observer import (
+    HistogramObserver,
+    MovingAverageMinMaxObserver,
+    MovingAveragePerChannelMinMaxObserver,
+    PerChannelMinMaxObserver,
+    PlaceholderObserver,
+)
+from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
+from torch.ao.quantization.quantizer.quantizer import (
+    QuantizationAnnotation,
+    QuantizationSpec,
+    Quantizer,
+    SharedQuantizationSpec,
+)
+from torch.ao.quantization.quantizer.utils import _get_module_name_filter
+from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
+    get_bias_qspec,
+    get_input_act_qspec,
+    get_output_act_qspec,
+    get_weight_qspec,
+    QuantizationConfig,
+)
+from torch.fx import Node
+from torch.fx.passes.utils.source_matcher_utils import (
+    get_source_partitions,
+    SourcePartition,
+)
+
+
+FilterFn: TypeAlias = Callable[[list[Node]], bool]
+
+
+if TYPE_CHECKING:
+    from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
+
+__all__ = [
+    "X86InductorQuantizer",
+    "get_default_x86_inductor_quantization_config",
+    "get_x86_inductor_linear_dynamic_fp16_config",
+]
+
+
+@dataclass
+class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
+    # _is_output_of_quantized_pattern:
+    #  * Node as output node of a fusion pattern.
+    #  * The fusion pattern supports int8 data type.
+    #  * The fusion pattern has inputs annotated to insert observer.
+    #  * The quantization_config is not `None`.
+    _is_output_of_quantized_pattern: bool = False
+
+
+# Operators that:
+# 1. Operators are optimized to run with int8 when int8 input provided.
+# 2. Operators do not support int8 input and produce fp32 output.
+int8_in_int8_out_ops: set = {
+    torch.ops.aten.max_pool2d.default,
+    torch.ops.aten.cat.default,
+    torch.ops.aten.avg_pool2d.default,
+    torch.ops.aten.adaptive_avg_pool2d.default,
+    torch.ops.aten.flatten.using_ints,
+}
+
+# Operators that support the int8 data type for quantization config propagation.
+# A superset of int8_in_int8_out_ops incorporating additional operators.
+propagation_quantizable_ops = int8_in_int8_out_ops
+
+# Operators support the int8 data type
+# and recipe is configured by default in X86InductorQuantizer.
+default_quantizable_ops = propagation_quantizable_ops | {
+    torch.ops.aten.conv1d.default,
+    torch.ops.aten.conv2d.default,
+    torch.ops.aten.linear.default,
+}
+
+# A superset of default_quantizable_ops includes operators support the int8 data type
+# but not enabled by default recipe of X86InductorQuantizer.
+quantizable_ops = default_quantizable_ops | {
+    torch.ops.aten.matmul.default,
+}
+
+QUANT_ANNOTATION_KEY = "quantization_annotation"
+
+
+def _skip_annotate(nodes: list[Node], filter_fn: Optional[FilterFn] = None) -> bool:
+    """Determine whether to skip annotation for a list of nodes."""
+
+    # 1) Skip annotate if any node is already annotated
+    if _is_any_annotated(nodes):
+        return True
+
+    # 2) Proceed annotate if a) a filter function is provided
+    # and b) the given nodes list passes the filter function check.
+    if filter_fn and filter_fn(nodes):
+        return False
+
+    return True
+
+
+def _create_module_name_filter(module_name: str) -> FilterFn:
+    """Create a filter function for a given module name.
+
+    The filter function takes a list of nodes (as determined by the annotate function)
+    and return True if *all* nodes come from the specified module name, False otherwise.
+
+    For example:
+        linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1`
+        relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1`
+
+    >> module_name_filter = _create_module_name_filter_inner("sub")
+    >> print(module_name_filter([relu, linear_1]))
+    # True  # These two nodes are determined by `_annotate_linear_unary` function and from "sub".
+    """
+
+    filter_fn = _get_module_name_filter(module_name)
+
+    def check_all_nodes_from_module(nodes: list[Node]) -> bool:
+        all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes)
+        return all_nodes_from_module_name
+
+    return check_all_nodes_from_module
+
+
+def _create_operator_type_filter(
+    operator_type: Callable,
+) -> FilterFn:
+    """Create a filter function for a given operator type.
+
+    The filter function takes a list of nodes and returns True if it contains
+    exactly one node with the specified operator type, False otherwise.
+
+    For example:
+        linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1`
+        relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1`
+
+    >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default)
+    >> print(operator_type_filter([relu, linear_1]))
+    # True  # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`.
+    """
+
+    def operator_type_filter(nodes: list[Node]):
+        num_nodes_with_operator_type = sum(
+            node.target == operator_type for node in nodes
+        )
+        if num_nodes_with_operator_type > 1:
+            raise NotImplementedError(
+                f"Several nodes within a single pattern are {operator_type}."
+            )
+        return num_nodes_with_operator_type == 1
+
+    return operator_type_filter
+
+
+def _global_config_filter(nodes: list[Node]) -> bool:
+    """Filter function for global configuration.
+
+    This filter function takes a list of nodes and returns True if there is exactly one node
+    in the list that is a default quantizable operation, False otherwise.
+    """
+    num_nodes_in_default_quantizable_ops = sum(
+        node.target in default_quantizable_ops for node in nodes
+    )
+    if num_nodes_in_default_quantizable_ops > 1:
+        raise NotImplementedError(
+            "Several nodes within a single pattern are default quantizable operations."
+        )
+    return num_nodes_in_default_quantizable_ops == 1
+
+
+def _map_module_function_to_aten_operator_type():
+    module_function_to_aten_operator: dict[Callable, torch._ops.OpOverloadPacket] = {}
+    map_list = (
+        ([torch.nn.Conv2d, F.conv1d], torch.ops.aten.conv1d.default),
+        ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default),
+        ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default),
+        ([torch.nn.MaxPool2d, F.max_pool2d], torch.ops.aten.max_pool2d.default),
+        (
+            [
+                torch.cat,
+            ],
+            torch.ops.aten.cat.default,
+        ),
+        ([torch.nn.AvgPool2d, F.avg_pool2d], torch.ops.aten.avg_pool2d.default),
+        (
+            [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d],
+            torch.ops.aten.adaptive_avg_pool2d.default,
+        ),
+        (
+            [
+                torch.flatten,
+            ],
+            torch.ops.aten.flatten.using_ints,
+        ),
+        (
+            [
+                torch.matmul,
+            ],
+            torch.ops.aten.matmul.default,
+        ),
+    )
+    for map_item in map_list:
+        module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1]))  # type: ignore[arg-type, call-overload]
+    return module_function_to_aten_operator
+
+
+def _mark_nodes_as_annotated(nodes: list[Node]):
+    for node in nodes:
+        if node is not None:
+            if QUANT_ANNOTATION_KEY not in node.meta:
+                node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation()
+            node.meta[QUANT_ANNOTATION_KEY]._annotated = True
+
+
+def _is_node_annotated(_node):
+    """
+    return True if the node is annotated, otherwise return False
+    """
+    return (
+        QUANT_ANNOTATION_KEY in _node.meta
+        and _node.meta[QUANT_ANNOTATION_KEY]._annotated
+    )
+
+
+def _is_any_annotated(nodes: list[Node]):
+    """
+    Given a list of nodes (that represents an operator pattern),
+    check if any of the node is annotated, return True if any of the node
+    is annotated, otherwise return False.
+    """
+    return any(_is_node_annotated(node) for node in nodes)
+
+
+def _is_all_annotated(nodes: list[Node]):
+    """
+    Given a list of nodes (that represents an operator pattern),
+    return True if all of the node is annotated, otherwise return False.
+    """
+    return all(_is_node_annotated(node) for node in nodes)
+
+
+def _is_quantized_op_pt2e(node: torch.fx.Node):
+    """
+    Used for pt2e flow to check if the node is a quantized node:
+    Case1: the node has been annotated as output node of a fusion pattern.
+    Case2: the node has been annotated as single quantized node.
+    """
+    if not _is_any_annotated([node]):
+        # The node has not been annotated, directly return False
+        return False
+    quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None)
+    assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
+    return quantization_annotation._is_output_of_quantized_pattern
+
+
+@functools.lru_cache
+def get_default_x86_inductor_quantization_config(
+    is_qat: bool = False,
+    is_dynamic: bool = False,
+    reduce_range: bool = False,
+):
+    """
+    reduce_range is False by default. Set it to True on earlier CPUs without VNNI to avoid accuracy issue.
+    """
+    extra_args: dict[str, Any] = {"eps": 2**-12}
+    if is_qat:
+        if is_dynamic:
+            act_observer_or_fake_quant_ctr = FakeQuantize
+            dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
+                averaging_constant=1
+            )
+            extra_args["observer"] = dynamic_quant_observer
+        else:
+            act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize  # type: ignore[assignment]
+    else:
+        if is_dynamic:
+            act_observer_or_fake_quant_ctr = PlaceholderObserver  # type: ignore[assignment]
+        else:
+            act_observer_or_fake_quant_ctr = HistogramObserver  # type: ignore[assignment]
+
+    # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py
+    act_quantization_spec = QuantizationSpec(
+        dtype=torch.uint8,
+        quant_min=0,
+        quant_max=127 if reduce_range else 255,
+        qscheme=torch.per_tensor_affine,
+        is_dynamic=is_dynamic,
+        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
+            **extra_args
+        ),
+    )
+
+    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
+        FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver
+    )
+
+    if is_qat:
+        # Only support per channel quant for now
+        extra_args["observer"] = MovingAveragePerChannelMinMaxObserver  # type: ignore[dict-item]
+    weight_quantization_spec = QuantizationSpec(
+        dtype=torch.int8,
+        quant_min=-128,
+        quant_max=127,
+        qscheme=torch.per_channel_symmetric,
+        ch_axis=0,  # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
+        is_dynamic=False,
+        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
+            **extra_args
+        ),
+    )
+    bias_quantization_spec = None  # will use placeholder observer by default
+    quantization_config = QuantizationConfig(
+        act_quantization_spec,
+        act_quantization_spec,
+        weight_quantization_spec,
+        bias_quantization_spec,
+        is_qat,
+    )
+    return quantization_config
+
+
+@functools.lru_cache
+def get_x86_inductor_linear_dynamic_fp16_config():
+    """
+    For linear_dynamic_fp16. The name may be confusing.
+    The op's behavior is fp32_input * (fp16_weight -> to_fp32) -> fp32_output.
+    """
+    weight_quantization_spec = QuantizationSpec(
+        dtype=torch.float16,
+        observer_or_fake_quant_ctr=PlaceholderObserver,
+    )
+    quantization_config = QuantizationConfig(
+        None,  # input_quantization_spec
+        None,  # output_quantization_spec
+        weight_quantization_spec,
+        None,  # bias_quantization_spec
+    )
+    return quantization_config
+
+
+def _annotate_nodes_not_quantize(nodes: Union[Node, list[Node]]) -> None:
+    """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`)."""
+    if not isinstance(nodes, list):
+        nodes = [nodes]
+    for node in nodes:
+        node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+            _annotated=True
+        )
+
+
+def _config_checker(method: Callable) -> Callable:
+    @functools.wraps(method)
+    def wrapper(
+        quantizer: "X86InductorQuantizer",
+        name: Any,
+        quantization_config: Optional["QuantizationConfig"],
+    ) -> "X86InductorQuantizer":
+        if quantizer._need_skip_config(quantization_config):
+            warnings.warn(
+                f"Skip the quantization config for {name}.",
+            )
+            return quantizer
+        return method(quantizer, name, quantization_config)
+
+    return wrapper
+
+
+@dataclass
+class _CurrentQuantizationMode:
+    r"""Configuration defining the current quantization mode for the quantizer.
+
+    All possible current quantization modes are listed below:
+    ----------------------------------------------------------------------------------------------------------
+                |                                       dynamic_state
+     qat_state  |---------------------------------------------------------------------------------------------
+                |                           None                              |    True       |  False
+    ----------------------------------------------------------------------------------------------------------
+        None    | quantizer does not receive a non-None `quantization_config` | \             | \
+        False   | quantizer will not do QAT                                   | dynamic       | static
+        True    | quantizer will do QAT                                       | QAT + dynamic | QAT + static
+    """
+
+    qat_state: Optional[bool]
+    dynamic_state: Optional[bool]
+
+
+class X86InductorQuantizer(Quantizer):
+    module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type()
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.global_config: Optional[QuantizationConfig] = None
+        self.operator_type_qconfig: dict[
+            torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
+        ] = {}
+        self.module_name_qconfig: dict[str, Optional[QuantizationConfig]] = {}
+
+    def _get_current_quantization_mode(self) -> _CurrentQuantizationMode:
+        """Retrieves the current quantization mode based on all configurations."""
+        qat_state = None
+        dynamic_state = None
+
+        # As we use `_need_skip_config` to skip all invalid configurations,
+        # we can safely assume that the all existing non-None configurations
+        # have the same quantization mode.
+        for qconfig in (
+            list(self.module_name_qconfig.values())
+            + list(self.operator_type_qconfig.values())
+            + [self.global_config]
+        ):
+            if qconfig is not None:
+                # Query the `is_qat` state
+                if qat_state is None:
+                    qat_state = qconfig.is_qat
+                else:
+                    assert qat_state == qconfig.is_qat, (
+                        f"All non-None quantization configs should have the same `is_qat`,"
+                        f"but got {qat_state} and {qconfig.is_qat}."
+                    )
+                # Query the `is_dynamic` state
+                input_activation_spec = qconfig.input_activation
+                if input_activation_spec is not None:
+                    if dynamic_state is None:
+                        dynamic_state = input_activation_spec.is_dynamic
+                    else:
+                        assert dynamic_state == input_activation_spec.is_dynamic, (
+                            f"All non-None `input_activation_spec` should have the same `is_dynamic`,"
+                            f"but got {dynamic_state} and {input_activation_spec.is_dynamic}."
+                        )
+        return _CurrentQuantizationMode(
+            qat_state=qat_state, dynamic_state=dynamic_state
+        )
+
+    def _need_skip_config(
+        self, quantization_config: Optional[QuantizationConfig]
+    ) -> bool:
+        """Check if the provided quantization config is valid for X86InductorQuantizer.
+
+        Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported.
+        To avoid such a mix, we compare the incoming configuration with current configuration status.
+        Refer the `_CurrentQuantizationMode` definition for all possible modes.
+        """
+        if quantization_config is None:
+            return False
+
+        need_skip = False
+        current_mode = self._get_current_quantization_mode()
+        if (
+            current_mode.qat_state is not None
+            and current_mode.qat_state != quantization_config.is_qat
+        ):
+            warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.")
+            need_skip = True
+        if current_mode.dynamic_state is not None:
+            input_activation_spec = quantization_config.input_activation
+            if (
+                input_activation_spec is not None
+                and current_mode.dynamic_state != input_activation_spec.is_dynamic
+            ):
+                warnings.warn(
+                    "Mixed dynamic and static quantization config is not supported."
+                )
+                need_skip = True
+        return need_skip
+
+    def set_global(self, quantization_config: QuantizationConfig):
+        if self._need_skip_config(quantization_config):
+            warnings.warn("Skip the global quantization config.")
+            return self
+        self.global_config = quantization_config
+        return self
+
+    def get_global_quantization_config(self):
+        if not isinstance(self.global_config, QuantizationConfig):
+            warnings.warn(
+                "The global_config for X86InductorQuantizer is currently invalid. \
+                Please ensure that you use set_global to establish the global quantization configuration."
+            )
+        return self.global_config
+
+    @_config_checker
+    def set_function_type_qconfig(
+        self,
+        function_type: Callable,
+        quantization_config: Optional[QuantizationConfig],
+    ) -> "X86InductorQuantizer":
+        if function_type in X86InductorQuantizer.module_function_to_aten_operator_type:
+            self._set_aten_operator_qconfig(
+                X86InductorQuantizer.module_function_to_aten_operator_type[
+                    function_type
+                ],
+                quantization_config,
+            )
+        else:
+            warnings.warn(
+                f"function: Unable to customize quantization config for {function_type} by X86InductorQuantizer."
+            )
+        return self
+
+    @_config_checker
+    def set_module_type_qconfig(
+        self,
+        module_type: torch.nn.Module,
+        quantization_config: Optional[QuantizationConfig],
+    ) -> "X86InductorQuantizer":
+        if module_type in X86InductorQuantizer.module_function_to_aten_operator_type:
+            self._set_aten_operator_qconfig(
+                X86InductorQuantizer.module_function_to_aten_operator_type[module_type],
+                quantization_config,
+            )
+        else:
+            warnings.warn(
+                f"Module: Unable to customize quantization config for {module_type} by X86InductorQuantizer."
+            )
+        return self
+
+    @_config_checker
+    def set_module_name_qconfig(
+        self, module_name: str, quantization_config: Optional[QuantizationConfig]
+    ):
+        """Set quantization_config for a submodule with name: `module_name`, for example:
+        quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator
+        patterns in the submodule with this module name with the given `quantization_config`
+
+        The supported operators include `quantizable_ops` and `propagation_quantizable_ops`.
+        """
+        self.module_name_qconfig[module_name] = quantization_config
+        return self
+
+    def _set_aten_operator_qconfig(
+        self,
+        operator_type: torch._ops.OpOverloadPacket,
+        quantization_config: Optional[QuantizationConfig],
+    ) -> "X86InductorQuantizer":
+        if operator_type in quantizable_ops:
+            self.operator_type_qconfig[operator_type] = quantization_config
+        else:
+            warnings.warn(
+                f"operator: Unable to quantize {operator} by X86InductorQuantizer."
+            )
+        return self
+
+    def _annotate_conv_node_helper(
+        self,
+        conv_node: torch.fx.Node,
+        annotate_output: bool,
+        quantization_config: Optional[QuantizationConfig],
+    ) -> None:
+        """Helper function to annotate the conv node"""
+        if quantization_config is None:
+            _annotate_nodes_not_quantize(conv_node)
+            return
+        input_qspec_map = {}
+        input_node = conv_node.args[0]
+        assert isinstance(input_node, Node)
+        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+        weight_node = conv_node.args[1]
+        assert isinstance(weight_node, Node)
+        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
+        bias_node = None if len(conv_node.args) == 2 else conv_node.args[2]
+        if isinstance(bias_node, Node):
+            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
+        if annotate_output:
+            conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map,
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+        else:
+            conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map,
+                _annotated=True,
+            )
+
+    def _annotate_linear_node_helper(
+        self,
+        linear_node: torch.fx.Node,
+        annotate_output: bool,
+        quantization_config: Optional[QuantizationConfig],
+    ) -> None:
+        """Helper function to annotate the linear node"""
+        if quantization_config is None:
+            _annotate_nodes_not_quantize(linear_node)
+            return
+        input_qspec_map = {}
+        assert linear_node.target in (torch.ops.aten.linear.default,)
+        has_bias = len(linear_node.args) == 3
+        input_index = 0
+        weight_index = 1
+        bias_index = 2
+
+        input_node = linear_node.args[input_index]
+        assert isinstance(input_node, Node)
+        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+
+        weight_node = linear_node.args[weight_index]
+        assert isinstance(weight_node, Node)
+        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
+
+        bias_node = linear_node.args[bias_index] if has_bias else None
+        if isinstance(bias_node, Node):
+            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
+
+        if annotate_output:
+            linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map,
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+        else:
+            linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map, _annotated=True
+            )
+
+    def _get_output_nodes_of_partitions(
+        self,
+        partition_list: list[SourcePartition],
+    ) -> list[torch.fx.Node]:
+        """Helper function to get the output node list from partition list"""
+        output_node_list = []
+        for partition in partition_list:
+            if len(partition.output_nodes) > 1:
+                raise ValueError("Input partition has more than one output node")
+            output_node = partition.output_nodes[0]
+            assert isinstance(output_node, Node)
+            output_node_list.append(output_node)
+        if len(output_node_list) != len(partition_list):
+            raise ValueError(
+                "length of output_node_list should equal to length of partition_list"
+            )
+        return output_node_list
+
+    def _get_input_idx_for_binary_node(
+        self,
+        conv_gemm_node: torch.fx.Node,
+        binary_node: torch.fx.Node,
+    ):
+        """Helper function to check conv_gemm and extra input node index
+        for binary node fused with conv_gemm.
+        """
+        conv_gemm_node_idx = None
+        extra_input_node_idx = None
+        if (binary_node.args[0].op == "call_function") and (  # type: ignore[union-attr]
+            binary_node.args[0] == conv_gemm_node
+        ):
+            conv_gemm_node_idx = 0
+            extra_input_node_idx = 1
+        elif (binary_node.args[1].op == "call_function") and (  # type: ignore[union-attr]
+            binary_node.args[1] == conv_gemm_node
+        ):
+            conv_gemm_node_idx = 1
+            extra_input_node_idx = 0
+        extra_input_node = binary_node.args[extra_input_node_idx]  # type: ignore[index]
+        assert isinstance(extra_input_node, Node)
+        return conv_gemm_node_idx, extra_input_node_idx
+
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """Annotate the given model with quantization configurations.
+
+        Annotation contracts:
+        1. Annotate each node according to the user's qconfig in the following order:
+        `module_name_qconfig`, `operator_type_qconfig`, and `global_config`.
+        2. Avoid re-annotating nodes already annotated in prior stages. For example,
+        if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again
+        during the processing of the 'operator_type_qconfig' or 'global_config'.
+        3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`.
+
+        For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created.
+        This filter function checks if the node is marked by current stage and not annotated by the previous stage.
+        """
+        for module_name, quantization_config in self.module_name_qconfig.items():
+            self._annotate_with_config(
+                model, quantization_config, _create_module_name_filter(module_name)
+            )
+
+        for operator_type, quantization_config in self.operator_type_qconfig.items():
+            self._annotate_with_config(
+                model, quantization_config, _create_operator_type_filter(operator_type)
+            )
+
+        if self.global_config:
+            self._annotate_with_config(
+                model,
+                self.global_config,
+                _global_config_filter,
+            )
+
+        # Once we've annotated the model with quantization configurations, we also need to annotate
+        # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs,
+        # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op.
+        # Refer to
+        # https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487  # noqa: B950
+
+        self._annotate_output_for_int8_in_int8_out_pattern_entry(model)
+
+        return model
+
+    def _annotate_with_config(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: FilterFn,
+    ) -> None:
+        """Annotate the model with the given quantization configuration.
+
+        High-level description of quantization recipe for X86 Inductor Backend:
+        Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively.
+        Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model
+        from start to the end. If a pattern supports computation with int8 data type and inputs connected to
+        quantized patterns, annotate its inputs as quantized pattern.
+        """
+
+        # Step1: Recipe of fusion patterns like conv/linear.
+        self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn)
+        self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn)
+        self._annotate_matmul(model, quantization_config, filter_fn)
+
+        # Step2: Recipe to propagate annotation for patterns beside conv/linear.
+        # Go through all the nodes from start to end.
+        # Recipe refer to
+        # https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538  # noqa: B950
+
+        self._annotate_propagation_quantizable_pattern_entry(
+            model, quantization_config, filter_fn
+        )
+
+    def _annotate_qat_conv2d_fusion_pattern(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ):
+        # Annotate QAT Specific patterns
+        self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn)
+        self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn)
+        self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn)
+        self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn)
+
+    def _annotate_qat_conv2d_bn_binary_unary(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU]
+        )
+        for fused_partition in fused_partitions:
+            (
+                conv_partition,
+                bn_partition,
+                binary_partition,
+                unary_partition,
+            ) = fused_partition
+
+            (
+                conv_node,
+                bn_output_node,
+                binary_node,
+                unary_node,
+            ) = self._get_output_nodes_of_partitions(
+                [conv_partition, bn_partition, binary_partition, unary_partition]
+            )
+            if len(bn_output_node.users) != 1:
+                # Conv BN pattern should only has 1 user.
+                continue
+            (
+                bn_output_node_idx,
+                extra_input_node_idx,
+            ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node)
+            if (bn_output_node_idx is None) or (extra_input_node_idx is None):
+                continue
+            if bn_output_node != binary_node.args[bn_output_node_idx]:
+                raise ValueError(f"{bn_output_node} doesn't match input of binary node")
+            extra_input_node = binary_node.args[extra_input_node_idx]
+
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+
+            if _skip_annotate(
+                [unary_node, binary_node, bn_output_node, conv_node], filter_fn
+            ):
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+
+            if quantization_config is not None:
+                binary_node_input_qspec_map = {}
+                binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
+                    quantization_config
+                )
+                binary_node.meta[QUANT_ANNOTATION_KEY] = (
+                    _X86InductorQuantizationAnnotation(
+                        input_qspec_map=binary_node_input_qspec_map,
+                        _annotated=True,
+                    )
+                )
+                unary_node.meta[QUANT_ANNOTATION_KEY] = (
+                    _X86InductorQuantizationAnnotation(
+                        # TODO Remove the annotate of output in QAT when qat util support pattern matcher.
+                        output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+                        _annotated=True,
+                        _is_output_of_quantized_pattern=True,
+                    )
+                )
+            else:
+                _annotate_nodes_not_quantize([binary_node, unary_node])
+            nodes_to_mark_annotated = list(conv_partition.nodes)
+            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
+            nodes_to_mark_annotated.extend(list(binary_partition.nodes))
+            nodes_to_mark_annotated.extend(list(unary_partition.nodes))
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+
+    def _annotate_qat_conv2d_bn_binary(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add]
+        )
+        for fused_partition in fused_partitions:
+            conv_partition, bn_partition, binary_partition = fused_partition
+            (
+                conv_node,
+                bn_output_node,
+                binary_node,
+            ) = self._get_output_nodes_of_partitions(
+                [conv_partition, bn_partition, binary_partition]
+            )
+            if len(bn_output_node.users) != 1:
+                # Conv BN pattern should only has 1 user.
+                continue
+            (
+                bn_output_node_idx,
+                extra_input_node_idx,
+            ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node)
+            if (bn_output_node_idx is None) or (extra_input_node_idx is None):
+                continue
+            if bn_output_node != binary_node.args[bn_output_node_idx]:
+                raise ValueError(f"{bn_output_node} doesn't match input of binary node")
+
+            extra_input_node = binary_node.args[extra_input_node_idx]
+
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+
+            if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn):
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+
+            if quantization_config is not None:
+                binary_node_input_qspec_map = {}
+                binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
+                    quantization_config
+                )
+                binary_node.meta[QUANT_ANNOTATION_KEY] = (
+                    _X86InductorQuantizationAnnotation(
+                        input_qspec_map=binary_node_input_qspec_map,
+                        # TODO Remove the annotate of output in QAT when qat util support pattern matcher.
+                        output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+                        _annotated=True,
+                        _is_output_of_quantized_pattern=True,
+                    )
+                )
+            else:
+                _annotate_nodes_not_quantize(binary_node)
+            nodes_to_mark_annotated = list(conv_partition.nodes)
+            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
+            nodes_to_mark_annotated.extend(list(binary_partition.nodes))
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+
+    def _annotate_qat_conv2d_bn_unary(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        fused_partitions = []
+        unary_patterns = [
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU],
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh],
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish],
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6],
+            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.SiLU],
+        ]
+        for unary_pattern in unary_patterns:
+            partitions = find_sequential_partitions(gm, unary_pattern)
+            if partitions:
+                # Extend the fused_partitions if partitions is not empty
+                fused_partitions.extend(partitions)
+
+        for fused_partition in fused_partitions:
+            conv_partition, bn_partition, unary_partition = fused_partition
+            (
+                conv_node,
+                bn_output_node,
+                unary_node,
+            ) = self._get_output_nodes_of_partitions(
+                [conv_partition, bn_partition, unary_partition]
+            )
+
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+
+            if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn):
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            if quantization_config is not None:
+                unary_node.meta[QUANT_ANNOTATION_KEY] = (
+                    _X86InductorQuantizationAnnotation(
+                        # TODO Remove the annotate of output in QAT when qat util support pattern matcher.
+                        output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+                        _annotated=True,
+                        _is_output_of_quantized_pattern=True,
+                    )
+                )
+            else:
+                _annotate_nodes_not_quantize(unary_node)
+            nodes_to_mark_annotated = list(conv_partition.nodes)
+            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
+            nodes_to_mark_annotated.extend(list(unary_partition.nodes))
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+
+    def _annotate_qat_conv2d_bn(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
+        )
+        for fused_partition in fused_partitions:
+            conv_partition, bn_partition = fused_partition
+            conv_node, bn_output_node = self._get_output_nodes_of_partitions(
+                [conv_partition, bn_partition]
+            )
+
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                continue
+
+            if _skip_annotate([bn_output_node, conv_node], filter_fn):
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            if quantization_config is not None:
+                bn_output_node.meta[QUANT_ANNOTATION_KEY] = (
+                    _X86InductorQuantizationAnnotation(
+                        # TODO Remove the annotate of output in QAT when qat util support pattern matcher.
+                        output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+                        _annotated=True,
+                        _is_output_of_quantized_pattern=True,
+                    )
+                )
+            else:
+                _annotate_nodes_not_quantize(bn_output_node)
+            nodes_to_mark_annotated = list(conv_partition.nodes)
+            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+
+    def _annotate_conv2d_fusion_pattern(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ):
+        if (quantization_config is None) or (quantization_config.is_qat):
+            # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat
+            self._annotate_qat_conv2d_fusion_pattern(
+                model, quantization_config, filter_fn
+            )
+        self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn)
+        self._annotate_conv2d_binary(model, quantization_config, filter_fn)
+        self._annotate_conv2d_unary(model, quantization_config, filter_fn)
+        self._annotate_conv2d(model, quantization_config, filter_fn)
+
+    def _annotate_linear_fusion_pattern(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ):
+        self._annotate_linear_binary_unary(model, quantization_config, filter_fn)
+        self._annotate_linear_unary(model, quantization_config, filter_fn)
+        self._annotate_linear(model, quantization_config, filter_fn)
+
+    def _annotate_matmul(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ):
+        for node in model.graph.nodes:
+            if node.target != torch.ops.aten.matmul.default:
+                continue
+            if _skip_annotate([node], filter_fn):
+                continue
+
+            if quantization_config is None:
+                _annotate_nodes_not_quantize(node)
+                continue
+
+            input_qspec_map = {}
+            matmul_node = node
+            for input_node in matmul_node.args:
+                input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+            matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=input_qspec_map,
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def _annotate_conv2d_binary_unary(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        # Conv2d + add + unary op
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU]
+        )
+        for fused_partition in fused_partitions:
+            conv_partition, binary_partition, unary_partition = fused_partition
+            conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions(
+                [conv_partition, binary_partition, unary_partition]
+            )
+            if len(conv_node.users) != 1:
+                # Conv Node should only has 1 user node
+                continue
+            conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
+                conv_node, binary_node
+            )
+            if (conv_node_idx is None) or (extra_input_node_idx is None):
+                continue
+            if conv_node != binary_node.args[conv_node_idx]:
+                raise ValueError(f"{conv_node} doesn't match input of binary node")
+            extra_input_node = binary_node.args[extra_input_node_idx]
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                # No conv node found to be fused with add
+                continue
+            if _skip_annotate([unary_node, binary_node, conv_node], filter_fn):
+                continue
+
+            if quantization_config is None:
+                _annotate_nodes_not_quantize([conv_node, binary_node, unary_node])
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            binary_node_input_qspec_map = {}
+            binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
+                quantization_config
+            )
+            binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=binary_node_input_qspec_map,
+                _annotated=True,
+            )
+            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def _annotate_conv2d_binary(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        # Conv2d + add
+        fused_partitions = find_sequential_partitions(
+            gm, [torch.nn.Conv2d, operator.add]
+        )
+        for fused_partition in fused_partitions:
+            conv_partition, binary_partition = fused_partition
+            conv_node, binary_node = self._get_output_nodes_of_partitions(
+                [conv_partition, binary_partition]
+            )
+            if len(conv_node.users) != 1:
+                # Conv Node should only has 1 user node
+                continue
+            conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
+                conv_node, binary_node
+            )
+            if (conv_node_idx is None) or (extra_input_node_idx is None):
+                continue
+            if conv_node != binary_node.args[conv_node_idx]:
+                raise ValueError(f"{conv_node} doesn't match input of binary node")
+            extra_input_node = binary_node.args[extra_input_node_idx]
+            assert isinstance(conv_node, Node)
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                # No conv node found to be fused with add
+                continue
+            if _skip_annotate([binary_node, conv_node], filter_fn):
+                continue
+
+            if quantization_config is None:
+                _annotate_nodes_not_quantize([conv_node, binary_node])
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            binary_node_input_qspec_map = {}
+            binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
+                quantization_config
+            )
+            binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                input_qspec_map=binary_node_input_qspec_map,
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def _annotate_conv2d_unary(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        fused_partitions = []
+        unary_patterns = [
+            [torch.nn.Conv2d, torch.nn.ReLU],
+            [torch.nn.Conv2d, torch.nn.Hardtanh],
+            [torch.nn.Conv2d, torch.nn.Hardswish],
+            [torch.nn.Conv2d, torch.nn.ReLU6],
+            [torch.nn.Conv2d, torch.nn.SiLU],
+            [torch.nn.Conv1d, torch.nn.ReLU],
+        ]
+        for unary_pattern in unary_patterns:
+            partitions = find_sequential_partitions(gm, unary_pattern)
+            if partitions:
+                # Extend the fused_partitions if partitions is not empty
+                fused_partitions.extend(partitions)
+
+        for fused_partition in fused_partitions:
+            conv_partition, unary_partition = fused_partition
+            conv_node, unary_node = self._get_output_nodes_of_partitions(
+                [conv_partition, unary_partition]
+            )
+            if conv_node.op != "call_function" or conv_node.target not in (
+                torch.ops.aten.conv2d.default,
+                torch.ops.aten.conv1d.default,
+            ):
+                continue
+            if _skip_annotate([unary_node, conv_node], filter_fn):
+                continue
+
+            if quantization_config is None:
+                _annotate_nodes_not_quantize([conv_node, unary_node])
+                continue
+
+            self._annotate_conv_node_helper(conv_node, False, quantization_config)
+            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def _annotate_conv2d(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        conv_partitions = get_source_partitions(
+            gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d]
+        )
+        conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values()))
+        for conv_partition in conv_partitions:
+            if len(conv_partition.output_nodes) > 1:
+                raise ValueError("conv partition has more than one output node")
+            conv_node = conv_partition.output_nodes[0]
+            if (
+                conv_node.op != "call_function"
+                or conv_node.target != torch.ops.aten.conv2d.default
+            ):
+                raise ValueError(f"{conv_node} is not an aten conv2d operator")
+            # skip annotation if it is already annotated
+            if _skip_annotate([conv_node], filter_fn):
+                continue
+            self._annotate_conv_node_helper(conv_node, True, quantization_config)
+
+    def _annotate_maxpool2d(
+        self,
+        node: Node,
+        quantization_config: Optional[QuantizationConfig],
+    ) -> None:
+        if node.target is not torch.ops.aten.max_pool2d.default:
+            return
+        if quantization_config is None:
+            _annotate_nodes_not_quantize(node)
+            return
+
+        maxpool_node = node
+        if _is_any_annotated(
+            [
+                maxpool_node,
+            ]
+        ):
+            return
+
+        input_node = maxpool_node.args[0]
+        assert isinstance(input_node, Node)
+        input_qspec_map = {}
+        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+        maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+            _is_output_of_quantized_pattern=True,
+        )
+
+    def _annotate_cat(
+        self, node: Node, quantization_config: QuantizationConfig
+    ) -> None:
+        if quantization_config is None:
+            _annotate_nodes_not_quantize(node)
+            return
+        cat_node = node
+        input_nodes = cat_node.args[0]
+        assert isinstance(input_nodes, Sequence)
+        first_input_node = input_nodes[0]
+        input_qspec_map = {}
+        assert isinstance(first_input_node, Node)
+        assert isinstance(cat_node, Node)
+        input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config)
+        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
+            (first_input_node, cat_node)
+        )
+
+        for input_node in input_nodes[1:]:
+            if input_node not in input_qspec_map:
+                # There has the case of cat same nodes: torch.cat([input0, input0], 1)
+                assert isinstance(input_node, Node)
+                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
+
+        cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+            _is_output_of_quantized_pattern=True,
+        )
+
+    def _annotate_propagation_quantizable_pattern_entry(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ):
+        for node in gm.graph.nodes:
+            self._annotate_propagation_quantizable_pattern(
+                node, quantization_config, filter_fn
+            )
+
+    def _annotate_propagation_quantizable_pattern(
+        self, node: Node, quantization_config, filter_fn
+    ) -> None:
+        # Propagate annotation to quantizable patterns.
+        if (
+            (node.target in propagation_quantizable_ops)
+            and (not _is_any_annotated([node]))
+            and (node.op == "call_function")
+        ):
+
+            def is_all_inputs_connected_to_quantized_op(input_nodes):
+                # Ensure all the inputs connect to fusion pattern or quantized node
+                for input_node in input_nodes:
+                    if not _is_quantized_op_pt2e(input_node):
+                        return False
+                return True
+
+            if _skip_annotate([node], filter_fn):
+                return
+
+            if quantization_config is None:
+                _annotate_nodes_not_quantize(node)
+                return
+
+            if node.target is torch.ops.aten.max_pool2d.default:
+                # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not
+                input_nodes_to_check = [node.all_input_nodes[0]]
+                if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
+                    if quantization_config is not None:
+                        warnings.warn(
+                            f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}."
+                        )
+                    return
+
+                self._annotate_maxpool2d(node, quantization_config)
+                return
+            elif node.target is torch.ops.aten.cat.default:
+                input_nodes_to_check = node.all_input_nodes
+                if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
+                    return
+                self._annotate_cat(node, quantization_config)
+            elif (
+                node.target is torch.ops.aten.flatten.using_ints
+                and len(node.users) > 0
+                and not any(
+                    user.target in quantizable_ops for user in node.users.keys()
+                )
+            ):
+                # Recipe of flatten: check if any users of flatten node are quantizable ops or not
+                return
+            else:
+                input_node = node.all_input_nodes[0]
+                if not is_all_inputs_connected_to_quantized_op(
+                    [
+                        input_node,
+                    ]
+                ):
+                    return
+                input_qspec_map = {}
+                input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+                node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                    input_qspec_map=input_qspec_map,
+                    _annotated=True,
+                    _is_output_of_quantized_pattern=True,
+                )
+        return
+
+    def _annotate_output_share_observer_as_input(
+        self, input_node: Node, source_node: Node
+    ):
+        source_node_quantization_annotation = (
+            source_node.meta[QUANT_ANNOTATION_KEY]
+            if QUANT_ANNOTATION_KEY in source_node.meta
+            else None
+        )
+        if (
+            source_node_quantization_annotation
+            and source_node_quantization_annotation._is_output_of_quantized_pattern
+        ):
+            edge_or_node = (input_node, source_node)
+            source_node_quantization_annotation.output_qspec = SharedQuantizationSpec(
+                edge_or_node
+            )
+        return
+
+    def _annotate_output_for_int8_in_int8_out_pattern_entry(
+        self,
+        model: torch.fx.GraphModule,
+    ):
+        for node in model.graph.nodes:
+            self._annotate_output_for_int8_in_int8_out_pattern(node)
+
+    def _annotate_output_for_int8_in_int8_out_pattern(
+        self,
+        node: Node,
+    ) -> None:
+        r"""
+        Check and insert observer at output of node in int8_in_int8_out_ops if needed.
+        Recipe refers to
+        https://github.com/intel/intel-extension-for-pytorch/blob/90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495
+        """  # noqa: B950
+        edge_or_node: tuple[Node, Node]
+        if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
+            if node.target == torch.ops.aten.max_pool2d.default:
+                maxpool_node = node
+                if not _is_all_annotated(
+                    [
+                        maxpool_node,
+                    ]
+                ):
+                    return
+
+                # Get the quantization_annotation from getitem_node
+                maxpool_node_quantization_annotation = (
+                    maxpool_node.meta[QUANT_ANNOTATION_KEY]
+                    if QUANT_ANNOTATION_KEY in maxpool_node.meta
+                    else None
+                )
+                if (
+                    maxpool_node_quantization_annotation
+                    and maxpool_node_quantization_annotation._is_output_of_quantized_pattern
+                ):
+                    # Annotate the output_qspec of getitem_node
+                    input_act = maxpool_node.args[0]
+                    assert isinstance(input_act, Node)
+                    assert isinstance(maxpool_node, Node)
+                    edge_or_node = (input_act, maxpool_node)
+                    maxpool_node_quantization_annotation.output_qspec = (
+                        SharedQuantizationSpec(edge_or_node)
+                    )
+            else:
+                input_node = node.all_input_nodes[0]
+                self._annotate_output_share_observer_as_input(input_node, node)
+        return
+
+    def _annotate_linear(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        linear_partitions = get_source_partitions(
+            gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
+        )
+        linear_partitions = list(
+            itertools.chain.from_iterable(linear_partitions.values())
+        )
+        for partition in linear_partitions:
+            if len(partition.output_nodes) > 1:
+                raise ValueError(
+                    "Linear partition cannot have more than one output node"
+                )
+            linear_node = partition.output_nodes[0]
+            if linear_node.op != "call_function" or linear_node.target not in (
+                torch.ops.aten.linear.default,
+            ):
+                raise ValueError(f"{linear_node} is not an aten linear operator")
+            # skip annotation if it is already annotated
+            if _skip_annotate([linear_node], filter_fn):
+                continue
+            self._annotate_linear_node_helper(linear_node, True, quantization_config)
+
+    def _annotate_linear_unary(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        postop_list = [
+            torch.nn.ReLU,
+            torch.nn.LeakyReLU,
+            torch.nn.Tanh,
+            torch.nn.GELU,
+        ]
+        fused_partitions: list[tuple] = []
+        for postop in postop_list:
+            fused_partitions = fused_partitions + find_sequential_partitions(
+                gm, [torch.nn.Linear, postop]
+            )
+        for fused_partition in fused_partitions:
+            linear_partition, unary_partition = fused_partition
+            linear_node, unary_node = self._get_output_nodes_of_partitions(
+                [linear_partition, unary_partition]
+            )
+            if linear_node.op != "call_function" or linear_node.target not in (
+                torch.ops.aten.linear.default,
+            ):
+                continue
+            if _skip_annotate([unary_node, linear_node], filter_fn):
+                continue
+
+            if quantization_config is None:
+                _annotate_nodes_not_quantize([linear_node, unary_node])
+                continue
+
+            self._annotate_linear_node_helper(linear_node, False, quantization_config)
+            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
+                _annotated=True,
+                _is_output_of_quantized_pattern=True,
+            )
+
+    def _annotate_linear_binary_unary(
+        self,
+        gm: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ) -> None:
+        # linear + binary_op + (optional) unary op
+        binary_op_list = [operator.add]
+        unary_op_list = [torch.nn.ReLU, None]
+        combinations = itertools.product(binary_op_list, unary_op_list)
+        for binary_op, unary_op in combinations:
+            has_unary = unary_op is not None
+            seq_partition = [torch.nn.Linear, binary_op]
+            if has_unary:
+                seq_partition.append(unary_op)
+            fused_partitions = find_sequential_partitions(gm, seq_partition)
+            for fused_partition in fused_partitions:
+                unary_partition, unary_node = None, None
+                if has_unary:
+                    (
+                        linear_partition,
+                        binary_partition,
+                        unary_partition,
+                    ) = fused_partition
+                    (
+                        linear_node,
+                        binary_node,
+                        unary_node,
+                    ) = self._get_output_nodes_of_partitions(
+                        [linear_partition, binary_partition, unary_partition]
+                    )
+                else:
+                    linear_partition, binary_partition = fused_partition
+                    linear_node, binary_node = self._get_output_nodes_of_partitions(
+                        [linear_partition, binary_partition]
+                    )
+                if len(linear_node.users) != 1:
+                    # Linear Node should only has 1 user node
+                    continue
+                (
+                    linear_node_idx,
+                    extra_input_node_idx,
+                ) = self._get_input_idx_for_binary_node(linear_node, binary_node)
+                if (linear_node_idx is None) or (extra_input_node_idx is None):
+                    continue
+                if linear_node != binary_node.args[linear_node_idx]:
+                    raise ValueError(
+                        f"{linear_node} doesn't match input of binary node"
+                    )
+                assert isinstance(linear_node, Node)
+                if (
+                    linear_node.op != "call_function"
+                    or linear_node.target != torch.ops.aten.linear.default
+                ):
+                    # No linear node found to be fused with add
+                    continue
+                node_list = (
+                    [binary_node, linear_node]
+                    if unary_node is None
+                    else [unary_node, binary_node, linear_node]
+                )
+                if _skip_annotate(node_list, filter_fn):
+                    continue
+
+                if quantization_config is None:
+                    _annotate_nodes_not_quantize(node_list)
+                    continue
+
+                self._annotate_linear_node_helper(
+                    linear_node, False, quantization_config
+                )
+                # We don't insert q-dq before the binary input node due to accuracy issues
+                binary_node.meta[QUANT_ANNOTATION_KEY] = (
+                    _X86InductorQuantizationAnnotation(
+                        input_qspec_map={},
+                        _annotated=True,
+                        _is_output_of_quantized_pattern=(not has_unary),
+                    )
+                )
+                if unary_node is not None:
+                    unary_node.meta[QUANT_ANNOTATION_KEY] = (
+                        _X86InductorQuantizationAnnotation(
+                            _annotated=True,
+                            _is_output_of_quantized_pattern=True,
+                        )
+                    )
+
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c87a568684686635958c98ecc0d547a9ecc9a9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py
@@ -0,0 +1,450 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import copy
+import functools
+import typing_extensions
+from typing import Any, Callable, Optional, TYPE_CHECKING
+
+import torch
+import torch._dynamo as torchdynamo
+import torch.nn.functional as F
+from torch.ao.quantization.fake_quantize import (
+    FakeQuantize,
+    FusedMovingAvgObsFakeQuantize,
+)
+from torch.ao.quantization.observer import (
+    HistogramObserver,
+    MinMaxObserver,
+    MovingAverageMinMaxObserver,
+    MovingAveragePerChannelMinMaxObserver,
+    PerChannelMinMaxObserver,
+    PlaceholderObserver,
+)
+from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
+from torch.ao.quantization.quantizer.utils import _get_module_name_filter
+from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
+    _convert_scalars_to_attrs,
+    OP_TO_ANNOTATOR,
+    OperatorConfig,
+    OperatorPatternType,
+    propagate_annotation,
+    QuantizationConfig,
+)
+from torch.fx._compatibility import compatibility
+
+
+if TYPE_CHECKING:
+    from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
+    from torch.fx import Node
+
+
+__all__ = [
+    "XNNPACKQuantizer",
+    "get_symmetric_quantization_config",
+]
+
+
+def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
+    gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
+    gm.graph.eliminate_dead_code()
+    return gm.graph
+
+
+def _get_linear_patterns(input_size: list[int]):
+    in_channels = input_size[-1]
+    out_channels = 8  # hard coding but this should not matter
+    weight = torch.ones((out_channels, in_channels))
+    bias = torch.ones((out_channels,))
+    act = torch.ones(input_size)
+
+    def linear_op(act, weight, bias=None):
+        return F.linear(act, weight, bias)
+
+    pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
+    pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
+    return [pattern_w_bias, pattern_wo_bias]
+
+
+def _supported_symmetric_quantized_operators() -> dict[str, list[OperatorPatternType]]:
+    supported_operators: dict[str, list[OperatorPatternType]] = {
+        # Both conv and linear should be able to handle relu + hardtanh fusion since
+        # those are clamp ops
+        "conv2d": [
+            [torch.nn.Conv2d, torch.nn.ReLU],
+            [torch.nn.Conv2d, F.relu],
+            [F.conv2d, torch.nn.ReLU],
+            [F.conv2d, F.relu],
+        ],
+        "linear": [[torch.nn.Linear], [F.linear]],
+        "add": [[torch.add]],
+        "adaptive_avg_pool2d": [
+            [torch.nn.AdaptiveAvgPool2d],
+            [F.adaptive_avg_pool2d],
+        ],
+    }
+    return copy.deepcopy(supported_operators)
+
+
+def _get_supported_symmetric_config_and_operators() -> list[OperatorConfig]:
+    supported_config_and_operators: list[OperatorConfig] = []
+    for quantization_config in [
+        get_symmetric_quantization_config(),
+        get_symmetric_quantization_config(is_qat=True),
+        get_symmetric_quantization_config(is_per_channel=True),
+        get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
+    ]:
+        ops = _supported_symmetric_quantized_operators()
+        supported_config_and_operators.extend(
+            OperatorConfig(quantization_config, pattern_list)
+            for pattern_list in ops.values()
+        )
+    return copy.deepcopy(supported_config_and_operators)
+
+
+@functools.lru_cache
+def get_symmetric_quantization_config(
+    is_per_channel: bool = False,
+    is_qat: bool = False,
+    is_dynamic: bool = False,
+    act_qmin: int = -128,
+    act_qmax: int = 127,
+    weight_qmin: int = -127,
+    weight_qmax: int = 127,
+):
+    extra_args: dict[str, Any] = {"eps": 2**-12}
+    if is_qat:
+        if is_dynamic:
+            act_observer_or_fake_quant_ctr = FakeQuantize
+            dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
+                averaging_constant=1
+            )
+            extra_args["observer"] = dynamic_quant_observer
+        else:
+            act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize  # type: ignore[assignment]
+    else:
+        if is_dynamic:
+            act_observer_or_fake_quant_ctr = PlaceholderObserver  # type: ignore[assignment]
+        else:
+            act_observer_or_fake_quant_ctr = HistogramObserver  # type: ignore[assignment]
+
+    act_quantization_spec = QuantizationSpec(
+        dtype=torch.int8,
+        quant_min=act_qmin,
+        quant_max=act_qmax,
+        qscheme=torch.per_tensor_affine,
+        is_dynamic=is_dynamic,
+        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
+            **extra_args,
+        ),
+    )
+    weight_qscheme = (
+        torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
+    )
+    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
+        MinMaxObserver
+    )
+    if is_qat:
+        # TODO: qat + per channel?
+        weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
+    elif is_per_channel:
+        weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
+
+    extra_args: dict[str, Any] = {"eps": 2**-12}
+    if is_qat:
+        if weight_qscheme == torch.per_tensor_symmetric:
+            extra_args["observer"] = MovingAverageMinMaxObserver
+        else:
+            extra_args["observer"] = MovingAveragePerChannelMinMaxObserver  # type: ignore[dict-item]
+    weight_quantization_spec = QuantizationSpec(
+        dtype=torch.int8,
+        quant_min=weight_qmin,
+        quant_max=weight_qmax,
+        qscheme=weight_qscheme,
+        ch_axis=0,
+        is_dynamic=False,
+        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
+            **extra_args
+        ),
+    )
+
+    bias_quantization_spec = None
+    if is_dynamic:
+        quantization_config = QuantizationConfig(
+            act_quantization_spec,
+            None,
+            weight_quantization_spec,
+            bias_quantization_spec,
+            is_qat,
+        )
+    else:
+        quantization_config = QuantizationConfig(
+            act_quantization_spec,
+            act_quantization_spec,
+            weight_quantization_spec,
+            bias_quantization_spec,
+            is_qat,
+        )
+    return quantization_config
+
+
+def _get_supported_config_and_operators() -> list[OperatorConfig]:
+    return _get_supported_symmetric_config_and_operators()
+
+
+def _get_module_type_filter(tp: Callable):
+    """Get the module_type_filter function for a given module type, the filter accepts
+    a node and checks if the node comes from a module that has certain module type
+
+    For example:
+        node: linear_op = call_function[...](...)  # comes from a module with type Block -> Sub -> Linear
+
+
+    >> module_type_filter = _get_module_type_filter(Sub)  # submodule with type `Sub`, under the `Block` submodule
+    >> print(module_type_filter(node))
+    True  # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
+    """
+
+    tp_str = tp.__module__ + "." + tp.__qualname__
+
+    def module_type_filter(n: Node) -> bool:
+        # example: {
+        #     'L__self___sub': ("L['self'].sub", ),
+        #     'L__self___sub_linear': ("L['self'].sub.linear", )
+        # }
+        nn_module_stack = n.meta.get("nn_module_stack", {})
+        types = []
+        for _, t in nn_module_stack.values():
+            # export() returns str, but older APIs (e.g. capture_pre_autograd_graph)
+            # return type. Handle both cases.
+            if isinstance(t, type):
+                t = t.__module__ + "." + t.__qualname__
+            types.append(t)
+        return tp_str in types
+
+    return module_type_filter
+
+
+def _get_not_module_type_or_name_filter(
+    tp_list: list[Callable], module_name_list: list[str]
+) -> Callable[[Node], bool]:
+    module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
+    module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
+
+    def not_module_type_or_name_filter(n: Node) -> bool:
+        return not any(f(n) for f in module_type_filters + module_name_list_filters)
+
+    return not_module_type_or_name_filter
+
+
+@compatibility(is_backward_compatible=False)
+@typing_extensions.deprecated(
+    "XNNPACKQuantizer is deprecated! Please use xnnpack quantizer in "
+    "ExecuTorch (https://github.com/pytorch/executorch/tree/main/backends/xnnpack/quantizer) instead."
+)
+class XNNPACKQuantizer(Quantizer):
+    """
+    !!! DEPRECATED !!!
+    XNNPACKQuantizer is a marked as deprected. It will be removed in the future.
+    It has been moved to executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer.
+    Please use the new quantizer instead.
+    """
+
+    supported_config_and_operators = _get_supported_config_and_operators()
+    STATIC_QAT_ONLY_OPS = [
+        "conv_bn_relu",
+        "conv_bn",
+        "conv_transpose_bn_relu",
+        "conv_transpose_bn",
+    ]
+
+    # static quantization ops (both PTQ and QAT)
+    # Preserve the order that fusions come before singular ops
+    STATIC_OPS = [
+        "linear_relu",
+        "linear",
+        "conv_relu",
+        "conv",
+        "conv_transpose_relu",
+        "adaptive_avg_pool2d",
+        # TODO: move this to BoltNNQuantizer?
+        "gru_io_only",
+        "add_relu",
+        "add",
+        "mul_relu",
+        "mul",
+        "cat",
+    ]
+
+    DYNAMIC_OPS = [
+        "linear",
+    ]
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.global_config: Optional[QuantizationConfig] = None
+        self.operator_type_config: dict[
+            torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
+        ] = {}
+        self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {}
+        self.module_name_config: dict[str, Optional[QuantizationConfig]] = {}
+
+    @classmethod
+    def get_supported_quantization_configs(cls) -> list[QuantizationConfig]:
+        op_configs: set[QuantizationConfig] = {
+            spec for spec, _ in cls.supported_config_and_operators
+        }
+        return list(op_configs)
+
+    @classmethod
+    def get_supported_operator_for_quantization_config(
+        cls, quantization_config: Optional[QuantizationConfig]
+    ) -> list[OperatorPatternType]:
+        if quantization_config is None:
+            all_ops = []
+            for _, ops in cls.supported_config_and_operators:
+                all_ops.extend(ops)
+            return all_ops
+
+        for config, ops in cls.supported_config_and_operators:
+            # note: this assumes each entry in cls.supported_spec_and_operators
+            # corresponds to one spec, e.g. we don't have
+            # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
+            # where the first and second entry have the same spec but did not
+            # merge the op list
+            if config == quantization_config:
+                return ops
+        return []
+
+    def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
+        self.global_config = quantization_config
+        return self
+
+    def set_operator_type(
+        self,
+        operator_type: torch._ops.OpOverloadPacket,
+        quantization_config: QuantizationConfig,
+    ) -> XNNPACKQuantizer:
+        self.operator_type_config[operator_type] = quantization_config
+        return self
+
+    def set_module_type(
+        self, module_type: Callable, quantization_config: QuantizationConfig
+    ):
+        """Set quantization_config for a submodule with type: `module_type`, for example:
+        quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
+        patterns in the submodule with this module type with the given `quantization_config`
+        """
+        self.module_type_config[module_type] = quantization_config
+        return self
+
+    def set_module_name(
+        self, module_name: str, quantization_config: Optional[QuantizationConfig]
+    ):
+        """Set quantization_config for a submodule with name: `module_name`, for example:
+        quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
+        patterns in the submodule with this module name with the given `quantization_config`
+        """
+        assert quantization_config is not None, (
+            " quantization_config == None is not supported yet"
+        )
+        self.module_name_config[module_name] = quantization_config
+        return self
+
+    def transform_for_annotation(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        """Transforms scalar values to tensor attributes"""
+        return _convert_scalars_to_attrs(model)
+
+    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+        """just handling global spec for now"""
+        # hacked for handling dynamic linear quant. will fix later.
+        if self.global_config and self.global_config.input_activation.is_dynamic:  # type: ignore[union-attr]
+            model = self._annotate_for_dynamic_quantization_config(model)
+        else:
+            model = self._annotate_for_static_quantization_config(model)
+        propagate_annotation(model)
+        return model
+
+    def _annotate_all_static_patterns(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[Callable[[Node], bool]] = None,
+    ) -> torch.fx.GraphModule:
+        # TODO: implement the support for None to be canceling out previous annotations
+        if quantization_config is None:
+            return model
+
+        if quantization_config.is_qat:
+            for op in self.STATIC_QAT_ONLY_OPS:
+                OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
+        for op in self.STATIC_OPS:
+            OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
+        return model
+
+    def _annotate_all_dynamic_patterns(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[Callable[[Node], bool]] = None,
+    ) -> torch.fx.GraphModule:
+        # TODO: implement the support for None to be canceling out previous annotations
+        if quantization_config is None:
+            return model
+
+        for op in self.DYNAMIC_OPS:
+            OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
+        return model
+
+    def _annotate_for_static_quantization_config(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        module_name_list = list(self.module_name_config.keys())
+        for module_name, config in self.module_name_config.items():
+            self._annotate_all_static_patterns(
+                model, config, _get_module_name_filter(module_name)
+            )
+
+        tp_list = list(self.module_type_config.keys())
+        for module_type, config in self.module_type_config.items():
+            self._annotate_all_static_patterns(
+                model, config, _get_module_type_filter(module_type)
+            )
+
+        self._annotate_all_static_patterns(
+            model,
+            self.global_config,
+            _get_not_module_type_or_name_filter(tp_list, module_name_list),
+        )
+        return model
+
+    def _annotate_for_dynamic_quantization_config(
+        self, model: torch.fx.GraphModule
+    ) -> torch.fx.GraphModule:
+        module_name_list = list(self.module_name_config.keys())
+        for module_name, config in self.module_name_config.items():
+            self._annotate_all_dynamic_patterns(
+                model, config, _get_module_name_filter(module_name)
+            )
+
+        tp_list = list(self.module_type_config.keys())
+        for module_type, config in self.module_type_config.items():
+            self._annotate_all_dynamic_patterns(
+                model, config, _get_module_type_filter(module_type)
+            )
+
+        self._annotate_all_dynamic_patterns(
+            model,
+            self.global_config,
+            _get_not_module_type_or_name_filter(tp_list, module_name_list),
+        )
+        return model
+
+    def validate(self, model: torch.fx.GraphModule) -> None:
+        pass
+
+    @classmethod
+    def get_supported_operators(cls) -> list[OperatorConfig]:
+        return cls.supported_config_and_operators
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb3914524edba18a19a732d6457fa24cb05107a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
@@ -0,0 +1,1127 @@
+# mypy: allow-untyped-defs
+import itertools
+import typing
+from dataclasses import dataclass
+from typing import Callable, NamedTuple, Optional
+
+import torch
+import torch.nn.functional as F
+from torch._subclasses import FakeTensor
+from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
+from torch.ao.quantization.pt2e.export_utils import _WrapperModule
+from torch.ao.quantization.pt2e.utils import (
+    _get_aten_graph_module_for_pattern,
+    _is_conv_node,
+    _is_conv_transpose_node,
+)
+from torch.ao.quantization.quantizer import (
+    QuantizationAnnotation,
+    QuantizationSpec,
+    SharedQuantizationSpec,
+)
+from torch.ao.quantization.quantizer.utils import (
+    _annotate_input_qspec_map,
+    _annotate_output_qspec,
+)
+from torch.fx import Node
+from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
+    SubgraphMatcherWithNameNodeMap,
+)
+from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
+
+
+__all__ = [
+    "OperatorConfig",
+    "OperatorPatternType",
+    "QuantizationConfig",
+    "get_input_act_qspec",
+    "get_output_act_qspec",
+    "get_weight_qspec",
+    "get_bias_qspec",
+    "OP_TO_ANNOTATOR",
+    "propagate_annotation",
+]
+
+
+# In the absence of better name, just winging it with QuantizationConfig
+@dataclass(eq=True, frozen=True)
+class QuantizationConfig:
+    input_activation: Optional[QuantizationSpec]
+    output_activation: Optional[QuantizationSpec]
+    weight: Optional[QuantizationSpec]
+    bias: Optional[QuantizationSpec]
+    # TODO: remove, since we can use observer_or_fake_quant_ctr to express this
+    is_qat: bool = False
+
+
+# Use Annotated because list[Callable].__module__ is read-only.
+OperatorPatternType = typing.Annotated[list[Callable], None]
+OperatorPatternType.__module__ = (
+    "torch.ao.quantization.quantizer.xnnpack_quantizer_utils"
+)
+
+AnnotatorType = Callable[
+    [
+        torch.fx.GraphModule,
+        Optional[QuantizationConfig],
+        Optional[Callable[[Node], bool]],
+    ],
+    Optional[list[list[Node]]],
+]
+OP_TO_ANNOTATOR: dict[str, AnnotatorType] = {}
+
+
+def register_annotator(op: str) -> Callable[[AnnotatorType], None]:
+    def decorator(annotator: AnnotatorType) -> None:
+        OP_TO_ANNOTATOR[op] = annotator
+
+    return decorator
+
+
+class OperatorConfig(NamedTuple):
+    # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
+    # Basically we are mapping a quantization config to some list of patterns.
+    # a pattern is defined as a list of nn module, function or builtin function names
+    # e.g. [nn.Conv2d, torch.relu, torch.add]
+    # We have not resolved whether fusion can be considered internal details of the
+    # quantizer hence it does not need communication to user.
+    # Note this pattern is not really informative since it does not really
+    # tell us the graph structure resulting from the list of ops.
+    config: QuantizationConfig
+    operators: list[OperatorPatternType]
+
+
+def _is_annotated(nodes: list[Node]):
+    """
+    Given a list of nodes (that represents an operator pattern),
+    check if any of the node is annotated, return True if any of the node
+    is annotated, otherwise return False
+    """
+    annotated = False
+    for node in nodes:
+        annotated = annotated or (
+            "quantization_annotation" in node.meta
+            and node.meta["quantization_annotation"]._annotated
+        )
+    return annotated
+
+
+def _mark_nodes_as_annotated(nodes: list[Node]):
+    for node in nodes:
+        if node is not None:
+            if "quantization_annotation" not in node.meta:
+                node.meta["quantization_annotation"] = QuantizationAnnotation()
+            node.meta["quantization_annotation"]._annotated = True
+
+
+def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
+    if quantization_config is None:
+        return None
+    if quantization_config.input_activation is None:
+        return None
+    quantization_spec: QuantizationSpec = quantization_config.input_activation
+    assert quantization_spec.qscheme in [
+        torch.per_tensor_affine,
+        torch.per_tensor_symmetric,
+    ]
+    return quantization_spec
+
+
+def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
+    if quantization_config is None:
+        return None
+    if quantization_config.output_activation is None:
+        return None
+    quantization_spec: QuantizationSpec = quantization_config.output_activation
+    assert quantization_spec.qscheme in [
+        torch.per_tensor_affine,
+        torch.per_tensor_symmetric,
+    ]
+    return quantization_spec
+
+
+def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
+    if quantization_config is None:
+        return None
+    assert quantization_config is not None
+    if quantization_config.weight is None:
+        return None
+    quantization_spec: QuantizationSpec = quantization_config.weight
+    if quantization_spec.qscheme not in [
+        torch.per_tensor_symmetric,
+        torch.per_channel_symmetric,
+        None,
+    ]:
+        raise ValueError(
+            f"Unsupported quantization_spec {quantization_spec} for weight"
+        )
+    return quantization_spec
+
+
+def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
+    if quantization_config is None:
+        return None
+    assert quantization_config is not None
+    if quantization_config.bias is None:
+        return None
+    quantization_spec: QuantizationSpec = quantization_config.bias
+    assert quantization_spec.dtype == torch.float, (
+        "Only float dtype for bias is supported for bias right now"
+    )
+    return quantization_spec
+
+
+@register_annotator("linear")
+def _annotate_linear(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    annotated_partitions = []
+    input_act_qspec = get_input_act_qspec(quantization_config)
+    output_act_qspec = get_output_act_qspec(quantization_config)
+    weight_qspec = get_weight_qspec(quantization_config)
+    bias_qspec = get_bias_qspec(quantization_config)
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
+            continue
+        if filter_fn and not filter_fn(node):
+            continue
+        act_node = node.args[0]
+        weight_node = node.args[1]
+        bias_node = None
+        if len(node.args) > 2:
+            bias_node = node.args[2]
+
+        if _is_annotated([node]) is False:  # type: ignore[list-item]
+            _annotate_input_qspec_map(
+                node,
+                act_node,
+                input_act_qspec,
+            )
+            _annotate_input_qspec_map(
+                node,
+                weight_node,
+                weight_qspec,
+            )
+            nodes_to_mark_annotated = [node, weight_node]
+            if bias_node:
+                _annotate_input_qspec_map(
+                    node,
+                    bias_node,
+                    bias_qspec,
+                )
+                nodes_to_mark_annotated.append(bias_node)
+            _annotate_output_qspec(node, output_act_qspec)
+            _mark_nodes_as_annotated(nodes_to_mark_annotated)
+            annotated_partitions.append(nodes_to_mark_annotated)
+
+    return annotated_partitions
+
+
+@register_annotator("linear_relu")
+def _annotate_linear_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    annotated_partitions = []
+    input_act_qspec = get_input_act_qspec(quantization_config)
+    output_act_qspec = get_output_act_qspec(quantization_config)
+    weight_qspec = get_weight_qspec(quantization_config)
+    bias_qspec = get_bias_qspec(quantization_config)
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target not in [
+            torch.ops.aten.relu.default,
+            torch.ops.aten.relu_.default,
+        ]:
+            continue
+        relu_node = node
+        maybe_linear_node = node.args[0]
+        if (
+            not isinstance(maybe_linear_node, Node)
+            or maybe_linear_node.op != "call_function"
+            or maybe_linear_node.target != torch.ops.aten.linear.default
+        ):
+            continue
+
+        linear_node = maybe_linear_node
+        if len(linear_node.users) > 1:
+            # if linear node has multiple users, then it can't be fused with relu
+            continue
+
+        input_qspec_map = {}
+        input_act = linear_node.args[0]
+        assert isinstance(input_act, Node)
+        input_qspec_map[input_act] = input_act_qspec
+
+        weight = linear_node.args[1]
+        assert isinstance(weight, Node)
+        input_qspec_map[weight] = weight_qspec
+
+        # adding weight node to the partition as well
+        partition = [relu_node, linear_node, weight]
+        bias = linear_node.args[2] if len(linear_node.args) > 2 else None
+        if isinstance(bias, Node):
+            input_qspec_map[bias] = bias_qspec
+            partition.append(bias)
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        linear_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+        )
+        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+        _mark_nodes_as_annotated(partition)
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("conv")
+def _annotate_conv(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    annotated_partitions = []
+    for n in gm.graph.nodes:
+        if n.op != "call_function" or n.target not in [
+            torch.ops.aten.conv1d.default,
+            torch.ops.aten.conv2d.default,
+        ]:
+            continue
+        conv_node = n
+
+        input_qspec_map = {}
+        input_act = conv_node.args[0]
+        assert isinstance(input_act, Node)
+        input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
+
+        weight = conv_node.args[1]
+        assert isinstance(weight, Node)
+        input_qspec_map[weight] = get_weight_qspec(quantization_config)
+
+        # adding weight node to the partition as well
+        partition = [conv_node, conv_node.args[1]]
+
+        bias = conv_node.args[2] if len(conv_node.args) > 2 else None
+        if isinstance(bias, Node):
+            input_qspec_map[bias] = get_bias_qspec(quantization_config)
+            partition.append(bias)
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            output_qspec=get_output_act_qspec(quantization_config),
+            _annotated=True,
+        )
+        _mark_nodes_as_annotated(partition)
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+def _do_annotate_conv_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+    is_conv_transpose: bool = False,
+):
+    annotated_partitions = []
+    for n in gm.graph.nodes:
+        if n.op != "call_function" or n.target not in [
+            torch.ops.aten.relu.default,
+            torch.ops.aten.relu_.default,
+        ]:
+            continue
+        relu_node = n
+        maybe_conv_node = n.args[0]
+
+        is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node
+        if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node):
+            continue
+        conv_node = maybe_conv_node
+
+        if len(conv_node.users) > 1:
+            # relu shouldn't be fuseable to conv if there are other users
+            # of convolution
+            continue
+
+        input_qspec_map = {}
+        input_act = conv_node.args[0]
+        assert isinstance(input_act, Node)
+        input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
+
+        weight = conv_node.args[1]
+        assert isinstance(weight, Node)
+        input_qspec_map[weight] = get_weight_qspec(quantization_config)
+
+        # adding weight node to the partition as well
+        partition = [relu_node, conv_node, conv_node.args[1]]
+        bias = conv_node.args[2] if len(conv_node.args) > 2 else None
+        if isinstance(bias, Node):
+            input_qspec_map[bias] = get_bias_qspec(quantization_config)
+            partition.append(bias)
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map, _annotated=True
+        )
+        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+            _annotated=True,
+        )
+        _mark_nodes_as_annotated(partition)
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("conv_relu")
+def _annotate_conv_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    return _do_annotate_conv_relu(
+        gm, quantization_config, filter_fn, is_conv_transpose=False
+    )
+
+
+@register_annotator("conv_transpose_relu")
+def _annotate_conv_transpose_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    return _do_annotate_conv_relu(
+        gm, quantization_config, filter_fn, is_conv_transpose=True
+    )
+
+
+@register_annotator("conv_bn")
+def _annotate_conv_bn(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    """
+    Find conv + batchnorm parititions
+    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
+    """
+    return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)
+
+
+@register_annotator("conv_bn_relu")
+def _annotate_conv_bn_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    """
+    Find conv + batchnorm + relu parititions
+    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
+    """
+    return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
+
+
+@register_annotator("conv_transpose_bn")
+def _annotate_conv_transpose_bn(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    """
+    Find conv_transpose + batchnorm parititions
+    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
+    """
+    return _do_annotate_conv_bn(
+        gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True
+    )
+
+
+@register_annotator("conv_transpose_bn_relu")
+def _annotate_conv_transpose_bn_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    """
+    Find conv_transpose + batchnorm + relu parititions
+    Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
+    """
+    return _do_annotate_conv_bn(
+        gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True
+    )
+
+
+def _do_annotate_conv_bn(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]],
+    has_relu: bool,
+    is_conv_transpose: bool = False,
+) -> list[list[Node]]:
+    """
+    Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
+    return a list of annotated partitions.
+
+    The output of the pattern must include a dictionary from string name to node
+    for the following names: "input", "conv", "weight", "bias", and "output".
+    """
+
+    # Example inputs for conv-bn1d patterns
+    _conv1d_bn_example_inputs = (
+        torch.randn(1, 1, 3),  # x
+        torch.randn(1, 1, 1),  # conv_weight
+        torch.randn(1),  # conv_bias
+        torch.randn(1),  # bn_weight
+        torch.randn(1),  # bn_bias
+        torch.randn(1),  # bn_running_mean
+        torch.randn(1),  # bn_running_var
+    )
+
+    # Example inputs for conv-bn2d patterns
+    _conv2d_bn_example_inputs = (
+        torch.randn(1, 1, 3, 3),  # x
+        torch.randn(1, 1, 1, 1),  # conv_weight
+        torch.randn(1),  # conv_bias
+        torch.randn(1),  # bn_weight
+        torch.randn(1),  # bn_bias
+        torch.randn(1),  # bn_running_mean
+        torch.randn(1),  # bn_running_var
+    )
+
+    def get_pattern(conv_fn: Callable, relu_is_inplace: bool):
+        def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
+            conv = conv_fn(x, conv_weight, conv_bias)
+            bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True)
+            if has_relu:
+                output = F.relu_(bn) if relu_is_inplace else F.relu(bn)
+            else:
+                output = bn
+            return output, {
+                "input": x,
+                "conv": conv,
+                "weight": conv_weight,
+                "bias": conv_bias,
+                "output": output,
+            }
+
+        return _WrapperModule(_conv_bn)
+
+    # Needed for matching, otherwise the matches gets filtered out due to unused
+    # nodes returned by batch norm
+    gm.graph.eliminate_dead_code()
+    gm.recompile()
+
+    matches = []
+    if is_conv_transpose:
+        combinations = [
+            (F.conv_transpose1d, _conv1d_bn_example_inputs),
+            (F.conv_transpose2d, _conv2d_bn_example_inputs),
+        ]
+    else:
+        combinations = [
+            (F.conv1d, _conv1d_bn_example_inputs),  # type: ignore[list-item]
+            (F.conv2d, _conv2d_bn_example_inputs),  # type: ignore[list-item]
+        ]
+
+    # Add `is_cuda` and `relu_is_inplace` dimensions
+    combinations = itertools.product(  # type: ignore[assignment]
+        combinations,
+        [True, False] if torch.cuda.is_available() else [False],  # is_cuda
+        [True, False] if has_relu else [False],  # relu_is_inplace
+    )
+
+    # Match against all conv dimensions and cuda variants
+    for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:  # type: ignore[misc]
+        pattern = get_pattern(conv_fn, relu_is_inplace)  # type: ignore[has-type]
+        pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)  # type: ignore[has-type]
+        pattern.graph.eliminate_dead_code()
+        pattern.recompile()
+        matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
+        matches.extend(matcher.match(gm.graph))
+
+    # Annotate nodes returned in the matches
+    annotated_partitions = []
+    for match in matches:
+        name_node_map = match.name_node_map
+        input_node = name_node_map["input"]
+        conv_node = name_node_map["conv"]
+        weight_node = name_node_map["weight"]
+        bias_node = name_node_map["bias"]
+        output_node = name_node_map["output"]
+
+        # TODO: annotate the uses of input, weight, and bias separately instead
+        # of assuming they come from a single conv node. This is not possible today
+        # because input may have multiple users, and we can't rely on the conv node
+        # always being the first user. This was the case in models with skip
+        # connections like resnet18
+
+        # Validate conv args
+        if conv_node.args[0] is not input_node:
+            raise ValueError("Conv arg did not contain input node ", input_node)
+        if conv_node.args[1] is not weight_node:
+            raise ValueError("Conv arg did not contain weight node ", weight_node)
+        if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node:
+            raise ValueError("Conv arg did not contain bias node ", bias_node)
+
+        # Skip if the partition is already annotated or is filtered out by the user
+        partition = [conv_node, weight_node]
+        if bias_node is not None:
+            partition.append(bias_node)
+        if _is_annotated(partition):
+            continue
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        # Annotate conv inputs and pattern output
+        input_qspec_map = {}
+        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
+        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
+        if bias_node is not None:
+            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
+        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+        )
+        output_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
+            _annotated=True,
+        )
+        _mark_nodes_as_annotated(partition)
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("gru_io_only")
+def _annotate_gru_io_only(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn)
+    gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values()))
+    annotated_partitions = []
+    for gru_partition in gru_partitions:
+        annotated_partitions.append(gru_partition.nodes)
+        output_nodes = gru_partition.output_nodes
+        input_nodes = gru_partition.input_nodes
+        # skip annotation if it is already annotated
+        if _is_annotated(input_nodes + output_nodes):
+            continue
+        # inside each GRU partition, we should be able to annotate each linear
+        # subgraph
+        input_act = input_nodes[0]
+        input_act_user = next(iter(input_act.users.keys()))
+        assert isinstance(input_act, Node)
+        assert isinstance(input_act_user, Node)
+        input_act_user.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map={
+                input_act: get_input_act_qspec(quantization_config),
+            },
+            _annotated=True,
+        )
+
+        hidden_state = input_nodes[1]
+        hidden_state_user = next(iter(hidden_state.users.keys()))
+        assert isinstance(hidden_state, Node)
+        assert isinstance(hidden_state_user, Node)
+        hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map={
+                hidden_state: get_input_act_qspec(quantization_config),
+            },
+            _annotated=True,
+        )
+
+        assert len(output_nodes) == 2, "expecting GRU to have two outputs"
+        for output in output_nodes:
+            output.meta["quantization_annotation"] = QuantizationAnnotation(
+                output_qspec=get_output_act_qspec(quantization_config),
+                _annotated=True,
+            )
+        nodes_to_mark_annotated = list(gru_partition.nodes)
+        _mark_nodes_as_annotated(nodes_to_mark_annotated)
+    return annotated_partitions
+
+
+@register_annotator("adaptive_avg_pool2d")
+def _annotate_adaptive_avg_pool2d(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    """Always annotate adaptive_avg_pool2d op"""
+    module_partitions = get_source_partitions(
+        gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn
+    )
+    partitions = list(itertools.chain.from_iterable(module_partitions.values()))
+    annotated_partitions = []
+    for partition in partitions:
+        pool_node = partition.output_nodes[0]
+        if (
+            pool_node.op != "call_function"
+            or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default
+        ):
+            raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
+
+        if _is_annotated([pool_node]):
+            continue
+
+        annotated_partitions.append(partition.nodes)
+        input_act = pool_node.args[0]
+        assert isinstance(input_act, Node)
+
+        # only annotate input output sharing operator
+        # when the output of the input node is annotated
+        if (
+            "quantization_annotation" not in input_act.meta
+            or not input_act.meta["quantization_annotation"]._annotated
+            or input_act.meta["quantization_annotation"].output_qspec is None
+        ):
+            input_act_qspec = get_input_act_qspec(quantization_config)
+        else:
+            input_act_qspec = SharedQuantizationSpec(input_act)
+
+        # output sharing with input
+        output_act_qspec = SharedQuantizationSpec((input_act, pool_node))
+        pool_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map={
+                input_act: input_act_qspec,
+            },
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule):
+    """Check if input is a large scalar value. So that we can skip quantization for the node
+    since histc op (in HistogramObserver) only works for values up to certain upper bound
+    """
+    if node.op == "get_attr":
+        qualified_name = str(node.target)
+        module_path, _, name = qualified_name.rpartition(".")
+        submod = gm.get_submodule(module_path)
+        tensor = getattr(submod, name)
+        # torch.histc works until this upper bound
+        HISTC_UPPER_BOUND = 3.4028235e15
+        return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
+    return False
+
+
+def _is_input_non_float_tensor(node: Node):
+    """Check if the input is not a float tensor, so that we can skip quantization for the node
+    since observers only works with float Tensors
+    """
+    if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
+        return True
+    return node.meta["val"].dtype != torch.float32
+
+
+@register_annotator("add_relu")
+def _annotate_add_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    annotated_partitions = []
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target not in [
+            torch.ops.aten.relu.default,
+            torch.ops.aten.relu_.default,
+        ]:
+            continue
+        relu_node = node
+        maybe_add = node.args[0]
+        if (
+            not isinstance(maybe_add, Node)
+            or maybe_add.op != "call_function"
+            or maybe_add.target
+            not in [
+                torch.ops.aten.add.Tensor,
+                torch.ops.aten.add_.Tensor,
+            ]
+        ):
+            continue
+
+        add_node = maybe_add
+
+        if len(add_node.users) > 1:
+            # add can't be fused with ReLU if the result of add is being used
+            # else where in the graph
+            continue
+
+        partition = [relu_node, add_node]
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        output_act_qspec = get_output_act_qspec(quantization_config)
+
+        input_qspec_map = {}
+        input_act0 = add_node.args[0]
+        if isinstance(input_act0, Node):
+            if _is_input_large_scalar(input_act0, gm):
+                continue
+            if _is_input_non_float_tensor(input_act0):
+                continue
+            partition.append(input_act0)
+            input_qspec_map[input_act0] = input_act_qspec
+
+        input_act1 = add_node.args[1]
+        if isinstance(input_act1, Node):
+            if _is_input_large_scalar(input_act1, gm):
+                continue
+            if _is_input_non_float_tensor(input_act1):
+                continue
+            partition.append(input_act1)
+            input_qspec_map[input_act1] = input_act_qspec
+
+        add_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+        )
+        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("add")
+def _annotate_add(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    annotated_partitions = []
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target not in [
+            torch.ops.aten.add.Tensor,
+            torch.ops.aten.add_.Tensor,
+        ]:
+            continue
+        add_node = node
+        partition = [add_node]
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        output_act_qspec = get_output_act_qspec(quantization_config)
+
+        input_qspec_map = {}
+        input_act0 = add_node.args[0]
+        if isinstance(input_act0, Node):
+            if _is_input_large_scalar(input_act0, gm):
+                continue
+            if _is_input_non_float_tensor(input_act0):
+                continue
+            input_qspec_map[input_act0] = input_act_qspec
+            partition.append(input_act0)
+
+        input_act1 = add_node.args[1]
+        if isinstance(input_act1, Node):
+            if _is_input_large_scalar(input_act1, gm):
+                continue
+            if _is_input_non_float_tensor(input_act1):
+                continue
+            input_qspec_map[input_act1] = input_act_qspec
+            partition.append(input_act1)
+
+        add_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("mul_relu")
+def _annotate_mul_relu(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    annotated_partitions = []
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target not in [
+            torch.ops.aten.relu.default,
+            torch.ops.aten.relu_.default,
+        ]:
+            continue
+        relu_node = node
+        maybe_mul = node.args[0]
+        if (
+            not isinstance(maybe_mul, Node)
+            or maybe_mul.op != "call_function"
+            or maybe_mul.target
+            not in [
+                torch.ops.aten.mul.Tensor,
+                torch.ops.aten.mul_.Tensor,
+            ]
+        ):
+            continue
+
+        mul_node = maybe_mul
+        if len(mul_node.users) > 1:
+            # mul can't be fused with ReLU if the result of mul is being used
+            # else where in the graph
+            continue
+
+        partition = [relu_node, mul_node]
+
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        output_act_qspec = get_output_act_qspec(quantization_config)
+
+        input_qspec_map = {}
+        input_act0 = mul_node.args[0]
+        if isinstance(input_act0, Node):
+            if _is_input_large_scalar(input_act0, gm):
+                continue
+            if _is_input_non_float_tensor(input_act0):
+                continue
+            partition.append(input_act0)
+            input_qspec_map[input_act0] = input_act_qspec
+
+        input_act1 = mul_node.args[1]
+        if isinstance(input_act1, Node):
+            if _is_input_large_scalar(input_act1, gm):
+                continue
+            if _is_input_non_float_tensor(input_act1):
+                continue
+            partition.append(input_act1)
+            input_qspec_map[input_act1] = input_act_qspec
+
+        mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            _annotated=True,
+        )
+        relu_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+@register_annotator("mul")
+def _annotate_mul(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    annotated_partitions = []
+    for node in gm.graph.nodes:
+        if node.op != "call_function" or node.target not in [
+            torch.ops.aten.mul.Tensor,
+            torch.ops.aten.mul_.Tensor,
+        ]:
+            continue
+
+        mul_node = node
+        partition = [mul_node]
+        if _is_annotated(partition):
+            continue
+
+        if filter_fn and any(not filter_fn(n) for n in partition):
+            continue
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        output_act_qspec = get_output_act_qspec(quantization_config)
+
+        input_qspec_map = {}
+        input_act0 = mul_node.args[0]
+        if isinstance(input_act0, Node):
+            if _is_input_large_scalar(input_act0, gm):
+                continue
+            if _is_input_non_float_tensor(input_act0):
+                continue
+            input_qspec_map[input_act0] = input_act_qspec
+            partition.append(input_act0)
+
+        input_act1 = mul_node.args[1]
+        if isinstance(input_act1, Node):
+            if _is_input_large_scalar(input_act1, gm):
+                continue
+            if _is_input_non_float_tensor(input_act1):
+                continue
+            input_qspec_map[input_act1] = input_act_qspec
+            partition.append(input_act0)
+
+        mul_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+        annotated_partitions.append(partition)
+    return annotated_partitions
+
+
+# TODO: remove Optional in return type, fix annotated_partitions logic
+@register_annotator("cat")
+def _annotate_cat(
+    gm: torch.fx.GraphModule,
+    quantization_config: Optional[QuantizationConfig],
+    filter_fn: Optional[Callable[[Node], bool]] = None,
+) -> Optional[list[list[Node]]]:
+    cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn)
+    cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values()))
+    annotated_partitions = []
+    for cat_partition in cat_partitions:
+        cat_node = cat_partition.output_nodes[0]
+        if _is_annotated([cat_node]):
+            continue
+
+        if cat_node.target != torch.ops.aten.cat.default:
+            # TODO: change this to AnnotationException
+            raise Exception(  # noqa: TRY002
+                f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
+                " please check if you are calling the correct capture API"
+            )
+
+        annotated_partitions.append(cat_partition.nodes)
+
+        input_act_qspec = get_input_act_qspec(quantization_config)
+        inputs = cat_node.args[0]
+
+        input_qspec_map = {}
+        input_act0 = inputs[0]  # type: ignore[index]
+        if isinstance(input_act0, Node):
+            input_qspec_map[input_act0] = input_act_qspec
+
+        shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node))  # type: ignore[arg-type]
+        for input_act in inputs[1:]:  # type: ignore[index, union-attr]
+            if input_act not in input_qspec_map:
+                input_qspec_map[input_act] = shared_with_input0_qspec  # type: ignore[index]
+
+        output_act_qspec = shared_with_input0_qspec
+
+        cat_node.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map=input_qspec_map,
+            output_qspec=output_act_qspec,
+            _annotated=True,
+        )
+    return annotated_partitions
+
+
+def _is_share_obs_or_fq_op(op: Callable) -> bool:
+    return op in [
+        torch.ops.aten.relu.default,
+        torch.ops.aten.hardtanh.default,
+        torch.ops.aten.hardtanh_.default,
+        torch.ops.aten.max_pool2d.default,
+        torch.ops.aten.mean.default,
+        torch.ops.aten.mean.dim,
+        torch.ops.aten.permute.default,
+        torch.ops.aten.permute_copy.default,
+        torch.ops.aten.squeeze.dim,
+        torch.ops.aten.squeeze_copy.dim,
+        # TODO: remove?
+        torch.ops.aten.adaptive_avg_pool2d.default,
+        torch.ops.aten.view_copy.default,
+        torch.ops.aten.view.default,
+        torch.ops.aten.slice_copy.Tensor,
+        torch.ops.aten.flatten.using_ints,
+    ]
+
+
+def propagate_annotation(model: torch.fx.GraphModule) -> None:
+    for n in model.graph.nodes:
+        if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
+            continue
+
+        prev_node = n.args[0]
+        if not isinstance(prev_node, Node):
+            continue
+
+        quantization_annotation = prev_node.meta.get("quantization_annotation", None)
+        if not quantization_annotation:
+            continue
+
+        output_qspec = quantization_annotation.output_qspec
+        if not output_qspec:
+            continue
+
+        # make sure current node is not annotated
+        if (
+            "quantization_annotation" in n.meta
+            and n.meta["quantization_annotation"]._annotated
+        ):
+            continue
+
+        shared_qspec = SharedQuantizationSpec(prev_node)
+        # propagate the previous output_qspec to the current node
+        n.meta["quantization_annotation"] = QuantizationAnnotation(
+            input_qspec_map={
+                prev_node: shared_qspec,
+            },
+            output_qspec=shared_qspec,
+            _annotated=True,
+        )
+
+
+# TODO: make the list of ops customizable
+def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
+    for n in model.graph.nodes:
+        if n.op != "call_function" or n.target not in [
+            torch.ops.aten.add.Tensor,
+            torch.ops.aten.mul.Tensor,
+        ]:
+            continue
+        args = list(n.args)
+        new_args = []
+        for i in range(len(args)):
+            if isinstance(args[i], torch.fx.Node):
+                new_args.append(args[i])
+                continue
+            prefix = "_tensor_constant_"
+            get_new_attr_name = get_new_attr_name_with_prefix(prefix)
+            tensor_constant_name = get_new_attr_name(model)
+            float_tensor = torch.tensor(float(args[i]))
+            model.register_buffer(tensor_constant_name, float_tensor)
+            fake_mode = n.meta["val"].fake_mode
+            with model.graph.inserting_before(n):
+                get_attr_node = model.graph.create_node(
+                    "get_attr", tensor_constant_name, (), {}
+                )
+                get_attr_node.meta["val"] = fake_mode.from_tensor(
+                    float_tensor, static_shapes=True
+                )
+                new_args.append(get_attr_node)
+        n.args = tuple(new_args)
+    model.recompile()
+    return model
diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..68dd42936cf529a95987694e9d0165910c64b34c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py
@@ -0,0 +1,120 @@
+# mypy: allow-untyped-defs
+import functools
+from typing import Any, Optional, TYPE_CHECKING
+
+import torch
+from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
+from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
+from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
+    _is_any_annotated,
+    FilterFn,
+    int8_in_int8_out_ops,
+    X86InductorQuantizer,
+)
+from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
+from torch.fx import Node
+
+
+if TYPE_CHECKING:
+    from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
+
+__all__ = [
+    "XPUInductorQuantizer",
+    "get_default_xpu_inductor_quantization_config",
+]
+
+
+@functools.lru_cache
+def get_default_xpu_inductor_quantization_config():
+    extra_args: dict[str, Any] = {"eps": 2**-12}
+    act_observer_or_fake_quant_ctr = HistogramObserver
+    act_quantization_spec = QuantizationSpec(
+        dtype=torch.int8,
+        quant_min=-128,
+        quant_max=127,
+        qscheme=torch.per_tensor_affine,
+        is_dynamic=False,
+        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
+            **extra_args
+        ),
+    )
+
+    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
+        PerChannelMinMaxObserver
+    )
+
+    weight_quantization_spec = QuantizationSpec(
+        dtype=torch.int8,
+        quant_min=-128,
+        quant_max=127,
+        qscheme=torch.per_channel_symmetric,
+        ch_axis=0,  # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
+        is_dynamic=False,
+        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
+            **extra_args
+        ),
+    )
+
+    bias_quantization_spec = None  # will use placeholder observer by default
+    quantization_config = QuantizationConfig(
+        act_quantization_spec,
+        act_quantization_spec,
+        weight_quantization_spec,
+        bias_quantization_spec,
+        False,
+    )
+    return quantization_config
+
+
+class XPUInductorQuantizer(X86InductorQuantizer):
+    """
+    XPUInductorQuantizer is a class designed to facilitate
+    quantization capability at Intel GPU backend. The class
+    highly reuses the existing implementation of
+    X86InductorQuantizer as both are intended to take advantage
+    of the optimized kernels in oneDNN library.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    """
+        Following annotate_xx overrides the impls in base class, as
+        no XPU implementation for these operators currently. We would
+        gradually enable the XPU implementation and remove following
+        overrides. We keep the annotate methods but make the function
+        body empty, aiming to let `_generate_qdq_quantized_model`
+        generate qdq around op and graph execute on fp32 dtype for
+        unspported operators.
+    """
+
+    def _annotate_qat_conv2d_fusion_pattern(
+        self,
+        model: torch.fx.GraphModule,
+        quantization_config: Optional[QuantizationConfig],
+        filter_fn: Optional[FilterFn] = None,
+    ):
+        pass
+
+    def _annotate_maxpool2d(
+        self,
+        node: Node,
+        quantization_config: Optional[QuantizationConfig],
+    ) -> None:
+        """
+        Here we skip the annotate logic for maxpool at XPU backend
+        as the quantized::max_pool2d is only implemented for CPU.
+        """
+        return
+
+    def _annotate_output_for_int8_in_int8_out_pattern(
+        self,
+        node: Node,
+    ) -> None:
+        if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
+            if node.target == torch.ops.aten.max_pool2d.default:
+                return
+            else:
+                input_node = node.all_input_nodes[0]
+                self._annotate_output_share_observer_as_input(input_node, node)
+        return
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39bc647e6868666524209ef0c3993cd7de042294
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/_coreml/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py b/.venv/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..3180e56a6baf96b56c88a712a4426108d8c8e2fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/_coreml/preprocess.py
@@ -0,0 +1,150 @@
+# mypy: allow-untyped-defs
+import hashlib
+import json
+
+import coremltools as ct  # type: ignore[import]
+from coremltools.converters.mil.input_types import TensorType  # type: ignore[import]
+from coremltools.converters.mil.mil import types  # type: ignore[import]
+from coremltools.models.neural_network import quantization_utils  # type: ignore[import]
+
+import torch
+
+
+CT_METADATA_VERSION = "com.github.apple.coremltools.version"
+CT_METADATA_SOURCE = "com.github.apple.coremltools.source"
+
+
+class ScalarType:
+    Float = 0
+    Double = 1
+    Int = 2
+    Long = 3
+    Undefined = 4
+
+
+# Supported Tensor types in coremltools:
+# https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/converter.py#L28
+torch_to_mil_types = {
+    ScalarType.Float: types.fp32,
+    ScalarType.Double: types.fp64,
+    ScalarType.Int: types.int32,
+    ScalarType.Long: types.int64,
+}
+
+
+class CoreMLComputeUnit:
+    CPU = "cpuOnly"
+    CPUAndGPU = "cpuAndGPU"
+    ALL = "all"
+
+
+class CoreMLQuantizationMode:
+    LINEAR = "linear"
+    LINEAR_SYMMETRIC = "linear_symmetric"
+    NONE = "none"
+
+
+def TensorSpec(shape, dtype=ScalarType.Float):
+    return (shape, dtype)
+
+
+def CompileSpec(
+    inputs,
+    outputs,
+    backend=CoreMLComputeUnit.CPU,
+    allow_low_precision=True,
+    quantization_mode=CoreMLQuantizationMode.NONE,
+    mlmodel_export_path=None,
+    convert_to=None,
+):
+    return (
+        inputs,
+        outputs,
+        backend,
+        allow_low_precision,
+        quantization_mode,
+        mlmodel_export_path,
+        convert_to,
+    )
+
+
+def _check_enumerated_shape(shape):
+    for s in shape:
+        if not isinstance(s, (list, tuple)):
+            return False
+    return True
+
+
+def _convert_to_mil_type(shape, dtype, name: str):
+    mil_shape = shape
+    if _check_enumerated_shape(shape):
+        mil_shape = ct.EnumeratedShapes(shape)
+    ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype])
+    ml_type.name = name
+    return ml_type
+
+
+def preprocess(script_module: torch._C.ScriptObject, compile_spec: dict[str, tuple]):
+    spec = compile_spec["forward"]
+    (
+        input_specs,
+        output_specs,
+        backend,
+        allow_low_precision,
+        quantization_mode,
+        mlmodel_export_path,
+        convert_to,
+    ) = spec
+    mil_inputs = []
+    inputs = []
+    for index, input in enumerate(input_specs):
+        shape, dtype = input
+        name = "input_" + str(index)
+        inputs.append([name, str(dtype), str(shape)])
+        ml_type = _convert_to_mil_type(shape, dtype, name)
+        mil_inputs.append(ml_type)
+    model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None)
+    mlmodel = ct.convert(model, inputs=mil_inputs, convert_to=convert_to)
+
+    if quantization_mode != CoreMLQuantizationMode.NONE:
+        quant_model_spec = quantization_utils.quantize_weights(
+            mlmodel, nbits=8, quantization_mode=quantization_mode
+        )
+        mlmodel = ct.models.MLModel(quant_model_spec)
+
+    spec = mlmodel.get_spec()
+    assert len(spec.description.output) == len(output_specs)  # type: ignore[attr-defined]
+    outputs = []
+    for index, output in enumerate(output_specs):
+        shape, dtype = output
+        name = spec.description.output[index].name  # type: ignore[attr-defined]
+        outputs.append([name, str(dtype), str(shape)])
+    mlmodel = ct.models.model.MLModel(spec)
+    print(mlmodel)
+
+    if mlmodel_export_path is not None:
+        print(f"Saving CoreML .mlmodel file to {mlmodel_export_path}")
+        mlmodel.save(mlmodel_export_path)
+
+    config = {
+        "spec_ver": str(spec.specificationVersion),  # type: ignore[attr-defined]
+        "backend": backend,
+        "allow_low_precision": str(allow_low_precision),
+    }
+    metadata = {
+        "coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION],
+        "torch_ver": mlmodel.user_defined_metadata[CT_METADATA_SOURCE],
+    }
+    coreml_compile_spec = {
+        "inputs": inputs,
+        "outputs": outputs,
+        "config": config,
+        "metadata": metadata,
+    }
+    mlmodel = spec.SerializeToString()  # type: ignore[attr-defined]
+
+    return {
+        "model": mlmodel,
+        "hash": str(hashlib.sha256(mlmodel).hexdigest()),
+        "extra": json.dumps(coreml_compile_spec),
+    }
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py b/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fc48d711111ffd417fa1c544bd4b2362e75cf16
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/prepare.py
@@ -0,0 +1,199 @@
+# mypy: allow-untyped-decorators
+# mypy: allow-untyped-defs
+from typing import Optional
+
+import torch
+from torch.backends._nnapi.serializer import _NnapiSerializer
+
+
+ANEURALNETWORKS_PREFER_LOW_POWER = 0
+ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
+ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
+
+
+class NnapiModule(torch.nn.Module):
+    """Torch Module that wraps an NNAPI Compilation.
+
+    This module handles preparing the weights, initializing the
+    NNAPI TorchBind object, and adjusting the memory formats
+    of all inputs and outputs.
+    """
+
+    # _nnapi.Compilation is defined
+    comp: Optional[torch.classes._nnapi.Compilation]  # type: ignore[name-defined]
+    weights: list[torch.Tensor]
+    out_templates: list[torch.Tensor]
+
+    def __init__(
+        self,
+        shape_compute_module: torch.nn.Module,
+        ser_model: torch.Tensor,
+        weights: list[torch.Tensor],
+        inp_mem_fmts: list[int],
+        out_mem_fmts: list[int],
+        compilation_preference: int,
+        relax_f32_to_f16: bool,
+    ):
+        super().__init__()
+        self.shape_compute_module = shape_compute_module
+        self.ser_model = ser_model
+        self.weights = weights
+        self.inp_mem_fmts = inp_mem_fmts
+        self.out_mem_fmts = out_mem_fmts
+        self.out_templates = []
+        self.comp = None
+        self.compilation_preference = compilation_preference
+        self.relax_f32_to_f16 = relax_f32_to_f16
+
+    @torch.jit.export
+    def init(self, args: list[torch.Tensor]):
+        assert self.comp is None
+        self.out_templates = self.shape_compute_module.prepare(self.ser_model, args)  # type: ignore[operator]
+        self.weights = [w.contiguous() for w in self.weights]
+        comp = torch.classes._nnapi.Compilation()
+        comp.init2(
+            self.ser_model,
+            self.weights,
+            self.compilation_preference,
+            self.relax_f32_to_f16,
+        )
+
+        self.comp = comp
+
+    def forward(self, args: list[torch.Tensor]) -> list[torch.Tensor]:
+        if self.comp is None:
+            self.init(args)
+        comp = self.comp
+        assert comp is not None
+        outs = [torch.empty_like(out) for out in self.out_templates]
+
+        assert len(args) == len(self.inp_mem_fmts)
+        fixed_args = []
+        for idx in range(len(args)):
+            fmt = self.inp_mem_fmts[idx]
+            # These constants match the values in DimOrder in serializer.py
+            # TODO: See if it's possible to use those directly.
+            if fmt == 0:
+                fixed_args.append(args[idx].contiguous())
+            elif fmt == 1:
+                fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
+            else:
+                raise ValueError("Invalid mem_fmt")
+        comp.run(fixed_args, outs)
+        assert len(outs) == len(self.out_mem_fmts)
+        for idx in range(len(self.out_templates)):
+            fmt = self.out_mem_fmts[idx]
+            # These constants match the values in DimOrder in serializer.py
+            # TODO: See if it's possible to use those directly.
+            if fmt in (0, 2):
+                pass
+            elif fmt == 1:
+                outs[idx] = outs[idx].permute(0, 3, 1, 2)
+            else:
+                raise ValueError("Invalid mem_fmt")
+        return outs
+
+
+def convert_model_to_nnapi(
+    model,
+    inputs,
+    serializer=None,
+    return_shapes=None,
+    use_int16_for_qint16=False,
+    compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
+    relax_f32_to_f16=False,
+):
+    (
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        retval_count,
+    ) = process_for_nnapi(
+        model, inputs, serializer, return_shapes, use_int16_for_qint16
+    )
+
+    nnapi_model = NnapiModule(
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        compilation_preference,
+        relax_f32_to_f16,
+    )
+
+    class NnapiInterfaceWrapper(torch.nn.Module):
+        """NNAPI list-ifying and de-list-ifying wrapper.
+
+        NNAPI always expects a list of inputs and provides a list of outputs.
+        This module allows us to accept inputs as separate arguments.
+        It returns results as either a single tensor or tuple,
+        matching the original module.
+        """
+
+        def __init__(self, mod):
+            super().__init__()
+            self.mod = mod
+
+    wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
+    wrapper_model = torch.jit.script(wrapper_model_py)
+    # TODO: Maybe make these names match the original.
+    arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
+    if retval_count < 0:
+        ret_expr = "retvals[0]"
+    else:
+        ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
+    wrapper_model.define(
+        f"def forward(self, {arg_list}):\n"
+        f"    retvals = self.mod([{arg_list}])\n"
+        f"    return {ret_expr}\n"
+    )
+    return wrapper_model
+
+
+def process_for_nnapi(
+    model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
+):
+    model = torch.jit.freeze(model)
+
+    if isinstance(inputs, torch.Tensor):
+        inputs = [inputs]
+
+    serializer = serializer or _NnapiSerializer(
+        config=None, use_int16_for_qint16=use_int16_for_qint16
+    )
+    (
+        ser_model,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        shape_compute_lines,
+        retval_count,
+    ) = serializer.serialize_model(model, inputs, return_shapes)
+    ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
+
+    # We have to create a new class here every time this function is called
+    # because module.define adds a method to the *class*, not the instance.
+    class ShapeComputeModule(torch.nn.Module):
+        """Code-gen-ed module for tensor shape computation.
+
+        module.prepare will mutate ser_model according to the computed operand
+        shapes, based on the shapes of args.  Returns a list of output templates.
+        """
+
+    shape_compute_module = torch.jit.script(ShapeComputeModule())
+    real_shape_compute_lines = [
+        "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
+    ] + [f"    {line}\n" for line in shape_compute_lines]
+    shape_compute_module.define("".join(real_shape_compute_lines))
+
+    return (
+        shape_compute_module,
+        ser_model_tensor,
+        used_weights,
+        inp_mem_fmts,
+        out_mem_fmts,
+        retval_count,
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py b/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2769b69eb83fe89db0d797822facc711dce1ed8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/_nnapi/serializer.py
@@ -0,0 +1,2228 @@
+# mypy: allow-untyped-defs
+import array
+import enum
+import functools
+import logging
+import operator
+import struct
+import sys
+from typing import NamedTuple, Optional
+
+import torch
+
+
+# TODO: Add type annotations
+# TODO: Check tensor types for ops
+
+
+LOG = logging.getLogger("nnapi_serialize")
+
+
+class NNAPI_OperandCode:
+    FLOAT32 = 0
+    INT32 = 1
+    UINT32 = 2
+    TENSOR_FLOAT32 = 3
+    TENSOR_INT32 = 4
+    TENSOR_QUANT8_ASYMM = 5
+    BOOL = 6
+    TENSOR_QUANT16_SYMM = 7
+    TENSOR_FLOAT16 = 8
+    TENSOR_BOOL8 = 9
+    FLOAT16 = 10
+    TENSOR_QUANT8_SYMM_PER_CHANNEL = 11
+    TENSOR_QUANT16_ASYMM = 12
+
+
+class NNAPI_OperationCode:
+    ADD = 0
+    AVERAGE_POOL_2D = 1
+    CONCATENATION = 2
+    CONV_2D = 3
+    DEPTHWISE_CONV_2D = 4
+    DEPTH_TO_SPACE = 5
+    DEQUANTIZE = 6
+    EMBEDDING_LOOKUP = 7
+    FLOOR = 8
+    FULLY_CONNECTED = 9
+    HASHTABLE_LOOKUP = 10
+    L2_NORMALIZATION = 11
+    L2_POOL_2D = 12
+    LOCAL_RESPONSE_NORMALIZATION = 13
+    LOGISTIC = 14
+    LSH_PROJECTION = 15
+    LSTM = 16
+    MAX_POOL_2D = 17
+    MUL = 18
+    RELU = 19
+    RELU1 = 20
+    RELU6 = 21
+    RESHAPE = 22
+    RESIZE_BILINEAR = 23
+    RNN = 24
+    SOFTMAX = 25
+    SPACE_TO_DEPTH = 26
+    SVDF = 27
+    TANH = 28
+    BATCH_TO_SPACE_ND = 29
+    DIV = 30
+    MEAN = 31
+    PAD = 32
+    SPACE_TO_BATCH_ND = 33
+    SQUEEZE = 34
+    STRIDED_SLICE = 35
+    SUB = 36
+    TRANSPOSE = 37
+    ABS = 38
+    ARGMAX = 39
+    ARGMIN = 40
+    AXIS_ALIGNED_BBOX_TRANSFORM = 41
+    BIDIRECTIONAL_SEQUENCE_LSTM = 42
+    BIDIRECTIONAL_SEQUENCE_RNN = 43
+    BOX_WITH_NMS_LIMIT = 44
+    CAST = 45
+    CHANNEL_SHUFFLE = 46
+    DETECTION_POSTPROCESSING = 47
+    EQUAL = 48
+    EXP = 49
+    EXPAND_DIMS = 50
+    GATHER = 51
+    GENERATE_PROPOSALS = 52
+    GREATER = 53
+    GREATER_EQUAL = 54
+    GROUPED_CONV_2D = 55
+    HEATMAP_MAX_KEYPOINT = 56
+    INSTANCE_NORMALIZATION = 57
+    LESS = 58
+    LESS_EQUAL = 59
+    LOG = 60
+    LOGICAL_AND = 61
+    LOGICAL_NOT = 62
+    LOGICAL_OR = 63
+    LOG_SOFTMAX = 64
+    MAXIMUM = 65
+    MINIMUM = 66
+    NEG = 67
+    NOT_EQUAL = 68
+    PAD_V2 = 69
+    POW = 70
+    PRELU = 71
+    QUANTIZE = 72
+    QUANTIZED_16BIT_LSTM = 73
+    RANDOM_MULTINOMIAL = 74
+    REDUCE_ALL = 75
+    REDUCE_ANY = 76
+    REDUCE_MAX = 77
+    REDUCE_MIN = 78
+    REDUCE_PROD = 79
+    REDUCE_SUM = 80
+    ROI_ALIGN = 81
+    ROI_POOLING = 82
+    RSQRT = 83
+    SELECT = 84
+    SIN = 85
+    SLICE = 86
+    SPLIT = 87
+    SQRT = 88
+    TILE = 89
+    TOPK_V2 = 90
+    TRANSPOSE_CONV_2D = 91
+    UNIDIRECTIONAL_SEQUENCE_LSTM = 92
+    UNIDIRECTIONAL_SEQUENCE_RNN = 93
+    RESIZE_NEAREST_NEIGHBOR = 94
+
+
+class NNAPI_FuseCode:
+    FUSED_NONE = 0
+    FUSED_RELU = 1
+    FUSED_RELU1 = 2
+    FUSED_RELU6 = 3
+
+
+class OperandValueSourceType:
+    IMMEDIATE = 0
+    NUMBERED_BUFFER = 2
+    NUMBERED_MEMORY = 3
+
+
+# Scalar types that appear explicitly in models.
+# These must be kept in sync with
+# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
+# TODO: Expose these directly to Python to avoid maintaining this list.
+class TorchScalarTypes(enum.Enum):
+    QUINT8 = 13
+
+
+def approx_equal(lhs, rhs, tolerance=1e-6):
+    return abs(lhs - rhs) <= tolerance * min(lhs, rhs)
+
+
+def tensor_size(op_type, dims):
+    ITEM_SIZES = {
+        NNAPI_OperandCode.TENSOR_FLOAT32: 4,
+        NNAPI_OperandCode.TENSOR_INT32: 4,
+        NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1,
+        NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2,
+        NNAPI_OperandCode.TENSOR_QUANT16_ASYMM: 2,
+    }
+    size = ITEM_SIZES[op_type]
+    for d in dims:
+        size *= d
+    return size
+
+
+def change_element(tup, index, value):
+    ls = list(tup)
+    ls[index] = value
+    return tuple(ls)
+
+
+class ConvPoolArgs2d(NamedTuple):
+    """Configuration arguments for a convolution."""
+
+    kernel_h: int
+    kernel_w: int
+    stride_h: int
+    stride_w: int
+    pad_t: int
+    pad_b: int
+    pad_l: int
+    pad_r: int
+    dilation_h: int
+    dilation_w: int
+    group: int
+
+
+class DimOrder(enum.Enum):
+    PRESUMED_CONTIGUOUS = 0
+    CHANNELS_LAST = 1
+    SCALAR_OR_VECTOR = 2
+    UNKNOWN_CONSTANT = 999
+
+
+class Operand(NamedTuple):
+    """Represenation of an NNAPI operand."""
+
+    # NNAPI operand type.  One of NNAPI_OperandCode.
+    # TODO: Make this an enum.
+    op_type: int
+
+    # This is always the PyTorch shape, which is NCHW for feature maps.
+    # The actual NNAPI operand might have a transposed shape.
+    # we use 0 for load time dynamic shapes & -1 for runtime dynamic shapes
+    shape: tuple[int, ...]
+
+    # Specifies how the shape of the operand that we define in NNAPI
+    # relates to the shape we track above.
+    # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match
+    #   the shape of the PyTorch tensor.
+    # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and
+    #   the NNAPI operand will be represented explicitly as NHWC.
+    dim_order: DimOrder
+
+    # Quantization params
+    scale: float
+    zero_point: int
+
+    def use_nchw(self):
+        if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS:
+            return True
+        if self.dim_order is DimOrder.CHANNELS_LAST:
+            return False
+        raise Exception("Unknown dim order")  # noqa: TRY002
+
+
+def broadcast_shapes(shape1, shape2):
+    assert len(shape1) > 0
+    assert len(shape2) > 0
+    s1 = list(shape1)
+    s2 = list(shape2)
+    # TODO: Support non-equal-rank broadcast where semantics match.
+    # This can be tricky for NHWC tensors because dimension orders
+    # don't match between PT and NNAPI, even though semantics match.
+    if len(s1) > len(s2):
+        # s2 = [1] * (len(s1) - len(s2)) + s2
+        raise Exception(  # noqa: TRY002
+            "Non-equal-rank broadcast is not supported yet."
+        )  # noqa: TRY002
+    if len(s2) > len(s1):
+        # s3 = [1] * (len(s2) - len(s1)) + s1
+        raise Exception(  # noqa: TRY002
+            "Non-equal-rank broadcast is not supported yet."
+        )  # noqa: TRY002
+    ret = []
+    for d1, d2 in zip(s1, s2):
+        if d1 == 1:
+            ret.append(d2)
+        elif d2 == 1:
+            ret.append(d1)
+        elif d1 == d2:
+            ret.append(d1)
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Cannot broadcast shapes: {shape1} and {shape2}"
+            )  # noqa: TRY002
+    return tuple(ret)
+
+
+def get_conv_pool_shape(image_shape, args, out_ch, transpose):
+    batch, _in_c, in_h, in_w = image_shape
+
+    # TODO: Handle dilation
+    if args.dilation_h != 1 or args.dilation_w != 1:
+        raise Exception("Dilation not supported yet.")  # noqa: TRY002
+
+    if transpose:
+        out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b
+        out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l
+    else:
+        out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1
+        out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1
+
+    # Handle variable-sized tensors.
+    if in_h == 0:
+        out_h = 0
+    if in_w == 0:
+        out_w = 0
+
+    out_shape = (batch, out_ch, out_h, out_w)
+    return out_shape
+
+
+def fix_shape(shape, dim_order):
+    # Return the actual shape that an operand should have in NNAPI,
+    # given a PyTorch shape and dimension order.  This is where we
+    # convert from PyTorch's "always NCHW" shape to explicit NHWC.
+    if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
+        return shape
+    if dim_order is DimOrder.CHANNELS_LAST:
+        return tuple([shape[0]] + list(shape[2:]) + [shape[1]])
+    if dim_order is DimOrder.SCALAR_OR_VECTOR:
+        assert len(shape) == 0 or len(shape) == 1
+        return shape
+    if dim_order is DimOrder.UNKNOWN_CONSTANT:
+        # XXX think this through
+        return shape
+    raise Exception(f"Bad dim_order: {dim_order!r}.")  # noqa: TRY002
+
+
+def reverse_map_dim(dim_order, d):
+    # Return the original PyTorch dimension position for a given dimension.
+    # d should be the dimension that NNAPI will see.
+    # reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x
+    # reverse_map_dim(CHANNELS_LAST, 3) == 1
+    if dim_order in (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.SCALAR_OR_VECTOR):
+        return d
+    assert dim_order is DimOrder.CHANNELS_LAST
+    return [0, 2, 3, 1][d]
+
+
+def flex_name(op_id, dim):
+    # Return the local variable name for the computed flexible size
+    # for a given op and dimension.
+    return f"s_{op_id}_{dim}"
+
+
+class _NnapiSerializer:
+    def __init__(self, config, use_int16_for_qint16=False):
+        self.operands = []
+        self.values = []
+        self.operations = []
+        self.value_data = []
+        self.operation_args = []
+        self.inputs = []
+        self.outputs = []
+        self.flexible_shape_computation_lines = []
+
+        self.modules = {}
+        self.constants = {}
+        self.tensor_sequences = {}
+        self.jitval_operand_map = {}
+        self.cached_immediates = {}
+        self.used_weights = []
+        self.weight_offset = 0
+        self.use_int16_for_qint16 = use_int16_for_qint16
+
+        if config is None:
+            config = {}
+
+    def get_next_operand_id(self):
+        return len(self.operands)
+
+    # Add a tensor operand corresponding to a JIT Value.
+    # Returns the NNAPI operand ID.  Can be looked up later with
+    # get_tensor_operand_by_jitval.
+    def add_tensor_operand(self, jitval, oper):
+        assert isinstance(oper, Operand)
+        if jitval in self.jitval_operand_map:
+            raise Exception(f"Duplicate tensor: {jitval!r}")  # noqa: TRY002
+
+        operand_id = self.get_next_operand_id()
+        self.operands.append(oper)
+        self.jitval_operand_map[jitval] = operand_id
+        return operand_id
+
+    # Add a tensor operand that does not correspond to a JIT Value.
+    # Useful for cases where multiple NNAPI operands are required
+    # to implement one JIT IR node.  Returns the NNAPI operand ID.
+    def add_anonymous_tensor_operand(self, oper):
+        assert isinstance(oper, Operand)
+        operand_id = self.get_next_operand_id()
+        self.operands.append(oper)
+        return operand_id
+
+    def torch_tensor_to_operand(self, tensor, dim_order):
+        dtype = str(tensor.dtype).replace("torch.", "")
+        scale = 0.0
+        zero_point = 0
+        if dtype == "float32":
+            op_type = NNAPI_OperandCode.TENSOR_FLOAT32
+        elif dtype == "int32":
+            op_type = NNAPI_OperandCode.TENSOR_INT32
+        elif dtype == "quint8":
+            op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+            scale = tensor.q_scale()
+            zero_point = tensor.q_zero_point()
+        elif dtype == "qint32":
+            op_type = NNAPI_OperandCode.TENSOR_INT32
+            scale = tensor.q_scale()
+            zero_point = tensor.q_zero_point()
+            assert zero_point == 0
+        elif dtype == "int16":
+            if self.use_int16_for_qint16:
+                nnapi_dtype = getattr(tensor, "nnapi_dtype", None)
+                op_codes = (
+                    NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
+                    NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
+                )
+                if nnapi_dtype in op_codes:
+                    op_type = nnapi_dtype
+                    scale = tensor.nnapi_scale
+                    zero_point = tensor.nnapi_zero_point
+                else:
+                    raise Exception(  # noqa: TRY002
+                        f"`nnapi_type` needs to be one of {op_codes} for `int16`"
+                    )
+            else:
+                raise Exception(  # noqa: TRY002
+                    "`int16` isn't supported. If you're trying to represent NNAPI"
+                    " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
+                )
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Can't handle input with dtype '{tensor.dtype}'"
+            )  # noqa: TRY002
+        return Operand(
+            shape=tuple(tensor.shape),
+            op_type=op_type,
+            dim_order=dim_order,
+            scale=scale,
+            zero_point=zero_point,
+        )
+
+    def add_tensor_operand_for_input(self, arg_idx, jitval, tensor):
+        dim_order = (
+            DimOrder.CHANNELS_LAST
+            if getattr(tensor, "nnapi_nhwc", False)
+            else DimOrder.PRESUMED_CONTIGUOUS
+        )
+        toper = self.torch_tensor_to_operand(tensor, dim_order)
+        operand_id = self.add_tensor_operand(jitval, toper)
+        self.inputs.append(operand_id)
+        for dim, size in enumerate(tensor.shape):
+            if size == 0:
+                self.compute_operand_shape(
+                    operand_id, dim, f"args[{arg_idx}].shape[{dim}]"
+                )
+        return operand_id
+
+    def add_tensor_operand_for_weight(
+        self, tensor, dim_order=DimOrder.UNKNOWN_CONSTANT
+    ):
+        toper = self.torch_tensor_to_operand(tensor, dim_order)
+        operand_id = len(self.operands)
+        self.operands.append(toper)
+        tsize = tensor_size(toper.op_type, toper.shape)
+        self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
+        buf_num = len(self.used_weights)
+        offset = 0
+        self.value_data.append(struct.pack("iii", buf_num, offset, tsize))
+        # For NHWC NNAPI op, lay out data in the same dim order by permuting torch tensor
+        if dim_order == DimOrder.CHANNELS_LAST:
+            tensor = tensor.permute(0, 2, 3, 1)
+        self.used_weights.append(tensor)
+        return operand_id
+
+    def add_immediate_operand(self, code, value, dims):
+        assert isinstance(dims, tuple)
+        cache_key = (code, value)
+        if cache_key not in self.cached_immediates:
+            operand_id = len(self.operands)
+            self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0))
+            self.values.append((operand_id, OperandValueSourceType.IMMEDIATE))
+            self.value_data.append(value)
+            self.cached_immediates[cache_key] = operand_id
+        return self.cached_immediates[cache_key]
+
+    def add_immediate_int_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.INT32, struct.pack("i", value), ()
+        )
+
+    def add_immediate_float_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.FLOAT32, struct.pack("f", value), ()
+        )
+
+    def add_immediate_bool_scalar(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.BOOL, b"\x01" if value else b"\x00", ()
+        )
+
+    def add_immediate_int_vector(self, value):
+        return self.add_immediate_operand(
+            NNAPI_OperandCode.TENSOR_INT32,
+            array.array("i", value).tobytes(),
+            (len(value),),
+        )
+
+    def has_operand_for_jitval(self, jitval):
+        return jitval in self.jitval_operand_map
+
+    def get_tensor_operand_by_jitval(self, jitval):
+        operand_id = self.jitval_operand_map[jitval]
+        return (operand_id, self.operands[operand_id])
+
+    def get_tensor_operand_by_jitval_fixed_size(self, jitval):
+        op_id, oper = self.get_tensor_operand_by_jitval(jitval)
+        for s in oper.shape:
+            if s == 0:
+                # TODO: Improve this error message, possibly after converting
+                # many callsites to support flexible size.
+                raise Exception(  # noqa: TRY002
+                    "Flexible size is not supported for this operand."
+                )  # noqa: TRY002
+            if s < 0:
+                # runtime flex
+                LOG.warning("Operand %s has runtime flex shape", oper)
+        return op_id, oper
+
+    def get_tensor_operand_or_constant(
+        self, jitval, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+    ):
+        operand_id = self.jitval_operand_map.get(jitval)
+        if operand_id is None:
+            _, value = self.get_constant_value(jitval, "TensorType")
+            operand_id = self.add_tensor_operand_for_weight(value, dim_order)
+        return (operand_id, self.operands[operand_id])
+
+    def get_tensor_operand_for_weight(self, jitval):
+        _, value = self.get_constant_value(jitval, "TensorType")
+        operand_id = self.add_tensor_operand_for_weight(value)
+        return (operand_id, self.operands[operand_id])
+
+    def add_operation(self, opcode, inputs, outputs):
+        self.operations.append((opcode, len(inputs), len(outputs)))
+        self.operation_args.extend(inputs + outputs)
+
+    def add_tensor_sequence(self, jitval, values):
+        assert jitval not in self.tensor_sequences
+        self.tensor_sequences[jitval] = values
+
+    def add_constant_value(self, jitval, ctype, value):
+        assert jitval not in self.constants
+        self.constants[jitval] = (ctype, value)
+
+    def get_constant_value(self, jitval, typekind=None):
+        record = self.constants.get(jitval)
+        if record is None:
+            raise Exception(  # noqa: TRY002
+                f"Could not find constant value for '{jitval!r}'."
+            )  # noqa: TRY002
+        ctype, _ = record
+        if typekind is not None and ctype.kind() != typekind:
+            raise Exception(  # noqa: TRY002
+                f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'"
+            )
+        return record
+
+    def operand_to_template_torchscript(self, op_id, oper, shape=None):
+        """Return a TorchScript expression to build a template for a given operand."""
+        if shape is None:
+            shape = oper.shape
+        else:
+            assert len(shape) == len(oper.shape)
+
+        shape_parts = ["("]
+        for d, s in enumerate(shape):
+            if s > 0:
+                # Fixed shape dimension: just add the value.
+                shape_parts.append(str(s))
+            elif s == 0:
+                # Load time flexible shape dimension: it should have been computed in a variable.
+                shape_parts.append(flex_name(op_id, d))
+            elif s == -1:
+                # Runtime flexible shape
+                shape_parts.append("0")
+            else:
+                raise Exception(  # noqa: TRY002
+                    "Unknown dim value, dimensions should be >= -1"
+                )  # noqa: TRY002
+            shape_parts.append(",")
+        shape_parts.append(")")
+        shape_code = "".join(shape_parts)
+        if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
+            return f"torch.zeros({shape_code}, dtype=torch.float32)"
+        elif oper.op_type == NNAPI_OperandCode.TENSOR_INT32:
+            return f"torch.zeros({shape_code}, dtype=torch.int32)"
+        elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+            return (
+                f"torch.quantize_per_tensor("
+                f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)"
+                f".expand({shape_code}).contiguous()"
+            )
+        elif oper.op_type in (
+            NNAPI_OperandCode.TENSOR_QUANT16_ASYMM,
+            NNAPI_OperandCode.TENSOR_QUANT16_SYMM,
+        ):
+            if self.use_int16_for_qint16:
+                return f"torch.zeros({shape_code}, dtype=torch.int16)"
+            else:
+                raise Exception(  # noqa: TRY002
+                    "`int16` isn't supported. If you're trying to represent NNAPI"
+                    " qint16 with Pytorch int16, set `use_int16_for_qint16 = True`"
+                )
+
+        raise Exception(  # noqa: TRY002
+            f"Unsupported output operand type: {oper.op_type}"
+        )  # noqa: TRY002
+
+    def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim):
+        self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim))
+
+    def compute_operand_shape(self, op_id, dim, expr):
+        self.flexible_shape_computation_lines.append(
+            f"{flex_name(op_id, dim)} = {expr}"
+        )
+
+    def transpose_to_nhwc(self, in_id, oper):
+        if oper.shape[2:] != (1, 1):
+            raise Exception(  # noqa: TRY002
+                "Automatic transpose only supported for H,W == 1,1"
+            )  # noqa: TRY002
+
+        out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
+
+        outputs = [None] * 1
+        outputs[0] = self.add_anonymous_tensor_operand(out_oper)
+
+        self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
+
+        return outputs[0], out_oper
+
+    # Transpose inputs as necessary to allow broadcasting.
+    def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
+        if in0_oper.dim_order == in1_oper.dim_order:
+            return in0_id, in0_oper, in1_id, in1_oper
+
+        # Assume NHWC is preferred if there is a mismatch.
+        orders = (in0_oper.dim_order, in1_oper.dim_order)
+        if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
+            return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
+        if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
+            return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
+
+        raise Exception(  # noqa: TRY002
+            f"Automatic transpose not supported for dim_orders: {in0_oper.dim_order!r}, {in1_oper.dim_order!r}"
+        )
+
+    def get_size_arg(self, jitval):
+        ctype, value = self.get_constant_value(jitval)
+        if ctype.kind() == "ListType":
+            assert ctype.getElementType().kind() == "IntType"
+            return value
+        raise Exception(  # noqa: TRY002
+            f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'"
+        )  # noqa: TRY002
+
+    def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config):
+        pc = [i.item() for i in packed_config]
+        assert pc[0] == 2
+        strides = [pc[1], pc[2]]
+        paddings = [pc[3], pc[4]]
+        dilations = [pc[5], pc[6]]
+        output_padding = [pc[7], pc[8]]
+        group_num = pc[9]
+
+        assert len(pc) == 11
+        assert output_padding == [0, 0]
+
+        return self.get_conv_pool_args_2d_common(
+            kernel_size, strides, paddings, dilations, group_num
+        )
+
+    def get_conv_pool_args_2d_from_jit(
+        self, kernel_size, stride, padding, dilation=None, group=None
+    ):
+        strides = self.get_size_arg(stride)
+        paddings = self.get_size_arg(padding)
+        if dilation is None:
+            dilations = [1, 1]
+        else:
+            dilations = self.get_size_arg(dilation)
+        if group is not None:
+            _, group_num = self.get_constant_value(group, "IntType")
+        else:
+            group_num = None
+        return self.get_conv_pool_args_2d_common(
+            kernel_size, strides, paddings, dilations, group_num
+        )
+
+    def get_conv_pool_args_2d_common(
+        self, kernel_size, strides, paddings, dilations, group_num
+    ):
+        kernels = list(kernel_size)
+
+        assert len(kernels) == 2
+        assert len(strides) == 2
+        assert len(paddings) == 2
+        assert len(dilations) == 2
+
+        # NNAPI uses 4 values for padding.
+        ph, pw = paddings
+        real_paddings = [ph, ph, pw, pw]
+
+        return ConvPoolArgs2d(
+            *(kernels + strides + real_paddings + dilations + [group_num])
+        )
+
+    def serialize_model(self, model, inputs, return_shapes=None):
+        self.add_immediate_bool_scalar(False)
+        self.add_immediate_bool_scalar(True)
+
+        inp_dim_orders = []
+        out_dim_orders = []
+
+        self_jitval = next(model.graph.inputs())
+        self.add_constant_value(self_jitval, self_jitval.type(), model)
+
+        for arg_idx, (input_value, input_tensor) in enumerate(
+            zip(list(model.graph.inputs())[1:], inputs)
+        ):
+            op_id = self.add_tensor_operand_for_input(
+                arg_idx, input_value, input_tensor
+            )
+            inp_dim_orders.append(self.operands[op_id].dim_order.value)
+
+        for idx, node in enumerate(model.graph.nodes()):
+            LOG.debug("Processing node #%d: %r", idx, node)
+            self.add_node(node)
+
+        retn = model.graph.return_node()
+        assert retn.inputsSize() == 1
+        assert retn.outputsSize() == 0
+        retn_input = retn.inputsAt(0)
+        template_return_lines = ["return ["]
+        if retn_input.type().kind() == "TensorType":
+            return_values = [retn_input]
+            retval_count = -1
+        elif retn_input.type().kind() == "TupleType":
+            return_values = self.tensor_sequences[retn_input]
+            retval_count = len(return_values)
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Unsupported return type: {retn_input.type()}"
+            )  # noqa: TRY002
+
+        if return_shapes is not None:
+            assert len(return_shapes) == len(return_values)
+        for i, v in enumerate(return_values):
+            op_id = self.jitval_operand_map[v]
+            self.outputs.append(op_id)
+            out_dim_orders.append(self.operands[op_id].dim_order.value)
+            shape = return_shapes[i] if return_shapes else None
+            template_return_lines.append(
+                self.operand_to_template_torchscript(op_id, self.operands[op_id], shape)
+                + ","
+            )
+        template_return_lines.append("]")
+
+        model = []
+
+        version = 1
+        header = struct.pack(
+            "iiiiii",
+            version,
+            len(self.operands),
+            len(self.values),
+            len(self.operations),
+            len(self.inputs),
+            len(self.outputs),
+        )
+        model.append(header)
+
+        serialized_values, serialized_value_data = self.serialize_values()
+
+        model.extend(
+            struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands
+        )
+        model.extend(serialized_values)
+        model.extend(struct.pack("iii", *x) for x in self.operations)
+
+        # Compact the model so we can get its length so far.
+        model = [b"".join(model)]
+        model_offset = len(model[0])
+        # Model offset is the index into the model (in 32-bit words, not bytes)
+        # of the next dimension we're about to serialize.  If it's 0,
+        # generate code to mutate it before passing to NNAPI.
+        assert model_offset % 4 == 0
+        model_offset = int(model_offset / 4)
+
+        for op_id, (_, dims, dim_order, _, _) in enumerate(self.operands):
+            shape = fix_shape(dims, dim_order)
+            for d, s in enumerate(shape):
+                if s == 0:
+                    pt_d = reverse_map_dim(dim_order, d)
+                    self.flexible_shape_computation_lines.append(
+                        f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}"
+                    )
+                model_offset += 1
+
+            # convert runtime flex shape from -1 to 0
+            shape = tuple(d if d != -1 else 0 for d in shape)
+            model.append(self.serialize_ints(shape))
+
+        model.extend(serialized_value_data)
+        model.append(self.serialize_ints(self.operation_args))
+        model.append(self.serialize_ints(self.inputs))
+        model.append(self.serialize_ints(self.outputs))
+
+        self.flexible_shape_computation_lines.extend(template_return_lines)
+
+        return (
+            array.array("i", b"".join(model)),
+            self.used_weights,
+            inp_dim_orders,
+            out_dim_orders,
+            self.flexible_shape_computation_lines,
+            retval_count,
+        )
+
+    def serialize_values(self):
+        serialized_values = []
+        serialized_value_data = []
+        assert len(self.values) == len(self.value_data)
+        for (op_index, source_type), data in zip(self.values, self.value_data):
+            source_length = len(data)
+
+            # Pad with 0 bytes out to a multiple of 4 for alignment.
+            physical_length = ((source_length - 1) | 0x3) + 1
+            padded_data = data + (b"\0" * (physical_length - source_length))
+
+            serialized_values.append(
+                struct.pack("iii", op_index, source_type, source_length)
+            )
+            serialized_value_data.append(padded_data)
+
+        return serialized_values, serialized_value_data
+
+    @staticmethod
+    def serialize_ints(ints):
+        return array.array("i", ints).tobytes()
+
+    ADDER_MAP = {
+        "prim::GetAttr": lambda self, node: self.add_getattr(node),
+        "prim::Constant": lambda self, node: self.add_constant_node(node),
+        "prim::ListConstruct": lambda self, node: self.add_list_construct(node),
+        "prim::TupleConstruct": lambda self, node: self.add_tuple_construct(node),
+        "aten::unsqueeze": lambda self, node: self.add_unsqueeze(node),
+        "aten::to": lambda self, node: self.add_to(node),
+        "aten::detach": lambda self, node: self._identity(node),
+        "aten::reshape": lambda self, node: self.add_reshape(node),
+        "aten::flatten": lambda self, node: self.add_flatten(node),
+        "aten::slice": lambda self, node: self.add_slice(node),
+        "aten::size": lambda self, node: self.add_size(node),
+        "aten::cat": lambda self, node: self.add_cat(node),
+        "aten::mean": lambda self, node: self.add_mean(node),
+        "aten::quantize_per_tensor": lambda self, node: self.add_quantize(node),
+        "aten::dequantize": lambda self, node: self.add_dequantize(node),
+        "aten::add": lambda self, node: self.add_add_sub_op(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::sub": lambda self, node: self.add_add_sub_op(
+            node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::mul": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
+            node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::div": lambda self, node: self.add_pointwise_simple_binary_broadcast_op(
+            node, NNAPI_OperationCode.DIV, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "aten::relu": lambda self, node: self.add_pointwise_simple_unary_op(
+            node, NNAPI_OperationCode.RELU
+        ),
+        "aten::sigmoid": lambda self, node: self.add_pointwise_simple_unary_op(
+            node, NNAPI_OperationCode.LOGISTIC
+        ),
+        "aten::softmax": lambda self, node: self.add_softmax(node),
+        "aten::hardtanh": lambda self, node: self.add_hardtanh(node),
+        "aten::avg_pool2d": lambda self, node: self.add_avg_pool2d(node),
+        "aten::max_pool2d": lambda self, node: self.add_pool2d_node(
+            node, NNAPI_OperationCode.MAX_POOL_2D
+        ),
+        "aten::adaptive_avg_pool2d": lambda self, node: self.add_adaptive_avg_pool2d(
+            node
+        ),
+        "aten::upsample_nearest2d": lambda self, node: self.add_upsample_nearest2d(
+            node
+        ),
+        "aten::prelu": lambda self, node: self.add_prelu_op(node),
+        "aten::addmm": lambda self, node: self.add_addmm(node),
+        "aten::linear": lambda self, node: self.add_linear(node),
+        "aten::_convolution": lambda self, node: self.add_conv_underscore(node),
+        "aten::conv2d": lambda self, node: self.add_conv2d(node),
+        "aten::log_softmax": lambda self, node: self.add_log_softmax(node),
+        "quantized::linear": lambda self, node: self.add_qlinear(node),
+        "quantized::conv2d": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "quantized::conv2d_relu": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_RELU
+        ),
+        "quantized::conv_transpose2d": lambda self, node: self.add_qconv2d(
+            node, NNAPI_FuseCode.FUSED_NONE, transpose=True
+        ),
+        "quantized::add": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE
+        ),
+        "quantized::add_relu": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU
+        ),
+        "quantized::mul": lambda self, node: self.add_qadd(
+            node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE
+        ),
+    }
+
+    def add_node(self, node):
+        adder = self.ADDER_MAP.get(node.kind())
+        if not adder:
+            raise Exception(  # noqa: TRY002
+                f"Unsupported node kind ({node.kind()!r}) in node {node!r}"
+            )  # noqa: TRY002
+        adder(self, node)
+
+    def _identity(self, node):
+        in_id, _in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        jitval = node.outputsAt(0)
+        self.jitval_operand_map[jitval] = in_id
+
+    def add_getattr(self, node):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+        obj_ctype, obj = self.get_constant_value(node.inputsAt(0))
+        assert str(obj_ctype).startswith("__torch__.")
+        name = node.s("name")
+        value = getattr(obj, name)
+        output = node.outputsAt(0)
+        ctype = output.type()
+        self.add_constant_value(output, ctype, value)
+
+    def add_constant_node(self, node):
+        assert node.inputsSize() == 0
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        ctype = output.type()
+        value = output.toIValue()
+        self.add_constant_value(output, ctype, value)
+
+    def add_list_construct(self, node):
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        ctype = output.type()
+        const_vals: Optional[list] = []
+        tensors: Optional[list] = []
+        for inp in node.inputs():
+            if const_vals is not None and inp in self.constants:
+                _, val = self.get_constant_value(inp)
+                const_vals.append(val)
+            else:
+                const_vals = None
+            if tensors is not None and inp.type().kind() == "TensorType":
+                tensors.append(inp)
+            else:
+                tensors = None
+
+        if const_vals is not None:
+            # NOTE: Now that TorchScript supports list constants,
+            # this code path might not be used anymore.
+            self.add_constant_value(output, ctype, const_vals)
+        if tensors is not None:
+            self.add_tensor_sequence(output, tensors)
+        if const_vals is None and tensors is None:
+            raise Exception(  # noqa: TRY002
+                f"Unable to handle ListConstruct node.  Neither all constants nor all tensors. {node!r}"
+            )
+
+    def add_tuple_construct(self, node):
+        assert node.outputsSize() == 1
+        output = node.outputsAt(0)
+        values = list(node.inputs())
+        self.add_tensor_sequence(output, values)
+
+    def add_unsqueeze(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+
+        _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
+        assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS
+
+        real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1
+        out_shape_list = list(in_oper.shape)
+        out_shape_list.insert(real_dim, 1)
+        out_shape = tuple(out_shape_list)
+        out_oper = in_oper._replace(shape=out_shape)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_scalar(dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs)
+
+    def add_to(self, node):
+        # Handle to("cpu") / to("gpu") case
+        self._identity(node)
+
+    def add_reshape(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+
+        shape_ctype, shape = self.get_constant_value(node.inputsAt(1))
+        assert shape_ctype.kind() == "ListType"
+        assert shape_ctype.getElementType().kind() == "IntType"
+        is_trivial_reshape = len(shape) == 2 and shape[1] == -1
+
+        if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape:
+            raise Exception(  # noqa: TRY002
+                "Currently, reshape is only supported on NHWC tensors if the target size is [X, -1]."
+            )
+
+        # Bit of a hack here.  Use a real tensor to infer the output shape.
+        out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape
+        out_oper = in_oper._replace(
+            shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+        )
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(shape)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
+
+    def add_flatten(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        _start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType")
+        _end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType")
+
+        # channels last with channels == 1 or (height & width both 1)
+        is_trivial_flatten = len(in_oper.shape) == 4 and (
+            in_oper.shape[1] == 1 or (in_oper.shape[2] == 1 and in_oper.shape[3] == 1)
+        )
+        if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_flatten:
+            raise Exception(  # noqa: TRY002
+                "Currently, flatten is not supported on NHWC tensors unless C=1 or H=W=1"
+            )
+
+        if start_dim < 0:
+            start_dim += len(in_oper.shape)
+        if end_dim < 0:
+            end_dim += len(in_oper.shape)
+
+        out_shape = (
+            in_oper.shape[:start_dim]
+            + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),)
+            + in_oper.shape[end_dim + 1 :]
+        )
+
+        if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]):
+            raise Exception(  # noqa: TRY002
+                "Flattening flexible dims is not supported yet"
+            )  # noqa: TRY002
+        non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :]
+        if non_flattened_dims.count(0) > 1:
+            raise Exception("Only 1 dim can be flexible")  # noqa: TRY002
+
+        out_oper = in_oper._replace(
+            shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS
+        )
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        for idx, dim in enumerate(out_shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, in_oper.shape.index(0))
+
+        inputs_1 = tuple(dim if dim != 0 else -1 for dim in out_shape)
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(inputs_1)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
+
+    def add_slice(self, node):
+        assert node.inputsSize() == 5
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        _, dim_value = self.get_constant_value(node.inputsAt(1))
+        _, start_value = self.get_constant_value(node.inputsAt(2))
+        _, stop_value = self.get_constant_value(node.inputsAt(3))
+        _, step_value = self.get_constant_value(node.inputsAt(4))
+
+        if start_value is None:
+            start_value = 0
+        if stop_value is None:
+            stop_value = sys.maxsize
+
+        if start_value < 0:
+            start_value += in_oper.shape[dim_value]
+        elif start_value == sys.maxsize:
+            start_value = 0
+
+        if start_value == 0 and stop_value == sys.maxsize:
+            self._identity(node)
+            return
+
+        if in_oper.shape[dim_value] == 0:
+            raise Exception("Unable to slice with flexible shape")  # noqa: TRY002
+
+        if stop_value < 0:
+            stop_value += in_oper.shape[dim_value]
+        elif stop_value == sys.maxsize:
+            stop_value = in_oper.shape[dim_value]
+
+        if start_value >= stop_value:
+            raise Exception(  # noqa: TRY002
+                "Slice start value should be less than stop value"
+            )  # noqa: TRY002
+
+        out_len = (stop_value - start_value) // step_value
+        out_shape = tuple(
+            out_len if i == dim_value else dim for i, dim in enumerate(in_oper.shape)
+        )
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), in_oper._replace(shape=out_shape)
+        )
+
+        # flex inputs
+        end_mask = 0
+        for idx, dim in enumerate(out_shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, idx)
+                end_mask |= 1 << idx
+
+        inputs = [None] * 7
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(
+            [start_value if i == dim_value else 0 for i in range(len(in_oper.shape))]
+        )
+        inputs[2] = self.add_immediate_int_vector(
+            [
+                stop_value if i == dim_value else dim
+                for i, dim in enumerate(in_oper.shape)
+            ]
+        )
+        inputs[3] = self.add_immediate_int_vector(
+            [step_value if i == dim_value else 1 for i in range(len(in_oper.shape))]
+        )
+        inputs[4] = self.add_immediate_int_scalar(0)  # begin mask
+        inputs[5] = self.add_immediate_int_scalar(end_mask)
+        inputs[6] = self.add_immediate_int_scalar(0)  # shrink axis mas
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.STRIDED_SLICE, inputs, outputs)
+
+    def add_size(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        _, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        _, value = self.constants[node.inputsAt(1)]
+        res = in_oper.shape[value]
+        output = node.outputsAt(0)
+        self.add_constant_value(output, output.type(), res)
+
+    def add_cat(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        tensors = self.tensor_sequences[node.inputsAt(0)]
+        _, dim = self.get_constant_value(node.inputsAt(1), "IntType")
+
+        assert len(tensors) > 0
+        in_ids = []
+        out_oper = None
+        out_dim_size = 0
+        for inp in tensors:
+            in_id, in_oper = self.get_tensor_operand_by_jitval(inp)
+            if out_oper is None:
+                out_shape = change_element(in_oper.shape, dim, -1)
+                out_oper = in_oper._replace(shape=out_shape)
+            assert in_oper.op_type == out_oper.op_type
+            assert in_oper.dim_order == out_oper.dim_order
+            assert change_element(in_oper.shape, dim, -1) == change_element(
+                out_oper.shape, dim, -1
+            )
+            # TODO: Possibly check scale and zero point.
+            in_ids.append(in_id)
+            # TODO: Possibly support variable-sized inputs.
+            out_dim_size += in_oper.shape[dim]
+
+        assert out_oper is not None
+        out_oper = out_oper._replace(
+            shape=change_element(out_oper.shape, dim, out_dim_size)
+        )
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST:  # type: ignore[possibly-undefined]
+            assert len(out_oper.shape) == 4
+            nnapi_dim = [0, 3, 1, 2][dim]
+        else:
+            nnapi_dim = dim
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+        for idx, d in enumerate(out_oper.shape):
+            if d == 0:
+                if idx == dim:
+                    shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids)
+                    self.compute_operand_shape(out_id, idx, shape)
+                else:
+                    self.forward_operand_shape(out_id, idx, in_ids[0], idx)
+
+        inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)
+
+    def add_mean(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        dim_ctype, dim = self.get_constant_value(node.inputsAt(1))
+        assert dim_ctype.kind() == "ListType"
+        assert dim_ctype.getElementType().kind() == "IntType"
+        _, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType")
+        # Expect None for dtype
+        self.get_constant_value(node.inputsAt(3), "NoneType")
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST:
+            assert len(in_oper.shape) == 4
+            nnapi_dim = [[0, 3, 1, 2][d] for d in dim]
+        else:
+            nnapi_dim = dim
+
+        collapsed_dims = set()
+        for d in dim:
+            if d < 0:
+                d += len(in_oper.shape)
+            collapsed_dims.add(d)
+
+        if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim:
+            assert collapsed_dims.issuperset({2, 3})
+            out_dim_order = DimOrder.PRESUMED_CONTIGUOUS
+        else:
+            out_dim_order = in_oper.dim_order
+
+        out_shape = []
+        for i, s in enumerate(in_oper.shape):
+            if i not in collapsed_dims:
+                out_shape.append(s)
+            elif keep_dim:
+                out_shape.append(1)
+
+        out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order)
+
+        inputs = [None] * 3
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_int_vector(nnapi_dim)
+        inputs[2] = self.add_immediate_int_scalar(keep_dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs)
+
+    def add_quantize(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        if in_oper.dim_order != DimOrder.CHANNELS_LAST:
+            raise Exception(  # noqa: TRY002
+                "Most hardware backends prefer NHWC quantized tensors.  "
+                "Try setting `t.nnapi_nhwc = True` on your tensor inputs.  "
+            )
+        _, scale = self.get_constant_value(node.inputsAt(1), "FloatType")
+        _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType")
+        _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType")
+        if scalar_type != TorchScalarTypes.QUINT8.value:
+            raise Exception(  # noqa: TRY002
+                "PyTorch NNAPI export only supports quantized tensors "
+                "with the quint8 dtype."
+            )
+        op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+
+        out_oper = in_oper._replace(
+            op_type=op_type,
+            scale=scale,
+            zero_point=zero_point,
+        )
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs)
+
+    def add_dequantize(self, node):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        out_oper = in_oper._replace(
+            op_type=NNAPI_OperandCode.TENSOR_FLOAT32,
+            scale=0.0,
+            zero_point=0,
+        )
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs)
+
+    def add_pointwise_simple_unary_op(self, node, opcode):
+        assert node.inputsSize() == 1
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        out_oper = in_oper
+        if opcode == NNAPI_OperationCode.LOGISTIC:
+            # NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale
+            # must be 1.f / 256 and the zeroPoint must be 0.
+            # https://fburl.com/h52stoog
+            if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+                out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256)
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        for idx, dim in enumerate(in_oper.shape):
+            if dim == 0:
+                self.forward_operand_shape(out_id, idx, in_id, idx)
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None):  # noqa: D401
+        """Helper for pointwise binary broadcast ops with superfluous extra args."""
+        assert node.outputsSize() == 1
+
+        assert node.inputsAt(0).type().kind() == "TensorType"
+        assert node.inputsAt(1).type().kind() == "TensorType"
+
+        if self.has_operand_for_jitval(node.inputsAt(0)):
+            in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+            in1_id, in1_oper = self.get_tensor_operand_or_constant(
+                node.inputsAt(1), in0_oper.dim_order
+            )
+        elif self.has_operand_for_jitval(node.inputsAt(1)):
+            in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
+            in0_id, in0_oper = self.get_tensor_operand_or_constant(
+                node.inputsAt(0), in1_oper.dim_order
+            )
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Can't do a NNAPI binary op: {opcode} on two constants"
+            )  # noqa: TRY002
+
+        assert in0_oper.op_type == in1_oper.op_type
+        in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
+            in0_id, in0_oper, in1_id, in1_oper
+        )
+        # NOTE: PyTorch and NNAPI have the same broadcast semantics.
+        out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
+        out_oper = in0_oper._replace(shape=out_shape)
+        if qparams is not None:
+            scale, zp = qparams
+            out_oper = out_oper._replace(scale=scale, zero_point=zp)
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
+        for idx, (d0, d1) in enumerate(zip(in0_oper.shape, in1_oper.shape)):
+            if d0 == 1 and d1 == 0:
+                self.forward_operand_shape(out_id, idx, in1_id, idx)
+            elif d0 == 0 and d1 == 1:
+                self.forward_operand_shape(out_id, idx, in0_id, idx)
+            elif d0 == 0 and d1 == 0:
+                self.flexible_shape_computation_lines.append(
+                    f"assert {flex_name(in0_id, idx)} == {flex_name(in1_id, idx)}"
+                )
+                self.forward_operand_shape(out_id, idx, in0_id, idx)
+
+        inputs = [None] * 3
+        inputs[0] = in0_id
+        inputs[1] = in1_id
+        inputs[2] = self.add_immediate_int_scalar(fuse_code)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 2
+        self._do_add_binary(node, opcode, fuse_code)
+
+    def add_add_sub_op(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 3
+
+        _, alpha = self.get_constant_value(node.inputsAt(2), "IntType")
+        if alpha != 1:
+            raise Exception(  # noqa: TRY002
+                "NNAPI does not support add/sub with alpha."
+            )  # noqa: TRY002
+
+        self._do_add_binary(node, opcode, fuse_code)
+
+    def add_qadd(self, node, opcode, fuse_code):
+        assert node.inputsSize() == 4
+
+        _, scale = self.get_constant_value(node.inputsAt(2), "FloatType")
+        _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType")
+
+        self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point))
+
+    def add_softmax(self, node):
+        assert node.inputsSize() == 3
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+
+        _, softmax_dim = self.get_constant_value(node.inputsAt(1), "IntType")
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
+        for dim, size in enumerate(in_oper.shape):
+            if size == 0:
+                self.forward_operand_shape(out_id, dim, in_id, dim)
+
+        inputs = [None] * 3
+        inputs[0] = in_id
+        inputs[1] = self.add_immediate_float_scalar(
+            1.0
+        )  # positive scaling factor of exponent, beta
+        inputs[2] = self.add_immediate_int_scalar(softmax_dim)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.SOFTMAX, inputs, outputs)
+
+    def add_hardtanh(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
+        _, min_val = self.get_constant_value(node.inputsAt(1), "FloatType")
+        _, max_val = self.get_constant_value(node.inputsAt(2), "FloatType")
+
+        op_map = {
+            (-1, 1): NNAPI_OperationCode.RELU1,
+            (0, 6): NNAPI_OperationCode.RELU6,  # noqa: E201
+        }
+
+        opcode = op_map.get((min_val, max_val))
+        if opcode is None:
+            raise Exception(  # noqa: TRY002
+                "NNAPI only supports hardtanh with args (-1, 1) or (0, 6)."
+            )  # noqa: TRY002
+
+        inputs = [None] * 1
+        inputs[0] = in_id
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper)
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_prelu_op(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        assert node.inputsAt(0).type().kind() == "TensorType"
+        assert node.inputsAt(1).type().kind() == "TensorType"
+
+        in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
+        w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1))
+        assert len(w_oper.shape) == 1
+        assert w_oper.shape[0] > 0
+        if w_oper.shape[0] > 1:
+            if in_oper.use_nchw():
+                # TODO: Support this by adding trailing 1 dims.
+                raise Exception(  # noqa: TRY002
+                    "Per-channel PReLU only supports channels_last right now."
+                )
+
+        out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
+        for dim, size in enumerate(in_oper.shape):
+            if size > 0:
+                pass
+            elif dim <= 1:
+                raise Exception(  # noqa: TRY002
+                    "PReLU requires fixed size for dim 0 and dim 1."
+                )  # noqa: TRY002
+            else:
+                self.forward_operand_shape(out_id, dim, in_id, dim)
+
+        inputs = [None] * 2
+        inputs[0] = in_id
+        inputs[1] = w_id
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs)
+
+    def add_pool2d_node(self, node, opcode):
+        assert node.inputsSize() == 6
+        assert node.outputsSize() == 1
+        image, kernel, stride, padding, dilation, _ceil_mode = node.inputs()
+
+        stride = stride or kernel
+
+        # TODO: Validate ceil_mode semantics.
+
+        args = self.get_conv_pool_args_2d_from_jit(
+            self.get_size_arg(kernel), stride, padding, dilation
+        )
+        if args.dilation_h != 1 or args.dilation_w != 1:
+            raise Exception("NNAPI does not support dilated pooling.")  # noqa: TRY002
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image)
+        assert len(image_oper.shape) == 4
+
+        out_shape = get_conv_pool_shape(
+            image_oper.shape, args, image_oper.shape[1], False
+        )
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[2] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[3] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[5] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[6] = self.add_immediate_int_scalar(args.stride_h)
+        inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
+        inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        self.add_operation(opcode, inputs, outputs)
+
+    def add_avg_pool2d(self, node):
+        assert node.inputsSize() == 7
+        assert node.outputsSize() == 1
+        (
+            image,
+            kernel,
+            stride,
+            padding,
+            _ceil_mode,
+            count_include_pad,
+            divisor_override,
+        ) = node.inputs()
+
+        _, count_include_pad_value = self.get_constant_value(count_include_pad)
+        _, divisor_override_value = self.get_constant_value(divisor_override)
+        if not count_include_pad_value or divisor_override_value:
+            raise Exception(  # noqa: TRY002
+                "NNAPI doesn't support count_include_pad=False or divisor_override"
+            )
+
+        args = self.get_conv_pool_args_2d_from_jit(
+            self.get_size_arg(kernel), stride, padding
+        )
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval(image)
+        assert len(image_oper.shape) == 4
+
+        out_shape = get_conv_pool_shape(
+            image_oper.shape, args, image_oper.shape[1], False
+        )
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[2] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[3] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[5] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[6] = self.add_immediate_int_scalar(args.stride_h)
+        inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
+        inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+        self._handle_conv_pool_flexible_input(out_id, image, args, False)
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
+
+    def add_adaptive_avg_pool2d(self, node):
+        assert node.inputsSize() == 2
+        assert node.outputsSize() == 1
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(
+            node.inputsAt(0)
+        )
+        assert len(image_oper.shape) == 4
+
+        size_ctype, size_arg = self.get_constant_value(node.inputsAt(1))
+        assert size_ctype.kind() == "ListType"
+        assert size_ctype.getElementType().kind() == "IntType"
+        if size_arg != [1, 1]:
+            raise Exception(  # noqa: TRY002
+                "NNAPI only supports adaptive_avg_pool2d with output size (1, 1)."
+            )
+
+        out_shape = image_oper.shape[0:2] + tuple(size_arg)
+        use_nchw = image_oper.use_nchw()
+
+        inputs = [None] * 11
+        inputs[0] = image_id
+        inputs[1] = self.add_immediate_int_scalar(0)
+        inputs[2] = self.add_immediate_int_scalar(0)
+        inputs[3] = self.add_immediate_int_scalar(0)
+        inputs[4] = self.add_immediate_int_scalar(0)
+        inputs[5] = self.add_immediate_int_scalar(1)
+        inputs[6] = self.add_immediate_int_scalar(1)
+        inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3])
+        inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2])
+        inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+        inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
+
+    def add_upsample_nearest2d(self, node):
+        assert node.inputsSize() == 3 or node.inputsSize() == 4
+        assert node.outputsSize() == 1
+        if node.inputsSize() == 3:
+            image, size_jit, scale_jit = node.inputs()
+        else:
+            image, size_jit, scale_h_jit, scale_w_jit = node.inputs()
+        size_ctype, size_arg = self.get_constant_value(size_jit)
+
+        if node.inputsSize() == 3:
+            scale_ctype, scale_arg = self.get_constant_value(scale_jit)  # type: ignore[possibly-undefined]
+        else:
+            scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit)  # type: ignore[possibly-undefined]
+            scale_w_ctype, _scale_w_arg = self.get_constant_value(scale_w_jit)  # type: ignore[possibly-undefined]
+
+            # The only way for the 4-argument overload of upsample_nearest2d to
+            # have been added to the graph without error is if the scale_h and
+            # scale_w arguments are None
+            assert scale_h_ctype.kind() == "NoneType"
+            assert scale_w_ctype.kind() == "NoneType"
+
+            scale_ctype = scale_h_ctype
+            scale_arg = scale_h_arg
+
+        image_id, image_oper = self.get_tensor_operand_by_jitval(image)
+        assert len(image_oper.shape) == 4
+
+        if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType":
+            raise Exception("Size and scale cannot both be non-None.")  # noqa: TRY002
+        elif size_ctype.kind() != "NoneType":
+            assert size_ctype.kind() == "ListType"
+            assert size_ctype.getElementType().kind() == "IntType"
+            assert scale_ctype.kind() == "NoneType"
+            assert scale_arg is None
+            assert isinstance(size_arg, list)
+            assert size_arg
+            assert all(isinstance(val, int) for val in size_arg)
+            if len(size_arg) == 1:
+                size_arg = size_arg * 2
+            assert len(size_arg) == 2
+            out_h = size_arg[0]
+            out_w = size_arg[1]
+            arg_h = self.add_immediate_int_scalar(out_h)
+            arg_w = self.add_immediate_int_scalar(out_w)
+        elif scale_ctype.kind() != "NoneType":
+            assert scale_ctype.kind() == "ListType"
+            assert scale_ctype.getElementType().kind() == "FloatType"
+            assert size_ctype.kind() == "NoneType"
+            assert size_arg is None
+            assert isinstance(scale_arg, list)
+            assert scale_arg
+            assert all(isinstance(val, float) for val in scale_arg)
+            if len(scale_arg) == 1:
+                scale_arg = scale_arg * 2
+            assert len(scale_arg) == 2
+            out_h = int(scale_arg[0] * image_oper.shape[2])
+            out_w = int(scale_arg[1] * image_oper.shape[3])
+            arg_h = self.add_immediate_float_scalar(scale_arg[0])
+            arg_w = self.add_immediate_float_scalar(scale_arg[1])
+        else:
+            raise Exception("Size and scale cannot both be None.")  # noqa: TRY002
+
+        out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w)
+        use_nchw = image_oper.use_nchw()
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), image_oper._replace(shape=out_shape)
+        )
+
+        if image_oper.shape[0] == 0 or image_oper.shape[1] == 0:
+            raise Exception("Flexible batch or channels not supported")  # noqa: TRY002
+
+        # Handle variable input size
+        for dim in (2, 3):  # h, w indices
+            if image_oper.shape[dim] == 0:
+                if size_ctype.kind() != "NoneType":
+                    self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
+                elif scale_ctype.kind() != "NoneType":
+                    self.compute_operand_shape(
+                        out_id,
+                        dim,
+                        f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
+                    )
+                else:
+                    raise Exception(  # noqa: TRY002
+                        "Size and scale cannot both be None."
+                    )  # noqa: TRY002
+
+        inputs = [None] * 4
+        inputs[0] = image_id
+        inputs[1] = arg_w
+        inputs[2] = arg_h
+        inputs[3] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs)
+
+    def add_addmm(self, node):
+        assert node.inputsSize() == 5
+        assert node.outputsSize() == 1
+        jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs()
+
+        for jitval in (jit_beta, jit_alpha):
+            scale_ctype, scale_value = self.get_constant_value(jitval)
+            assert scale_ctype.kind() in ("IntType", "FloatType")
+            if scale_value != 1:
+                raise Exception(  # noqa: TRY002
+                    "NNAPI Fully-Connected does not support alpha and beta."
+                )
+
+        self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias)
+
+    def add_linear(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+        jit_input, jit_weight, jit_bias = node.inputs()
+
+        self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias)
+
+    def add_addmm_or_linear(
+        self, node, transpose_weight, jit_input, jit_weight, jit_bias
+    ):
+        input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input)
+        bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias)
+
+        assert len(input_oper.shape) == 2
+        assert len(bias_oper.shape) == 1
+
+        # TODO: Transform at load time to share weights with CPU model.
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        assert len(weight_tensor.shape) == 2
+        if transpose_weight:
+            nnapi_weight_tensor = weight_tensor.t().contiguous()
+        else:
+            nnapi_weight_tensor = weight_tensor.contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        out_shape = (input_oper.shape[0], weight_oper.shape[0])
+        out_id = self.add_tensor_operand(
+            node.outputsAt(0), input_oper._replace(shape=out_shape)
+        )
+
+        if input_oper.shape[0] == 0:
+            self.forward_operand_shape(out_id, 0, input_id, 0)
+
+        inputs = [None] * 4
+        inputs[0] = input_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+
+        outputs = [None] * 1
+        outputs[0] = out_id
+
+        self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
+
+    def add_qlinear(self, node):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+        (
+            jit_input,
+            jit_packed_weight,
+            jit_scale,
+            jit_zero_point,
+        ) = node.inputs()
+
+        input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
+        # TODO: Support automatic reshape
+        assert len(input_oper.shape) == 2
+
+        _, out_scale = self.get_constant_value(jit_scale, "FloatType")
+        _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
+        weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
+        assert weight_ctype.name() == "LinearPackedParamsBase"
+        raw_weight, raw_bias = packed_weight.__getstate__()[0]
+        assert raw_bias is not None
+
+        assert len(raw_weight.shape) == 2
+        assert len(raw_bias.shape) == 1
+        assert raw_bias.shape[0] == raw_weight.shape[0]
+        assert raw_weight.shape[1] == input_oper.shape[1]
+
+        assert raw_weight.qscheme() == torch.per_tensor_affine
+        if raw_weight.dtype == torch.quint8:
+            unsigned_weight = raw_weight
+        else:
+            assert raw_weight.dtype == torch.qint8
+            unsigned_weight = torch._make_per_tensor_quantized_tensor(
+                (raw_weight.int_repr().int() + 128).to(torch.uint8),
+                scale=raw_weight.q_scale(),
+                zero_point=raw_weight.q_zero_point() + 128,
+            )
+        weight_scale = unsigned_weight.q_scale()
+        bias_scale = input_oper.scale * weight_scale
+        int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
+        bias_id = self.add_tensor_operand_for_weight(int_bias)
+
+        multiplier = input_oper.scale * weight_scale / out_scale
+        assert multiplier > 0
+        if multiplier >= 1:
+            raise Exception(  # noqa: TRY002
+                "Quantized convolution multiplier is greater than 1.  "
+                "This is supported by NNAPI, but not by most hardware backends.  "
+                "Try training a model without quantization-aware training.  "
+            )
+
+        # TODO: Transform at load time to share weights with CPU model.
+        nnapi_weight_tensor = unsigned_weight.contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        out_shape = (input_oper.shape[0], weight_oper.shape[0])
+        out_oper = input_oper._replace(
+            shape=out_shape,
+            scale=out_scale,
+            zero_point=out_zero_point,
+        )
+
+        inputs = [None] * 4
+        inputs[0] = input_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
+
+        self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
+
+    def get_optional_bias(self, jit_bias, weight_tensor, transpose=False):
+        ctype, _value = self.get_constant_value(jit_bias)
+        if ctype.kind() == "NoneType":
+            bias_idx = 1 if transpose else 0
+            nnapi_bias_tensor = torch.zeros(
+                weight_tensor.size()[bias_idx], dtype=weight_tensor.dtype
+            )
+            bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor)
+            bias_oper = self.operands[bias_id]
+            return bias_id, bias_oper
+        else:
+            return self.get_tensor_operand_for_weight(jit_bias)
+
+    def add_conv2d(self, node):
+        assert node.inputsSize() == 7
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_weight,
+            jit_bias,
+            jit_stride,
+            jit_pad,
+            jit_dilation,
+            jit_groups,
+        ) = node.inputs()
+
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
+        args = self.get_conv_pool_args_2d_from_jit(
+            weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
+        )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            0.0,
+            0,
+            jit_image,
+            weight_tensor,
+            bias_id,
+            args,
+            False,  # transpose
+            NNAPI_FuseCode.FUSED_NONE,
+        )
+
+    def add_conv_underscore(self, node):
+        assert node.inputsSize() == 13
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_weight,
+            jit_bias,
+            jit_stride,
+            jit_pad,
+            jit_dilation,
+            jit_transpose,
+            _,
+            jit_groups,
+            _,
+            _,
+            _,
+            _,
+        ) = node.inputs()
+
+        _, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
+        _, transpose = self.get_constant_value(jit_transpose)
+        bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose)
+        args = self.get_conv_pool_args_2d_from_jit(
+            weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups
+        )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            0.0,
+            0,
+            jit_image,
+            weight_tensor,
+            bias_id,
+            args,
+            transpose,
+            NNAPI_FuseCode.FUSED_NONE,
+        )
+
+    def add_log_softmax(self, node):
+        assert node.inputsSize() == 3
+        assert node.outputsSize() == 1
+
+        jit_input, jit_dim, _jit_half_to_float = node.inputs()
+        input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
+        _, dim = self.get_constant_value(jit_dim, "IntType")
+
+        out_shape = input_oper.shape
+
+        inputs = [None] * 3
+        inputs[0] = input_id
+        # specifying 1 as the scaling factor for the exponent, beta
+        inputs[1] = self.add_immediate_float_scalar(1)
+        inputs[2] = self.add_immediate_int_scalar(dim)
+
+        outputs = [None] * 1
+        outputs[0] = self.add_tensor_operand(
+            node.outputsAt(0), input_oper._replace(shape=out_shape)
+        )
+        self.add_operation(NNAPI_OperationCode.LOG_SOFTMAX, inputs, outputs)
+
+    def add_qconv2d(self, node, fuse_code, transpose=False):
+        assert node.inputsSize() == 4
+        assert node.outputsSize() == 1
+
+        (
+            jit_image,
+            jit_packed_weight,
+            jit_scale,
+            jit_zero_point,
+        ) = node.inputs()
+
+        _, out_scale = self.get_constant_value(jit_scale, "FloatType")
+        _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
+        weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
+        assert weight_ctype.name() == "Conv2dPackedParamsBase"
+        (
+            pack_version,
+            tensors,
+            opt_tensors,
+        ) = packed_weight.__getstate__()[0]
+        assert pack_version == "2"
+        packed_config, raw_weight = tensors
+        (raw_bias,) = opt_tensors
+        assert raw_bias is not None
+        args = self.get_conv_pool_args_2d_from_pack(
+            raw_weight.shape[2:4], packed_config
+        )
+
+        assert raw_weight.qscheme() == torch.per_tensor_affine
+        if raw_weight.dtype == torch.quint8:
+            unsigned_weight = raw_weight
+        else:
+            assert raw_weight.dtype == torch.qint8
+            unsigned_weight = torch._make_per_tensor_quantized_tensor(
+                (raw_weight.int_repr().int() + 128).to(torch.uint8),
+                scale=raw_weight.q_scale(),
+                zero_point=raw_weight.q_zero_point() + 128,
+            )
+        weight_scale = unsigned_weight.q_scale()
+        _, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        bias_scale = image_oper.scale * weight_scale
+        int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
+        bias_id = self.add_tensor_operand_for_weight(int_bias)
+
+        multiplier = image_oper.scale * weight_scale / out_scale
+        assert multiplier > 0
+        if multiplier >= 1:
+            raise Exception(  # noqa: TRY002
+                "Quantized convolution multiplier is greater than 1.  "
+                "This is supported by NNAPI, but not by most hardware backends.  "
+                "Try training a model without quantization-aware training.  "
+            )
+
+        return self.add_conv2d_common(
+            node.outputsAt(0),
+            out_scale,
+            out_zero_point,
+            jit_image,
+            unsigned_weight,
+            bias_id,
+            args,
+            transpose,
+            fuse_code,
+        )
+
+    def add_conv2d_common(
+        self,
+        jit_out,
+        out_scale,
+        out_zero_point,
+        jit_image,
+        weight_tensor,
+        bias_id,
+        args,
+        transpose,
+        fuse_code,
+    ):
+        image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        in_c = image_oper.shape[1]
+
+        if args.group == 1:
+            # Full convolution
+            depthwise = False
+            if transpose:
+                weight_permutation = (1, 2, 3, 0)
+            else:
+                weight_permutation = (0, 2, 3, 1)
+        elif args.group == in_c:
+            # Depthwise convolution
+            depthwise = True
+            weight_permutation = (1, 2, 3, 0)
+        else:
+            raise Exception("Group convolution not supported yet.")  # noqa: TRY002
+
+        # TODO: Transform at load time to share weights with CPU model.
+        nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous()
+        weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
+        weight_oper = self.operands[weight_id]
+
+        bias_oper = self.operands[bias_id]
+
+        if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
+            assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
+            assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
+        elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
+            assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
+            assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32
+            assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale)
+            assert bias_oper.zero_point == 0
+        else:
+            raise Exception(  # noqa: TRY002
+                f"Unsupported input type for conv2d: {image_oper.op_type}"
+            )  # noqa: TRY002
+
+        assert len(image_oper.shape) == 4
+        assert len(weight_oper.shape) == 4
+        assert len(bias_oper.shape) == 1
+
+        if depthwise:
+            # Depthwise convolution
+            one, _kern_h, _kern_w, out_c = weight_oper.shape
+            assert one == 1
+            assert out_c % in_c == 0
+            channel_multiplier = out_c // in_c
+            assert channel_multiplier == 1  # Don't support multiplier
+            assert out_c == in_c
+        else:
+            # Full convolution
+            out_c, _kern_h, _kern_w, kern_d = weight_oper.shape
+            assert kern_d == in_c
+
+        assert out_c == bias_oper.shape[0]
+
+        use_nchw = image_oper.use_nchw()
+
+        if depthwise:
+            num_args = 12
+            opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D
+        else:
+            num_args = 11
+            if transpose:
+                opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D
+            else:
+                opcode = NNAPI_OperationCode.CONV_2D
+
+        inputs = [None] * num_args
+        inputs[0] = image_id
+        inputs[1] = weight_id
+        inputs[2] = bias_id
+        inputs[3] = self.add_immediate_int_scalar(args.pad_l)
+        inputs[4] = self.add_immediate_int_scalar(args.pad_r)
+        inputs[5] = self.add_immediate_int_scalar(args.pad_t)
+        inputs[6] = self.add_immediate_int_scalar(args.pad_b)
+        inputs[7] = self.add_immediate_int_scalar(args.stride_w)
+        inputs[8] = self.add_immediate_int_scalar(args.stride_h)
+        if depthwise:
+            inputs[9] = self.add_immediate_int_scalar(1)
+            inputs[10] = self.add_immediate_int_scalar(fuse_code)
+            inputs[11] = self.add_immediate_bool_scalar(use_nchw)
+        else:
+            inputs[9] = self.add_immediate_int_scalar(fuse_code)
+            inputs[10] = self.add_immediate_bool_scalar(use_nchw)
+
+        outputs = [None] * 1
+        out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose)
+        out_oper = image_oper._replace(
+            shape=out_shape,
+            scale=out_scale,
+            zero_point=out_zero_point,
+        )
+        out_id = self.add_tensor_operand(jit_out, out_oper)
+        self._handle_conv_pool_flexible_input(out_id, jit_image, args, transpose)
+
+        outputs[0] = out_id
+        self.add_operation(opcode, inputs, outputs)
+
+    def _handle_conv_pool_flexible_input(self, out_id, jit_image, args, transpose):
+        image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image)
+        batch, in_ch, in_h, in_w = image_oper.shape
+
+        if batch == 0:
+            self.forward_operand_shape(out_id, 0, image_id, 0)
+        if in_ch == 0:
+            raise Exception("Input channels can't be flexible")  # noqa: TRY002
+        # H & W
+        if transpose:
+            if in_h == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    2,
+                    f"({flex_name(image_id, 2)} - 1) * {args.stride_h} + {args.kernel_h} - {args.pad_t} - {args.pad_b}",
+                )
+            if in_w == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    3,
+                    f"({flex_name(image_id, 3)} - 1) * {args.stride_w} + {args.kernel_w} - {args.pad_l} - {args.pad_r}",
+                )
+        else:
+            if in_h == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    2,
+                    f"({flex_name(image_id, 2)} - {args.kernel_h} + {args.pad_t} + {args.pad_b}) // {args.stride_h} + 1",
+                )
+            if in_w == 0:
+                self.compute_operand_shape(
+                    out_id,
+                    3,
+                    f"({flex_name(image_id, 3)} - {args.kernel_w} + {args.pad_l} + {args.pad_r}) // {args.stride_w} + 1",
+                )
+
+
+def serialize_model(
+    module, inputs, *, config=None, return_shapes=None, use_int16_for_qint16=False
+):
+    """Convert to NNAPI and serialize torchscript module.
+
+    Parameters:
+        module: Torchscript module to convert
+        inputs: Tensors used to specify input details for NNAPI
+        config (optional): Optional config to attach to module
+        return_shapes (optional): Specify shape of outputs if
+            your module uses runtime flexible shapes to set output
+            buffer size for NNAPI
+        use_int16_for_qint16 (optional): Use Pytorch int16 to represent NNAPI qint16 values
+    """
+    return _NnapiSerializer(config, use_int16_for_qint16).serialize_model(
+        module, inputs, return_shapes
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cpu/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/cpu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82dc52cd4904c1cda023c876c586550a5a33ff7a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/cpu/__init__.py
@@ -0,0 +1,21 @@
+import torch
+
+
+__all__ = [
+    "get_cpu_capability",
+]
+
+
+def get_cpu_capability() -> str:
+    r"""Return cpu capability as a string value.
+
+    Possible values:
+    - "DEFAULT"
+    - "VSX"
+    - "Z VECTOR"
+    - "NO AVX"
+    - "AVX2"
+    - "AVX512"
+    - "SVE256"
+    """
+    return torch._C._get_cpu_capability()
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa08906df57190bb27ef03c9c7bd610a2af97c21
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/cpu/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cuda/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/cuda/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7541c9c7ca670d4f420c10e38d0b8a5da8496441
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/cuda/__init__.py
@@ -0,0 +1,536 @@
+# mypy: allow-untyped-defs
+import contextlib
+from typing import Union
+from typing_extensions import deprecated
+
+import torch
+
+
+__all__ = [
+    "is_built",
+    "cuFFTPlanCacheAttrContextProp",
+    "cuFFTPlanCache",
+    "cuFFTPlanCacheManager",
+    "cuBLASModule",
+    "preferred_linalg_library",
+    "preferred_blas_library",
+    "preferred_rocm_fa_library",
+    "cufft_plan_cache",
+    "matmul",
+    "SDPAParams",
+    "enable_cudnn_sdp",
+    "cudnn_sdp_enabled",
+    "enable_flash_sdp",
+    "flash_sdp_enabled",
+    "enable_mem_efficient_sdp",
+    "mem_efficient_sdp_enabled",
+    "math_sdp_enabled",
+    "enable_math_sdp",
+    "allow_fp16_bf16_reduction_math_sdp",
+    "fp16_bf16_reduction_math_sdp_allowed",
+    "is_flash_attention_available",
+    "can_use_flash_attention",
+    "can_use_efficient_attention",
+    "can_use_cudnn_attention",
+    "sdp_kernel",
+]
+
+
+def is_built():
+    r"""
+    Return whether PyTorch is built with CUDA support.
+
+    Note that this doesn't necessarily mean CUDA is available; just that if this PyTorch
+    binary were run on a machine with working CUDA drivers and devices, we would be able to use it.
+    """
+    return torch._C._has_cuda
+
+
+class cuFFTPlanCacheAttrContextProp:
+    # Like regular ContextProp, but uses the `.device_index` attribute from the
+    # calling object as the first argument to the getter and setter.
+    def __init__(self, getter, setter):
+        self.getter = getter
+        self.setter = setter
+
+    def __get__(self, obj, objtype):
+        return self.getter(obj.device_index)
+
+    def __set__(self, obj, val):
+        if isinstance(self.setter, str):
+            raise RuntimeError(self.setter)
+        self.setter(obj.device_index, val)
+
+
+class cuFFTPlanCache:
+    r"""
+    Represent a specific plan cache for a specific `device_index`.
+
+    The attributes `size` and `max_size`, and method `clear`, can fetch and/ or
+    change properties of the C++ cuFFT plan cache.
+    """
+
+    def __init__(self, device_index):
+        self.device_index = device_index
+
+    size = cuFFTPlanCacheAttrContextProp(
+        torch._cufft_get_plan_cache_size,
+        ".size is a read-only property showing the number of plans currently in the "
+        "cache. To change the cache capacity, set cufft_plan_cache.max_size.",
+    )
+
+    max_size = cuFFTPlanCacheAttrContextProp(
+        torch._cufft_get_plan_cache_max_size, torch._cufft_set_plan_cache_max_size
+    )
+
+    def clear(self):
+        return torch._cufft_clear_plan_cache(self.device_index)
+
+
+class cuFFTPlanCacheManager:
+    r"""
+    Represent all cuFFT plan caches, return the cuFFTPlanCache for a given device when indexed.
+
+    Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g.,
+    setting the `.max_size`) attribute, the current device's cuFFT plan cache is
+    used.
+    """
+
+    __initialized = False
+
+    def __init__(self):
+        self.caches = []
+        self.__initialized = True
+
+    def __getitem__(self, device):
+        index = torch.cuda._utils._get_device_index(device)
+        if index < 0 or index >= torch.cuda.device_count():
+            raise RuntimeError(
+                f"cufft_plan_cache: expected 0 <= device index < {torch.cuda.device_count()}, but got "
+                f"device with index {index}"
+            )
+        if len(self.caches) == 0:
+            self.caches.extend(
+                cuFFTPlanCache(index) for index in range(torch.cuda.device_count())
+            )
+        return self.caches[index]
+
+    def __getattr__(self, name):
+        return getattr(self[torch.cuda.current_device()], name)
+
+    def __setattr__(self, name, value):
+        if self.__initialized:
+            return setattr(self[torch.cuda.current_device()], name, value)
+        else:
+            return super().__setattr__(name, value)
+
+
+class cuBLASModule:
+    def __getattr__(self, name):
+        if name == "allow_tf32":
+            return torch._C._get_cublas_allow_tf32()
+        elif name == "allow_fp16_reduced_precision_reduction":
+            return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
+        elif name == "allow_bf16_reduced_precision_reduction":
+            return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
+        elif name == "allow_fp16_accumulation":
+            return torch._C._get_cublas_allow_fp16_accumulation()
+        raise AttributeError("Unknown attribute " + name)
+
+    def __setattr__(self, name, value):
+        if name == "allow_tf32":
+            return torch._C._set_cublas_allow_tf32(value)
+        elif name == "allow_fp16_reduced_precision_reduction":
+            return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
+        elif name == "allow_bf16_reduced_precision_reduction":
+            return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
+        elif name == "allow_fp16_accumulation":
+            return torch._C._set_cublas_allow_fp16_accumulation(value)
+        raise AttributeError("Unknown attribute " + name)
+
+
+_LinalgBackends = {
+    "default": torch._C._LinalgBackend.Default,
+    "cusolver": torch._C._LinalgBackend.Cusolver,
+    "magma": torch._C._LinalgBackend.Magma,
+}
+_LinalgBackends_str = ", ".join(_LinalgBackends.keys())
+
+
+def preferred_linalg_library(
+    backend: Union[None, str, torch._C._LinalgBackend] = None
+) -> torch._C._LinalgBackend:
+    r"""
+    Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations.
+
+    .. warning:: This flag is experimental and subject to change.
+
+    When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries,
+    and if both are available it decides which to use with a heuristic.
+    This flag (a :class:`str`) allows overriding those heuristics.
+
+    * If `"cusolver"` is set then cuSOLVER will be used wherever possible.
+    * If `"magma"` is set then MAGMA will be used wherever possible.
+    * If `"default"` (the default) is set then heuristics will be used to pick between
+      cuSOLVER and MAGMA if both are available.
+    * When no input is given, this function returns the currently preferred library.
+    * User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER
+      globally.
+      This flag only sets the initial value of the preferred library and the preferred library
+      may still be overridden by this function call later in your script.
+
+    Note: When a library is preferred other libraries may still be used if the preferred library
+    doesn't implement the operation(s) called.
+    This flag may achieve better performance if PyTorch's heuristic library selection is incorrect
+    for your application's inputs.
+
+    Currently supported linalg operators:
+
+    * :func:`torch.linalg.inv`
+    * :func:`torch.linalg.inv_ex`
+    * :func:`torch.linalg.cholesky`
+    * :func:`torch.linalg.cholesky_ex`
+    * :func:`torch.cholesky_solve`
+    * :func:`torch.cholesky_inverse`
+    * :func:`torch.linalg.lu_factor`
+    * :func:`torch.linalg.lu`
+    * :func:`torch.linalg.lu_solve`
+    * :func:`torch.linalg.qr`
+    * :func:`torch.linalg.eigh`
+    * :func:`torch.linalg.eighvals`
+    * :func:`torch.linalg.svd`
+    * :func:`torch.linalg.svdvals`
+    """
+    if backend is None:
+        pass
+    elif isinstance(backend, str):
+        if backend not in _LinalgBackends:
+            raise RuntimeError(
+                "Unknown input value. " f"Choose from: {_LinalgBackends_str}."
+            )
+        torch._C._set_linalg_preferred_backend(_LinalgBackends[backend])
+    elif isinstance(backend, torch._C._LinalgBackend):
+        torch._C._set_linalg_preferred_backend(backend)
+    else:
+        raise RuntimeError("Unknown input value type.")
+
+    return torch._C._get_linalg_preferred_backend()
+
+
+_BlasBackends = {
+    "default": torch._C._BlasBackend.Default,
+    "cublas": torch._C._BlasBackend.Cublas,
+    "hipblas": torch._C._BlasBackend.Cublas,  # alias
+    "cublaslt": torch._C._BlasBackend.Cublaslt,
+    "hipblaslt": torch._C._BlasBackend.Cublaslt,  # alias
+    "ck": torch._C._BlasBackend.Ck,
+}
+_BlasBackends_str = ", ".join(_BlasBackends.keys())
+
+
+def preferred_blas_library(
+    backend: Union[None, str, torch._C._BlasBackend] = None
+) -> torch._C._BlasBackend:
+    r"""
+    Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only].
+
+    .. warning:: This flag is experimental and subject to change.
+
+    When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available.
+    For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance.
+    This flag (a :class:`str`) allows overriding which BLAS library to use.
+
+    * If `"cublas"` is set then cuBLAS will be used wherever possible.
+    * If `"cublaslt"` is set then cuBLASLt will be used wherever possible.
+    * If `"ck"` is set then CK will be used wherever possible.
+    * If `"default"` (the default) is set then heuristics will be used to pick between the other options.
+    * When no input is given, this function returns the currently preferred library.
+    * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt
+      globally.
+      This flag only sets the initial value of the preferred library and the preferred library
+      may still be overridden by this function call later in your script.
+
+    Note: When a library is preferred other libraries may still be used if the preferred library
+    doesn't implement the operation(s) called.
+    This flag may achieve better performance if PyTorch's library selection is incorrect
+    for your application's inputs.
+
+    """
+    if backend is None:
+        pass
+    elif isinstance(backend, str):
+        if backend not in _BlasBackends:
+            raise RuntimeError(
+                "Unknown input value. " f"Choose from: {_BlasBackends_str}."
+            )
+        torch._C._set_blas_preferred_backend(_BlasBackends[backend])
+    elif isinstance(backend, torch._C._BlasBackend):
+        torch._C._set_blas_preferred_backend(backend)
+    else:
+        raise RuntimeError("Unknown input value type.")
+
+    return torch._C._get_blas_preferred_backend()
+
+
+_ROCmFABackends = {
+    "default": torch._C._ROCmFABackend.Default,
+    "aotriton": torch._C._ROCmFABackend.AOTriton,
+    "ck": torch._C._ROCmFABackend.Ck,
+}
+_ROCmFABackends_str = ", ".join(_ROCmFABackends.keys())
+
+
+from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend
+
+
+def preferred_rocm_fa_library(
+    backend: Union[None, str, torch._C._ROCmFABackend] = None
+) -> torch._C._ROCmFABackend:
+    r"""
+    [ROCm-only]
+    Override the backend PyTorch uses in ROCm environments for Flash Attention. Choose between AOTriton and CK
+
+    .. warning:: This flag is experimeental and subject to change.
+
+    When Flash Attention is enabled and desired, PyTorch defaults to using AOTriton as the backend.
+    This flag (a :class:`str`) allows users to override this backend to use composable_kernel
+
+    * If `"default"` is set then the default backend will be used wherever possible. Currently AOTriton.
+    * If `"aotriton"` is set then AOTriton will be used wherever possible.
+    * If `"ck"` is set then CK will be used wherever possible.
+    * When no input is given, this function returns the currently preferred library.
+    * User may use the environment variable TORCH_ROCM_FA_PREFER_CK=1 to set the preferred library to CK
+      globally.
+
+    Note: When a library is preferred other libraries may still be used if the preferred library
+    doesn't implement the operation(s) called.
+    This flag may achieve better performance if PyTorch's library selection is incorrect
+    for your application's inputs.
+    """
+    if backend is None:
+        pass
+    elif isinstance(backend, str):
+        if backend not in _ROCmFABackends:
+            raise RuntimeError(
+                "Unknown input value. " f"Choose from: {_ROCmFABackends_str}."
+            )
+        torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend])
+    elif isinstance(backend, torch._C._ROCmFABackend):
+        torch._C._set_rocm_fa_preferred_backend(backend)
+    else:
+        raise ValueError("Unknown input value. " f"Choose from: {_ROCmFABackends_str}.")
+
+    return torch._C._get_rocm_fa_preferred_backend()
+
+
+# Set the __module__ attribute
+SDPAParams.__module__ = "torch.backends.cuda"
+SDPAParams.__name__ = "SDPAParams"
+
+
+def flash_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether flash scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_flash_sdp_enabled()
+
+
+def enable_flash_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables flash scaled dot product attention.
+    """
+    torch._C._set_sdp_use_flash(enabled)
+
+
+def mem_efficient_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether memory efficient scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_mem_efficient_sdp_enabled()
+
+
+def enable_mem_efficient_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables memory efficient scaled dot product attention.
+    """
+    torch._C._set_sdp_use_mem_efficient(enabled)
+
+
+def math_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether math scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_math_sdp_enabled()
+
+
+def enable_math_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables math scaled dot product attention.
+    """
+    torch._C._set_sdp_use_math(enabled)
+
+
+def allow_fp16_bf16_reduction_math_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables fp16/bf16 reduction in math scaled dot product attention.
+    """
+    torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled)
+
+
+def fp16_bf16_reduction_math_sdp_allowed():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_math_sdp_allow_fp16_bf16_reduction()
+
+
+def is_flash_attention_available() -> bool:
+    r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention.
+
+    Returns:
+        True if FlashAttention is built and available; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._is_flash_attention_available()
+
+
+def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool:
+    r"""Check if FlashAttention can be utilized in scaled_dot_product_attention.
+
+    Args:
+        params: An instance of SDPAParams containing the tensors for query,
+                key, value, an optional attention mask, dropout rate, and
+                a flag indicating if the attention is causal.
+        debug: Whether to logging.warn debug information as to why FlashAttention could not be run.
+            Defaults to False.
+
+    Returns:
+        True if FlashAttention can be used with the given parameters; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._can_use_flash_attention(params, debug)
+
+
+def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool:
+    r"""Check if efficient_attention can be utilized in scaled_dot_product_attention.
+
+    Args:
+        params: An instance of SDPAParams containing the tensors for query,
+                key, value, an optional attention mask, dropout rate, and
+                a flag indicating if the attention is causal.
+        debug: Whether to logging.warn with information as to why efficient_attention could not be run.
+            Defaults to False.
+
+    Returns:
+        True if efficient_attention can be used with the given parameters; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._can_use_mem_efficient_attention(params, debug)
+
+
+def can_use_cudnn_attention(params: SDPAParams, debug: bool = False) -> bool:
+    r"""Check if cudnn_attention can be utilized in scaled_dot_product_attention.
+
+    Args:
+        params: An instance of SDPAParams containing the tensors for query,
+                key, value, an optional attention mask, dropout rate, and
+                a flag indicating if the attention is causal.
+        debug: Whether to logging.warn with information as to why cuDNN attention could not be run.
+            Defaults to False.
+
+    Returns:
+        True if cuDNN can be used with the given parameters; otherwise, False.
+
+    Note:
+        This function is dependent on a CUDA-enabled build of PyTorch. It will return False
+        in non-CUDA environments.
+    """
+    return torch._C._can_use_cudnn_attention(params, debug)
+
+
+def cudnn_sdp_enabled():
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Returns whether cuDNN scaled dot product attention is enabled or not.
+    """
+    return torch._C._get_cudnn_sdp_enabled()
+
+
+def enable_cudnn_sdp(enabled: bool):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    Enables or disables cuDNN scaled dot product attention.
+    """
+    torch._C._set_sdp_use_cudnn(enabled)
+
+
+@contextlib.contextmanager
+@deprecated(
+    (
+        "`torch.backends.cuda.sdp_kernel()` is deprecated. "
+        "In the future, this context manager will be removed. "
+        "Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, "
+        "with updated signature."
+    ),
+    category=FutureWarning,
+)
+def sdp_kernel(
+    enable_flash: bool = True,
+    enable_math: bool = True,
+    enable_mem_efficient: bool = True,
+    enable_cudnn: bool = True,
+):
+    r"""
+    .. warning:: This flag is beta and subject to change.
+
+    This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
+    Upon exiting the context manager, the previous state of the flags will be restored.
+    """
+    from torch.nn.attention import sdpa_kernel
+
+    backend_list = []
+    if enable_flash:
+        backend_list.append(SDPBackend.FLASH_ATTENTION)
+    if enable_mem_efficient:
+        backend_list.append(SDPBackend.EFFICIENT_ATTENTION)
+    if enable_math:
+        backend_list.append(SDPBackend.MATH)
+    if enable_cudnn:
+        backend_list.append(SDPBackend.CUDNN_ATTENTION)
+
+    with sdpa_kernel(backend_list) as context:
+        try:
+            yield context
+        finally:
+            pass
+
+
+cufft_plan_cache = cuFFTPlanCacheManager()
+matmul = cuBLASModule()
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e268f572c334009957951bef25afe7cdb504bed
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/cuda/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba6f9e3b40bdfd8194dc9be97f542910ffb4632
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__init__.py
@@ -0,0 +1,208 @@
+# mypy: allow-untyped-defs
+import os
+import sys
+import warnings
+from contextlib import contextmanager
+from typing import Optional
+
+import torch
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+
+try:
+    from torch._C import _cudnn
+except ImportError:
+    _cudnn = None  # type: ignore[assignment]
+
+# Write:
+#
+#   torch.backends.cudnn.enabled = False
+#
+# to globally disable CuDNN/MIOpen
+
+__cudnn_version: Optional[int] = None
+
+if _cudnn is not None:
+
+    def _init():
+        global __cudnn_version
+        if __cudnn_version is None:
+            __cudnn_version = _cudnn.getVersionInt()
+            runtime_version = _cudnn.getRuntimeVersion()
+            compile_version = _cudnn.getCompileVersion()
+            runtime_major, runtime_minor, _ = runtime_version
+            compile_major, compile_minor, _ = compile_version
+            # Different major versions are always incompatible
+            # Starting with cuDNN 7, minor versions are backwards-compatible
+            # Not sure about MIOpen (ROCm), so always do a strict check
+            if runtime_major != compile_major:
+                cudnn_compatible = False
+            elif runtime_major < 7 or not _cudnn.is_cuda:
+                cudnn_compatible = runtime_minor == compile_minor
+            else:
+                cudnn_compatible = runtime_minor >= compile_minor
+            if not cudnn_compatible:
+                if os.environ.get("PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK", "0") == "1":
+                    return True
+                base_error_msg = (
+                    f"cuDNN version incompatibility: "
+                    f"PyTorch was compiled  against {compile_version} "
+                    f"but found runtime version {runtime_version}. "
+                    f"PyTorch already comes bundled with cuDNN. "
+                    f"One option to resolving this error is to ensure PyTorch "
+                    f"can find the bundled cuDNN. "
+                )
+
+                if "LD_LIBRARY_PATH" in os.environ:
+                    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
+                    if any(
+                        substring in ld_library_path for substring in ["cuda", "cudnn"]
+                    ):
+                        raise RuntimeError(
+                            f"{base_error_msg}"
+                            f"Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn. "
+                            f"Please either remove it from the path or install cudnn {compile_version}"
+                        )
+                    else:
+                        raise RuntimeError(
+                            f"{base_error_msg}"
+                            f"one possibility is that there is a "
+                            f"conflicting cuDNN in LD_LIBRARY_PATH."
+                        )
+                else:
+                    raise RuntimeError(base_error_msg)
+
+        return True
+
+else:
+
+    def _init():
+        return False
+
+
+def version():
+    """Return the version of cuDNN."""
+    if not _init():
+        return None
+    return __cudnn_version
+
+
+CUDNN_TENSOR_DTYPES = {
+    torch.half,
+    torch.float,
+    torch.double,
+}
+
+
+def is_available():
+    r"""Return a bool indicating if CUDNN is currently available."""
+    return torch._C._has_cudnn
+
+
+def is_acceptable(tensor):
+    if not torch._C._get_cudnn_enabled():
+        return False
+    if tensor.device.type != "cuda" or tensor.dtype not in CUDNN_TENSOR_DTYPES:
+        return False
+    if not is_available():
+        warnings.warn(
+            "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild "
+            "PyTorch making sure the library is visible to the build system."
+        )
+        return False
+    if not _init():
+        warnings.warn(
+            "cuDNN/MIOpen library not found. Check your {libpath}".format(
+                libpath={"darwin": "DYLD_LIBRARY_PATH", "win32": "PATH"}.get(
+                    sys.platform, "LD_LIBRARY_PATH"
+                )
+            )
+        )
+        return False
+    return True
+
+
+def set_flags(
+    _enabled=None,
+    _benchmark=None,
+    _benchmark_limit=None,
+    _deterministic=None,
+    _allow_tf32=None,
+):
+    orig_flags = (
+        torch._C._get_cudnn_enabled(),
+        torch._C._get_cudnn_benchmark(),
+        None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
+        torch._C._get_cudnn_deterministic(),
+        torch._C._get_cudnn_allow_tf32(),
+    )
+    if _enabled is not None:
+        torch._C._set_cudnn_enabled(_enabled)
+    if _benchmark is not None:
+        torch._C._set_cudnn_benchmark(_benchmark)
+    if _benchmark_limit is not None and is_available():
+        torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit)
+    if _deterministic is not None:
+        torch._C._set_cudnn_deterministic(_deterministic)
+    if _allow_tf32 is not None:
+        torch._C._set_cudnn_allow_tf32(_allow_tf32)
+    return orig_flags
+
+
+@contextmanager
+def flags(
+    enabled=False,
+    benchmark=False,
+    benchmark_limit=10,
+    deterministic=False,
+    allow_tf32=True,
+):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(
+            enabled, benchmark, benchmark_limit, deterministic, allow_tf32
+        )
+    try:
+        yield
+    finally:
+        # recover the previous values
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+# The magic here is to allow us to intercept code like this:
+#
+#   torch.backends..enabled = True
+
+
+class CudnnModule(PropModule):
+    def __init__(self, m, name):
+        super().__init__(m, name)
+
+    enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
+    deterministic = ContextProp(
+        torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic
+    )
+    benchmark = ContextProp(
+        torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark
+    )
+    benchmark_limit = None
+    if is_available():
+        benchmark_limit = ContextProp(
+            torch._C._cuda_get_cudnn_benchmark_limit,
+            torch._C._cuda_set_cudnn_benchmark_limit,
+        )
+    allow_tf32 = ContextProp(
+        torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32
+    )
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__)
+
+# Add type annotation for the replaced module
+enabled: bool
+deterministic: bool
+benchmark: bool
+allow_tf32: bool
+benchmark_limit: int
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cc56a51c004d46672357220e1784bb2ab0a5b23
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/cudnn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py b/.venv/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b253e19054027e88c3ead9e536d0fa0cd0da3cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/cudnn/rnn.py
@@ -0,0 +1,64 @@
+# mypy: allow-untyped-defs
+import torch.cuda
+
+
+try:
+    from torch._C import _cudnn
+except ImportError:
+    # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
+    # so it's safe to not emit any checks here.
+    _cudnn = None  # type: ignore[assignment]
+
+
+def get_cudnn_mode(mode):
+    if mode == "RNN_RELU":
+        return int(_cudnn.RNNMode.rnn_relu)
+    elif mode == "RNN_TANH":
+        return int(_cudnn.RNNMode.rnn_tanh)
+    elif mode == "LSTM":
+        return int(_cudnn.RNNMode.lstm)
+    elif mode == "GRU":
+        return int(_cudnn.RNNMode.gru)
+    else:
+        raise Exception(f"Unknown mode: {mode}")  # noqa: TRY002
+
+
+# NB: We don't actually need this class anymore (in fact, we could serialize the
+# dropout state for even better reproducibility), but it is kept for backwards
+# compatibility for old models.
+class Unserializable:
+    def __init__(self, inner):
+        self.inner = inner
+
+    def get(self):
+        return self.inner
+
+    def __getstate__(self):
+        # Note: can't return {}, because python2 won't call __setstate__
+        # if the value evaluates to False
+        return ""
+
+    def __setstate__(self, state):
+        self.inner = None
+
+
+def init_dropout_state(dropout, train, dropout_seed, dropout_state):
+    dropout_desc_name = "desc_" + str(torch.cuda.current_device())
+    dropout_p = dropout if train else 0
+    if (dropout_desc_name not in dropout_state) or (
+        dropout_state[dropout_desc_name].get() is None
+    ):
+        if dropout_p == 0:
+            dropout_state[dropout_desc_name] = Unserializable(None)
+        else:
+            dropout_state[dropout_desc_name] = Unserializable(
+                torch._cudnn_init_dropout_state(  # type: ignore[call-arg]
+                    dropout_p,
+                    train,
+                    dropout_seed,
+                    self_ty=torch.uint8,
+                    device=torch.device("cuda"),
+                )
+            )
+    dropout_ts = dropout_state[dropout_desc_name].get()
+    return dropout_ts
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..da46274a2846da188d02f9b5b66a7dbd7d2cd38f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__init__.py
@@ -0,0 +1,57 @@
+# mypy: allow-untyped-defs
+from typing import Optional
+
+import torch
+
+
+__all__ = [
+    "version",
+    "is_available",
+    "get_max_alg_id",
+]
+
+try:
+    from torch._C import _cusparselt
+except ImportError:
+    _cusparselt = None  # type: ignore[assignment]
+
+__cusparselt_version: Optional[int] = None
+__MAX_ALG_ID: Optional[int] = None
+
+if _cusparselt is not None:
+
+    def _init():
+        global __cusparselt_version
+        global __MAX_ALG_ID
+        if __cusparselt_version is None:
+            __cusparselt_version = _cusparselt.getVersionInt()
+            if __cusparselt_version == 400:
+                __MAX_ALG_ID = 4
+            elif __cusparselt_version == 502:
+                __MAX_ALG_ID = 5
+            elif __cusparselt_version == 602:
+                __MAX_ALG_ID = 37
+        return True
+
+else:
+
+    def _init():
+        return False
+
+
+def version() -> Optional[int]:
+    """Return the version of cuSPARSELt"""
+    if not _init():
+        return None
+    return __cusparselt_version
+
+
+def is_available() -> bool:
+    r"""Return a bool indicating if cuSPARSELt is currently available."""
+    return torch._C._has_cusparselt
+
+
+def get_max_alg_id() -> Optional[int]:
+    if not _init():
+        return None
+    return __MAX_ALG_ID
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1548009f327436ba3bbd80d0d35c79cbafc47e5b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/cusparselt/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a681b77ef58ce1f390232b82c4a9843d5559ca3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__init__.py
@@ -0,0 +1,7 @@
+# mypy: allow-untyped-defs
+import torch
+
+
+def is_available():
+    r"""Return whether PyTorch is built with KleidiAI support."""
+    return torch._C._has_kleidiai
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3136a590e23c3602ca5def16b155a565771a9c5
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/kleidiai/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/mha/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/mha/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1dd2ebd688805bdf3359cb56b64d0854cf258c4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/mha/__init__.py
@@ -0,0 +1,25 @@
+# Config options to enable/disable C++ kernel for nn.functional.MHA
+# and nn.TransformerEncoder
+import torch
+
+
+_is_fastpath_enabled: bool = True
+
+
+def get_fastpath_enabled() -> bool:
+    """Returns whether fast path for TransformerEncoder and MultiHeadAttention
+    is enabled, or ``True`` if jit is scripting.
+
+    .. note::
+        The fastpath might not be run even if ``get_fastpath_enabled`` returns
+        ``True`` unless all conditions on inputs are met.
+    """
+    if not torch.jit.is_scripting():
+        return _is_fastpath_enabled
+    return True
+
+
+def set_fastpath_enabled(value: bool) -> None:
+    """Sets whether fast path is enabled"""
+    global _is_fastpath_enabled
+    _is_fastpath_enabled = value
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e96298aa616981ef0566b1d35493bcab7919ccd9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/mha/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/mkl/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/mkl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f96d692ae02190457b1e38cceadb62d95ab698a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/mkl/__init__.py
@@ -0,0 +1,57 @@
+# mypy: allow-untyped-defs
+import torch
+
+
+def is_available():
+    r"""Return whether PyTorch is built with MKL support."""
+    return torch._C.has_mkl
+
+
+VERBOSE_OFF = 0
+VERBOSE_ON = 1
+
+
+class verbose:
+    """
+    On-demand oneMKL verbosing functionality.
+
+    To make it easier to debug performance issues, oneMKL can dump verbose
+    messages containing execution information like duration while executing
+    the kernel. The verbosing functionality can be invoked via an environment
+    variable named `MKL_VERBOSE`. However, this methodology dumps messages in
+    all steps. Those are a large amount of verbose messages. Moreover, for
+    investigating the performance issues, generally taking verbose messages
+    for one single iteration is enough. This on-demand verbosing functionality
+    makes it possible to control scope for verbose message dumping. In the
+    following example, verbose messages will be dumped out for the second
+    inference only.
+
+    .. highlight:: python
+    .. code-block:: python
+
+        import torch
+        model(data)
+        with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON):
+            model(data)
+
+    Args:
+        level: Verbose level
+            - ``VERBOSE_OFF``: Disable verbosing
+            - ``VERBOSE_ON``:  Enable verbosing
+    """
+
+    def __init__(self, enable):
+        self.enable = enable
+
+    def __enter__(self):
+        if self.enable == VERBOSE_OFF:
+            return
+        st = torch._C._verbose.mkl_set_verbose(self.enable)
+        assert (
+            st
+        ), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope."
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        torch._C._verbose.mkl_set_verbose(VERBOSE_OFF)
+        return False
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a0402ef18dc02676f5b434aaa6e7d1bb4c0c21c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/mkl/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c97bcd9b079ba06226f210bfd227b2a79a3b216
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py
@@ -0,0 +1,114 @@
+# mypy: allow-untyped-defs
+import sys
+from contextlib import contextmanager
+from typing import TYPE_CHECKING
+
+import torch
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+
+def is_available():
+    r"""Return whether PyTorch is built with MKL-DNN support."""
+    return torch._C._has_mkldnn
+
+
+VERBOSE_OFF = 0
+VERBOSE_ON = 1
+VERBOSE_ON_CREATION = 2
+
+
+class verbose:
+    """
+    On-demand oneDNN (former MKL-DNN) verbosing functionality.
+
+    To make it easier to debug performance issues, oneDNN can dump verbose
+    messages containing information like kernel size, input data size and
+    execution duration while executing the kernel. The verbosing functionality
+    can be invoked via an environment variable named `DNNL_VERBOSE`. However,
+    this methodology dumps messages in all steps. Those are a large amount of
+    verbose messages. Moreover, for investigating the performance issues,
+    generally taking verbose messages for one single iteration is enough.
+    This on-demand verbosing functionality makes it possible to control scope
+    for verbose message dumping. In the following example, verbose messages
+    will be dumped out for the second inference only.
+
+    .. highlight:: python
+    .. code-block:: python
+
+        import torch
+        model(data)
+        with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON):
+            model(data)
+
+    Args:
+        level: Verbose level
+            - ``VERBOSE_OFF``: Disable verbosing
+            - ``VERBOSE_ON``:  Enable verbosing
+            - ``VERBOSE_ON_CREATION``: Enable verbosing, including oneDNN kernel creation
+    """
+
+    def __init__(self, level):
+        self.level = level
+
+    def __enter__(self):
+        if self.level == VERBOSE_OFF:
+            return
+        st = torch._C._verbose.mkldnn_set_verbose(self.level)
+        assert (
+            st
+        ), "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope."
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF)
+        return False
+
+
+def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None):
+    orig_flags = (
+        torch._C._get_mkldnn_enabled(),
+        torch._C._get_mkldnn_deterministic(),
+        torch._C._get_onednn_allow_tf32(),
+    )
+    if _enabled is not None:
+        torch._C._set_mkldnn_enabled(_enabled)
+    if _deterministic is not None:
+        torch._C._set_mkldnn_deterministic(_deterministic)
+    if _allow_tf32 is not None:
+        torch._C._set_onednn_allow_tf32(_allow_tf32)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=False, deterministic=False, allow_tf32=True):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled, deterministic, allow_tf32)
+    try:
+        yield
+    finally:
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+class MkldnnModule(PropModule):
+    def __init__(self, m, name):
+        super().__init__(m, name)
+
+    def is_available(self):
+        return is_available()
+
+    enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled)
+    deterministic = ContextProp(
+        torch._C._get_mkldnn_deterministic, torch._C._set_mkldnn_deterministic
+    )
+    allow_tf32 = ContextProp(
+        torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32
+    )
+
+
+if TYPE_CHECKING:
+    enabled: ContextProp
+    deterministic: ContextProp
+    allow_tf32: ContextProp
+
+sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b88fc6b2cbedb251d544b9dba5950cbdc86cb4ee
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/mkldnn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/mps/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/mps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3b934a6ced4c3b5e7be9b798c2c54f13906bd4f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/mps/__init__.py
@@ -0,0 +1,55 @@
+# mypy: allow-untyped-defs
+from functools import lru_cache as _lru_cache
+from typing import Optional, TYPE_CHECKING
+
+import torch
+from torch.library import Library as _Library
+
+
+__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
+
+
+def is_built() -> bool:
+    r"""Return whether PyTorch is built with MPS support.
+
+    Note that this doesn't necessarily mean MPS is available; just that
+    if this PyTorch binary were run a machine with working MPS drivers
+    and devices, we would be able to use it.
+    """
+    return torch._C._has_mps
+
+
+@_lru_cache
+def is_available() -> bool:
+    r"""Return a bool indicating if MPS is currently available."""
+    return torch._C._mps_is_available()
+
+
+@_lru_cache
+def is_macos_or_newer(major: int, minor: int) -> bool:
+    r"""Return a bool indicating whether MPS is running on given MacOS or newer."""
+    return torch._C._mps_is_on_macos_or_newer(major, minor)
+
+
+@_lru_cache
+def is_macos13_or_newer(minor: int = 0) -> bool:
+    r"""Return a bool indicating whether MPS is running on MacOS 13 or newer."""
+    return torch._C._mps_is_on_macos_or_newer(13, minor)
+
+
+_lib: Optional[_Library] = None
+
+
+def _init():
+    r"""Register prims as implementation of var_mean and group_norm."""
+    global _lib
+
+    if _lib is not None or not is_built():
+        return
+
+    from torch._decomp.decompositions import native_group_norm_backward
+    from torch._refs import native_group_norm
+
+    _lib = _Library("aten", "IMPL")  # noqa: TOR901
+    _lib.impl("native_group_norm", native_group_norm, "MPS")
+    _lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS")
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d827320d37d528c4b8b2d56717d89a69648676ce
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/mps/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d8a72f3cda9b0da16702c0d7c6fe92ae8f3f153
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__init__.py
@@ -0,0 +1,32 @@
+# mypy: allow-untyped-defs
+from contextlib import contextmanager
+
+import torch
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+
+__all__ = ["is_available", "flags", "set_flags"]
+
+
+def is_available():
+    r"""Return whether PyTorch is built with NNPACK support."""
+    return torch._nnpack_available()
+
+
+def set_flags(_enabled):
+    r"""Set if nnpack is enabled globally"""
+    orig_flags = (torch._C._get_nnpack_enabled(),)
+    torch._C._set_nnpack_enabled(_enabled)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=False):
+    r"""Context manager for setting if nnpack is enabled globally"""
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled)
+    try:
+        yield
+    finally:
+        with __allow_nonbracketed_mutation():
+            set_flags(orig_flags[0])
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..befc2760d687f82569d236b43f1f4ee9ee2728e6
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/openmp/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/openmp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff8d46cd4ac2d9ff49942542d99ac2afbb85896
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/openmp/__init__.py
@@ -0,0 +1,7 @@
+# mypy: allow-untyped-defs
+import torch
+
+
+def is_available():
+    r"""Return whether PyTorch is built with OpenMP support."""
+    return torch._C.has_openmp
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..875b1e42201d78d87f2fe5ca134739faf23889e9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/openmp/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c107cc1e44832398d3ee03ea2d6073c13af541
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/opt_einsum/__init__.py
@@ -0,0 +1,119 @@
+# mypy: allow-untyped-defs
+import sys
+import warnings
+from contextlib import contextmanager
+from functools import lru_cache as _lru_cache
+from typing import Any
+
+from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
+
+
+try:
+    import opt_einsum as _opt_einsum  # type: ignore[import]
+except ImportError:
+    _opt_einsum = None
+
+
+@_lru_cache
+def is_available() -> bool:
+    r"""Return a bool indicating if opt_einsum is currently available.
+
+    You must install opt-einsum in order for torch to automatically optimize einsum. To
+    make opt-einsum available, you can install it along with torch: ``pip install torch[opt-einsum]``
+    or by itself: ``pip install opt-einsum``. If the package is installed, torch will import
+    it automatically and use it accordingly. Use this function to check whether opt-einsum
+    was installed and properly imported by torch.
+    """
+    return _opt_einsum is not None
+
+
+def get_opt_einsum() -> Any:
+    r"""Return the opt_einsum package if opt_einsum is currently available, else None."""
+    return _opt_einsum
+
+
+def _set_enabled(_enabled: bool) -> None:
+    if not is_available() and _enabled:
+        raise ValueError(
+            f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap "
+            "the benefits of calculating an optimal path for einsum. torch.einsum will "
+            "fall back to contracting from left to right. To enable this optimal path "
+            "calculation, please install opt-einsum."
+        )
+    global enabled
+    enabled = _enabled
+
+
+def _get_enabled() -> bool:
+    return enabled
+
+
+def _set_strategy(_strategy: str) -> None:
+    if not is_available():
+        raise ValueError(
+            f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. "
+            "torch.einsum will bypass path calculation and simply contract from left to right. "
+            "Please install opt_einsum or unset `strategy`."
+        )
+    if not enabled:
+        raise ValueError(
+            f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. "
+            "torch.einsum will bypass path calculation and simply contract from left to right. "
+            "Please set `enabled` to `True` as well or unset `strategy`."
+        )
+    if _strategy not in ["auto", "greedy", "optimal"]:
+        raise ValueError(
+            f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}"
+        )
+    global strategy
+    strategy = _strategy
+
+
+def _get_strategy() -> str:
+    return strategy
+
+
+def set_flags(_enabled=None, _strategy=None):
+    orig_flags = (enabled, None if not is_available() else strategy)
+    if _enabled is not None:
+        _set_enabled(_enabled)
+    if _strategy is not None:
+        _set_strategy(_strategy)
+    return orig_flags
+
+
+@contextmanager
+def flags(enabled=None, strategy=None):
+    with __allow_nonbracketed_mutation():
+        orig_flags = set_flags(enabled, strategy)
+    try:
+        yield
+    finally:
+        # recover the previous values
+        with __allow_nonbracketed_mutation():
+            set_flags(*orig_flags)
+
+
+# The magic here is to allow us to intercept code like this:
+#
+#   torch.backends.opt_einsum.enabled = True
+
+
+class OptEinsumModule(PropModule):
+    def __init__(self, m, name):
+        super().__init__(m, name)
+
+    global enabled
+    enabled = ContextProp(_get_enabled, _set_enabled)
+    global strategy
+    strategy = None
+    if is_available():
+        strategy = ContextProp(_get_strategy, _set_strategy)
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
+
+enabled = True if is_available() else False
+strategy = "auto" if is_available() else None
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/quantized/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/quantized/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..caabfdf243783f2161a201c6a6ec9bd6eca83b18
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/quantized/__init__.py
@@ -0,0 +1,65 @@
+# mypy: allow-untyped-defs
+import sys
+import types
+
+import torch
+
+
+# This function should correspond to the enums present in c10/core/QEngine.h
+def _get_qengine_id(qengine: str) -> int:
+    if qengine == "none" or qengine == "" or qengine is None:
+        ret = 0
+    elif qengine == "fbgemm":
+        ret = 1
+    elif qengine == "qnnpack":
+        ret = 2
+    elif qengine == "onednn":
+        ret = 3
+    elif qengine == "x86":
+        ret = 4
+    else:
+        ret = -1
+        raise RuntimeError(f"{qengine} is not a valid value for quantized engine")
+    return ret
+
+
+# This function should correspond to the enums present in c10/core/QEngine.h
+def _get_qengine_str(qengine: int) -> str:
+    all_engines = {0: "none", 1: "fbgemm", 2: "qnnpack", 3: "onednn", 4: "x86"}
+    return all_engines.get(qengine, "*undefined")
+
+
+class _QEngineProp:
+    def __get__(self, obj, objtype) -> str:
+        return _get_qengine_str(torch._C._get_qengine())
+
+    def __set__(self, obj, val: str) -> None:
+        torch._C._set_qengine(_get_qengine_id(val))
+
+
+class _SupportedQEnginesProp:
+    def __get__(self, obj, objtype) -> list[str]:
+        qengines = torch._C._supported_qengines()
+        return [_get_qengine_str(qe) for qe in qengines]
+
+    def __set__(self, obj, val) -> None:
+        raise RuntimeError("Assignment not supported")
+
+
+class QuantizedEngine(types.ModuleType):
+    def __init__(self, m, name):
+        super().__init__(name)
+        self.m = m
+
+    def __getattr__(self, attr):
+        return self.m.__getattribute__(attr)
+
+    engine = _QEngineProp()
+    supported_engines = _SupportedQEnginesProp()
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__)
+engine: str
+supported_engines: list[str]
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2cb73848ed5b94715c9d668a5242edee222f7e12
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/backends/quantized/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/xeon/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/xeon/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py b/.venv/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dcff85a8ec43d8b604c1bc97c71056b192a324b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/xeon/run_cpu.py
@@ -0,0 +1,942 @@
+# mypy: allow-untyped-defs
+"""
+This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable Processors with optimal configurations.
+
+Single instance inference, multi-instance inference are enabled.
+
+Note: term "instance" here doesn't refer to a cloud instance. This script is executed as a single process. It invokes
+multiple "instances" which are formed from multiple threads for each. "instance" is kind of group of threads in this
+context.
+
+Illustrated as below:
+
+::
+
+    +-----------------------------+----------------------+-------+
+    |            process          |        thread        | core  |
+    +=============================+======================+=======+
+    | torch.backends.xeon.run_cpu | instance 0: thread 0 |   0   |
+    |                             |             thread 1 |   1   |
+    |                             +----------------------+-------+
+    |                             | instance 1: thread 0 |   2   |
+    |                             |             thread 1 |   3   |
+    |                             +----------------------+-------+
+    |                             | ...                  |  ...  |
+    |                             +----------------------+-------+
+    |                             | instance N: thread 0 |   M   |
+    |                             |             thread 1 |  M+1  |
+    +-----------------------------+----------------------+-------+
+
+To get the peak performance on Intel(R) Xeon(R) Scalable Processors, the script optimizes the configuration of thread and memory
+management. For thread management, the script configures thread affinity and the preload of Intel OMP library.
+For memory management, it configures NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc).
+
+Environment variables that will be set by this script:
+
++------------------+-------------------------------------------------------------------------------------------------+
+| Environ Variable |                                             Value                                               |
++==================+=================================================================================================+
+|    LD_PRELOAD    | Depending on knobs you set, /libiomp5.so, /libjemalloc.so, /libtcmalloc.so might |
+|                  | be appended to LD_PRELOAD.                                                                      |
++------------------+-------------------------------------------------------------------------------------------------+
+|   KMP_AFFINITY   | If libiomp5.so is preloaded, KMP_AFFINITY could be set to "granularity=fine,compact,1,0".       |
++------------------+-------------------------------------------------------------------------------------------------+
+|   KMP_BLOCKTIME  | If libiomp5.so is preloaded, KMP_BLOCKTIME is set to "1".                                       |
++------------------+-------------------------------------------------------------------------------------------------+
+|  OMP_NUM_THREADS | value of ncores_per_instance                                                                    |
++------------------+-------------------------------------------------------------------------------------------------+
+|    MALLOC_CONF   | If libjemalloc.so is preloaded, MALLOC_CONF will be set to                                      |
+|                  | "oversize_threshold:1,background_thread:true,metadata_thp:auto".                                |
++------------------+-------------------------------------------------------------------------------------------------+
+
+*Note*: This script respects environment variables set preliminarily. I.e. If you set the environment variables
+mentioned above before running the script, the script will not overwrite the values in the script.
+
+How to use this module:
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Single instance inference
+-------------------------
+
+1. Run single-instance inference on a single node with all CPU nodes.
+
+::
+
+   python -m torch.backends.xeon.run_cpu --throughput-mode script.py args
+
+2. Run single-instance inference on a single CPU node.
+
+::
+
+   python -m torch.backends.xeon.run_cpu --node-id 1 script.py args
+
+Multi-instance inference
+------------------------
+
+1. Multi-instance
+   By default this tool runs one process per node. If you want to set the instance numbers and core per instance,
+   --ninstances and  --ncores-per-instance should be set.
+
+::
+
+   python -m torch.backends.xeon.run_cpu -- python_script args
+
+   eg: on an Intel(R) Xeon(R) Scalable Processor with 14 instance, 4 cores per instance
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 14 --ncores-per-instance 4 python_script args
+
+2. Run single-instance inference among multiple instances.
+   By default, runs all ninstances. If you want to independently run a single instance among ninstances, specify rank.
+
+   eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 0-27)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 0 python_script args
+
+   eg: run 1st instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance (i.e., numactl -C 28-55)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --ninstances 2 --rank 1 python_script args
+
+   eg: run 0th instance on an Intel(R) Xeon(R) Scalable Processor with 2 instance, 2 cores per instance,
+   first four cores (i.e., numactl -C 0-1)
+
+::
+
+   python -m torch.backends.xeon.run_cpu --core-list "0, 1, 2, 3" --ninstances 2 --ncores-per-instance 2
+   --rank 0 python_script args
+
+3. To look up what optional arguments this module offers:
+
+::
+
+    python -m torch.backends.xeon.run_cpu --help
+
+Memory allocator
+----------------
+
+"--enable-tcmalloc" and "--enable-jemalloc" can be used to enable different memory allcator.
+
+"""
+
+import glob
+import logging
+import os
+import platform
+import re
+import subprocess
+import sys
+from argparse import ArgumentParser, RawTextHelpFormatter, REMAINDER
+from os.path import expanduser
+
+from torch.distributed.elastic.multiprocessing import (
+    DefaultLogsSpecs as _DefaultLogsSpecs,
+    start_processes,
+    Std,
+)
+
+
+format_str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+logging.basicConfig(level=logging.INFO, format=format_str)
+logger = logging.getLogger(__name__)
+
+
+class _CPUinfo:
+    """Get CPU information, such as cores list and NUMA information."""
+
+    def __init__(self, test_input=""):
+        self.cpuinfo = []
+        if platform.system() in ["Windows", "Darwin"]:
+            raise RuntimeError(f"{platform.system()} is not supported!!!")
+        elif platform.system() == "Linux":
+            # Sample output of: `lscpu --parse=CPU,Core,Socket,Node`
+            #
+            # # The following is the parsable format, which can be fed to other
+            # # programs. Each different item in every column has an unique ID
+            # # starting from zero.
+            # # CPU,Core,Socket,Node
+            # 0,0,0,0
+            # 1,1,0,0
+            # ...
+            if test_input == "":
+                lscpu_cmd = ["lscpu", "--parse=CPU,Core,Socket,Node"]
+                lscpu_info = subprocess.check_output(
+                    lscpu_cmd, universal_newlines=True
+                ).split("\n")
+            else:
+                lscpu_info = test_input.split("\n")
+
+            # Get information about  cpu, core, socket and node
+            for line in lscpu_info:
+                pattern = r"^([\d]+,[\d]+,[\d]+,[\d]?)"
+                regex_out = re.search(pattern, line)
+                if regex_out:
+                    self.cpuinfo.append(regex_out.group(1).strip().split(","))
+
+            # physical cores := core column in lscpu output
+            #  logical cores :=  cPU column in lscpu output
+            self.node_nums = int(max(line[3] for line in self.cpuinfo)) + 1
+            self.node_physical_cores: list[list[int]] = []  # node_id is index
+            self.node_logical_cores: list[list[int]] = []  # node_id is index
+            self.physical_core_node_map = {}  # physical core to numa node id
+            self.logical_core_node_map = {}  # logical core to numa node id
+
+            for node_id in range(self.node_nums):
+                cur_node_physical_core = []
+                cur_node_logical_core = []
+                for cpuinfo in self.cpuinfo:
+                    nid = cpuinfo[3] if cpuinfo[3] != "" else "0"
+                    if node_id == int(nid):
+                        if int(cpuinfo[1]) not in cur_node_physical_core:
+                            cur_node_physical_core.append(int(cpuinfo[1]))
+                            self.physical_core_node_map[int(cpuinfo[1])] = int(node_id)
+                        cur_node_logical_core.append(int(cpuinfo[0]))
+                        self.logical_core_node_map[int(cpuinfo[0])] = int(node_id)
+                self.node_physical_cores.append(cur_node_physical_core)
+                self.node_logical_cores.append(cur_node_logical_core)
+
+    def _physical_core_nums(self):
+        return len(self.node_physical_cores) * len(self.node_physical_cores[0])
+
+    def _logical_core_nums(self):
+        return len(self.node_logical_cores) * len(self.node_logical_cores[0])
+
+    def get_node_physical_cores(self, node_id):
+        if node_id < 0 or node_id > self.node_nums - 1:
+            raise ValueError(
+                f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}"
+            )
+        return self.node_physical_cores[node_id]
+
+    def get_node_logical_cores(self, node_id):
+        if node_id < 0 or node_id > self.node_nums - 1:
+            raise ValueError(
+                f"Invalid node id: {node_id}. Valid node ids: {list(range(len(self.node_physical_cores)))}"
+            )
+        return self.node_logical_cores[node_id]
+
+    def get_all_physical_cores(self):
+        all_cores = []
+        for cores in self.node_physical_cores:
+            all_cores.extend(cores)
+        return all_cores
+
+    def get_all_logical_cores(self):
+        all_cores = []
+        for cores in self.node_logical_cores:
+            all_cores.extend(cores)
+        return all_cores
+
+    def numa_aware_check(self, core_list):
+        """
+        Check whether all cores in core_list are in the same NUMA node.
+
+        Cross NUMA will reduce performance.
+        We strongly advice to not use cores on different nodes.
+        """
+        cores_numa_map = self.logical_core_node_map
+        numa_ids = []
+        for core in core_list:
+            numa_id = cores_numa_map[core]
+            if numa_id not in numa_ids:
+                numa_ids.append(numa_id)
+        if len(numa_ids) > 1:
+            logger.warning(
+                "Numa Aware: cores:%s on different NUMA nodes:%s. To avoid \
+this behavior, please use --ncores-per-instance knob to make sure number of cores is divisible by --ncores-per-\
+instance. Alternatively, please use --skip-cross-node-cores knob.",
+                str(core_list),
+                str(numa_ids),
+            )
+        if len(numa_ids) == 0:
+            raise RuntimeError(
+                "invalid number of NUMA nodes; please make sure numa_ids >= 1"
+            )
+        return numa_ids
+
+
+class _Launcher:
+    r"""Class for launcher."""
+
+    msg_lib_notfound = f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \
+or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \
+{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set."
+
+    def __init__(self) -> None:
+        self.cpuinfo = _CPUinfo()
+
+    def add_lib_preload(self, lib_type):
+        """Enable TCMalloc/JeMalloc/intel OpenMP."""
+        library_paths = []
+        if "CONDA_PREFIX" in os.environ:
+            library_paths.append(f"{os.environ['CONDA_PREFIX']}/lib")
+        if "VIRTUAL_ENV" in os.environ:
+            library_paths.append(f"{os.environ['VIRTUAL_ENV']}/lib")
+
+        library_paths += [
+            f"{expanduser('~')}/.local/lib",
+            "/usr/local/lib",
+            "/usr/local/lib64",
+            "/usr/lib",
+            "/usr/lib64",
+        ]
+
+        lib_find = False
+        lib_set = False
+        for item in os.getenv("LD_PRELOAD", "").split(":"):
+            if item.endswith(f"lib{lib_type}.so"):
+                lib_set = True
+                break
+        if not lib_set:
+            for lib_path in library_paths:
+                library_file = os.path.join(lib_path, f"lib{lib_type}.so")
+                matches = glob.glob(library_file)
+                if len(matches) > 0:
+                    ld_preloads = [f"{matches[0]}", os.getenv("LD_PRELOAD", "")]
+                    os.environ["LD_PRELOAD"] = os.pathsep.join(
+                        [p.strip(os.pathsep) for p in ld_preloads if p]
+                    )
+                    lib_find = True
+                    break
+        return lib_set or lib_find
+
+    def is_numactl_available(self):
+        numactl_available = False
+        try:
+            cmd = ["numactl", "-C", "0", "-m", "0", "hostname"]
+            r = subprocess.run(
+                cmd,
+                env=os.environ,
+                stdout=subprocess.DEVNULL,
+                stderr=subprocess.DEVNULL,
+                check=False,
+            )
+            if r.returncode == 0:
+                numactl_available = True
+        except Exception:
+            pass
+        return numactl_available
+
+    def set_memory_allocator(
+        self, enable_tcmalloc=True, enable_jemalloc=False, use_default_allocator=False
+    ):
+        """
+        Enable TCMalloc/JeMalloc with LD_PRELOAD and set configuration for JeMalloc.
+
+        By default, PTMalloc will be used for PyTorch, but TCMalloc and JeMalloc can get better
+        memory reuse and reduce page fault to improve performance.
+        """
+        if enable_tcmalloc and enable_jemalloc:
+            raise RuntimeError(
+                "Unable to enable TCMalloc and JEMalloc at the same time."
+            )
+
+        if enable_tcmalloc:
+            find_tc = self.add_lib_preload(lib_type="tcmalloc")
+            if not find_tc:
+                msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge gperftools" to install {{0}}'
+                logger.warning(msg.format("TCmalloc", "tcmalloc"))  # noqa: G001
+            else:
+                logger.info("Use TCMalloc memory allocator")
+
+        elif enable_jemalloc:
+            find_je = self.add_lib_preload(lib_type="jemalloc")
+            if not find_je:
+                msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge jemalloc" to install {{0}}'
+                logger.warning(msg.format("Jemalloc", "jemalloc"))  # noqa: G001
+            else:
+                logger.info("Use JeMalloc memory allocator")
+                self.set_env(
+                    "MALLOC_CONF",
+                    "oversize_threshold:1,background_thread:true,metadata_thp:auto",
+                )
+
+        elif use_default_allocator:
+            pass
+
+        else:
+            find_tc = self.add_lib_preload(lib_type="tcmalloc")
+            if find_tc:
+                logger.info("Use TCMalloc memory allocator")
+                return
+            find_je = self.add_lib_preload(lib_type="jemalloc")
+            if find_je:
+                logger.info("Use JeMalloc memory allocator")
+                return
+            logger.warning(
+                """Neither TCMalloc nor JeMalloc is found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib
+                            or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or
+                           %s/.local/lib/ so the LD_PRELOAD environment variable will not be set.
+                           This may drop the performance""",
+                expanduser("~"),
+            )
+
+    def log_env_var(self, env_var_name=""):
+        if env_var_name in os.environ:
+            logger.info("%s=%s", env_var_name, os.environ[env_var_name])
+
+    def set_env(self, env_name, env_value):
+        if not env_value:
+            logger.warning("%s is None", env_name)
+        if env_name not in os.environ:
+            os.environ[env_name] = env_value
+        elif os.environ[env_name] != env_value:
+            logger.warning(
+                "Overriding value with the one set in environment variable: %s. \
+Value applied: %s. Value ignored: %s",
+                env_name,
+                os.environ[env_name],
+                env_value,
+            )
+        self.log_env_var(env_name)
+
+    # set_kmp_affinity is used to control whether to set KMP_AFFINITY or not.
+    # In scenario that use all cores on all nodes, including logical cores, setting KMP_AFFINITY disables logical cores.
+    # In this case, KMP_AFFINITY should not be set.
+    def set_multi_thread_and_allocator(
+        self,
+        ncores_per_instance,
+        disable_iomp=False,
+        set_kmp_affinity=True,
+        enable_tcmalloc=True,
+        enable_jemalloc=False,
+        use_default_allocator=False,
+    ):
+        """
+        Set multi-thread configuration and enable Intel openMP and TCMalloc/JeMalloc.
+
+        By default, GNU openMP and PTMalloc are used in PyTorch. but Intel openMP and TCMalloc/JeMalloc are better alternatives
+        to get performance benefit.
+        """
+        self.set_memory_allocator(
+            enable_tcmalloc, enable_jemalloc, use_default_allocator
+        )
+        self.set_env("OMP_NUM_THREADS", str(ncores_per_instance))
+        if not disable_iomp:
+            find_iomp = self.add_lib_preload(lib_type="iomp5")
+            if not find_iomp:
+                msg = f'{self.msg_lib_notfound} you can use "conda install mkl" to install {{0}}'
+                logger.warning(msg.format("iomp", "iomp5"))  # noqa: G001
+            else:
+                logger.info("Using Intel OpenMP")
+                if set_kmp_affinity:
+                    self.set_env("KMP_AFFINITY", "granularity=fine,compact,1,0")
+                self.set_env("KMP_BLOCKTIME", "1")
+        self.log_env_var("LD_PRELOAD")
+
+    r"""
+     Launcher for single instance and multi-instance
+     """
+
+    def launch(self, args):
+        cores = []
+        set_kmp_affinity = True
+        enable_taskset = False
+        if args.core_list:  # user specify what cores will be used by params
+            cores = [int(x) for x in args.core_list.split(",")]
+            if args.ncores_per_instance == -1:
+                raise RuntimeError(
+                    'please specify the "--ncores-per-instance" if you have pass the --core-list params'
+                )
+            elif (
+                args.ninstances > 1
+                and args.ncores_per_instance * args.ninstances < len(cores)
+            ):
+                logger.warning(
+                    "only first %s cores will be used, \
+but you specify %s cores in core_list",
+                    args.ncores_per_instance * args.ninstances,
+                    len(cores),
+                )
+            else:
+                args.ninstances = len(cores) // args.ncores_per_instance
+
+        else:
+            if args.use_logical_core:
+                if args.node_id != -1:
+                    cores = self.cpuinfo.get_node_logical_cores(args.node_id)
+                else:
+                    cores = self.cpuinfo.get_all_logical_cores()
+                    # When using all cores on all nodes, including logical cores,
+                    # setting KMP_AFFINITY disables logical cores. Thus, KMP_AFFINITY should not be set.
+                    set_kmp_affinity = False
+            else:
+                if args.node_id != -1:
+                    cores = self.cpuinfo.get_node_physical_cores(args.node_id)
+                else:
+                    cores = self.cpuinfo.get_all_physical_cores()
+            if (
+                not args.multi_instance
+                and args.ninstances == -1
+                and args.ncores_per_instance == -1
+            ):
+                args.ninstances = 1
+                args.ncores_per_instance = len(cores)
+            elif (
+                args.multi_instance
+                and args.ninstances == -1
+                and args.ncores_per_instance == -1
+            ):
+                args.throughput_mode = True
+            elif args.ncores_per_instance == -1 and args.ninstances != -1:
+                if args.ninstances > len(cores):
+                    raise RuntimeError(
+                        f"there are {len(cores)} total cores but you specify {args.ninstances} ninstances; \
+please make sure ninstances <= total_cores)"
+                    )
+                else:
+                    args.ncores_per_instance = len(cores) // args.ninstances
+            elif args.ncores_per_instance != -1 and args.ninstances == -1:
+                if not args.skip_cross_node_cores:
+                    args.ninstances = len(cores) // args.ncores_per_instance
+                else:
+                    ncore_per_node = len(self.cpuinfo.node_physical_cores[0])
+                    num_leftover_cores = ncore_per_node % args.ncores_per_instance
+                    if args.ncores_per_instance > ncore_per_node:
+                        # too many ncores_per_instance to skip cross-node cores
+                        logger.warning(
+                            "there are %s core(s) per socket, but you specify %s ncores_per_instance and \
+skip_cross_node_cores. Please make sure --ncores-per-instance < core(s) per \
+socket",
+                            ncore_per_node,
+                            args.ncores_per_instance,
+                        )
+                        sys.exit(-1)
+                    elif num_leftover_cores == 0:
+                        # aren't any cross-node cores
+                        logger.info(
+                            "--skip-cross-node-cores is set, but there are no cross-node cores."
+                        )
+                        args.ninstances = len(cores) // args.ncores_per_instance
+                    else:
+                        # skip cross-node cores
+                        if args.ninstances != -1:
+                            logger.warning(
+                                "--skip-cross-node-cores is exclusive to --ninstances. --ninstances \
+won't take effect even if it is set explicitly."
+                            )
+
+                        i = 1
+                        leftover_cores = set()
+                        while ncore_per_node * i <= len(cores):
+                            leftover_cores.update(
+                                cores[
+                                    ncore_per_node * i
+                                    - num_leftover_cores : ncore_per_node * i
+                                ]
+                            )
+                            i += 1
+                        cores = list(set(cores) - leftover_cores)
+                        assert len(cores) % args.ncores_per_instance == 0
+                        args.ninstances = len(cores) // args.ncores_per_instance
+            else:
+                if args.ninstances * args.ncores_per_instance > len(cores):
+                    raise RuntimeError(
+                        "Please make sure ninstances * ncores_per_instance <= total_cores"
+                    )
+            if args.latency_mode:
+                logger.warning(
+                    "--latency-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \
+--use-logical-core. They won't take effect even they are set explicitly."
+                )
+                args.ncores_per_instance = 4
+                cores = self.cpuinfo.get_all_physical_cores()
+                args.ninstances = len(cores) // args.ncores_per_instance
+
+            if args.throughput_mode:
+                logger.warning(
+                    "--throughput-mode is exclusive to --ninstances, --ncores-per-instance, --node-id and \
+--use-logical-core. They won't take effect even they are set explicitly."
+                )
+                args.ninstances = self.cpuinfo.node_nums
+                cores = self.cpuinfo.get_all_physical_cores()
+                args.ncores_per_instance = len(cores) // args.ninstances
+
+        if args.ninstances > 1 and args.rank != -1:
+            logger.info(
+                "assigning %s cores for instance %s",
+                args.ncores_per_instance,
+                args.rank,
+            )
+
+        if not args.disable_numactl:
+            numactl_available = self.is_numactl_available()
+            if not numactl_available:
+                if not args.disable_taskset:
+                    logger.warning(
+                        "Core binding with numactl is not available. Disabling numactl and using taskset instead. \
+                    This may affect performance in multi-socket system; please use numactl if memory binding is needed."
+                    )
+                    args.disable_numactl = True
+                    enable_taskset = True
+                else:
+                    logger.warning(
+                        "Core binding with numactl is not available, and --disable_taskset is set. \
+                    Please unset --disable_taskset to use taskset instead of numactl."
+                    )
+                    sys.exit(-1)
+
+        if not args.disable_taskset:
+            enable_taskset = True
+
+        self.set_multi_thread_and_allocator(
+            args.ncores_per_instance,
+            args.disable_iomp,
+            set_kmp_affinity,
+            args.enable_tcmalloc,
+            args.enable_jemalloc,
+            args.use_default_allocator,
+        )
+        entrypoint = ""
+        launch_args = {}
+        launch_envs: dict[int, dict] = {}
+        launch_tee = {}
+        # check whether is launched from torchrun with --nproc-per-node 
+        local_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
+        local_rank = int(os.environ.get("LOCAL_RANK", 0))
+        for i in range(args.ninstances):
+            cmd = []
+            cur_process_cores = ""
+            if not args.disable_numactl or enable_taskset:
+                if not args.disable_numactl:
+                    cmd = ["numactl"]
+                elif enable_taskset:
+                    cmd = ["taskset"]
+                cores = sorted(cores)
+                if (
+                    args.rank == -1
+                ):  # sequentially assign ncores_per_instance to ninstances
+                    core_list = cores[
+                        i
+                        * args.ncores_per_instance : (i + 1)
+                        * args.ncores_per_instance
+                    ]
+                else:  # assign ncores_per_instance from rank
+                    core_list = cores[
+                        args.rank
+                        * args.ncores_per_instance : (args.rank + 1)
+                        * args.ncores_per_instance
+                    ]
+
+                core_ranges: list[dict] = []
+                if local_size > 1:
+                    total_num_cores = len(core_list)
+                    cores_per_rank = total_num_cores // local_size
+                    assert (
+                        cores_per_rank >= 1
+                    ), "At least one core needs to be assigned to each rank"
+                    core_list = core_list[
+                        cores_per_rank * local_rank : cores_per_rank * (local_rank + 1)
+                    ]
+                for core in core_list:
+                    if len(core_ranges) == 0:
+                        range_elem = {"start": core, "end": core}
+                        core_ranges.append(range_elem)
+                    else:
+                        if core - core_ranges[-1]["end"] == 1:
+                            core_ranges[-1]["end"] = core
+                        else:
+                            range_elem = {"start": core, "end": core}
+                            core_ranges.append(range_elem)
+                for r in core_ranges:
+                    cur_process_cores = f"{cur_process_cores}{r['start']}-{r['end']},"
+                cur_process_cores = cur_process_cores[:-1]
+                if not args.disable_numactl:
+                    numa_params = f"-C {cur_process_cores} "
+                    numa_ids = ",".join(
+                        [
+                            str(numa_id)
+                            for numa_id in self.cpuinfo.numa_aware_check(core_list)
+                        ]
+                    )
+                    numa_params += f"-m {numa_ids}"
+                    cmd.extend(numa_params.split())
+                elif enable_taskset:
+                    taskset_params = f"-c {cur_process_cores} "
+                    cmd.extend(taskset_params.split())
+            with_python = not args.no_python
+            if with_python:
+                cmd.append(sys.executable)
+                cmd.append("-u")
+            if args.module:
+                cmd.append("-m")
+            cmd.append(args.program)
+            cmd.extend(args.program_args)
+            cmd_s = " ".join(cmd)
+            logger.info(cmd_s)
+            if entrypoint == "":
+                entrypoint = cmd[0]
+            del cmd[0]
+            launch_args[i] = tuple(cmd)
+            launch_envs[i] = {}
+            launch_tee[i] = Std.ALL
+
+            if args.rank != -1:  # launches single instance, rank, only
+                break
+
+        ctx = start_processes(
+            name=args.log_file_prefix,
+            entrypoint=entrypoint,
+            args=launch_args,
+            envs=launch_envs,
+            logs_specs=_DefaultLogsSpecs(log_dir=args.log_path, tee=launch_tee),
+        )
+        ctx.wait()
+
+
+def _add_memory_allocator_params(parser):
+    group = parser.add_argument_group("Memory Allocator Parameters")
+    # allocator control
+    group.add_argument(
+        "--enable-tcmalloc",
+        "--enable_tcmalloc",
+        action="store_true",
+        default=False,
+        help="Enable tcmalloc allocator",
+    )
+    group.add_argument(
+        "--enable-jemalloc",
+        "--enable_jemalloc",
+        action="store_true",
+        default=False,
+        help="Enable jemalloc allocator",
+    )
+    group.add_argument(
+        "--use-default-allocator",
+        "--use_default_allocator",
+        action="store_true",
+        default=False,
+        help="Use default memory allocator",
+    )
+
+
+def _add_multi_instance_params(parser):
+    group = parser.add_argument_group("Multi-instance Parameters")
+    # multi-instance control
+    group.add_argument(
+        "--ncores-per-instance",
+        "--ncores_per_instance",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="Cores per instance",
+    )
+    group.add_argument(
+        "--ninstances",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="For multi-instance, you should give the cores number you used for per instance.",
+    )
+    group.add_argument(
+        "--skip-cross-node-cores",
+        "--skip_cross_node_cores",
+        action="store_true",
+        default=False,
+        help="If specified --ncores-per-instance, skips cross-node cores.",
+    )
+    group.add_argument(
+        "--rank",
+        metavar="\b",
+        default="-1",
+        type=int,
+        help="Specify instance index to assign ncores_per_instance for rank; \
+otherwise ncores_per_instance will be assigned sequentially to ninstances. Please refer to \
+https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md",
+    )
+    group.add_argument(
+        "--latency-mode",
+        "--latency_mode",
+        action="store_true",
+        default=False,
+        help="By default 4 core per instance and use all physical cores",
+    )
+    group.add_argument(
+        "--throughput-mode",
+        "--throughput_mode",
+        action="store_true",
+        default=False,
+        help="By default one instance per node and use all physical cores",
+    )
+    group.add_argument(
+        "--node-id",
+        "--node_id",
+        metavar="\b",
+        default=-1,
+        type=int,
+        help="node id for multi-instance, by default all nodes will be used",
+    )
+    group.add_argument(
+        "--use-logical-core",
+        "--use_logical_core",
+        action="store_true",
+        default=False,
+        help="Whether only use physical cores",
+    )
+    group.add_argument(
+        "--disable-numactl",
+        "--disable_numactl",
+        action="store_true",
+        default=False,
+        help="Disable numactl",
+    )
+    group.add_argument(
+        "--disable-taskset",
+        "--disable_taskset",
+        action="store_true",
+        default=False,
+        help="Disable taskset",
+    )
+    group.add_argument(
+        "--core-list",
+        "--core_list",
+        metavar="\b",
+        default=None,
+        type=str,
+        help='Specify the core list as "core_id, core_id, ....", otherwise, all the cores will be used.',
+    )
+    group.add_argument(
+        "--log-path",
+        "--log_path",
+        metavar="\b",
+        default="",
+        type=str,
+        help="The log file directory. Default path is "
+        ", which means disable logging to files.",
+    )
+    group.add_argument(
+        "--log-file-prefix",
+        "--log_file_prefix",
+        metavar="\b",
+        default="run",
+        type=str,
+        help="log file prefix",
+    )
+
+
+def _add_kmp_iomp_params(parser):
+    group = parser.add_argument_group("IOMP Parameters")
+    group.add_argument(
+        "--disable-iomp",
+        "--disable_iomp",
+        action="store_true",
+        default=False,
+        help="By default, we use Intel OpenMP and libiomp5.so will be add to LD_PRELOAD",
+    )
+
+
+def create_args(parser=None):
+    """
+    Parse the command line options.
+
+    @retval ArgumentParser
+    """
+    parser.add_argument(
+        "--multi-instance",
+        "--multi_instance",
+        action="store_true",
+        default=False,
+        help="Enable multi-instance, by default one instance per node",
+    )
+
+    parser.add_argument(
+        "-m",
+        "--module",
+        default=False,
+        action="store_true",
+        help="Changes each process to interpret the launch script "
+        "as a python module, executing with the same behavior as"
+        '"python -m".',
+    )
+
+    parser.add_argument(
+        "--no-python",
+        "--no_python",
+        default=False,
+        action="store_true",
+        help='Do not prepend the --program script with "python" - just exec '
+        "it directly. Useful when the script is not a Python script.",
+    )
+
+    _add_memory_allocator_params(parser)
+    _add_kmp_iomp_params(parser)
+
+    _add_multi_instance_params(parser)
+    # positional
+    parser.add_argument(
+        "program",
+        type=str,
+        help="The full path to the program/script to be launched. "
+        "followed by all the arguments for the script",
+    )
+
+    # rest from the training program
+    parser.add_argument("program_args", nargs=REMAINDER)
+
+
+def main(args):
+    env_before = set(os.environ.keys())
+    if platform.system() in ["Windows", "Darwin"]:
+        raise RuntimeError(f"{platform.system()} is not supported!!!")
+
+    if args.log_path:
+        os.makedirs(args.log_path, exist_ok=True)
+    else:
+        args.log_path = os.devnull
+
+    if args.latency_mode and args.throughput_mode:
+        raise RuntimeError(
+            "Either args.latency_mode or args.throughput_mode should be set"
+        )
+
+    if not args.no_python and not args.program.endswith(".py"):
+        raise RuntimeError(
+            'For non Python script, you should use "--no-python" parameter.'
+        )
+
+    # Verify LD_PRELOAD
+    if "LD_PRELOAD" in os.environ:
+        lst_valid = []
+        tmp_ldpreload = os.environ["LD_PRELOAD"]
+        for item in tmp_ldpreload.split(":"):
+            matches = glob.glob(item)
+            if len(matches) > 0:
+                lst_valid.append(item)
+            else:
+                logger.warning("%s doesn't exist. Removing it from LD_PRELOAD.", item)
+        if len(lst_valid) > 0:
+            os.environ["LD_PRELOAD"] = ":".join(lst_valid)
+        else:
+            os.environ["LD_PRELOAD"] = ""
+
+    launcher = _Launcher()
+    launcher.launch(args)
+    for x in sorted(set(os.environ.keys()) - env_before):
+        logger.debug("%s=%s", x, os.environ[x])
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser(
+        description="This is a script for launching PyTorch inference on Intel(R) Xeon(R) Scalable "
+        "Processors with optimal configurations. Single instance inference, "
+        "multi-instance inference are enable. To get the peak performance on Intel(R) "
+        "Xeon(R) Scalable Processors, the script optimizes the configuration "
+        "of thread and memory management. For thread management, the script configures thread "
+        "affinity and the preload of Intel OMP library. For memory management, it configures "
+        "NUMA binding and preload optimized memory allocation library (e.g. tcmalloc, jemalloc) "
+        "\n################################# Basic usage ############################# \n"
+        "\n 1. single instance\n"
+        "\n   >>> python -m torch.backends.xeon.run_cpu python_script args \n"
+        "\n2. multi-instance \n"
+        "\n   >>> python -m torch.backends.xeon.run_cpu --ninstances xxx "
+        "--ncores-per-instance xx python_script args\n"
+        "\n############################################################################# \n",
+        formatter_class=RawTextHelpFormatter,
+    )
+    create_args(parser)
+    args = parser.parse_args()
+    main(args)
diff --git a/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py b/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e69876927d01878a9d1cb836d72fd14adf95e9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/backends/xnnpack/__init__.py
@@ -0,0 +1,29 @@
+# mypy: allow-untyped-defs
+import sys
+import types
+
+import torch
+
+
+class _XNNPACKEnabled:
+    def __get__(self, obj, objtype):
+        return torch._C._is_xnnpack_enabled()
+
+    def __set__(self, obj, val):
+        raise RuntimeError("Assignment not supported")
+
+
+class XNNPACKEngine(types.ModuleType):
+    def __init__(self, m, name):
+        super().__init__(name)
+        self.m = m
+
+    def __getattr__(self, attr):
+        return self.m.__getattribute__(attr)
+
+    enabled = _XNNPACKEnabled()
+
+
+# This is the sys.modules replacement trick, see
+# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
+sys.modules[__name__] = XNNPACKEngine(sys.modules[__name__], __name__)
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96534e00ac598b70bab73668051812e09da7ee26
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6de056e2cc5e2c23f21596f26e9e247715225aac
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e21d9145c97b47ecfa956daa2cd800f542d81306
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_pin_memory_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f3cd5b75f4875bf1d5fc8d311d4e958b39f55b9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/gds.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/gds.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7ec04d62dccb969e18e26001b1c3d12e50aaaf2
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/gds.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/graphs.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/graphs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37e6d0c1cb1ce5b35e85ee99bb8b3602bb9e1203
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/graphs.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/jiterator.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/jiterator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cd2c435225ddf7e2ab741f8c20dd713d9fd1e15
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/jiterator.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/memory.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/memory.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0206cfb9ec653e0e96cb36da3591efa91267c878
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/memory.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/nccl.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/nccl.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f1073a6b1071315b61e7594e2093db1a3c7de9c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/nccl.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/nvtx.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/nvtx.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d63f176ce4428afbb20b15dddca8e4d9a258207b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/nvtx.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/profiler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/profiler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff48eb200d634cee7e6ad1179a534130db447134
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/profiler.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/random.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/random.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1ebcfdb1f44f9d0c46287d0d83102ebb6542d4f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/random.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/sparse.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/sparse.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1eead296c3ee7748ce763f6781189b35b132ba74
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/sparse.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/streams.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/streams.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00e30d7fa6135da327d57256c7bdbe733898563b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/streams.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/tunable.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/tunable.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..556e3a9d7766e5bf4934bc6e4ddcaaef20b90908
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/__pycache__/tunable.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/amp/__init__.py b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..74520496372f549356a26c2c70b637b9e3ea4d4d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__init__.py
@@ -0,0 +1,12 @@
+from .autocast_mode import autocast, custom_bwd, custom_fwd
+from .common import amp_definitely_not_available
+from .grad_scaler import GradScaler
+
+
+__all__ = [
+    "amp_definitely_not_available",
+    "autocast",
+    "custom_bwd",
+    "custom_fwd",
+    "GradScaler",
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..135cce5bc59a2ea43f99a513420c429ad9898067
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9296f8625f9d333ce96abda63bbd645880e9f27a
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/autocast_mode.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/common.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bee842832808994147aad72ae1d4c55f2d6ef35b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/common.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33062bfbb6012be34c0ce03e42d382771522400d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/cuda/amp/__pycache__/grad_scaler.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/amp/autocast_mode.py b/.venv/lib/python3.12/site-packages/torch/cuda/amp/autocast_mode.py
new file mode 100644
index 0000000000000000000000000000000000000000..d52ff7cf672bbdaaf4ac31a06a21cac8aabcfe0f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/cuda/amp/autocast_mode.py
@@ -0,0 +1,90 @@
+# mypy: allow-untyped-defs
+import functools
+from typing import Any
+from typing_extensions import deprecated
+
+import torch
+
+
+__all__ = ["autocast", "custom_fwd", "custom_bwd"]
+
+
+class autocast(torch.amp.autocast_mode.autocast):
+    r"""See :class:`torch.autocast`.
+
+    ``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead.
+    """
+
+    @deprecated(
+        "`torch.cuda.amp.autocast(args...)` is deprecated. "
+        "Please use `torch.amp.autocast('cuda', args...)` instead.",
+        category=FutureWarning,
+    )
+    def __init__(
+        self,
+        enabled: bool = True,
+        dtype: torch.dtype = torch.float16,
+        cache_enabled: bool = True,
+    ):
+        if torch._jit_internal.is_scripting():
+            self._enabled = enabled
+            self.device = "cuda"
+            self.fast_dtype = dtype
+            return
+        super().__init__(
+            "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
+        )
+
+    def __enter__(self):
+        if torch._jit_internal.is_scripting():
+            return self
+        return super().__enter__()
+
+    # TODO: discuss a unified TorchScript-friendly API for autocast
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):  # type: ignore[override]
+        if torch._jit_internal.is_scripting():
+            return
+        return super().__exit__(exc_type, exc_val, exc_tb)
+
+    def __call__(self, func):
+        if torch._jit_internal.is_scripting():
+            return func
+        return super().__call__(func)
+
+
+# Preserved only for BC reasons
+@deprecated(
+    "`torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. "
+    "Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.",
+    category=FutureWarning,
+)
+def _cast(value, dtype):
+    return torch.amp.autocast_mode._cast(value, "cuda", dtype)
+
+
+@deprecated(
+    "`torch.cuda.amp.custom_fwd(args...)` is deprecated. "
+    "Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.",
+    category=FutureWarning,
+)
+def custom_fwd(fwd=None, *, cast_inputs=None):
+    """
+    ``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use
+    ``torch.amp.custom_fwd(args..., device_type='cuda')`` instead.
+    """
+    return functools.partial(torch.amp.custom_fwd, device_type="cuda")(
+        fwd=fwd, cast_inputs=cast_inputs
+    )
+
+
+@deprecated(
+    "`torch.cuda.amp.custom_bwd(args...)` is deprecated. "
+    "Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.",
+    category=FutureWarning,
+)
+def custom_bwd(bwd):
+    """
+    ``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use
+    ``torch.amp.custom_bwd(args..., device_type='cuda')`` instead.
+    """
+    return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd)
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/amp/common.py b/.venv/lib/python3.12/site-packages/torch/cuda/amp/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..915a9b4f4a9ca6c147abefd7c8ab1891ee5a8179
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/cuda/amp/common.py
@@ -0,0 +1,11 @@
+# mypy: allow-untyped-defs
+from importlib.util import find_spec
+
+import torch
+
+
+__all__ = ["amp_definitely_not_available"]
+
+
+def amp_definitely_not_available():
+    return not (torch.cuda.is_available() or find_spec("torch_xla"))
diff --git a/.venv/lib/python3.12/site-packages/torch/cuda/amp/grad_scaler.py b/.venv/lib/python3.12/site-packages/torch/cuda/amp/grad_scaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..62e2020073c8ed99f7295edd1aaea4c54d815f63
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/cuda/amp/grad_scaler.py
@@ -0,0 +1,38 @@
+from typing_extensions import deprecated
+
+import torch
+
+# We need to keep this unused import for BC reasons
+from torch.amp.grad_scaler import OptState  # noqa: F401
+
+
+__all__ = ["GradScaler"]
+
+
+class GradScaler(torch.amp.GradScaler):
+    r"""
+    See :class:`torch.amp.GradScaler`.
+    ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead.
+    """
+
+    @deprecated(
+        "`torch.cuda.amp.GradScaler(args...)` is deprecated. "
+        "Please use `torch.amp.GradScaler('cuda', args...)` instead.",
+        category=FutureWarning,
+    )
+    def __init__(
+        self,
+        init_scale: float = 2.0**16,
+        growth_factor: float = 2.0,
+        backoff_factor: float = 0.5,
+        growth_interval: int = 2000,
+        enabled: bool = True,
+    ) -> None:
+        super().__init__(
+            "cuda",
+            init_scale=init_scale,
+            growth_factor=growth_factor,
+            backoff_factor=backoff_factor,
+            growth_interval=growth_interval,
+            enabled=enabled,
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63cd95b7c9797315b8b877a81ab3cb1863d8acd6
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d7fa03c25fbf6b084ba125cb7da45724b89ddfd
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/bernoulli.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7402ac1520492dd6ede9cd0dc79a67ff7a8a13d3
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/beta.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5345daea2039dd4b9ac00ad442e3cd56483cc36a
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/binomial.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9750d2b5df408233bcccbf1a762397018dd999cb
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/categorical.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6272a2586fa244643f5f5377d042bb01700eeab0
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/cauchy.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..970b51700b57963c9064c64fa03e7435397bd156
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/chi2.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d7c90ccd92f17bcf92d4c19426b4e5891699750
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e59f0d0b88906fe634c3f5c662811aa31063cb5e
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/constraints.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1833fab1774cc0bbe27be6c7fa4fa33f51c669f0
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b9385ace1360714bb65b010d07c3bdfd750cd39
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/dirichlet.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13e20b485fa7b2bf5d1882aaa60556d772879e34
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/distribution.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62843675b6133677acb3c0d79682d173000a9179
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exp_family.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..876d77cd7d0adbeafcda3aada4203fa34e198460
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/exponential.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cbe30c3c46803b56c0faf57cf3a8f6c8fbc0ae5
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7397692b3cb9cd66c4b4e29e5553c58357de760e
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gamma.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e02f5e97586eef6cee286868072cb5c9e7daa6a1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb6a8810c6a7aec6c435790b90feb2c7e71f4ffc
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/geometric.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..054e8c0b427f0877380a890580434c510067930c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/gumbel.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f235d1209b43919adcccac9d66f83e0e5d581fee
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4be9017cdfaed8779b41558fa8ff32ab4b36dfa
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/half_normal.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..caa455b7a31abece59c401371a76f285a371ceda
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/independent.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a419c467882c28f0e254a11c65e18fff307390e4
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68071b64814f838f88cfa030b9fdb1fe73aed6b3
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kl.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5358f3954c42c002423f95d11220fe86ce6e8161
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59bb784b5bf97646a022f0df43f5a758806e98b5
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/laplace.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4b6c54d6617ad8241c6054f5a9fc77909eca2e5
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c615c2962a19a80f501a593e95f7104e40f8f387
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/log_normal.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f22210ecd873bbb75ee492bc23777b8b75ca6138
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2524cfac39568bfcccf2f7e020e1f93582f2ead
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42413b83f56c30a75257534630a8f4a81c348c4b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..897d4f4489e7d28c868aefac31e4c0b50a547024
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multinomial.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d23079ef8a2b06c9d770acb54ef8479c2d6c07d1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca0f884525a29383d6450aa5d739512a42870561
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8adafe044f0ed3c33ca62aa48d41e5f561aa27b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/normal.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..487c2633398603ab178aee0a1c3c53da05071a76
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50fa1b6b35363cb40c9eaa9889921bad275e1f9f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/pareto.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3973673e07f941a95382615823960c5bceed3345
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/poisson.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fca7d20d2bbed93734c68f4e70d703c83edf1728
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f38757638c6e2c68a93790b2063305ea762e2db8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cf575e679781cbed415715eda8ed7d008da0428
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/studentT.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73b4e4ab9e52f57da29bc139db39ae6688b3b8ab
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50d4499c6a8fe8341af8c785e21f9dd7aee8c26a
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/transforms.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1b3242e1470a6e1b4e8fafff1876c53756377cd
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/uniform.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f95a0aefdee2c48f6b0f48fc595837f7a35300f3
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abac355acd32940ba9913bb3d4b3ff3e59645827
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/von_mises.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..012bba78bbe7c0c7526bd36c70d72e88b0e92a0f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/weibull.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14db2d93473d9dbdc09ee4af9aad342be496fd73
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/distributions/__pycache__/wishart.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebc95ee412ae05806ef14ce9c54ebb21a951d621
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/func/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddd9c09d63ed9ba41e267e85995410d3303e32b4
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..176fd768da06d6c2e4db0a38347d9c7f6f65b3d1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_async.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3df84f7f8c8ac5d40d55240af4643ccddad03112
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_await.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d973f54d64df1850aa97666ebd8cfc9217341a29
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_builtins.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c32c52ea763fc2bbf7342e2b717b9bc4d0bd6778
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_check.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe744890430f9423b60010b3cb064f0a8cca41c7
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_dataclass_impls.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d20f2e24b0acc5214252bf6c80d153bc5e46da19
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_decomposition_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6bb4959ef5c2331ab50316d6f9fcce256dc55536
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_freeze.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41d9fead2097cbc8d8dc60922053bb73c3e87aba
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_fuser.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cccd0ef08215eeb141974579145cb5569c9967f2
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_ir_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3aae4d768b70c764ffef8dcb4cc10912eee0367
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_monkeytype_config.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c47edd90d1d60f8170836a661508b789b5d0ae86
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_recursive.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0e0e2b787ff9df262f1ab9f8152bd5354928441
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_script.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d3bd3003d42747fc3e6c1c49e7fabecc8b926a4
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_serialization.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa9e43fecf22e7a1f8d5635a201a1c0e8f1902e0
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_state.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1827a4dd2da014f9f695f92e3505c2f97ffc00c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/_trace.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93b8886b4a3c3531abcc0dd3b3423b3554c72282
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/annotations.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7773dcf4850845c3f28e20ba9143c89fc0e18eb7
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/jit/__pycache__/frontend.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/_passes/__init__.py b/.venv/lib/python3.12/site-packages/torch/jit/_passes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py b/.venv/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c410b8fbb7fd329442aa867c0e39c03cd4f15199
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/jit/_passes/_property_propagation.py
@@ -0,0 +1,46 @@
+"""
+Tools to help with tensor property propagation.
+
+This is not intended to be imported directly; please use the exposed
+functionalities in `torch.jit`.
+"""
+
+from typing import Any
+
+import torch
+from torch import TensorType
+from torch._C import Graph
+
+
+def apply_input_props_using_example(graph: Graph, example_input: list[Any]) -> None:
+    """
+    Applies properties for each tensor in the graph inputs
+    using the example supplied.
+    """
+    graph_inputs = list(graph.inputs())
+    if len(graph_inputs) == 0:
+        return
+
+    # Strip self args off for methods
+    in_0 = graph_inputs[0]
+    if isinstance(in_0.type(), torch._C.ClassType) and in_0.debugName() == "self":
+        graph_inputs = graph_inputs[1:]
+
+    if not len(graph_inputs) == len(example_input):
+        raise RuntimeError(
+            "Number of inputs in graph does not match number of inputs in the example"
+        )
+
+    for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)):
+        if example_i is None:
+            continue  # Skip the type check
+
+        if isinstance(example_i, torch.Tensor) != isinstance(
+            graph_i.type(), TensorType
+        ):
+            raise RuntimeError(
+                f"Input {i} does not match type of example", graph_i, example_i
+            )
+
+        if isinstance(example_i, torch.Tensor):
+            graph_i.setType(TensorType.create_from_tensor(example_i))  # type: ignore[arg-type]
diff --git a/.venv/lib/python3.12/site-packages/torch/jit/mobile/__init__.py b/.venv/lib/python3.12/site-packages/torch/jit/mobile/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..593dabf2fc43ca075d604977a0d254ea02f3ba55
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/jit/mobile/__init__.py
@@ -0,0 +1,232 @@
+# mypy: allow-untyped-defs
+import os
+
+import torch
+from torch.jit._serialization import validate_map_location
+
+
+def _load_for_lite_interpreter(f, map_location=None):
+    r"""
+    Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter`.
+
+    Args:
+        f: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+        map_location: a string or torch.device used to dynamically remap
+            storages to an alternative set of devices.
+
+    Returns:
+        A :class:`LiteScriptModule` object.
+
+    Example:
+
+    .. testcode::
+
+        import torch
+        import io
+
+        # Load LiteScriptModule from saved file path
+        torch.jit._load_for_lite_interpreter('lite_script_module.pt')
+
+        # Load LiteScriptModule from io.BytesIO object
+        with open('lite_script_module.pt', 'rb') as f:
+            buffer = io.BytesIO(f.read())
+
+        # Load all tensors to the original device
+        torch.jit.mobile._load_for_lite_interpreter(buffer)
+    """
+    if isinstance(f, (str, os.PathLike)):
+        if not os.path.exists(f):
+            raise ValueError(f"The provided filename {f} does not exist")
+        if os.path.isdir(f):
+            raise ValueError(f"The provided filename {f} is a directory")
+
+    map_location = validate_map_location(map_location)
+
+    if isinstance(f, (str, os.PathLike)):
+        cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
+    else:
+        cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
+            f.read(), map_location
+        )
+
+    return LiteScriptModule(cpp_module)
+
+
+class LiteScriptModule:
+    def __init__(self, cpp_module):
+        self._c = cpp_module
+        super().__init__()
+
+    def __call__(self, *input):
+        return self._c.forward(input)
+
+    def find_method(self, method_name):
+        return self._c.find_method(method_name)
+
+    def forward(self, *input):
+        return self._c.forward(input)
+
+    def run_method(self, method_name, *input):
+        return self._c.run_method(method_name, input)
+
+
+def _export_operator_list(module: LiteScriptModule):
+    r"""Return a set of root operator names (with overload name) that are used by any method in this mobile module."""
+    return torch._C._export_operator_list(module._c)
+
+
+def _get_model_bytecode_version(f_input) -> int:
+    r"""Take a file-like object to return an integer.
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        version: An integer. If the integer is -1, the version is invalid. A warning
+            will show in the log.
+
+    Example:
+    .. testcode::
+
+        from torch.jit.mobile import _get_model_bytecode_version
+
+        # Get bytecode version from a saved file path
+        version = _get_model_bytecode_version("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_model_bytecode_version(os.fspath(f_input))
+    else:
+        return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
+
+
+def _get_mobile_model_contained_types(f_input) -> int:
+    r"""Take a file-like object and return a set of string, like ("int", "Optional").
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        type_list: A set of string, like ("int", "Optional"). These are types used in bytecode.
+
+    Example:
+
+    .. testcode::
+
+        from torch.jit.mobile import _get_mobile_model_contained_types
+
+        # Get type list from a saved file path
+        type_list = _get_mobile_model_contained_types("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
+    else:
+        return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
+
+
+def _backport_for_mobile(f_input, f_output, to_version):
+    r"""Take a input string containing a file name (file-like object) and a new destination to return a boolean.
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+        f_output: path to new model destination
+        to_version: the expected output model bytecode version
+    Returns:
+        success: A boolean. If backport success, return true, otherwise false
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if (isinstance(f_input, (str, os.PathLike))) and (
+        isinstance(f_output, (str, os.PathLike))
+    ):
+        return torch._C._backport_for_mobile(
+            os.fspath(f_input), os.fspath(f_output), to_version
+        )
+    else:
+        return torch._C._backport_for_mobile_from_buffer(
+            f_input.read(), str(f_output), to_version
+        )
+
+
+def _backport_for_mobile_to_buffer(f_input, to_version):
+    r"""Take a string containing a file name (file-like object).
+
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
+    else:
+        return torch._C._backport_for_mobile_from_buffer_to_buffer(
+            f_input.read(), to_version
+        )
+
+
+def _get_model_ops_and_info(f_input):
+    r"""Retrieve the root (top level) operators of a model and their corresponding compatibility info.
+
+    These root operators can call other operators within them (traced ops), and
+    a root op can call many different traced ops depending on internal code paths in the root op.
+    These traced ops are not returned by this function. Those operators are abstracted into the
+    runtime as an implementation detail (and the traced ops themselves can also call other operators)
+    making retrieving them difficult and their value from this api negligible since they will differ
+    between which runtime version the model is run on. Because of this, there is a false positive this
+    api can't prevent in a compatibility usecase. All the root ops of a model are present in a
+    target runtime, but not all the traced ops are which prevents a model from being able to run.
+    Args:
+        f_input: a file-like object (has to implement read, readline, tell, and seek),
+            or a string containing a file name
+
+    Returns:
+        Operators and info: A Dictionary mapping strings (the qualified names of the root operators)
+        of the model to their OperatorInfo structs.
+
+    Example:
+
+    .. testcode::
+
+        from torch.jit.mobile import _get_model_ops_and_info
+
+        # Get bytecode version from a saved file path
+        ops_and_info = _get_model_ops_and_info("path/to/model.ptl")
+
+    """
+    if isinstance(f_input, (str, os.PathLike)):
+        if not os.path.exists(f_input):
+            raise ValueError(f"The provided filename {f_input} does not exist")
+        if os.path.isdir(f_input):
+            raise ValueError(f"The provided filename {f_input} is a directory")
+
+    if isinstance(f_input, (str, os.PathLike)):
+        return torch._C._get_model_ops_and_info(os.fspath(f_input))
+    else:
+        return torch._C._get_model_ops_and_info(f_input.read())
diff --git a/.venv/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h b/.venv/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h
new file mode 100644
index 0000000000000000000000000000000000000000..e441ff5a28936d8ca999fcb61ddc8dbbb2c8c12b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/lib/libshm/alloc_info.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include 
+
+struct AllocInfo {
+  pid_t pid;
+  char free;
+  char filename[60];
+};
diff --git a/.venv/lib/python3.12/site-packages/torch/lib/libshm/err.h b/.venv/lib/python3.12/site-packages/torch/lib/libshm/err.h
new file mode 100644
index 0000000000000000000000000000000000000000..e1e6aa4e277c3a94dd642ff2a27e6cd564322e46
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/lib/libshm/err.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include 
+#include 
+
+// `errno` is only meaningful when it fails. E.g., a  successful `fork()` sets
+// `errno` to `EINVAL` in child process on some macos
+// (https://stackoverflow.com/a/20295079), and thus `errno` should really only
+// be inspected if an error occurred.
+//
+// All functions used in `libshm` (so far) indicate error by returning `-1`. If
+// you want to use a function with a different error reporting mechanism, you
+// need to port `SYSCHECK` from `torch/lib/c10d/Utils.hpp`.
+#define SYSCHECK_ERR_RETURN_NEG1(expr)                          \
+  while (true) {                                                \
+    if ((expr) == -1) {                                         \
+      if (errno == EINTR) {                                     \
+        continue;                                               \
+      } else {                                                  \
+        throw std::system_error(errno, std::system_category()); \
+      }                                                         \
+    } else {                                                    \
+      break;                                                    \
+    }                                                           \
+  }
diff --git a/.venv/lib/python3.12/site-packages/torch/lib/libshm/libshm.h b/.venv/lib/python3.12/site-packages/torch/lib/libshm/libshm.h
new file mode 100644
index 0000000000000000000000000000000000000000..28024aa2338d1f46ce280abeb92a633f89be1385
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/lib/libshm/libshm.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include 
+
+#ifdef __cplusplus
+
+void libshm_init(const char* manager_exec_path);
+
+// Superclass to run a constructor before at::RefcountedMapAllocator
+class THManagedMapAllocatorInit {
+ protected:
+  THManagedMapAllocatorInit(const char* manager_handle, const char* filename);
+  std::string manager_handle_;
+};
+
+// Like a at::RefcountedMapAllocator, but it also makes use of an external
+// shared memory manager process to ensure that shared memory regions actually
+// get freed in the end (even if processes lose the memory).
+class THManagedMapAllocator : private THManagedMapAllocatorInit,
+                              public at::RefcountedMapAllocator {
+ public:
+  THManagedMapAllocator(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+
+  void close() override;
+
+  ~THManagedMapAllocator() override {
+    close();
+  }
+
+  static at::DataPtr makeDataPtr(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+  static THManagedMapAllocator* fromDataPtr(const at::DataPtr&);
+
+  const char* manager_handle() const {
+    return manager_handle_.c_str();
+  }
+};
+
+#endif
diff --git a/.venv/lib/python3.12/site-packages/torch/lib/libshm/socket.h b/.venv/lib/python3.12/site-packages/torch/lib/libshm/socket.h
new file mode 100644
index 0000000000000000000000000000000000000000..6b7207eb70a860b4e7977d66cbdc75e6aee11123
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/lib/libshm/socket.h
@@ -0,0 +1,167 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+class Socket {
+ public:
+  int socket_fd;
+  Socket(const Socket& other) = delete;
+
+ protected:
+  Socket() {
+    SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
+  }
+  Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
+    other.socket_fd = -1;
+  };
+  explicit Socket(int fd) : socket_fd(fd) {}
+
+  virtual ~Socket() {
+    if (socket_fd != -1)
+      close(socket_fd);
+  }
+
+  struct sockaddr_un prepare_address(const char* path) {
+    struct sockaddr_un address;
+    address.sun_family = AF_UNIX;
+    strcpy(address.sun_path, path);
+    return address;
+  }
+
+  // Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html
+  size_t address_length(struct sockaddr_un address) {
+    return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1;
+  }
+
+  void recv(void* _buffer, size_t num_bytes) {
+    char* buffer = (char*)_buffer;
+    size_t bytes_received = 0;
+    ssize_t step_received;
+    struct pollfd pfd = {};
+    pfd.fd = socket_fd;
+    pfd.events = POLLIN;
+    while (bytes_received < num_bytes) {
+      SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
+      if (pfd.revents & POLLIN) {
+        SYSCHECK_ERR_RETURN_NEG1(
+            step_received =
+                ::read(socket_fd, buffer, num_bytes - bytes_received));
+        if (step_received == 0)
+          throw std::runtime_error("Other end has closed the connection");
+        bytes_received += step_received;
+        buffer += step_received;
+      } else if (pfd.revents & (POLLERR | POLLHUP)) {
+        throw std::runtime_error(
+            "An error occurred while waiting for the data");
+      } else {
+        throw std::runtime_error(
+            "Shared memory manager connection has timed out");
+      }
+    }
+  }
+
+  void send(const void* _buffer, size_t num_bytes) {
+    const char* buffer = (const char*)_buffer;
+    size_t bytes_sent = 0;
+    ssize_t step_sent;
+    while (bytes_sent < num_bytes) {
+      SYSCHECK_ERR_RETURN_NEG1(
+          step_sent = ::write(socket_fd, buffer, num_bytes));
+      bytes_sent += step_sent;
+      buffer += step_sent;
+    }
+  }
+};
+
+class ManagerSocket : public Socket {
+ public:
+  explicit ManagerSocket(int fd) : Socket(fd) {}
+
+  AllocInfo receive() {
+    AllocInfo info;
+    recv(&info, sizeof(info));
+    return info;
+  }
+
+  void confirm() {
+    send("OK", 2);
+  }
+};
+
+class ManagerServerSocket : public Socket {
+ public:
+  explicit ManagerServerSocket(const std::string& path) {
+    socket_path = path;
+    try {
+      struct sockaddr_un address = prepare_address(path.c_str());
+      size_t len = address_length(address);
+      SYSCHECK_ERR_RETURN_NEG1(
+          bind(socket_fd, (struct sockaddr*)&address, len));
+      SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
+    } catch (std::exception&) {
+      SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
+      throw;
+    }
+  }
+
+  void remove() {
+    struct stat file_stat;
+    if (fstat(socket_fd, &file_stat) == 0)
+      SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str()));
+  }
+
+  ~ManagerServerSocket() override {
+    unlink(socket_path.c_str());
+  }
+
+  ManagerSocket accept() {
+    int client_fd;
+    struct sockaddr_un addr;
+    socklen_t addr_len = sizeof(addr);
+    SYSCHECK_ERR_RETURN_NEG1(
+        client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len));
+    return ManagerSocket(client_fd);
+  }
+
+  std::string socket_path;
+};
+
+class ClientSocket : public Socket {
+ public:
+  explicit ClientSocket(const std::string& path) {
+    try {
+      struct sockaddr_un address = prepare_address(path.c_str());
+      size_t len = address_length(address);
+      SYSCHECK_ERR_RETURN_NEG1(
+          connect(socket_fd, (struct sockaddr*)&address, len));
+    } catch (std::exception&) {
+      SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
+      throw;
+    }
+  }
+
+  void register_allocation(AllocInfo& info) {
+    char buffer[3] = {0, 0, 0};
+    send(&info, sizeof(info));
+    recv(buffer, 2);
+    if (strcmp(buffer, "OK") != 0)
+      throw std::runtime_error(
+          "Shared memory manager didn't respond with an OK");
+  }
+
+  void register_deallocation(AllocInfo& info) {
+    send(&info, sizeof(info));
+  }
+};
diff --git a/.venv/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h b/.venv/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h
new file mode 100644
index 0000000000000000000000000000000000000000..4dd193df93d110e3a04d33a3f9d3e3ec24948277
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/lib/libshm_windows/libshm.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include 
+
+#ifdef __cplusplus
+
+#ifdef SHM_EXPORTS
+#define SHM_API __declspec(dllexport)
+#else
+#define SHM_API __declspec(dllimport)
+#endif
+
+SHM_API void libshm_init(const char* manager_exec_path);
+
+class SHM_API THManagedMapAllocator : public at::RefcountedMapAllocator {
+ public:
+  THManagedMapAllocator(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size)
+      : at::RefcountedMapAllocator(filename, flags, size) {}
+
+  static at::DataPtr makeDataPtr(
+      const char* manager_handle,
+      const char* filename,
+      int flags,
+      size_t size);
+  static THManagedMapAllocator* fromDataPtr(const at::DataPtr&);
+
+  const char* manager_handle() const {
+    return "no_manager";
+  }
+};
+
+#endif
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3ebfa6d69e74d0c5e3faccbc2a50b7b052c6c14
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2c3e626ed44d49481dc7dfca5c31f84a0a571e6
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_docs.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..078ca5729f9877d7f63fbd1aaa1c7d34119ac328
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/__pycache__/_ops.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ef878d3c4b20ef38c7dfd6e14631e99b2fddcc1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from .binary import _apply_native_binary, _is_native_binary
+from .core import is_masked_tensor, MaskedTensor
+from .passthrough import _apply_pass_through_fn, _is_pass_through_fn
+from .reductions import _apply_reduction, _is_reduction
+from .unary import _apply_native_unary, _is_native_unary
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7916b1b5b59460ee319965da909314d772521dbf
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06c8f7ed72a27815d4e8a35f4889c759604e8beb
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/binary.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b48bde5ef4ee3c75fd0b7f6ce6594bfd5b8147ef
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/core.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a6a2315ebd6a4f7ba922cee968fac959e1e8d496
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/creation.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e4ec876ed1e91a5e4a065b88c79f5950582ceb9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/passthrough.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0e8373116eee24e2ecd58821b385c9a74a9b698
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/reductions.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6063f9f800302ef23fb3a5cd6e37c02f9d1a152b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/__pycache__/unary.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py
new file mode 100644
index 0000000000000000000000000000000000000000..719df7eac464f7b44499646f0129623b95aa9461
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/_ops_refs.py
@@ -0,0 +1,531 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from functools import partial
+from typing import Any, Callable, TYPE_CHECKING
+
+import torch
+
+from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
+from .core import (
+    _get_data,
+    _masks_match,
+    _maybe_get_mask,
+    is_masked_tensor,
+    MaskedTensor,
+)
+from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
+from .reductions import (
+    _apply_reduction,
+    NATIVE_REDUCE_FNS,
+    TENSOR_REDUCE_FNS,
+    TORCH_REDUCE_FNS,
+)
+from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS
+
+
+if TYPE_CHECKING:
+    from torch._ops import OpOverload
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _check_args_kwargs_length(
+    args, kwargs, error_prefix, len_args=None, len_kwargs=None
+):
+    if len_args is not None and len_args != len(args):
+        raise ValueError(
+            f"{error_prefix}: len(args) must be {len_args} but got {len(args)}"
+        )
+    if len_kwargs is not None and len_kwargs != len(kwargs):
+        raise ValueError(
+            f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}"
+        )
+
+
+class _MaskedContiguous(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
+
+        if input.is_contiguous():
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+
+        return MaskedTensor(data.contiguous(), mask.contiguous())
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output
+
+
+class _MaskedToDense(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
+
+        if input.layout == torch.strided:
+            return input
+
+        ctx.layout = input.layout
+        data = input.get_data()
+        mask = input.get_mask()
+
+        return MaskedTensor(data.to_dense(), mask.to_dense())
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        layout = ctx.layout
+
+        if layout == torch.sparse_coo:
+            return grad_output.to_sparse_coo()
+        elif layout == torch.sparse_csr:
+            return grad_output.to_sparse_csr()
+        elif layout == torch.strided:
+            return grad_output.to_dense()
+        raise ValueError("to_dense: Unsupported input layout: ", layout)
+
+
+class _MaskedToSparse(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
+
+        # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo
+        if input.layout == torch.sparse_coo:
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+        sparse_mask = mask.to_sparse_coo().coalesce()
+        sparse_data = data.sparse_mask(sparse_mask)
+
+        return MaskedTensor(sparse_data, sparse_mask)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output.to_dense()
+
+
+class _MaskedToSparseCsr(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input):
+        if not is_masked_tensor(input):
+            raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
+
+        if input._masked_data.ndim != 2:
+            raise ValueError(
+                f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}"
+            )
+
+        if input.layout == torch.sparse_csr:
+            return input
+
+        data = input.get_data()
+        mask = input.get_mask()
+        sparse_mask = mask.to_sparse_csr()
+        sparse_data = data.sparse_mask(sparse_mask)
+
+        return MaskedTensor(sparse_data, sparse_mask)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output.to_dense()
+
+
+class _MaskedWhere(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, cond, self, other):
+        ctx.mark_non_differentiable(cond)
+        ctx.save_for_backward(cond)
+        return torch.ops.aten.where(cond, self, other)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        (cond,) = ctx.saved_tensors
+
+        def masked_out_like(mt):
+            return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool())
+
+        return (
+            None,
+            torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)),
+            torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output),
+        )
+
+
+_MASKEDTENSOR_FUNCTION_TABLE = {}
+
+_function_fn_apply_map = {
+    (
+        tuple(NATIVE_REDUCE_FNS),
+        tuple(TORCH_REDUCE_FNS),
+        tuple(TENSOR_REDUCE_FNS),
+    ): _apply_reduction,
+}
+
+for fn_map_list, apply_fn in _function_fn_apply_map.items():
+    for fn_map in fn_map_list:
+        for fn in fn_map:
+            _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn)
+
+
+def register_function_func(ops):
+    """
+    Used for registering a new __torch_function__ function to MaskedTensor
+    Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
+
+    The code to register a new function looks like:
+
+    @register_function_func(list_of_ops)
+    def foo(func, *args, **kwargs):
+        
+    """
+
+    def wrapper(func):
+        for op in ops:
+            _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op)
+
+    return wrapper
+
+
+@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
+def _general_function_reductions(func, *args, **kwargs):
+    return _apply_reduction(func, *args, **kwargs)
+
+
+@register_function_func([torch.Tensor.where, torch.where])
+def _function_where(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0
+    )
+    return _MaskedWhere.apply(*args)
+
+
+@register_function_func([torch.Tensor.contiguous])
+def _function_contiguous(func, *args, **kwargs):
+    return _MaskedContiguous.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_dense])
+def _function_to_dense(func, *args, **kwargs):
+    return _MaskedToDense.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_sparse])
+def _function_to_sparse(func, *args, **kwargs):
+    return _MaskedToSparse.apply(args[0])
+
+
+@register_function_func([torch.Tensor.to_sparse_csr])
+def _function_to_sparse_csr(func, *args, **kwargs):
+    return _MaskedToSparseCsr.apply(args[0])
+
+
+_MASKEDTENSOR_DISPATCH_TABLE: dict["OpOverload", Callable[..., Any]] = {}
+
+
+def register_dispatch_func(aten_ops):
+    """
+    Used for registering a new __torch_dispatch__ function to MaskedTensor
+    Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
+
+    The code to register a new function looks like:
+
+    @register_dispatch_func(list_of_ops)
+    def foo(func, *args, **kwargs):
+        
+    """
+
+    def wrapper(func):
+        for aten_op in aten_ops:
+            _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op)
+
+    return wrapper
+
+
+@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
+def _general_reduction(func, *args, **kwargs):
+    return _apply_reduction(func, *args, **kwargs)
+
+
+@register_dispatch_func(PASSTHROUGH_FNS)
+def _general_passthrough(func, *args, **kwargs):
+    return _apply_pass_through_fn(func, *args, **kwargs)
+
+
+@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS)
+def _general_unary(func, *args, **kwargs):
+    return _apply_native_unary(func, *args, **kwargs)
+
+
+@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS)
+def _general_binary(func, *args, **kwargs):
+    return _apply_native_binary(func, *args, **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.stride])
+def stride(func, *args, **kwargs):
+    return None
+
+
+@register_dispatch_func([torch.ops.aten.sym_stride])
+def sym_stride(func, *args, **kwargs):
+    return None
+
+
+@register_dispatch_func([torch.ops.prim.layout])
+def layout(func, *args, **kwargs):
+    return _get_data(args[0]).layout
+
+
+@register_dispatch_func([torch.ops.aten.is_contiguous])
+def is_contiguous(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError("MaskedTensors with sparse data do not have is_contiguous")
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.is_strides_like_format])
+def is_strides_like_format(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError(
+            "MaskedTensors with sparse data do not have is_strides_like_format"
+        )
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense])
+def is_non_overlapping_and_dense(func, *args, **kwargs):
+    data = _get_data(args[0])
+    if data.is_sparse:
+        raise ValueError(
+            "MaskedTensors with sparse data do not have is_non_overlapping_and_dense"
+        )
+    return func(data, *args[1:], **kwargs)
+
+
+@register_dispatch_func([torch.ops.aten.contiguous])
+def contiguous(func, *args, **kwargs):
+    if _get_data(args[0]).is_sparse:
+        raise ValueError("MaskedTensors with sparse data do not have contiguous")
+    return _MaskedContiguous.apply(args[0])
+
+
+@register_dispatch_func([torch.ops.aten.new_empty_strided])
+def new_empty_strided(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3)
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    if tuple(args[1]) != tuple(data.size()):
+        raise ValueError(
+            f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()"
+        )
+    if tuple(args[2]) != tuple(data.stride()):
+        raise ValueError(
+            f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()"
+        )
+    return MaskedTensor(func(data, args[1], args[2], **kwargs), mask)
+
+
+@register_dispatch_func([torch.ops.aten._local_scalar_dense])
+def _local_scalar_dense(func, *args, **kwargs):
+    if not _maybe_get_mask(args[0]):
+        raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor")
+    return torch.ops.aten._local_scalar_dense(_get_data(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone])
+def _apply_fn_on_data(func, *args, **kwargs):
+    return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten._to_copy])
+def _to_copy(func, *args, **kwargs):
+    new_data = func(_get_data(args[0]), *args[1:], **kwargs)
+    return MaskedTensor(new_data, _maybe_get_mask(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten._softmax])
+def _softmax(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
+    )
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
+    return MaskedTensor(result_data, mask)
+
+
+@register_dispatch_func([torch.ops.aten.ones_like])
+def ones_like(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1)
+    result_data = func(_get_data(args[0]), **kwargs)
+    return MaskedTensor(result_data, _maybe_get_mask(args[0]))
+
+
+@register_dispatch_func([torch.ops.aten._softmax_backward_data])
+def _softmax_backward_data(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
+    grad, output, dim, _input_dtype = args
+    if is_masked_tensor(grad) and is_masked_tensor(output):
+        if not _masks_match(grad, output):
+            raise ValueError(
+                "__torch_dispatch__, {func}: expected the masks of grad and output to match"
+            )
+        grad_data = _get_data(grad)
+        new_grad_data = torch.ops.aten._masked_softmax_backward(
+            grad_data,
+            _get_data(output),
+            ~_maybe_get_mask(grad),
+            dim % grad_data.ndim,
+        )
+        res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
+        return res
+    else:
+        raise ValueError(
+            f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors"
+        )
+
+
+@register_dispatch_func([torch.ops.aten.copy_])
+def copy_(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
+    if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])):
+        raise ValueError("args[0] mask and args[1] mask must match but do not")
+    func(_get_data(args[0]), _get_data(args[1]))
+    return args[0]
+
+
+@register_dispatch_func([torch.ops.aten.where])
+def where(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mx = args[1]
+    my = args[2]
+    if not is_masked_tensor(mx):
+        mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool))
+    if not is_masked_tensor(my):
+        my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool))
+    new_data = func(args[0], mx.get_data(), my.get_data())
+    new_mask = func(args[0], mx.get_mask(), my.get_mask())
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_sparse])
+def _to_sparse(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool))
+    if mt.is_sparse_coo():
+        return mt
+    new_mask = func(_maybe_get_mask(args[0])).coalesce()
+    new_data = _get_data(args[0]).sparse_mask(new_mask)
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_sparse_csr])
+def _to_sparse_csr(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt).bool())
+    if mt.is_sparse_csr():
+        return mt
+    new_mask = func(_maybe_get_mask(args[0]))
+    new_data = _get_data(args[0]).sparse_mask(new_mask)
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._to_dense])
+def _to_dense(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    if not torch.is_tensor(args[0]):
+        raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
+    mt = args[0]
+    if not is_masked_tensor(mt):
+        mt = MaskedTensor(mt, torch.ones_like(mt).bool())
+    new_data = func(_get_data(args[0]))
+    new_mask = func(_maybe_get_mask(args[0]))
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten._indices])
+def _indices(func, *args, **kwargs):
+    # Assumes data is sparse
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0]).indices()
+    return MaskedTensor(data, torch.ones_like(data).bool())
+
+
+@register_dispatch_func([torch.ops.aten._values])
+def _values(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0]).values()
+    return MaskedTensor(data, torch.ones_like(data).bool())
+
+
+@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors])
+def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs):
+    new_args = list(args)
+    if is_masked_tensor(args[-1]):
+        new_args[-1] = args[-1].get_data()
+    if is_masked_tensor(args[-2]):
+        new_args[-2] = args[-2].get_data()
+
+    new_data = func(*new_args, **kwargs)
+    new_args[-1] = torch.ones_like(new_args[-1])
+    new_mask = func(*new_args, **kwargs).bool()
+
+    return MaskedTensor(new_data, new_mask)
+
+
+@register_dispatch_func([torch.ops.aten.is_same_size])
+def is_same_size(func, *args, **kwargs):
+    _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
+    return _get_data(args[0]).is_same_size(_get_data(args[1]))
+
+
+@register_dispatch_func([torch.ops.aten._is_any_true])
+def _is_any_true(func, *args, **kwargs):
+    _check_args_kwargs_length(
+        args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
+    )
+    data = _get_data(args[0])
+    mask = _maybe_get_mask(args[0])
+    if mask is None:
+        raise ValueError(
+            f"__torch_dispatch__, {func}: expected args[0] to be a MaskedTensor"
+        )
+    if data.dtype != torch.bool:
+        raise ValueError(f"__torch_dispatch__, {func}: expected a boolean tensor")
+    if data.is_sparse:
+        raise ValueError(f"MaskedTensors with sparse data do not have {func}")
+
+    return MaskedTensor(func(data & mask), torch.tensor(True))
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..8315ae11be7175c2b5aaef178a4bc4785dcbcb29
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/binary.py
@@ -0,0 +1,200 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import torch
+
+from .core import (
+    _map_mt_args_kwargs,
+    _masks_match,
+    _tensors_match,
+    _wrap_result,
+    is_masked_tensor,
+)
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+BINARY_NAMES = [
+    "add",
+    "atan2",
+    "arctan2",
+    "bitwise_and",
+    "bitwise_or",
+    "bitwise_xor",
+    "bitwise_left_shift",
+    "bitwise_right_shift",
+    "div",
+    "divide",
+    "floor_divide",
+    "fmod",
+    "logaddexp",
+    "logaddexp2",
+    "mul",
+    "multiply",
+    "nextafter",
+    "remainder",
+    "sub",
+    "subtract",
+    "true_divide",
+    "eq",
+    "ne",
+    "le",
+    "ge",
+    "greater",
+    "greater_equal",
+    "gt",
+    "less_equal",
+    "lt",
+    "less",
+    "maximum",
+    "minimum",
+    "fmax",
+    "fmin",
+    "not_equal",
+]
+
+INPLACE_BINARY_NAMES = [
+    n + "_"
+    for n in (
+        list(
+            set(BINARY_NAMES)
+            - {
+                "logaddexp",
+                "logaddexp2",
+                "equal",
+                "fmin",
+                "minimum",
+                "maximum",
+                "fmax",
+            }
+        )
+    )
+]
+
+
+def _get_at_least_one_mask(a, b):
+    if not is_masked_tensor(a) and not is_masked_tensor(b):
+        raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
+    if not _masks_match(a, b):
+        raise ValueError("a and b must have matching masks")
+    if is_masked_tensor(a):
+        return a.get_mask()
+    return b.get_mask()
+
+
+def _binary_helper(fn, args, kwargs, inplace):
+    if len(kwargs) != 0:
+        raise ValueError("len(kwargs) must equal 0")
+    for a in args[2:]:
+        if torch.is_tensor(a):
+            raise TypeError(
+                "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
+            )
+
+    if not _masks_match(*args[:2]):
+        raise ValueError(
+            "Input masks must match. If you need support for this, please open an issue on Github."
+        )
+
+    data_args, _data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
+    mask_args, _mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
+
+    args0_layout = data_args[0].layout
+    same_layout = (
+        torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])
+    ) and (args0_layout == data_args[1].layout)
+
+    if args0_layout == torch.sparse_coo:
+        if same_layout:
+            if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
+                raise ValueError(
+                    "sparse_coo indices must match. If you need support for this, please open an issue on Github."
+                )
+            if data_args[0].size() != data_args[1].size():
+                raise ValueError(
+                    "input1 and input2 must have the same size for binary functions."
+                )
+
+            data_args[1] = data_args[1].values()
+
+        i = data_args[0].indices()
+        size = data_args[0].size()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_coo_tensor(i, v, size)
+
+    elif args0_layout == torch.sparse_csr:
+        if same_layout:
+            if not (
+                _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
+                and _tensors_match(
+                    data_args[0].col_indices(), data_args[1].col_indices()
+                )
+            ):
+                raise ValueError(
+                    "sparse_csr indices must match. If you need support for this, please open an issue on Github."
+                )
+
+            data_args[1] = data_args[1].values()
+
+        crow = data_args[0].crow_indices()
+        col = data_args[0].col_indices()
+        size = data_args[0].size()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_csr_tensor(crow, col, v, size)
+
+    else:
+        result_data = fn(*data_args)
+
+    if inplace:
+        args[0]._set_data_mask(result_data, mask_args[0])
+        return args[0]
+    else:
+        result_mask = _get_at_least_one_mask(*args[:2])
+        # sparse tensors don't have strides so we can only expand if the layout is strided
+        if args0_layout == torch.strided:
+            result_mask = result_mask.expand_as(result_data)
+        return _wrap_result(result_data, result_mask)
+
+
+def _torch_binary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def binary_fn(*args, **kwargs):
+        return _binary_helper(fn, args, kwargs, inplace=False)
+
+    return binary_fn
+
+
+def _torch_inplace_binary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def binary_fn(*args, **kwargs):
+        return _binary_helper(fn, args, kwargs, inplace=True)
+
+    return binary_fn
+
+
+NATIVE_BINARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
+}
+NATIVE_INPLACE_BINARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_inplace_binary(name)
+    for name in INPLACE_BINARY_NAMES
+}
+
+NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
+NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
+
+
+def _is_native_binary(fn):
+    return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
+
+
+def _apply_native_binary(fn, *args, **kwargs):
+    if fn in NATIVE_BINARY_FNS:
+        return NATIVE_BINARY_MAP[fn](*args, **kwargs)
+    if fn in NATIVE_INPLACE_BINARY_FNS:
+        return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e3608b3e6d3daf388a478855d5571a2f7ebddba
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/core.py
@@ -0,0 +1,359 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import warnings
+from typing import Any
+from typing_extensions import TypeIs
+
+import torch
+from torch.overrides import get_default_nowrap_functions
+
+
+__all__ = [
+    "MaskedTensor",
+    "is_masked_tensor",
+]
+
+
+def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
+    r"""Returns True if the input is a MaskedTensor, else False
+
+    Args:
+        a: any input
+
+    Examples:
+
+        >>> # xdoctest: +SKIP
+        >>> from torch.masked import MaskedTensor
+        >>> data = torch.arange(6).reshape(2, 3)
+        >>> mask = torch.tensor([[True, False, False], [True, True, False]])
+        >>> mt = MaskedTensor(data, mask)
+        >>> is_masked_tensor(mt)
+        True
+    """
+    return isinstance(obj, MaskedTensor)
+
+
+def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
+    if is_masked_tensor(a) or is_masked_tensor(b):
+        raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
+    if a.layout != b.layout:
+        raise ValueError(
+            f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}"
+        )
+
+    if a.dtype != b.dtype:
+        b = b.type(a.dtype)
+    if a.layout == b.layout == torch.sparse_coo:
+        return _tensors_match(a.values(), b.values(), exact) and _tensors_match(
+            a.indices(), b.indices(), exact
+        )
+    elif a.layout == b.layout == torch.sparse_csr:
+        return (
+            _tensors_match(a.crow_indices(), b.crow_indices(), exact)
+            and _tensors_match(a.col_indices(), b.col_indices(), exact)
+            and _tensors_match(a.values(), b.values(), exact)
+        )
+    if exact:
+        return (a.dim() == b.dim()) and torch.eq(a, b).all().item()
+    return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol)
+
+
+def _masks_match(a, b):
+    if is_masked_tensor(a) and is_masked_tensor(b):
+        mask_a = a.get_mask()
+        mask_b = b.get_mask()
+        return _tensors_match(mask_a, mask_b, exact=True)
+    return True
+
+
+def _map_mt_args_kwargs(args, kwargs, map_fn):
+    def _helper(a, map_fn):
+        if is_masked_tensor(a):
+            return map_fn(a)
+        elif torch.is_tensor(a):
+            return a
+        elif isinstance(a, list):
+            a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
+            return a_impl
+        elif isinstance(a, tuple):
+            a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
+            return tuple(a_impl)
+        else:
+            return a
+
+    if kwargs is None:
+        kwargs = {}
+    impl_args = []
+    for a in args:
+        impl_args.append(_helper(a, map_fn))
+    impl_kwargs = {}
+    for k in kwargs.keys():
+        impl_kwargs[k] = _helper(a, map_fn)
+    return impl_args, impl_kwargs
+
+
+def _wrap_result(result_data, result_mask):
+    if isinstance(result_data, list):
+        return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)]
+    if isinstance(result_data, tuple):
+        return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
+    if torch.is_tensor(result_data):
+        return MaskedTensor(result_data, result_mask)
+    # Expect result_data and result_mask to be Tensors only
+    return NotImplemented
+
+
+def _masked_tensor_str(data, mask, formatter):
+    if data.layout in {torch.sparse_coo, torch.sparse_csr}:
+        data = data.to_dense()
+        mask = mask.to_dense()
+    if data.dim() == 1:
+        formatted_elements = [
+            formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
+            for d in data
+        ]
+        max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
+        return (
+            "["
+            + ", ".join(
+                [
+                    "--".rjust(max_len) if m else e
+                    for (e, m) in zip(formatted_elements, ~mask)
+                ]
+            )
+            + "]"
+        )
+    sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)]
+    sub_strings = ["\n".join(["  " + si for si in s.split("\n")]) for s in sub_strings]
+    return "[\n" + ",\n".join(sub_strings) + "\n]"
+
+
+def _get_data(a):
+    if is_masked_tensor(a):
+        return a._masked_data
+    return a
+
+
+def _maybe_get_mask(a):
+    if is_masked_tensor(a):
+        return a.get_mask()
+    return None
+
+
+class MaskedTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, data, mask, requires_grad=False):
+        if is_masked_tensor(data) or not torch.is_tensor(data):
+            raise TypeError("data must be a Tensor")
+        if is_masked_tensor(mask) or not torch.is_tensor(mask):
+            raise TypeError("mask must be a Tensor")
+        # Use a Tensor that of the give size for the wrapper.
+        kwargs = {
+            "device": data.device,
+            "dtype": data.dtype,
+            "layout": data.layout,
+            "requires_grad": requires_grad,
+            "dispatch_sizes_strides_policy": "strides",
+            "dispatch_layout": True,
+        }
+        warnings.warn(
+            (
+                "The PyTorch API of MaskedTensors is in prototype stage "
+                "and will change in the near future. Please open a Github issue "
+                "for features requests and see our documentation on the torch.masked "
+                "module for further information about the project."
+            ),
+            UserWarning,
+            stacklevel=2,
+        )
+        if data.requires_grad:
+            warnings.warn(
+                "It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
+                "To avoid this, you can use data.detach().clone()",
+                UserWarning,
+                stacklevel=2,
+            )
+        return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
+
+    def _preprocess_data(self, data, mask):
+        from .._ops import _sparse_coo_where, _sparse_csr_where
+
+        if data.layout != mask.layout:
+            raise TypeError("data and mask must have the same layout.")
+        if data.layout == torch.sparse_coo:
+            data = data.coalesce()
+            mask = mask.coalesce()
+            if data._nnz() != mask._nnz():
+                data = _sparse_coo_where(mask, data, torch.tensor(0))
+        elif data.layout == torch.sparse_csr:
+            if data._nnz() != mask._nnz():
+                data = _sparse_csr_where(mask, data, torch.tensor(0))
+
+        # Have to pick awkward names to not conflict with existing fields such as data
+        self._masked_data = data.clone()
+        self._masked_mask = mask.clone()
+
+    def _validate_members(self):
+        data = self._masked_data
+        mask = self.get_mask()
+        if type(data) != type(mask):
+            raise TypeError(
+                f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
+            )
+        if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
+            raise TypeError(f"data layout of {data.layout} is not supported.")
+        if data.layout == torch.sparse_coo:
+            if not _tensors_match(data.indices(), mask.indices(), exact=True):
+                raise ValueError(
+                    "data and mask are both sparse COO tensors but do not have the same indices."
+                )
+        elif data.layout == torch.sparse_csr:
+            if not _tensors_match(
+                data.crow_indices(), mask.crow_indices(), exact=True
+            ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
+                raise ValueError(
+                    "data and mask are both sparse CSR tensors but do not share either crow or col indices."
+                )
+        if mask.dtype != torch.bool:
+            raise TypeError("mask must have dtype bool.")
+        if not (
+            data.dtype == torch.float16
+            or data.dtype == torch.float32
+            or data.dtype == torch.float64
+            or data.dtype == torch.bool
+            or data.dtype == torch.int8
+            or data.dtype == torch.int16
+            or data.dtype == torch.int32
+            or data.dtype == torch.int64
+        ):
+            raise TypeError(f"{data.dtype} is not supported in MaskedTensor.")
+        if data.dim() != mask.dim():
+            raise ValueError("data.dim() must equal mask.dim()")
+        if data.size() != mask.size():
+            raise ValueError("data.size() must equal mask.size()")
+
+    def __init__(self, data, mask, requires_grad=False):
+        self._preprocess_data(data, mask)
+        self._validate_members()
+
+    @staticmethod
+    def _from_values(data, mask):
+        """Differentiable constructor for MaskedTensor"""
+
+        class Constructor(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, data, mask):
+                return MaskedTensor(data, mask)
+
+            @staticmethod
+            def backward(ctx, grad_output):
+                return grad_output, None
+
+        result = Constructor.apply(data, mask)
+        return result
+
+    def _set_data_mask(self, data, mask):
+        self._masked_data = data
+        self._masked_mask = mask
+        self._validate_members()
+
+    def __repr__(self):  # type: ignore[override]
+        formatter = "{0:8.4f}"
+        if self.dim() == 0:
+            scalar_data = self.get_data().item()
+            data_formatted = (
+                formatter.format(scalar_data)
+                if isinstance(scalar_data, float)
+                else str(scalar_data)
+            )
+            if not self.get_mask().item():
+                data_formatted = "--"
+            return (
+                "MaskedTensor("
+                + data_formatted
+                + ", "
+                + str(self.get_mask().item())
+                + ")"
+            )
+        s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter)
+        s = "\n".join("  " + si for si in s.split("\n"))
+        return "MaskedTensor(\n" + s + "\n)"
+
+    # Seems like this needs to be defined before torch_dispatch to work
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+
+        from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
+
+        if func in _MASKEDTENSOR_FUNCTION_TABLE:
+            return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
+
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+        with torch._C.DisableTorchFunctionSubclass():
+            ret = func(*args, **kwargs)
+            if func in get_default_nowrap_functions():
+                return ret
+            else:
+                return torch._tensor._convert(ret, cls)
+
+    @classmethod
+    def unary(cls, fn, data, mask):
+        return MaskedTensor(fn(data), mask)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):  # type: ignore[override]
+        func = func.overloadpacket
+
+        from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
+
+        if func in _MASKEDTENSOR_DISPATCH_TABLE:
+            return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
+
+        msg = (
+            f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n"
+            "If you would like this operator to be supported, please file an issue for a feature request at "
+            "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
+            "In the case that the semantics for the operator are not trivial, it would be appreciated "
+            "to also include a proposal for the semantics."
+        )
+        warnings.warn(msg)
+        return NotImplemented
+
+    def __lt__(self, other):
+        if is_masked_tensor(other):
+            return MaskedTensor(self.get_data() < _get_data(other), self.get_mask())
+        return MaskedTensor(self.get_data() < other, self.get_mask())
+
+    def to_tensor(self, value):
+        return self.get_data().masked_fill(~self.get_mask(), value)
+
+    def get_data(self):
+        class GetData(torch.autograd.Function):
+            @staticmethod
+            def forward(ctx, self):
+                return self._masked_data.detach()
+
+            @staticmethod
+            def backward(ctx, grad_output):
+                if is_masked_tensor(grad_output):
+                    return grad_output
+                return MaskedTensor(grad_output, self.get_mask())
+
+        return GetData.apply(self)
+
+    def get_mask(self):
+        return self._masked_mask
+
+    def is_sparse_coo(self):
+        return self.layout == torch.sparse_coo
+
+    def is_sparse_csr(self):  # type: ignore[override]
+        return self.layout == torch.sparse_csr
+
+    # Update later to support more sparse layouts
+    @property
+    def is_sparse(self):  # type: ignore[override]
+        return self.is_sparse_coo() or self.is_sparse_csr()
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py
new file mode 100644
index 0000000000000000000000000000000000000000..35c8e3d2aa9438dbcfc7995a1cdcd3c5cc8dc1fc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/creation.py
@@ -0,0 +1,24 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from .core import MaskedTensor
+
+
+__all__ = [
+    "as_masked_tensor",
+    "masked_tensor",
+]
+
+
+# These two factory functions are intended to mirror
+#     torch.tensor - guaranteed to be a leaf node
+#     torch.as_tensor - differentiable constructor that preserves the autograd history
+
+
+def masked_tensor(
+    data: object, mask: object, requires_grad: bool = False
+) -> MaskedTensor:
+    return MaskedTensor(data, mask, requires_grad)
+
+
+def as_masked_tensor(data: object, mask: object) -> MaskedTensor:
+    return MaskedTensor._from_values(data, mask)
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba13f50c1fee9c9fc10563ffc9f4ff3211c0dca6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/passthrough.py
@@ -0,0 +1,50 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+"""
+These are functions that should simply be applied to both mask and data.
+Take select or stack as an example. This operation can be applied to
+both the mask and data of a MaskedTensor and the result wrapped into
+a new MaskedTensor as a result.
+"""
+
+import torch
+
+from .core import _map_mt_args_kwargs, _wrap_result
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+PASSTHROUGH_FNS = [
+    torch.ops.aten.select,
+    torch.ops.aten.transpose,
+    torch.ops.aten.split,
+    torch.ops.aten.t,
+    torch.ops.aten.slice,
+    torch.ops.aten.slice_backward,
+    torch.ops.aten.select_backward,
+    torch.ops.aten.index,
+    torch.ops.aten.expand,
+    torch.ops.aten.view,
+    torch.ops.aten._unsafe_view,
+    torch.ops.aten._reshape_alias,
+    torch.ops.aten.cat,
+    torch.ops.aten.unsqueeze,
+    torch.ops.aten.unfold,
+    torch.ops.aten.unfold_backward,
+    torch.ops.aten.im2col,
+    torch.ops.aten.col2im,
+    torch.ops.aten.stack,
+]
+
+
+def _is_pass_through_fn(fn):
+    return fn in PASSTHROUGH_FNS
+
+
+def _apply_pass_through_fn(fn, *args, **kwargs):
+    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
+    result_data = fn(*data_args, **data_kwargs)
+    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
+    result_mask = fn(*mask_args, **mask_kwargs)
+    return _wrap_result(result_data, result_mask)
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py
new file mode 100644
index 0000000000000000000000000000000000000000..fedab1c12a63786fe950edda37ced9637beed83c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/reductions.py
@@ -0,0 +1,176 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import warnings
+
+import torch
+
+from .core import is_masked_tensor
+from .creation import as_masked_tensor, masked_tensor
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+def _masked_all_all(data, mask=None):
+    if mask is None:
+        return data.all()
+    return data.masked_fill(~mask, True).all()
+
+
+def _masked_all_dim(data, dim, keepdim=False, mask=None):
+    if mask is None:
+        return torch.all(data, dim=dim, keepdim=keepdim)
+    return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
+
+
+def _masked_all(*args, **kwargs):
+    if len(args) == 1 and len(kwargs) == 1:
+        return _masked_all_all(args[0], mask=kwargs["mask"])
+    return _masked_all_dim(*args, **kwargs)
+
+
+def _multidim_any(mask, dim, keepdim):
+    if isinstance(dim, int):
+        return _multidim_any(mask, [dim], keepdim)
+    for d in sorted(dim, reverse=True):
+        mask = torch.any(mask, dim=d, keepdim=keepdim)
+    return mask
+
+
+def _get_masked_fn(fn):
+    if fn == "all":
+        return _masked_all
+    return getattr(torch.masked, fn)
+
+
+def _torch_reduce_all(fn):
+    def reduce_all(self):
+        masked_fn = _get_masked_fn(fn)
+        data = self.get_data()
+        mask = self.get_mask().values() if self.is_sparse else self.get_mask()
+        # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
+        # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
+        # Therefore, this implementation calculates it using the strides.
+        if fn == "all":
+            result_data = masked_fn(data, mask=mask)
+
+        elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
+            sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
+            indices = (
+                data.to_sparse_coo().indices()
+                if not self.is_sparse_coo()
+                else data.indices()
+            )
+            idx = indices.unbind(1)[sparse_idx]
+            stride = data.size().numel() / torch.tensor(
+                data.size(), device=data.device
+            ).cumprod(0)
+            result_data = torch.sum(idx * stride)
+
+        # we simply pass in the values for sparse COO/CSR tensors
+        elif self.is_sparse:
+            result_data = masked_fn(masked_tensor(data.values(), mask))
+
+        else:
+            result_data = masked_fn(self, mask=mask)
+
+        return as_masked_tensor(result_data, torch.any(mask))
+
+    return reduce_all
+
+
+def _torch_reduce_dim(fn):
+    def reduce_dim(self, dim, keepdim=False, dtype=None):
+        if self.is_sparse:
+            msg = (
+                f"The sparse version of {fn} is not implemented in reductions.\n"
+                "If you would like this operator to be supported, please file an issue for a feature request at "
+                "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
+                "In the case that the semantics for the operator are not trivial, it would be appreciated "
+                "to also include a proposal for the semantics."
+            )
+            warnings.warn(msg)
+            return NotImplemented
+        if not is_masked_tensor(self):
+            raise TypeError("Input to reduce_dim must be a MaskedTensor")
+
+        masked_fn = _get_masked_fn(fn)
+        data = self.get_data()
+        mask = self.get_mask()
+        if fn == "all":
+            result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
+        else:
+            result_data = masked_fn(
+                self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
+            )
+        return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
+
+    return reduce_dim
+
+
+def _torch_reduce(fn):
+    def reduce_fn(*args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0:
+            return _torch_reduce_all(fn)(args[0])
+        return _torch_reduce_dim(fn)(*args, **kwargs)
+
+    return reduce_fn
+
+
+def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
+    return input, dim, keepdim, dtype
+
+
+def _torch_grad_reduce(fn):
+    def grad_reduce(*args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0:
+            return _torch_reduce_all(fn)(args[0])
+        # TODO: autograd.Function doesn't support kwarg
+        input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
+        return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
+
+    return grad_reduce
+
+
+REDUCE_NAMES = [
+    "sum",
+    "mean",
+    "amin",
+    "amax",
+    "argmin",
+    "argmax",
+    "prod",
+    "all",
+    "norm",
+    "var",
+    "std",
+]
+
+NATIVE_REDUCE_MAP = {
+    getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
+}
+TORCH_REDUCE_MAP = {
+    getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
+}
+TENSOR_REDUCE_MAP = {
+    getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
+}
+
+NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
+TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
+TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
+
+
+def _is_reduction(fn):
+    return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
+
+
+def _apply_reduction(fn, *args, **kwargs):
+    if fn in NATIVE_REDUCE_MAP:
+        return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
+    if fn in TORCH_REDUCE_MAP:
+        return TORCH_REDUCE_MAP[fn](*args, **kwargs)
+    if fn in TENSOR_REDUCE_MAP:
+        return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py
new file mode 100644
index 0000000000000000000000000000000000000000..e04ee6e810a7418829b68323097612391017b14e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/masked/maskedtensor/unary.py
@@ -0,0 +1,194 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import torch
+
+from .core import _map_mt_args_kwargs, _wrap_result
+
+
+__all__ = []  # type: ignore[var-annotated]
+
+
+UNARY_NAMES = [
+    "abs",
+    "absolute",
+    "acos",
+    "arccos",
+    "acosh",
+    "arccosh",
+    "angle",
+    "asin",
+    "arcsin",
+    "asinh",
+    "arcsinh",
+    "atan",
+    "arctan",
+    "atanh",
+    "arctanh",
+    "bitwise_not",
+    "ceil",
+    "clamp",
+    "clip",
+    "conj_physical",
+    "cos",
+    "cosh",
+    "deg2rad",
+    "digamma",
+    "erf",
+    "erfc",
+    "erfinv",
+    "exp",
+    "exp2",
+    "expm1",
+    "fix",
+    "floor",
+    "frac",
+    "lgamma",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "logit",
+    "i0",
+    "isnan",
+    "nan_to_num",
+    "neg",
+    "negative",
+    "positive",
+    "pow",
+    "rad2deg",
+    "reciprocal",
+    "round",
+    "rsqrt",
+    "sigmoid",
+    "sign",
+    "sgn",
+    "signbit",
+    "sin",
+    "sinc",
+    "sinh",
+    "sqrt",
+    "square",
+    "tan",
+    "tanh",
+    "trunc",
+]
+
+INPLACE_UNARY_NAMES = [
+    n + "_"
+    for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
+]
+
+# Explicitly tracking functions we know are currently not supported
+# This might be due to missing code gen or because of complex semantics
+UNARY_NAMES_UNSUPPORTED = [
+    "atan2",
+    "arctan2",
+    "bitwise_left_shift",
+    "bitwise_right_shift",
+    "copysign",
+    "float_power",
+    "fmod",
+    "frexp",
+    "gradient",
+    "imag",
+    "ldexp",
+    "lerp",
+    "logical_not",
+    "hypot",
+    "igamma",
+    "igammac",
+    "mvlgamma",
+    "nextafter",
+    "polygamma",
+    "real",
+    "remainder",
+    "true_divide",
+    "xlogy",
+]
+
+
+def _unary_helper(fn, args, kwargs, inplace):
+    if len(kwargs) != 0:
+        raise ValueError(
+            "MaskedTensor unary ops require that len(kwargs) == 0. "
+            "If you need support for this, please open an issue on Github."
+        )
+    for a in args[1:]:
+        if torch.is_tensor(a):
+            raise TypeError(
+                "MaskedTensor unary ops do not support additional Tensor arguments"
+            )
+
+    mask_args, _mask_kwargs = _map_mt_args_kwargs(
+        args, kwargs, lambda x: x._masked_mask
+    )
+    data_args, _data_kwargs = _map_mt_args_kwargs(
+        args, kwargs, lambda x: x._masked_data
+    )
+
+    if args[0].layout == torch.sparse_coo:
+        data_args[0] = data_args[0].coalesce()
+        s = data_args[0].size()
+        i = data_args[0].indices()
+        data_args[0] = data_args[0].coalesce().values()
+        v = fn(*data_args)
+        result_data = torch.sparse_coo_tensor(i, v, size=s)
+
+    elif args[0].layout == torch.sparse_csr:
+        crow = data_args[0].crow_indices()
+        col = data_args[0].col_indices()
+        data_args[0] = data_args[0].values()
+        v = fn(*data_args)
+        result_data = torch.sparse_csr_tensor(crow, col, v)
+
+    else:
+        result_data = fn(*data_args)
+
+    if inplace:
+        args[0]._set_data_mask(result_data, mask_args[0])
+        return args[0]
+    else:
+        return _wrap_result(result_data, mask_args[0])
+
+
+def _torch_unary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def unary_fn(*args, **kwargs):
+        return _unary_helper(fn, args, kwargs, inplace=False)
+
+    return unary_fn
+
+
+def _torch_inplace_unary(fn_name):
+    fn = getattr(torch.ops.aten, fn_name)
+
+    def unary_fn(*args, **kwargs):
+        return _unary_helper(fn, args, kwargs, inplace=True)
+
+    return unary_fn
+
+
+NATIVE_UNARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
+}
+NATIVE_INPLACE_UNARY_MAP = {
+    getattr(torch.ops.aten, name): _torch_inplace_unary(name)
+    for name in INPLACE_UNARY_NAMES
+}
+
+NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
+NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
+
+
+def _is_native_unary(fn):
+    return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
+
+
+def _apply_native_unary(fn, *args, **kwargs):
+    if fn in NATIVE_UNARY_FNS:
+        return NATIVE_UNARY_MAP[fn](*args, **kwargs)
+    if fn in NATIVE_INPLACE_UNARY_FNS:
+        return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
+    return NotImplemented
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50e5970a9c1969e696cba6ed00c6591337611599
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/nested/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/_internal/__init__.py b/.venv/lib/python3.12/site-packages/torch/nested/_internal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..311d637e1897b89faa6743d9c583a4b51a8d781c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23b8838ee36c8316537cb530633656afc36e07d8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_int.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a756ccdd046f113f359888423b1abd674889bba
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/nested/_internal/__pycache__/nested_tensor.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py b/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py
new file mode 100644
index 0000000000000000000000000000000000000000..59090b331d5010568b07151a9db2cafef324e149
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_int.py
@@ -0,0 +1,116 @@
+from typing import *  # noqa: F403
+
+import torch
+from torch.fx.experimental._constant_symnode import ConstantIntNode
+
+
+__all__ = ["NestedIntNode"]
+
+
+# Python version of aten/src/ATen/core/NestedIntSymNodeImpl.cpp
+def _eq(lhs: Any, rhs: Any) -> bool:
+    return (
+        isinstance(lhs, NestedIntNode)
+        and isinstance(rhs, NestedIntNode)
+        and lhs.t_id == rhs.t_id
+        and lhs.coeff == rhs.coeff
+    )
+
+
+def _ge(lhs: Any, rhs: Any) -> bool:
+    if isinstance(rhs, NestedIntNode) and isinstance(lhs, NestedIntNode):
+        if lhs.t_id == rhs.t_id:
+            return lhs.coeff >= rhs.coeff
+        raise ValueError("ge: relation is indeterminate")
+    elif isinstance(lhs, NestedIntNode):
+        if rhs.is_constant() and rhs.constant_int() <= 2:
+            return True
+        raise ValueError("ge: relation is indeterminate")
+    elif isinstance(rhs, NestedIntNode):
+        if lhs.is_constant() and lhs.constant_int() < 2:
+            return False
+        raise ValueError("ge: relation is indeterminate")
+    else:
+        raise ValueError("inputs unsupported")
+
+
+class NestedIntNode:
+    def __init__(self, t_id: int, coeff: int):
+        self.t_id = t_id
+        self.coeff = coeff
+
+    def nested_int_coeff(self) -> int:
+        return self.coeff
+
+    def maybe_as_int(self) -> Optional[int]:
+        return None
+
+    def is_int(self) -> bool:
+        return True
+
+    def is_float(self) -> bool:
+        return False
+
+    def is_bool(self) -> bool:
+        return False
+
+    def is_nested_int(self) -> bool:
+        return True
+
+    def clone(self) -> "NestedIntNode":
+        return self
+
+    def _str(self) -> Any:
+        if self.coeff == 1:
+            return f"j{self.t_id}"
+        return f"{self.coeff}*j{self.t_id}"
+
+    def str(self) -> Any:
+        return self._str()
+
+    def __str__(self) -> Any:
+        return self._str()
+
+    def __repr__(self) -> Any:
+        return self._str()
+
+    def _graph_repr(self) -> Any:
+        return self._str()
+
+    def mul(self, other: Any) -> "NestedIntNode":
+        if other.is_constant():
+            other = other.constant_int()
+        else:
+            raise ValueError(f"unsupported: {type(other)}")
+        return NestedIntNode(self.t_id, self.coeff * other)
+
+    def eq(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_eq(self, other))
+
+    def ne(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _eq(self, other))
+
+    def gt(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _ge(other, self))
+
+    def lt(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(not _ge(self, other))
+
+    def le(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_ge(other, self))
+
+    def ge(self, other: Any) -> Any:
+        return torch._C._get_constant_bool_symnode(_ge(self, other))
+
+    def is_symbolic(self) -> bool:
+        return False
+
+    def nested_int(self) -> int:
+        return self.t_id
+
+    def is_constant(self) -> bool:
+        return False
+
+    def wrap_int(self, num: int) -> ConstantIntNode:
+        assert type(num) is int
+        return ConstantIntNode(num)
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py b/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..14e71c506385eb3c827c029bd00b9986c19badc5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py
@@ -0,0 +1,665 @@
+# mypy: allow-untyped-defs
+from typing import *  # noqa: F403
+
+import torch
+from torch._C import DispatchKey, DispatchKeySet
+from torch._prims_common import is_expandable_to
+from torch.nested._internal.nested_int import NestedIntNode
+from torch.utils.weak import WeakTensorKeyDictionary
+
+
+_tensor_id_counter = 0
+_tensor_symint_registry = WeakTensorKeyDictionary()
+
+
+def get_tensor_symint(tensor, *, coeff=1):
+    from torch._subclasses.fake_tensor import FakeTensor
+    from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
+
+    # NB: Only FakeTensor is associated with a memo
+    tensor = mb_unwrap_functional_tensor(tensor)
+    if isinstance(tensor, FakeTensor):
+        return tensor.get_nested_int(coeff=coeff)
+
+    global _tensor_id_counter
+
+    tensor_symint = _tensor_symint_registry.get(tensor)
+    if tensor_symint is None:
+        tensor_symint = torch.SymInt(NestedIntNode(_tensor_id_counter, coeff))
+        _tensor_id_counter += 1
+        _tensor_symint_registry[tensor] = tensor_symint
+    return tensor_symint
+
+
+# SDPA metadata; max / min seqlens are needed for e.g. flash
+def _get_sdpa_extreme_seqlen(func, tensor):
+    return int(func(tensor).item())
+
+
+def _store_val_in_tensor(val) -> torch.Tensor:
+    # hack to get dynamic shapes support: store in a (val, 0) shaped tensor
+    return torch.zeros(val, 0)
+
+
+def _load_val_from_tensor(t: torch.Tensor):
+    return t.shape[0]
+
+
+# serialization function must be defined at top level
+def _rebuild_njt(constructor_kwargs):
+    return NestedTensor(**constructor_kwargs)
+
+
+class NestedTensor(torch.Tensor):
+    _values: torch.Tensor  # type: ignore[assignment]
+    _offsets: torch.Tensor
+    _lengths: Optional[torch.Tensor]
+    # NOTE [ Nested ints for ragged sizes and strides ]
+    #
+    # Jagged layout tensors are tensors that represent a n-dim tensor with a
+    # ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
+    # a jagged tensor with outer shape [B, x, D] is represented internally by a
+    # tensor with shape [sum(x), D] where we introduce what we call a nested int
+    # denoted as "x" here (but sometimes denoted with "*" to
+    # represent the ragged dimension, and sum(x) represents the dim of the inner
+    # tensor or equivalently the sum of all the sizes of the constituent
+    # tensors' varying lengths.
+    #
+    # We also use nested ints to represent the strides of this tensor.
+    # For example, a jagged tensor with shape [B, x, D] can be strided in two
+    # ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
+    _size: tuple[int, ...]
+    _strides: tuple[int, ...]
+    # Indicates that the nth dimension is ragged
+    _ragged_idx: int
+    _metadata_cache: Dict[str, Any]
+
+    @staticmethod
+    def __new__(
+        cls,
+        values,
+        offsets,
+        *,
+        lengths=None,
+        **kwargs,
+    ):
+        ks = DispatchKeySet(DispatchKey.NestedTensor)
+        ks = ks.add(DispatchKey.AutogradNestedTensor)
+
+        # Only support jagged for now.
+        assert offsets is not None
+        assert offsets.ndim == 1
+        assert not isinstance(values, NestedTensor)
+        assert values.device == offsets.device
+
+        # Query cache for the symint associated with offsets or lengths
+        # (create a new one if needed).
+        ragged_source = offsets if lengths is None else lengths
+        ragged_size = get_tensor_symint(ragged_source, coeff=1)
+        _ragged_idx = kwargs.get("_ragged_idx", 1)
+        B = offsets.shape[0] - 1
+        if lengths is not None:
+            assert B == lengths.shape[0]
+
+        # subtract 1 to convert to values dim space
+        r = _ragged_idx - 1
+        _size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :])
+        stride = values.stride()
+        _strides = (ragged_size * stride[r], *stride)
+
+        r = torch.Tensor._make_wrapper_subclass(
+            cls,
+            _size,
+            _strides,
+            0,
+            torch.contiguous_format,
+            values.dtype,
+            torch.jagged,
+            values.device,
+            False,
+            kwargs.get("requires_grad", False),
+            "sizes",
+            False,
+            True,  # dispatch_layout
+            ks,
+            # don't try to calculate storage based on non-zero size
+            storage_size=values.untyped_storage().size(),
+        )
+        r._ragged_idx = _ragged_idx
+        r._size = _size
+        r._strides = _strides
+
+        return r
+
+    def __init__(self, values, offsets, *, lengths=None, **kwargs):
+        super().__init__()
+
+        self._values = values
+        self._offsets = offsets
+        self._lengths = lengths
+
+        # holds properties that are computed lazily
+        self._metadata_cache = kwargs.get("_metadata_cache") or {}
+
+        # collapsed ragged dim must always be dynamic
+        torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx)
+        torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1)
+
+        # min / max sequence length should be dynamic if present
+        max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None)
+        if max_seqlen_tensor is not None:
+            torch._dynamo.mark_dynamic(max_seqlen_tensor, 0)
+        min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None)
+        if min_seqlen_tensor is not None:
+            torch._dynamo.mark_dynamic(min_seqlen_tensor, 0)
+
+    def values(self):
+        # dispatch to get proper view relationship
+        return torch._nested_get_values(self)  # type: ignore[attr-defined]
+
+    def offsets(self):
+        return self._offsets
+
+    def lengths(self):
+        return self._lengths
+
+    # Private accessor functions for min / max sequence length. They're
+    # purposefully not @properties because those don't work with PT2 (yet).
+    # These compute / cache if not present.
+    # TODO: Revisit this when @properties are better supported by PT2. I think the ideal
+    # state would be to have public @properties for min / max sequence length that compile
+    # (including setters).
+    def _get_max_seqlen(self):
+        max_seqlen_tensor = self._max_seqlen_tensor
+        if max_seqlen_tensor is None:
+            # compute & cache
+            max_val = _get_sdpa_extreme_seqlen(
+                torch.max,
+                self._offsets.diff() if self._lengths is None else self._lengths,
+            )
+            max_seqlen_tensor = _store_val_in_tensor(max_val)
+            self._metadata_cache["max_seqlen"] = max_seqlen_tensor
+        return _load_val_from_tensor(max_seqlen_tensor)
+
+    def _get_min_seqlen(self):
+        min_seqlen_tensor = self._min_seqlen_tensor
+        if min_seqlen_tensor is None:
+            # compute & cache
+            min_val = _get_sdpa_extreme_seqlen(
+                torch.min,
+                self._offsets.diff() if self._lengths is None else self._lengths,
+            )
+            min_seqlen_tensor = _store_val_in_tensor(min_val)
+            self._metadata_cache["min_seqlen"] = min_seqlen_tensor
+        return _load_val_from_tensor(min_seqlen_tensor)
+
+    # Private accessors used for treating min / max seqlen as inner tensors for
+    # flatten / unflatten. These must be properties to work with the traceable wrapper
+    # subclass logic. These do not compute / cache if not present.
+    @property
+    def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
+        return self._metadata_cache.get("max_seqlen", None)
+
+    @_max_seqlen_tensor.setter
+    def _max_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
+        self._metadata_cache["max_seqlen"] = val
+
+    @property
+    def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
+        return self._metadata_cache.get("min_seqlen", None)
+
+    @_min_seqlen_tensor.setter
+    def _min_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
+        self._metadata_cache["min_seqlen"] = val
+
+    # These are old private @property accessors that are kept around for internal BC
+    # reasons. TODO: Remove these!
+    @property
+    def _max_seqlen(self):
+        return self._get_max_seqlen()
+
+    @property
+    def _min_seqlen(self):
+        return self._get_min_seqlen()
+
+    # Convenience accessors that return a min / max seqlen if one is present and do NOT
+    # compute / cache them if they're not.
+    @property
+    def _maybe_max_seqlen(self) -> Optional[int]:
+        mt = self._max_seqlen_tensor
+        return None if mt is None else _load_val_from_tensor(mt)
+
+    @property
+    def _maybe_min_seqlen(self) -> Optional[int]:
+        mt = self._min_seqlen_tensor
+        return None if mt is None else _load_val_from_tensor(mt)
+
+    def __repr__(self):  # type: ignore[override]
+        # We should implement this in torch/_tensor_str.py instead
+        grad_fn_str = (
+            f", requires_grad={self.requires_grad}" if self.requires_grad else ""
+        )
+        if self.grad_fn:
+            grad_fn_str = f", grad_fn={self.grad_fn}"
+        return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})"
+
+    # TODO: Remove this in favor of the default tensor subclass serialization logic.
+    # We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.
+    def __reduce_ex__(self, proto):
+        state = torch._utils._get_obj_state(self)
+
+        # Cached PyCapsules for sizes / strides are not serializable.
+        # See Note [Tensor Subclass custom size/stride caching strategy]
+        self._clear_non_serializable_cached_data()
+        # SymNodes are not serializable
+        assert "_size" in state and "_strides" in state
+        state = dict(state)
+        del state["_size"]
+        del state["_strides"]
+
+        func = _rebuild_njt
+        constructor_kwargs = {
+            "values": self._values,
+            "offsets": self._offsets,
+            "lengths": self._lengths,
+            "_ragged_idx": self._ragged_idx,
+            "_metadata_cache": self._metadata_cache,
+            "requires_grad": self.requires_grad,
+        }
+        args = (constructor_kwargs,)
+        return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))
+
+    def __tensor_flatten__(self):
+        ctx = {
+            "requires_grad": self.requires_grad,
+            "ragged_idx": self._ragged_idx,
+        }
+        inner_tensors = ["_values", "_offsets"]
+        if self._lengths is not None:
+            inner_tensors.append("_lengths")
+        if self._min_seqlen_tensor is not None:
+            inner_tensors.append("_min_seqlen_tensor")
+        if self._max_seqlen_tensor is not None:
+            inner_tensors.append("_max_seqlen_tensor")
+        return inner_tensors, ctx
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
+        from torch._subclasses.fake_tensor import FakeTensor
+
+        # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen]
+        assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5
+        values = inner_tensors["_values"]
+        offsets = inner_tensors["_offsets"]
+        lengths = inner_tensors.get("_lengths", None)
+        min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None)
+        max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None)
+
+        metadata_cache = {}
+        if min_seqlen_tensor is not None:
+            metadata_cache["min_seqlen"] = min_seqlen_tensor
+        if max_seqlen_tensor is not None:
+            metadata_cache["max_seqlen"] = max_seqlen_tensor
+        ragged_idx = meta["ragged_idx"]
+
+        # Alternatively, we could make it the caller's responsibility to
+        # cache it. But this heuristic seems simple enough.
+        ragged_source = offsets if lengths is None else lengths
+        if isinstance(ragged_source, FakeTensor):
+            ragged_size = outer_size[ragged_idx]
+            ragged_source.nested_int_memo = ragged_size
+
+        return NestedTensor(
+            values,
+            offsets=offsets,
+            lengths=lengths,
+            requires_grad=meta["requires_grad"],
+            _ragged_idx=ragged_idx,
+            _metadata_cache=metadata_cache,
+        )
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):  # type: ignore[override]
+        # If you're wondering why there's a nested tensor with one of its
+        # size = -1, see note: [NJT outer_size in AOTDispatcher]
+        kwargs = {} if kwargs is None else kwargs
+
+        # Lazy import to avoid circular dependency
+        from .ops import lookup_jagged
+
+        fn = lookup_jagged(func, *args, **kwargs)
+        if fn is not None:
+            return fn(*args, **kwargs)
+
+        # Poor man's redispatch for composite ops. This becomes relevant under inference
+        # mode, where disabling autograd key dispatch prevents decomposition.
+        all_dks = (
+            # We want to handle both the cases where NestedTensor overrides the
+            # composite implicit autograd kernel, and the case where it doesn't.
+            # Prioritize calling into NestedTensor's kernel if it exists.
+            torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor,
+            torch._C.DispatchKey.CompositeImplicitAutograd,
+        )
+        for dk in all_dks:
+            if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
+                with torch.overrides.enable_reentrant_dispatch():
+                    return func._op_dk(dk, *args, **kwargs)
+
+        raise NotImplementedError(func)
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        if kwargs is None:
+            kwargs = {}
+
+        from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
+
+        from .ops import jagged_torch_function
+
+        # This should be removed after
+        # https://github.com/pytorch/pytorch/pull/125941/ lands
+        with maybe_enable_thunkify():
+            try:
+                return jagged_torch_function(func, *args, **kwargs)
+            except NotImplementedError:
+                pass
+            with torch._C.DisableTorchFunctionSubclass():
+                return func(*args, **kwargs)
+
+
+# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!
+# TODO: Remove ViewBufferFromNested, ViewNestedFromBuffer, and buffer_from_jagged once the
+# internal BC period has passed.
+
+
+# Not actually a view!
+class ViewBufferFromNested(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x: NestedTensor):  # type: ignore[override]
+        ctx.save_for_backward(x.offsets())
+        ctx.metadata_cache = x._metadata_cache
+        ctx.ragged_idx = x._ragged_idx
+        return x._values
+
+    @staticmethod
+    def backward(ctx, gO: torch.Tensor):  # type: ignore[override]
+        (offsets,) = ctx.saved_tensors
+        return NestedTensor(
+            gO,
+            offsets=offsets,
+            _metadata_cache=ctx.metadata_cache,
+            _ragged_idx=ctx.ragged_idx,
+        )
+
+
+# Not actually a view!
+class ViewNestedFromBuffer(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        values: torch.Tensor,
+        offsets: torch.Tensor,
+        metadata_cache: Optional[Dict[str, Any]] = None,
+    ):  # type: ignore[override]
+        # maintain BC with this usages of this where the seqlens are stuffed
+        # directly into the metadata cache as non-Tensors / ints
+        if metadata_cache is not None:
+            min_seqlen = metadata_cache.get("min_seqlen", None)
+            max_seqlen = metadata_cache.get("max_seqlen", None)
+            if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor):
+                metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen)
+            if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor):
+                metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen)
+        return NestedTensor(
+            values.detach(),
+            offsets=offsets,
+            _metadata_cache=metadata_cache,
+        )
+
+    @staticmethod
+    def backward(ctx, gO: NestedTensor):  # type: ignore[override]
+        return gO._values, None, None
+
+
+def buffer_from_jagged(jagged):
+    return ViewBufferFromNested.apply(jagged)
+
+
+# Need to make it obvious that users should be passing in offsets
+def jagged_from_list(
+    tensors: List[torch.Tensor],
+    offsets: Optional[torch.Tensor],
+    dtype=None,
+    device=None,
+) -> tuple[NestedTensor, torch.Tensor]:
+    """Constructs a NestedTensor backed by jagged layout from a list of tensors"""
+
+    if len(tensors) == 0:
+        raise RuntimeError("Cannot construct a nested tensor from an empty tensor list")
+    if not len(set(t.dtype for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must have the same dtype"
+        )
+    if not len(set(t.device for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must be on the same device"
+        )
+    if not len(set(t.dim() for t in tensors)) == 1:  # noqa: C401
+        raise RuntimeError(
+            "When constructing a nested tensor, all tensors in list must have the same dim"
+        )
+    component_dim = tensors[0].dim()
+    if component_dim == 0:
+        raise RuntimeError(
+            "Cannot construct a nested tensor from a list of zero-dim tensors"
+        )
+
+    # Check that the NT is representable by the jagged layout, which
+    # allows for a single ragged dimension after the batch dim.
+    # e.g. (B, *, D_0, ..., D_N), (B, D_0, *, ..., D_N), etc.
+    sizes = [t.shape for t in tensors]
+    ragged_idx = None
+    for d in range(component_dim):
+        dim_is_ragged = any(size[d] != sizes[0][d] for size in sizes)
+        if dim_is_ragged:
+            if ragged_idx is None:
+                # add 1 to convert to outer NJT dim space
+                ragged_idx = d + 1
+            else:
+                raise RuntimeError(
+                    "Cannot represent given tensor list as a nested tensor with the jagged layout. "
+                    "Note that the jagged layout only allows for a single ragged dimension. "
+                    "For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim."
+                )
+
+    # allow for a rectangular NJT and default the ragged dim next to the batch dim
+    if ragged_idx is None:
+        ragged_idx = 1
+
+    # Set properties appropriately.
+    values = torch.cat(tensors, dim=(ragged_idx - 1))
+    to_kwargs = {}
+    if device is not None:
+        to_kwargs["device"] = device
+    if dtype is not None:
+        to_kwargs["dtype"] = dtype
+    values = values.to(**to_kwargs)
+
+    # Calculate jagged offsets if not provided.
+    if offsets is None:
+        # Jagged layout specifies that offsets are stored as int64 on the same device as values.
+        # TODO: An alternative way to construct offsets is to use F.pad. This avoids creating
+        # an extra leaf tensor during the forward, potentially resolving compatibility issues.
+        offsets = torch.cat(
+            [
+                torch.zeros(1, dtype=torch.int64, device=values.device),
+                torch.tensor(
+                    [s[ragged_idx - 1] for s in sizes], device=values.device
+                ).cumsum(dim=0),
+            ]
+        )
+
+    # compute this now since it's easy
+    min_seqlen = min(t.shape[ragged_idx - 1] for t in tensors)
+    max_seqlen = max(t.shape[ragged_idx - 1] for t in tensors)
+    ret_nt = nested_view_from_values_offsets(
+        values,
+        offsets,
+        min_seqlen=min_seqlen,
+        max_seqlen=max_seqlen,
+        ragged_idx=ragged_idx,
+    )
+    return (ret_nt, offsets)  # type: ignore[return-value]
+
+
+def jagged_from_tensor_and_lengths(
+    tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
+) -> tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
+    """Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
+    batch_size = tensor.shape[0]
+    if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
+        lengths.shape, (batch_size,)
+    ):
+        start_list = starts.expand(batch_size)
+        length_list = lengths.expand(batch_size)
+    else:
+        raise RuntimeError(
+            "When constructing a jagged nested tensor using narrow(), "
+            "your start and length must be Tensors that broadcast to input.shape[0]"
+        )
+
+    # Calculate jagged offsets
+    assert len(tensor.shape) >= 2, (
+        "tensor must at least be 2D for the nested narrow op to work"
+    )
+    max_seq_len = tensor.shape[1]
+    offset_lengths = max_seq_len * torch.arange(
+        0, batch_size, dtype=torch.int64, device=tensor.device
+    )
+    # Jagged layout specifies that offsets are stored as int64 on the same device as values.
+    offsets = torch.cat(
+        [
+            start_list + offset_lengths,
+            (start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
+        ]
+    )
+
+    # Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
+    if len(tensor.shape) > 2:
+        values = tensor.view(-1, *tensor.shape[2:])
+    else:
+        values = tensor.view(-1)
+
+    # Check if offsets and lengths make it possibly contiguous and return a regular NT
+    is_contiguous = True
+    orig_dim = tensor.shape[1]
+    if torch.any(length_list[1:-1].ne(orig_dim)):
+        is_contiguous = False
+    if torch.any(offsets[1:-2].diff().ne(orig_dim)):
+        is_contiguous = False
+    if offsets[0] + length_list[0] != orig_dim:
+        is_contiguous = False
+
+    actual_max_seqlen = int(torch.max(lengths).item())
+    min_seqlen = int(torch.min(lengths).item())
+
+    if is_contiguous:
+        ret_nt = nested_view_from_values_offsets(
+            values[offsets[0] : offsets[-1]],
+            offsets - offsets[0],
+            min_seqlen=min_seqlen,
+            max_seqlen=actual_max_seqlen,
+        )
+    else:
+        ret_nt = nested_view_from_values_offsets_lengths(
+            values,
+            offsets,
+            length_list,
+            min_seqlen=min_seqlen,
+            max_seqlen=actual_max_seqlen,
+        )
+
+    return (ret_nt, offsets, None if is_contiguous else length_list)
+
+
+# NB: A dummy arg is required so that NestedTensor.__torch_dispatch__() is invoked
+# for _nested_view_from_values_offsets(). Sizes don't matter much, but they shouldn't be
+# 0/1 because the dummy can be fake-ified and we want to avoid specializing.
+# This arg is otherwise unused.
+_dummy_instance: Optional[torch.Tensor] = None
+
+
+def _nt_view_dummy() -> torch.Tensor:
+    global _dummy_instance
+    if _dummy_instance is None:
+        _dummy_instance = NestedTensor(
+            values=torch.zeros(3, 3, device="meta"),
+            offsets=torch.zeros(3, device="meta", dtype=torch.int64),
+        ).detach()
+    return _dummy_instance
+
+
+def nested_view_from_values_offsets(
+    values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_view_from_jagged(  # type: ignore[attr-defined]
+        values,
+        offsets,
+        _nt_view_dummy(),
+        None,
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+    )  # type: ignore[return-value]
+
+
+def nested_view_from_values_offsets_lengths(
+    values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_view_from_jagged(  # type: ignore[attr-defined]
+        values,
+        offsets,
+        _nt_view_dummy(),
+        lengths,
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+    )  # type: ignore[return-value]
+
+
+def nested_from_padded(
+    padded, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None, sum_S=None
+):
+    min_seqlen_tensor = None
+    if min_seqlen is not None:
+        min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
+
+    max_seqlen_tensor = None
+    if max_seqlen is not None:
+        max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
+
+    return torch._nested_from_padded_tensor(
+        padded,
+        offsets,
+        _nt_view_dummy(),
+        ragged_idx,
+        min_seqlen_tensor,
+        max_seqlen_tensor,
+        sum_S,
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/_internal/ops.py b/.venv/lib/python3.12/site-packages/torch/nested/_internal/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eb962f8a308d20b96ee8b21cba6911a3429e5d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/nested/_internal/ops.py
@@ -0,0 +1,2743 @@
+# mypy: allow-untyped-defs
+import functools
+import math
+import operator
+from typing import *  # noqa: F403
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch.fx.operator_schemas import normalize_function
+from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
+
+from .nested_tensor import NestedTensor
+
+
+__all__: list[Any] = []
+
+JAGGED_OPS_TABLE: Dict[Any, Any] = {}
+
+
+def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
+    from torch._prims_common import canonicalize_dims
+
+    if isinstance(dim, (tuple, list)):
+        output = type(dim)(_outer_to_inner_dim(ndim, d, ragged_dim) for d in dim)
+        # ensure no duplicates, which can result from both batch and ragged mapping to 0
+        return type(output)(dict.fromkeys(output))
+
+    if canonicalize:
+        dim = canonicalize_dims(ndim, dim)
+
+    assert dim >= 0 and dim < ndim
+
+    # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
+    # For other dims, subtract 1 to convert to inner space.
+    return ragged_dim - 1 if dim == 0 else dim - 1
+
+
+def _wrap_jagged_dim(
+    ndim,
+    dim,
+    ragged_dim,
+    op_name,
+    convert_to_inner_dim=True,
+    allow_ragged_dim=False,
+    allow_batch_dim=False,
+):
+    from torch._prims_common import canonicalize_dims
+
+    wrapped = canonicalize_dims(ndim, dim)
+    if wrapped == ragged_dim and not allow_ragged_dim:
+        raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
+    elif wrapped == 0 and not allow_batch_dim:
+        raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
+    ret = (
+        _outer_to_inner_dim(ndim, wrapped, ragged_dim)
+        if convert_to_inner_dim
+        else wrapped
+    )
+    if allow_batch_dim:
+        # Need to disambiguate whether we're operating on the batch dim or not.
+        # Operating on dim=1 -> dim=0 after the inner dim conversion.
+        operating_on_batch = wrapped == 0
+        return (ret, operating_on_batch)
+    return ret
+
+
+def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
+    """
+    For NestedTensor operators,
+    wraps dimensions to non-negative values,
+    and returns metadata related to reduction dimension(s).
+    """
+    from torch._prims_common import canonicalize_dims
+
+    assert isinstance(dims, (tuple, list)), (
+        f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
+    )
+
+    wrapped_dims = [
+        canonicalize_dims(ndim, d) for d in dims
+    ]  # convert all indices to non-negative values
+
+    operate_on_batch = 0 in wrapped_dims
+    operate_on_ragged = ragged_idx in wrapped_dims
+    operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims)
+
+    # ensure no duplicates, which can result from both batch and ragged mapping to 0
+    outer_to_inner_dim = tuple(
+        dict.fromkeys(_outer_to_inner_dim(ndim, d, ragged_idx) for d in wrapped_dims)
+    )
+
+    return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
+
+
+def check_schema(schema_str: str, func, *args, **kwargs) -> None:
+    named_arg_types = schema_str.split(", ")
+    num_optional_args = [x.endswith("?") for x in named_arg_types].count(True)
+    min_args = len(named_arg_types) - num_optional_args
+
+    # special case: ellipses allows for any number of unchecked args at the end
+    if named_arg_types[-1] == "...":
+        named_arg_types = named_arg_types[:-1]
+    else:
+        if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
+            raise ValueError(
+                f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
+                f"arguments and at most {len(named_arg_types)} arguments, but got: "
+                f"{len(args)} arguments"
+            )
+
+    arg_type_check_fns = {
+        "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
+        "jt": lambda x: isinstance(x, NestedTensor)
+        and x._lengths is None
+        and x._ragged_idx == 1,  # ops with "jt" require contiguous JT only
+        "jt_all": lambda x: isinstance(
+            x, NestedTensor
+        ),  # ops with "jt_all" can accept all kinds of JT
+        "any": lambda x: True,
+    }
+    for i, named_arg_type in enumerate(named_arg_types):
+        name, arg_type = named_arg_type.split(": ")
+        is_optional = arg_type.endswith("?")
+        normalized_arg_type = arg_type[:-1] if is_optional else arg_type
+        if normalized_arg_type not in arg_type_check_fns.keys():
+            raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
+
+        if i >= len(args):
+            if not is_optional:
+                raise ValueError(
+                    f"NestedTensor {func.__name__}({schema_str}) "
+                    f"missing required argument: {name}"
+                )
+            continue
+
+        _check_fn = arg_type_check_fns[normalized_arg_type]
+
+        def check_fn(x, is_optional=is_optional):
+            if is_optional:
+                return x is None or _check_fn(x)
+            else:
+                return _check_fn(x)
+
+        if not check_fn(args[i]):
+            type_to_desc = {
+                "t": "tensor",
+                "t?": "optional tensor",
+                "jt": "contiguous jagged layout NestedTensor",
+                "jt_all": "jagged layout NestedTensor",
+                "any": "",
+            }
+
+            raise ValueError(
+                f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a "
+                f"{type_to_desc[arg_type]}"
+            )
+
+
+def check_ragged_dim_same(
+    func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
+) -> None:
+    # Calling into .shape here
+    if a._size[a._ragged_idx] != b._size[b._ragged_idx]:
+        raise RuntimeError(
+            f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
+            "same exact offsets tensor."
+        )
+
+
+# returns True if the raggedness-relevant portions of the NT shape
+# match those of the specified size
+def raggedness_matches(nt, size):
+    end = nt._ragged_idx + 1
+    nt_ragged = nt._size[:end]
+    size_ragged = size[:end]
+    return len(nt_ragged) == len(size_ragged) and (
+        all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged))
+    )
+
+
+def squeeze_leading_ones(t):
+    # Note: [ Squeezing leading ones ]
+    #
+    # Squeeze leading ones from t.
+    #
+    # We want:
+    #   (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+    #   (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)  (not yet supported)
+    #
+    # 1) Squeeze extra ones and grab values from NT
+    #   (1, 1, ?, ?) -> (?, ?)   and   (sum(*), ?, ?) -> (B, j0, ?, ?)
+    # 2) Do dense broadcasting:
+    #   (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
+    # 3) Construct nested tensor
+    #   (sum(*), ?, ?) -> (B, j0, ?, ?)
+    #
+    # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
+    # at step (4) and we would need to update this function to record how
+    # many ones we unsqueezed.
+    while t.dim() > 0 and t.shape[0] == 1:
+        t = t.squeeze(0)
+    return t
+
+
+def register_func(tables, aten_ops, schema_str):
+    if not isinstance(aten_ops, list):
+        aten_ops = [aten_ops]
+    if not isinstance(tables, list):
+        tables = [tables]
+
+    def wrapper(func):
+        for aten_op in aten_ops:
+
+            def get_inner(aten_op):
+                def inner(*args, **kwargs):
+                    check_schema(schema_str, func, *args, **kwargs)
+                    return func(aten_op, *args, **kwargs)
+
+                return inner
+
+            for table in tables:
+                table[aten_op] = get_inner(aten_op)
+        return func
+
+    return wrapper
+
+
+register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
+
+
+def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
+    dispatch_func = JAGGED_OPS_TABLE.get(func, None)
+    if dispatch_func is not None:
+        return dispatch_func
+
+    # Handle pointwise fallbacks
+    if torch.Tag.pointwise in func.tags:
+        from torch.fx.experimental.symbolic_shapes import is_nested_int
+
+        # No pointwise ops legitimately accept nested int inputs. Without this check,
+        # they will be incorrectly interpreted as tensors.
+        # See https://github.com/pytorch/pytorch/issues/138496
+        for arg in args:
+            if is_nested_int(arg):
+                raise RuntimeError(
+                    f"NestedTensor {func.__name__}: invalid argument {arg}"
+                )
+
+        # Assume there aren't additional tensors that aren't the "unary/binary" args
+        num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args)
+        if num_tensor_args == 1:
+            # Build up the check schema string. The first tensor arg is assumed to be
+            # an NJT and other args are sent through as-is.
+            schema_parts = []
+            for arg in func._schema.arguments:
+                if isinstance(arg.type, torch.TensorType):
+                    schema_parts.append(f"{arg.name}: jt_all")
+                    break
+                else:
+                    schema_parts.append(f"{arg.name}: any")
+            schema_parts.append("...")
+            check_schema_str = ", ".join(schema_parts)
+            check_schema(check_schema_str, func, *args, **kwargs)
+            return functools.partial(jagged_unary_pointwise, func)
+        elif num_tensor_args == 2:
+            check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
+            return functools.partial(jagged_binary_pointwise, func)
+
+    return None
+
+
+def extract_kwargs(arg):
+    kwargs = {
+        "offsets": arg.offsets(),
+        "lengths": arg.lengths(),
+        "_metadata_cache": arg._metadata_cache,
+        "_ragged_idx": arg._ragged_idx,
+    }
+    return kwargs
+
+
+def jagged_unary_pointwise(func, *args, **kwargs):
+    # assume if we get here that there is a single NJT input in the args
+    njt = next(arg for arg in args if isinstance(arg, NestedTensor))
+    return NestedTensor(
+        func(*(arg._values if arg is njt else arg for arg in args), **kwargs),
+        **extract_kwargs(njt),
+    )
+
+
+def jagged_binary_pointwise(func, *args, **kwargs):
+    a, b = args[0], args[1]
+    assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor)
+
+    mismatch_error_msg = (
+        "cannot call binary pointwise function {} with inputs of shapes {} and {}"
+    )
+    # a is NT, b is NT
+    if isinstance(a, NestedTensor) and isinstance(b, NestedTensor):
+        # ex: (B, j0, D) + (B, j0, D)
+        # ex: (B, j0, D) + (B, j0, 1)
+        if raggedness_matches(a, b._size):
+            return NestedTensor(
+                func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a)
+            )
+        raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size))
+    # either a is NT or b is NT at this point
+    a_is_nt = isinstance(a, NestedTensor)
+    extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b)
+
+    # === Handle broadcasting across the batch / ragged dims ===
+
+    # Easy case: take advantage of pre-existing broadcasting logic
+    # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
+    # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
+    # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+    nt, t = (a, b) if a_is_nt else (b, a)
+    # See Note: [ Squeezing leading ones ]
+    if t.dim() > nt.dim():
+        raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
+    t_squeezed = squeeze_leading_ones(t)
+    if nt.dim() >= t_squeezed.dim() + 2:
+        lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
+        return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
+
+    # Harder case: do manual broadcasting when NT dim == non-NT dim
+    # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1)
+    if a.dim() == b.dim():
+        # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should
+        # be (B, j0, D_0, D_1) but not yet supported
+        if a.shape[0] != b.shape[0]:
+            raise RuntimeError(
+                mismatch_error_msg.format(func.__name__, a.shape, b.shape)
+            )
+
+        from .nested_tensor import nested_from_padded
+
+        # handle broadcasting via padded dense -> jagged conversion
+        min_seqlen = nt._maybe_min_seqlen
+        max_seqlen = nt._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = nt._values.shape[nt._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        # convert dense tensor -> jagged
+        t = t.expand(
+            [x if i != nt._ragged_idx else padded_max_S for i, x in enumerate(t.shape)]
+        )
+        t_as_nt = nested_from_padded(
+            t,
+            offsets=nt._offsets,
+            ragged_idx=nt._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+        # function call with two NJTs
+        lhs, rhs = (nt, t_as_nt) if a_is_nt else (t_as_nt, nt)
+        return func(lhs, rhs, *args[2:], **kwargs)
+
+    # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant
+    # that ragged dim is wrt left-most batch dim
+    raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape))
+
+
+def jagged_torch_function(func, *args, **kwargs):
+    # SDPA has special kernels that handle nested tensors.
+    # Dispatch to the correct implementation here
+    if func is torch._C._nn.scaled_dot_product_attention:
+        return jagged_scaled_dot_product_attention(*args, **kwargs)
+
+    if func.__name__ == "apply_":
+        func(args[0]._values, *args[1:], **kwargs)
+        return args[0]
+
+    # Handle flatten() here because it's CompositeImplicit.
+    if func.__name__ == "flatten":
+
+        def _flatten_sig(input, start_dim=0, end_dim=-1):
+            pass
+
+        _, new_kwargs = normalize_function(  # type: ignore[misc]
+            _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+        )
+
+        inp = new_kwargs.pop("input")
+
+        # NB: stay in outer dim space because we're going to redispatch on a NT input
+        start_dim = _wrap_jagged_dim(
+            inp.dim(),
+            new_kwargs["start_dim"],
+            inp._ragged_idx,
+            "flatten",
+            convert_to_inner_dim=False,
+        )
+        end_dim = _wrap_jagged_dim(
+            inp.dim(),
+            new_kwargs["end_dim"],
+            inp._ragged_idx,
+            "flatten",
+            convert_to_inner_dim=False,
+        )
+
+        if start_dim == end_dim:
+            return inp
+
+        product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1])
+        new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :])
+
+        return inp.reshape(*new_shape)
+
+    # Handle nested-specific input validation for CompositeImplicit rms_norm
+    if func.__name__ == "rms_norm":
+
+        def _rms_norm_sig(input, normalized_shape, weight=None, eps=None):
+            pass
+
+        _, new_kwargs = normalize_function(  # type: ignore[misc]
+            _rms_norm_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+        )
+
+        inp = new_kwargs.pop("input")
+        normalized_shape = new_kwargs.pop("normalized_shape")
+
+        # can't normalize over the ragged dim (yet)
+        max_normalizable = inp.dim() - inp._ragged_idx - 1
+        if len(normalized_shape) > max_normalizable:
+            raise ValueError(
+                "rms_norm(): Normalization over the ragged dim not supported for nested tensors"
+            )
+
+        with torch._C.DisableTorchFunctionSubclass():
+            return func(*args, **kwargs)
+
+    raise NotImplementedError(func)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.is_non_overlapping_and_dense.default,
+        torch.ops.aten.sym_size.default,
+        torch.ops.aten.dim.default,
+        torch.ops.aten.numel.default,
+        torch.ops.aten.sym_numel.default,
+        torch.ops.aten.sym_stride.default,
+        torch.ops.aten.sym_storage_offset.default,
+    ],
+    "self: jt_all",
+)
+def tensor_attr_supported_getter(func, *args, **kwargs):
+    if func == torch.ops.aten.is_non_overlapping_and_dense.default:
+        return False
+
+    if func == torch.ops.aten.sym_size.default:
+        return args[0]._size
+
+    if func == torch.ops.aten.dim.default:
+        return len(args[0]._size)
+
+    if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default):
+        if args[0]._lengths is not None:
+            return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:]))
+        return args[0]._values.numel()
+
+    if func == torch.ops.aten.sym_stride.default:
+        return args[0]._strides
+
+    if func == torch.ops.aten.sym_storage_offset.default:
+        return args[0]._values.storage_offset()
+
+
+@register_jagged_func(torch.ops.prim.layout.default, "self: jt_all")
+def prim_layout_default(func, *args, **kwargs):
+    return torch.jagged
+
+
+@register_jagged_func(
+    [torch.ops.aten.size.default],
+    "self: jt_all",
+)
+def tensor_attr_unsupported_getter(func, *args, **kwargs):
+    if func == torch.ops.aten.size.default:
+        raise RuntimeError(
+            "NestedTensor does not support directly calling torch.ops.aten.size; "
+            "please use `nested_tensor.size()` instead."
+        )
+
+
+@register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all")
+def is_contiguous_general(func, *args, **kwargs):
+    from torch._prims_common import is_contiguous_for_memory_format
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+
+    # If created from narrow() check for lengths
+    if inp.lengths() is not None:
+        return False
+
+    new_kwargs["memory_format"] = new_kwargs.get(
+        "memory_format", torch.contiguous_format
+    )
+    if new_kwargs["memory_format"] == torch.preserve_format:
+        return True
+    return is_contiguous_for_memory_format(inp._values, **new_kwargs)
+
+
+register_jagged_func(
+    torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?"
+)(is_contiguous_general)
+
+
+@register_jagged_func(
+    torch.ops.aten.clone.default, "input: jt_all, memory_format: any?"
+)
+def clone_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_meta = extract_kwargs(inp)
+
+    if inp._lengths is not None:
+        if new_kwargs["memory_format"] == torch.contiguous_format:
+            # need to copy to remove "holes" non-contiguity / lengths metadata
+            # TODO: write a kernel for this
+            from .nested_tensor import jagged_from_list
+
+            # TODO: We probably want the output to have the same ragged structure / nested int.
+            assert inp._ragged_idx == 1, (
+                "NJT with ragged_idx != 1 not supported for contiguous clone"
+            )
+            contig, _ = jagged_from_list(inp.unbind(), offsets=None)
+            return contig
+
+    return NestedTensor(func(inp._values, **new_kwargs), **new_meta)
+
+
+@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
+def linear_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.linear_backward.default,
+    "self: jt, grad_output: jt, weight: t, output_mask: any",
+)
+def linear_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    grad_output = new_kwargs.pop("grad_output")
+    weight = new_kwargs.pop("weight")
+    output_mask = new_kwargs.pop("output_mask")
+
+    ds, dw, db = None, None, None
+    check_ragged_dim_same(func, inp, "self", grad_output, "grad_output")
+    if output_mask[0]:
+        ds = NestedTensor(
+            torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output)
+        )
+    if output_mask[1]:
+        # NB: Fold dims of values for input and grad_output to treat them as 2D. This
+        # trick avoids materializing large intermediates and immediately reducing over
+        # them via sum(). This is equivalent to computing:
+        #     torch.matmul(grad_output._values.transpose(-2, -1), inp._values)
+        # and then summing over the leading dimensions to get a 2D weight grad.
+        grad_2d = grad_output._values.reshape(-1, weight.size(0))
+        input_2d = inp._values.reshape(-1, weight.size(1))
+        dw = torch.matmul(grad_2d.t(), input_2d)
+    if output_mask[2]:
+        # Sum over all but the last dim to get a 1D bias grad. We cannot
+        # rely on the autograd engine to reduce for us, because returning a
+        # tensor aliasing the input would violate the aten signature annotation
+        reduce_dims = tuple(range(grad_output._values.ndim - 1))
+        if reduce_dims == ():
+            db = grad_output._values.clone()
+        else:
+            db = torch.sum(grad_output._values, reduce_dims, keepdim=False)
+    return (ds, dw, db)
+
+
+@register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any")
+def to_dtype(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
+def to_copy_default(func, *args, **kwargs):
+    from .nested_tensor import _tensor_symint_registry
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    # don't change layout
+    new_kwargs.pop("layout")
+
+    new_values = func(inp._values, **new_kwargs)
+    new_offsets = inp._offsets.to(device=new_values.device)
+    new_lengths = None
+    if inp._lengths is not None:
+        new_lengths = inp._lengths.to(device=new_values.device)
+
+    from torch._subclasses.fake_tensor import FakeTensor
+    from torch._subclasses.functional_tensor import (
+        FunctionalTensor,
+        mb_unwrap_functional_tensor,
+    )
+
+    ragged_source = inp._offsets if inp._lengths is None else inp._lengths
+    new_thing = new_offsets if new_lengths is None else new_lengths
+    if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
+        # Temporary hack until we have the union find
+        tgt = mb_unwrap_functional_tensor(new_thing)
+        src = mb_unwrap_functional_tensor(ragged_source)
+        tgt.nested_int_memo = src.nested_int_memo
+    else:
+        _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
+    inp_kwargs = extract_kwargs(inp)
+    inp_kwargs["offsets"] = new_offsets
+    inp_kwargs["lengths"] = new_lengths
+
+    output = NestedTensor(new_values, **inp_kwargs)
+    return output
+
+
+@register_jagged_func(
+    torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?"
+)
+def copy_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    src = new_kwargs.pop("src")
+    if inp._size != src._size:
+        # try to recursively copy_ on unbound components to get around nested int mismatch
+        # TODO: eventually do a direct copy when this is possible
+        inp_comps = inp.unbind()
+        inp_comp_shapes = [c.shape for c in inp_comps]
+        src_comps = src.unbind()
+        src_comp_shapes = [c.shape for c in src_comps]
+        if inp_comp_shapes != src_comp_shapes:
+            raise RuntimeError(
+                "copy_(): expected compatible input and src shapes, but got: "
+                f"{inp.shape} and {src.shape}"
+            )
+        for inp_comp, src_comp in zip(inp_comps, src_comps):
+            inp_comp.copy_(src_comp)
+
+    # AOTD allows mutations of inputs only, (not views of the inputs).
+    # NJT.values() returns _values.detach() to workaround some issues.
+    # To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).
+    # Here we directly mutate self._values to not emit .detach() in the graph, which would make it non-compilable.
+    inp._values.copy_(src._values)
+    return inp
+
+
+register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")(
+    jagged_unary_pointwise
+)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.empty_like.default,
+        torch.ops.aten.ones_like.default,
+        torch.ops.aten.zeros_like.default,
+        torch.ops.aten.rand_like.default,
+        torch.ops.aten.randn_like.default,
+    ],
+    "self: jt_all",
+)
+def like_factory_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # Default layout is technically torch.strided but only jagged is supported here.
+    # Rather than force users to specify the layout, assume jagged.
+    # This should be set to strided for redispatching on values.
+    new_kwargs["layout"] = torch.strided
+
+    new_values = func(inp._values, **new_kwargs)
+    new_offsets = inp._offsets.to(device=new_values.device)
+    new_lengths = None
+    if inp._lengths is not None:
+        new_lengths = inp._lengths.to(device=new_values.device)
+    output_kwargs = extract_kwargs(inp)
+    if "offsets" in output_kwargs:
+        output_kwargs["offsets"] = new_offsets
+    if "lengths" in output_kwargs:
+        output_kwargs["lengths"] = new_lengths
+
+    if inp.device != new_values.device:
+        # Update the nested int registry to indicate that the ragged structure is the same
+        # between the two offsets / lengths on different devices.
+        from torch._subclasses.fake_tensor import FakeTensor
+        from torch._subclasses.functional_tensor import (
+            FunctionalTensor,
+            mb_unwrap_functional_tensor,
+        )
+
+        from .nested_tensor import _tensor_symint_registry
+
+        ragged_source = inp._offsets if inp._lengths is None else inp._lengths
+        new_thing = new_offsets if new_lengths is None else new_lengths
+        if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
+            # Temporary hack until we have the union find
+            tgt = mb_unwrap_functional_tensor(new_thing)
+            src = mb_unwrap_functional_tensor(ragged_source)
+            tgt.nested_int_memo = src.nested_int_memo
+        else:
+            _tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
+
+    return NestedTensor(new_values, **output_kwargs)
+
+
+register_jagged_func(torch.ops.aten.full_like.default, "self: jt_all, fill_value: any")(
+    like_factory_default
+)
+
+register_jagged_func(torch.ops.aten.randint_like.default, "self: jt_all, high: any")(
+    like_factory_default
+)
+
+register_jagged_func(
+    torch.ops.aten.randint_like.low_dtype, "self: jt_all, low: any, high: any"
+)(like_factory_default)
+
+
+@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
+def zero__default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    func(inp._values)
+    return inp
+
+
+@register_jagged_func(
+    torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any"
+)
+def _softmax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    if isinstance(new_kwargs["dim"], tuple):
+        raise RuntimeError(
+            "softmax(): not supported for dimensions of type 'tuple' for NestedTensor"
+        )
+
+    inp = new_kwargs.pop("input")
+
+    (
+        new_kwargs["dim"],
+        reduce_on_batch,
+        reduce_on_ragged,
+        _reduce_on_non_batch,
+    ) = _wrap_jagged_dims(
+        inp.dim(),
+        (new_kwargs["dim"],),
+        "softmax",
+        inp._ragged_idx,
+    )
+
+    if reduce_on_batch:
+        raise RuntimeError(
+            "softmax(): not supported when reducing across the batch dimension for NestedTensor"
+        )
+
+    if reduce_on_ragged and inp._ragged_idx > 1:
+        raise RuntimeError(
+            "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor"
+        )
+
+    if reduce_on_ragged and inp._lengths is not None:
+        raise RuntimeError(
+            "softmax(): not supported where lengths is not None "
+            + "if reducing across the ragged dimension for NestedTensor"
+        )
+
+    new_kwargs["dim"] = new_kwargs["dim"][
+        0
+    ]  # torch.softmax takes in the reduction dimension as an integer
+
+    if reduce_on_ragged:
+        padded_softmax_values = torch.nn.functional.softmax(
+            torch.ops.aten._jagged_to_padded_dense_forward(
+                inp._values.reshape(
+                    inp._values.shape[0], -1
+                ),  # values are required to be 2D tensors for j2pd
+                [inp._offsets],
+                max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+                padding_value=float("-inf"),  # e^-inf = 0
+            ),
+            dim=inp._ragged_idx,
+        )
+
+        softmax_values = torch.ops.aten._padded_dense_to_jagged_forward(
+            padded_softmax_values,
+            [inp._offsets],
+            total_L=inp._values.shape[
+                0
+            ],  # providing this parameter helps avoid a GPU/CPU sync
+        ).reshape(
+            -1, *inp._values.shape[1:]
+        )  # expand softmax_values back to original shape (inp._values.shape)
+
+        return NestedTensor(softmax_values, **extract_kwargs(inp))
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten._softmax_backward_data.default,
+    "grad_output: jt, output: jt, dim: any, input_dtype: any",
+)
+def _softmax_backward(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_out = new_kwargs.pop("grad_output")
+    output = new_kwargs.pop("output")
+    return NestedTensor(
+        func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?"
+)
+def native_dropout_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    out1, out2 = func(inp._values, **new_kwargs)
+    return (
+        NestedTensor(out1, **extract_kwargs(inp)),
+        NestedTensor(out2, **extract_kwargs(inp)),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.native_dropout_backward.default,
+    "grad_output: jt, mask: jt, scale: any",
+)
+def native_dropout_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_output = new_kwargs.pop("grad_output")
+    mask = new_kwargs.pop("mask")
+    return NestedTensor(
+        func(grad_output._values, mask._values, **new_kwargs),
+        **extract_kwargs(grad_output),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.prod.dim_int,
+    "self: jt_all, dim: any, keepdim: any?, dtype: any?",
+)
+def prod_dim_int(func, *args, **kwargs):
+    return _apply_reduction(func, "prod", 1, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.prod.default, "self: jt_all, dtype: any?")
+def prod_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any?"
+)
+def split_tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split"
+    )
+
+    return tuple(
+        NestedTensor(values=x, **extract_kwargs(inp))
+        for x in func(inp._values, **new_kwargs)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any?"
+)
+def split_with_sizes_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split_with_sizes"
+    )
+
+    return [
+        NestedTensor(values=x, **extract_kwargs(inp))
+        for x in func(inp._values, **new_kwargs)
+    ]
+
+
+@register_jagged_func(
+    torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any"
+)
+def narrow(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+
+    dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
+    values = func(
+        inp._values,
+        dim=dim,
+        start=new_kwargs["start"],
+        length=new_kwargs["length"],
+    )
+    return NestedTensor(values, **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?")
+def chunk_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "chunk", allow_batch_dim=True
+    )
+
+    if operating_on_batch:
+        chunks = new_kwargs["chunks"]
+
+        # get _offsets of the chunks
+        lengths = inp._offsets.diff()
+        chunked_lengths = lengths.chunk(chunks)
+        chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths]
+        chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets]  # type: ignore[arg-type]
+        nested_kwargs = [
+            {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx}
+            for per_offsets in chunked_offsets
+        ]
+
+        # get _values of the chunks
+        split_sizes = [x.sum().item() for x in chunked_lengths]
+        chunk_values = inp._values.split(split_sizes)
+
+        # Note that the actual number of chunks returned is not necessarily the same as
+        # the input number; it can be counter-intuitive, but it matches dense behavior.
+        return [
+            NestedTensor(values=chunk_values[i], **(nested_kwargs[i]))
+            for i in range(0, len(chunk_values))
+        ]
+    else:
+        return [
+            NestedTensor(values=x, **extract_kwargs(inp))
+            for x in func(inp._values, **new_kwargs)
+        ]
+
+
+@register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?")
+def unbind_int(func, *args, **kwargs):
+    # Note that this specializes on the length of the offsets
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dim = new_kwargs["dim"]
+    if dim != 0:
+        raise RuntimeError("unbind(): only supported for NestedTensor on dim=0")
+
+    inp = new_kwargs.pop("input")
+    values = inp.values()
+    offsets = inp.offsets()
+    lengths = inp.lengths()
+    ragged_idx = inp._ragged_idx
+
+    def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
+        # This torch._check and torch._check_is_size are needed for torch.compile
+        # symbolic shapes processing.
+        # offsets and lengths are symbolic variables during compilation,
+        # we guarantee the correct offsets/lengths correspondence:
+        # sum of lengths <= total ragged_dim_size
+        # every length and offset are size-like variable (allows sym shapes to reason it as [2, inf))
+        # offset[i] + length[i] <= ragged_dim_size, for unbind and split dim correctness
+        # offsets[i] <= ragged_dim_size
+
+        lengths_sum = 0
+        ragged_dim_size = values.shape[ragged_idx - 1]
+        for i in range(len(_lengths)):
+            torch._check_is_size(_lengths[i])
+            torch._check(_lengths[i] <= ragged_dim_size)
+
+            lengths_sum += _lengths[i]
+            if _offsets is not None:
+                torch._check(
+                    _offsets[i] + _lengths[i] <= ragged_dim_size,
+                    lambda: "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension",
+                )
+        torch._check(lengths_sum <= ragged_dim_size)
+
+        if _offsets is not None:
+            for i in range(len(_offsets)):
+                torch._check_is_size(_offsets[i])
+                torch._check(_offsets[i] <= ragged_dim_size)
+
+    if lengths is None:
+        lengths_scalars = offsets.diff().tolist()
+        _torch_check(lengths_scalars)
+
+        return torch.split(values, lengths_scalars, dim=(ragged_idx - 1))
+
+    if ragged_idx <= 0:
+        raise RuntimeError(
+            "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)"
+        )
+
+    lengths_scalars = lengths.tolist()
+    offsets_scalars = offsets.tolist()
+
+    _torch_check(lengths_scalars, offsets_scalars)
+
+    return [
+        torch.narrow(
+            values,
+            dim=(ragged_idx - 1),
+            start=offsets_scalars[i],
+            length=lengths_scalars[i],
+        )
+        for i in range(lengths.shape[0])
+    ]
+
+
+@register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any")
+def squeeze_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    values = inp._values
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(inp._size), new_kwargs["dim"], inp._ragged_idx, "squeeze"
+    )
+    return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt_all, dim: any")
+def unsqueeze_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    values = inp._values
+
+    # Account for collapsed jagged dim
+    dim = new_kwargs["dim"]
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(inp._size) + 1, dim, inp._ragged_idx, "unsqueeze", allow_ragged_dim=True
+    )
+
+    # ragged_idx changes if a dimension is added before it
+    output_kwargs = extract_kwargs(inp)
+    if new_kwargs["dim"] <= inp._ragged_idx - 1:
+        output_kwargs["_ragged_idx"] += 1
+
+    return NestedTensor(func(values, **new_kwargs), **output_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any")
+def cat_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    tensors = new_kwargs.pop("tensors")
+
+    # Convert any non-nested to nested
+    nested = [t for t in tensors if t.is_nested]
+    assert len(nested) > 0
+    first = nested[0]
+    tensors = [t if t.is_nested else t.expand_as(first) for t in tensors]
+
+    # Account for collapsed jagged dim
+    dim = new_kwargs["dim"]
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(first.shape), dim, first._ragged_idx, "cat"
+    )
+
+    return NestedTensor(
+        func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
+    )
+
+
+@register_jagged_func(torch.ops.aten.matmul.default, "self: any, other: any")
+def matmul_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    def _unbind_impl(a, b):
+        return [
+            func(a_comp, b_comp) for (a_comp, b_comp) in zip(a.unbind(), b.unbind())
+        ]
+
+    def _padded_impl(a, b):
+        if a.is_nested:
+            nt = a
+        else:
+            nt = b
+
+        from .nested_tensor import nested_from_padded
+
+        min_seqlen = nt._maybe_min_seqlen
+        max_seqlen = nt._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = nt._values.shape[nt._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        padded_shape = (
+            *nt.shape[: nt._ragged_idx],
+            padded_max_S,
+            *nt.shape[nt._ragged_idx + 1 :],
+        )
+        padded_nt = nt.to_padded_tensor(0.0, output_size=padded_shape)
+        if a.is_nested:
+            padded_t = func(padded_nt, b)
+        else:
+            padded_t = func(a, padded_nt)
+        return nested_from_padded(
+            padded_t,
+            offsets=nt._offsets,
+            ragged_idx=nt._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+    # TODO: Back these with proper kernels (e.g. grouped GEMM)
+    # NJT x dense
+    if inp.is_nested and not other.is_nested:
+        # (B, j1, D) x (B, D, E) => (B, j1, E)
+        if (
+            inp.dim() >= 3
+            and inp.dim() == other.dim()
+            and inp._ragged_idx < inp.dim() - 1
+        ):
+            # convert to padded for this
+            return _padded_impl(inp, other)
+        # Support broadcasting the dense:
+        # (B, j1, D) x (D, E) => (B, j1, E)
+        # (B, j1, D, E) x (E, F) => (B, j1, D, F)
+        # etc.
+        elif (
+            other.dim() == 2
+            and inp.dim() > other.dim()
+            and inp._ragged_idx < inp.dim() - 1
+        ):
+            return NestedTensor(
+                func(inp._values, other, **new_kwargs), **extract_kwargs(inp)
+            )
+    # Dense x NJT
+    elif not inp.is_nested and other.is_nested:
+        # (B, D, E) x (B, E, j1) => (B, E, j1)
+        if other.dim() >= 3 and other.dim() == inp.dim() and other._ragged_idx >= 2:
+            # convert to padded for this
+            return _padded_impl(inp, other)
+        # Support broadcasting the dense:
+        # (D, E) x (B, E, j1) => (B, D, j1)
+        # (D, E) x (B, E, j1, F) => (B, D, j1, F)
+        # etc.
+        elif inp.dim() == 2 and other.dim() > inp.dim() and other._ragged_idx >= 2:
+            return NestedTensor(
+                func(inp, other._values, **new_kwargs), **extract_kwargs(other)
+            )
+
+    # NJT x NJT
+    elif inp.is_nested and other.is_nested:
+        # Support ragged batch dim:
+        # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F), etc.
+        if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size):
+            return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp))
+        # Support reducing over ragged with dense output:
+        # (B, D, j1) x (B, j1, E) => (B, D, E)
+        elif (
+            inp.dim() == 3
+            and other.dim() == 3
+            and inp._ragged_idx == 2
+            and other._ragged_idx == 1
+            and inp.size(inp._ragged_idx) == other.size(other._ragged_idx)
+        ):
+            # do unbind for this; can't use padded conversion due to j1 in last dim
+            return torch.stack(_unbind_impl(inp, other))
+
+    raise RuntimeError(
+        f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}"
+    )
+
+
+@register_jagged_func(torch.ops.aten.bmm.default, "self: jt_all, mat2: any")
+def bmm_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("mat2")
+
+    if inp.dim() != 3:
+        raise ValueError("bmm(): input must be 3D")
+    if other.dim() != 3:
+        raise ValueError("bmm(): mat2 must be 3D")
+
+    return matmul_default(torch.ops.aten.matmul.default, inp, other)
+
+
+@register_jagged_func(
+    torch.ops.aten.expand.default, "self: jt_all, size: any, implicit: any?"
+)
+def expand_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs["size"]
+
+    assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit"))
+    if not raggedness_matches(inp, size):
+        raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}")
+
+    expand_arg = [-1 if d == inp._ragged_idx else size[d] for d in range(1, inp.dim())]
+    return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt")
+def expand_as_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    return NestedTensor(func(inp, other._values), **extract_kwargs(other))
+
+
+@register_jagged_func(torch.ops.aten.broadcast_to.default, "self: jt_all, size: any")
+def broadcast_to(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs.pop("size")
+
+    if len(size) <= inp.dim():
+        return inp.expand([*(1 for _ in range(inp.dim() - len(size))), *size])
+
+    raise ValueError(
+        "broadcast_to(): broadcasting to a higher-dim shape is currently not supported "
+        "for nested tensors with the jagged layout"
+    )
+
+
+@register_jagged_func(torch.ops.aten.broadcast_tensors.default, "tensors: any")
+def broadcast_tensors(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    tensors = new_kwargs.pop("tensors")
+    if len(tensors) == 0:
+        raise ValueError("broadcast_tensors(): expected at least one tensor input")
+    if len(tensors) == 1:
+        return tensors[0]
+
+    outs = []
+    broadcast_shape = torch.broadcast_shapes(*(t.shape for t in tensors))
+    # Pull out the first NJT. If broadcast_shapes() worked, the nested ints are compatible.
+    njt = next(t for t in tensors if isinstance(t, NestedTensor))
+    for t in tensors:
+        if t.is_nested:
+            outs.append(t.broadcast_to(broadcast_shape))
+        elif t.dim() < len(broadcast_shape):
+            outs.append(
+                NestedTensor(t.broadcast_to(njt._values.shape), **extract_kwargs(njt))
+            )
+        else:
+            raise ValueError(
+                "broadcast_tensors(): broadcasting nested tensors with dense tensors of equal "
+                "or higher dim is not currently supported"
+            )
+
+    return tuple(outs)
+
+
+@register_jagged_func(
+    torch.ops.aten.where.self, "condition: jt_all, self: any, other: any"
+)
+def where_self(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    condition = new_kwargs.pop("condition")
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+
+    # if the tensors aren't compatible, broadcast_tensors() will let us know
+    condition, inp, other = torch.broadcast_tensors(condition, inp, other)
+
+    return NestedTensor(
+        func(condition._values, inp._values, other._values, **new_kwargs),
+        **extract_kwargs(condition),
+    )
+
+
+@register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?")
+def _pin_memory_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?")
+def is_pinned_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all"
+)
+def is_same_size_default(func, *args, **kwargs):
+    return args[0]._size == args[1]._size
+
+
+def _apply_reduction(func, func_name, identity_element, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # some ops use dim=None to indicate a full reduction; some use an empty dim list
+    full_reduction = new_kwargs["dim"] is None or (
+        isinstance(new_kwargs["dim"], (tuple, list)) and len(new_kwargs["dim"]) == 0
+    )
+    if full_reduction:
+        out = func(inp._values, **new_kwargs)
+        if new_kwargs.get("keepdim", False):
+            if isinstance(out, (tuple, list)):
+                # some ops return multiple things; unsqueeze all of them
+                out = type(out)(o.unsqueeze(inp._ragged_idx) for o in out)
+            else:
+                out = out.unsqueeze(inp._ragged_idx)
+        return out
+
+    # some ops support lists of dims; some don't
+    dim_to_convert = new_kwargs["dim"]
+    is_dimlist = isinstance(new_kwargs["dim"], (tuple, list))
+    if not is_dimlist:
+        dim_to_convert = [dim_to_convert]
+
+    (
+        converted_dim,
+        reduce_on_batch,
+        reduce_on_ragged,
+        reduce_on_non_batch,
+    ) = _wrap_jagged_dims(
+        inp.dim(),
+        dim_to_convert,
+        f"{func_name}",
+        inp._ragged_idx,
+    )
+
+    if not is_dimlist:
+        # convert back from list
+        converted_dim = converted_dim[0]
+    new_kwargs["dim"] = converted_dim
+
+    if reduce_on_ragged and inp._lengths is not None:
+        raise RuntimeError(
+            f"{func_name}(): reducing across the ragged dimension is not supported "
+            "for non-contiguous nested tensors with holes"
+        )
+
+    from torch.utils._pytree import tree_map
+
+    # raggedness reduced away --> return dense tensor
+    if reduce_on_ragged:
+        # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
+        if reduce_on_batch:
+            # no need to read offsets --> apply sum directly on values
+            out = func(inp._values, **new_kwargs)
+            if new_kwargs.get("keepdim", False):
+                # some ops return multiple things; unsqueeze all of them
+                out = tree_map(lambda o: o.unsqueeze(0), out)
+            return out
+        else:
+            # invalid reduction cases: (ragged, non-batch), etc.
+            if reduce_on_non_batch:
+                raise RuntimeError(
+                    f"{func_name}(): reducing along a ragged and non-batch dimension "
+                    "is not supported for nested tensors"
+                )
+
+            # reduction cases: (ragged)
+            # convert to padded dense and reduce
+            new_kwargs.pop("dim")
+            dim_to_pass = [inp._ragged_idx] if is_dimlist else inp._ragged_idx
+            return func(
+                inp.to_padded_tensor(identity_element), dim=dim_to_pass, **new_kwargs
+            )
+    # raggedness preserved --> return nested tensor
+    else:
+        # invalid reduction cases: (batch), (batch, non-batch), etc.
+        if reduce_on_batch:
+            raise RuntimeError(
+                f"{func_name}(): reducing along the batch dimension but not "
+                "the ragged dimension is not supported for nested tensors"
+            )
+
+        # reduction cases: (non-batch), (non-batch, non-batch), etc.
+        # apply sum directly on values
+        out = func(inp._values, **new_kwargs)
+        out_kwargs = extract_kwargs(inp)
+        if not new_kwargs.get("keepdim", False):
+            # dims are reduced away -> ragged_idx of output needs to be reevaluated
+            dimlist = (
+                new_kwargs["dim"]
+                if isinstance(new_kwargs["dim"], (tuple, list))
+                else [new_kwargs["dim"]]
+            )
+            for d in dimlist:
+                # adjust for all dims reduced before the ragged dim
+                if d < inp._ragged_idx - 1:
+                    out_kwargs["_ragged_idx"] -= 1
+
+        # some ops return multiple things; wrap each of them as an NJT
+        return tree_map(lambda o: NestedTensor(o, **out_kwargs), out)
+
+
+@register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?")
+def sum_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.sum.dim_IntList,
+    "self: jt_all, dim: any?, keepdim: any?, dtype: any?",
+)
+def sum_dim_IntList(func, *args, **kwargs):
+    return _apply_reduction(func, "sum", 0, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any"
+)
+def transpose_int(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    from torch._prims_common import canonicalize_dims
+
+    inp = new_kwargs.pop("input")
+    dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"]))
+
+    # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2
+    # instead of 1, although the internal Flash and mem-effn implementations will
+    # use the inputs with raggedness in dim 1.
+    if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx:
+        if dim0 == 0 or dim1 == 0:
+            raise ValueError(
+                "Transpose is not supported on the batch dimension for jagged NT"
+            )
+        if dim0 == inp._ragged_idx:
+            to_dim = dim1
+        else:
+            to_dim = dim0
+        inp_kwargs = extract_kwargs(inp)
+        inp_kwargs["_ragged_idx"] = to_dim
+        return NestedTensor(
+            inp.values().transpose(
+                _outer_to_inner_dim(len(inp._size), dim0, inp._ragged_idx),
+                _outer_to_inner_dim(len(inp._size), dim1, inp._ragged_idx),
+            ),
+            **inp_kwargs,
+        )
+
+    new_kwargs["dim0"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim0"], inp._ragged_idx, "transpose"
+    )
+    new_kwargs["dim1"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim1"], inp._ragged_idx, "transpose"
+    )
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any")
+def permute_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    dims = new_kwargs.pop("dims")
+    inp_kwargs = extract_kwargs(inp)
+    inp_dim = len(inp._size)
+
+    # The first two checks are the same as the checks in the normal permute implementation
+    if inp_dim != len(dims):
+        raise ValueError(
+            f"permute(): number of dimensions in the tensor input ({inp_dim}) "
+            + f"does not match the length of the desired ordering of dimensions ({len(dims)}).",
+        )
+
+    from torch._prims_common import canonicalize_dims
+
+    canonicalized_dims = canonicalize_dims(inp_dim, dims)
+
+    if len(canonicalized_dims) != len(set(canonicalized_dims)):
+        raise ValueError("permute(): duplicate dims are not allowed.")
+
+    if inp._lengths is not None:
+        raise ValueError(
+            "permute(): not supported on jagged layout nested tensor with holes"
+        )
+    if canonicalized_dims[0] != 0:
+        raise ValueError(
+            "Permute is not supported on the batch dimension for jagged NT"
+        )
+    inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx)
+    inner_dims = [
+        _outer_to_inner_dim(inp_dim, dim, inp._ragged_idx)
+        for dim in canonicalized_dims[1:]
+    ]
+    new_kwargs["dims"] = inner_dims
+    return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs)
+
+
+@register_jagged_func(
+    [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default],
+    "self: jt_all, size: any",
+)
+def view_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    size = new_kwargs.pop("size")
+
+    if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size):
+        raise RuntimeError(
+            f"view(): does not support ragged_idx != 1 except when inp._size == size. "
+            f"inp._size is ({inp._size}) and size is ({size})."
+        )
+
+    # Ensure specified size still includes batch and ragged dims
+    if len(size) < 3 or not raggedness_matches(inp, size):
+        raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
+
+    # outer size: the size of the NT, e.g. [3, j0, 10]
+    # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8])
+    # this function gets inner_size[inner_idx] for a given inner_idx.
+    #
+    # example: for outer size [a, b, c, j0, d, e, f]
+    #                         assume that j0 is ragged, other are concrete integers
+    #                         and ragged_idx=3
+    # inner size will be      [b, c, inp._values.size(ragged_idx), d, e, f]
+    # therefore:
+    #    inner_size[0] = outer_size[1]
+    #    inner_size[1] = outer_size[2]
+    #    inner_size[0] = inp._values.size(ragged_idx - 1)
+    #    inner_size[3] = outer_size[4]
+    #    inner_size[4] = outer_size[5]
+    def get_inner_size(inner_idx):
+        nonlocal inp, size
+        if inner_idx == inp._ragged_idx - 1:
+            return inp._values.size(inner_idx)
+        else:
+            return size[inner_idx + 1]
+
+    inner_size = [get_inner_size(i) for i in range(len(size) - 1)]
+
+    # Preserve inference-mode-ness of input.
+    # TODO: Do this for all other views!
+    with torch.inference_mode(inp.is_inference()):
+        return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.native_layer_norm.default,
+    "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
+)
+def native_layer_norm_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if inp.dim() <= 2:
+        raise RuntimeError(
+            "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions"
+        )
+
+    normalized_shape = new_kwargs["normalized_shape"]
+    ragged_size = inp.shape[inp._ragged_idx]
+
+    num_dims_not_normalized = inp.dim() - len(normalized_shape)
+
+    if (
+        num_dims_not_normalized == 0
+    ):  # error if trying to normalize over the batch dimension
+        raise RuntimeError(
+            "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor"
+        )
+
+    if ragged_size in normalized_shape and inp._lengths is not None:
+        raise RuntimeError(
+            "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor"
+        )
+
+    if (
+        ragged_size in normalized_shape
+    ):  # special handling for normalizing over the ragged dimension
+        padded_input = torch.ops.aten._jagged_to_padded_dense_forward(
+            inp._values.flatten(
+                start_dim=inp._ragged_idx
+            ),  # _jagged_to_padded_dense_forward requires values to be a 2D tensor
+            [inp._offsets],
+            max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+        )
+
+        padded_mask = torch.ops.aten._jagged_to_padded_dense_forward(
+            torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype),
+            [inp._offsets],
+            max_lengths=[inp._max_seqlen],  # max length of ragged dimension
+        ).expand(
+            padded_input.shape
+        )  # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor)
+
+        ragged_lengths = (
+            inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2]
+        )  # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize)
+
+        mean = (
+            torch.sum(
+                padded_input,
+                dim=(1, 2),
+                keepdim=True,
+            )
+            / ragged_lengths
+        )  # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
+
+        padded_normalized = (
+            (padded_input - mean) * padded_mask
+        )  # mask elements outside of the ragged dimension size for correct variance calculation
+
+        variance = (
+            torch.sum(
+                torch.square(padded_normalized),
+                dim=(1, 2),
+                keepdim=True,
+            )
+            / ragged_lengths
+        )  # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
+
+        std = torch.sqrt(variance + new_kwargs["eps"])
+        padded_layer_norm = padded_normalized / std
+
+        jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward(
+            padded_layer_norm,
+            [inp._offsets],
+            total_L=inp._values.shape[
+                0
+            ],  # providing this parameter helps avoid a GPU/CPU sync
+        ).unflatten(
+            -1, inp.shape[inp._ragged_idx + 1 :]
+        )  # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H)
+
+        return (
+            NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)),
+            mean,
+            std,
+        )
+
+    output, mean, std = func(inp._values, **new_kwargs)
+    return (NestedTensor(output, **extract_kwargs(inp)), mean, std)
+
+
+@register_jagged_func(
+    torch.ops.aten.native_layer_norm_backward.default,
+    "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any",
+)
+def native_layer_norm_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    grad_out = new_kwargs.pop("grad_out")
+    inp = new_kwargs.pop("input")
+    d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs)
+    if d_input is None:
+        return (None, d_gamma, d_beta)
+
+    return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta)
+
+
+@register_jagged_func(torch.ops.aten.select.int, "self: jt_all, dim: any, index: any")
+def select_int(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True
+    )
+
+    # handle batch dim slicing via unbind() for now
+    # TODO: make this more efficient
+    if operating_on_batch:
+        return inp.unbind()[new_kwargs["index"]]
+
+    if inp._lengths is not None:
+        raise ValueError(
+            "select(): not yet supported on dim != 0 for non-contiguous nested tensor with holes"
+        )
+
+    # if selecting before the ragged dim, adjust output ragged_idx
+    out_kwargs = extract_kwargs(inp)
+    if new_kwargs["dim"] < inp._ragged_idx - 1:
+        out_kwargs["_ragged_idx"] -= 1
+
+    return NestedTensor(func(inp._values, **new_kwargs), **out_kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.slice.Tensor,
+    "self: jt, dim: any?, start: any?, end: any?, step: any?",
+)
+def slice_tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        inp.dim(), new_kwargs["dim"], inp._ragged_idx, "slice"
+    )
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.index_put.default,
+    "input: jt_all, indices: any, values: t, accumulate: any?",
+)
+@register_jagged_func(
+    torch.ops.aten.index_put_.default,
+    "input: jt_all, indices: any, values: t, accumulate: any?",
+)
+def index_put_(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp: NestedTensor = new_kwargs.pop("input")
+
+    # For index_put_ to work, we add together the indices of the ragged dimension
+    # and the batch dimension, adding the offsets of each ragged dimension to its
+    # indices
+
+    indices = new_kwargs.pop("indices")
+
+    assert len(indices) <= inp.dim()
+
+    if len(indices) < inp._ragged_idx + 1:
+        if not inp.is_contiguous():
+            raise RuntimeError(
+                "index_put(): If ragged dimension is not part of indices, this only works on contiguous NJTs"
+            )
+        # Ragged dim is NOT part of indices, we need to pad the nested tensor to apply func
+        from .nested_tensor import nested_from_padded
+
+        min_seqlen = inp._maybe_min_seqlen
+        max_seqlen = inp._maybe_max_seqlen
+        padded_max_S = max_seqlen
+        total_L = inp._values.shape[inp._ragged_idx - 1]
+        if padded_max_S is None:
+            # use upper bound on max seqlen if it's not present
+            padded_max_S = total_L
+
+        padded_shape = (
+            *inp.shape[: inp._ragged_idx],
+            padded_max_S,
+            *inp.shape[inp._ragged_idx + 1 :],
+        )
+        padded_inp = inp.to_padded_tensor(0.0, output_size=padded_shape)
+        new_njt = nested_from_padded(
+            func(padded_inp, indices, **new_kwargs),
+            offsets=inp._offsets,
+            ragged_idx=inp._ragged_idx,
+            sum_S=total_L,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        )
+
+        if func == torch.ops.aten.index_put_.default:
+            inp._values.copy_(new_njt.values())
+            return inp
+        return new_njt
+
+    # We can run on the underlying values directly
+
+    # Validate indices
+    if inp.lengths() is None:
+        lengths = inp.offsets().diff()
+    else:
+        lengths = inp.lengths()
+    torch._assert_async(
+        torch.all(indices[inp._ragged_idx] < lengths),
+        "Some indices in the ragged dimension are out of bounds!",
+    )
+
+    # Recompute indices for _values
+    ragged_indices = inp.offsets()[indices[0]] + indices[inp._ragged_idx]
+    func_indices = (
+        # before ragged dim
+        indices[1 : inp._ragged_idx]
+        # ragged dim (combined with batch)
+        + [ragged_indices]
+        # after ragged dim
+        + indices[inp._ragged_idx + 1 :]
+    )
+
+    if func == torch.ops.aten.index_put_.default:
+        inp._values = func(inp._values, func_indices, **new_kwargs)
+        return inp
+
+    return NestedTensor(
+        func(inp._values, func_indices, **new_kwargs),
+        **extract_kwargs(inp),
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.convolution.default,
+    "input: jt, weight: t, bias: t?, stride: any, padding: any, "
+    "dilation: any, transposed: any, output_padding: any, groups: any",
+)
+def convolution_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(
+    torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?"
+)
+def mean_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs["input"]
+    (_, reduce_on_batch, reduce_on_ragged, reduce_on_non_batch) = _wrap_jagged_dims(
+        inp.dim(),
+        new_kwargs["dim"],
+        "mean",
+        inp._ragged_idx,
+    )
+
+    if reduce_on_ragged and not reduce_on_batch:
+        assert not reduce_on_non_batch
+        # calculate an intermediate sum and leave the dim in for normalization purposes
+        keepdim = new_kwargs["keepdim"]
+        new_kwargs["keepdim"] = True
+        intermediate_sum = _apply_reduction(
+            torch.ops.aten.sum.dim_IntList, "mean", 0, **new_kwargs
+        )
+
+        # normalize by sequence lengths
+        lengths = inp._lengths if inp._lengths is not None else inp._offsets.diff()
+        for _ in range(intermediate_sum.dim() - 1):
+            lengths = lengths.unsqueeze(-1)
+        out = intermediate_sum / lengths
+        if not keepdim:
+            out = out.squeeze(inp._ragged_idx)
+        return out
+
+    # at this point, we're just redispatching on the values buffer
+    # since we expect it to be unused, specify a weird intermediate value to
+    # hopefully make errors obvious
+    intermediate_value = 0.42
+    return _apply_reduction(func, "mean", intermediate_value, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.mean.default, "self: jt_all, dtype: any?")
+def mean_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.any.dims, "self: jt_all, dim: any?, keepdim: any?")
+def any_dims(func, *args, **kwargs):
+    return _apply_reduction(func, "any", False, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.any.dim, "self: jt_all, dim: any, keepdim: any?")
+def any_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # wrap dim in list to redispatch to dims overload
+    new_kwargs["dim"] = [new_kwargs["dim"]]
+    return any_dims(torch.ops.aten.any.dims, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.all.dims, "self: jt_all, dim: any?, keepdim: any?")
+def all_dims(func, *args, **kwargs):
+    return _apply_reduction(func, "all", True, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.all.dim, "self: jt_all, dim: any, keepdim: any?")
+def all_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # wrap dim in list to redispatch to dims overload
+    new_kwargs["dim"] = [new_kwargs["dim"]]
+    return all_dims(torch.ops.aten.all.dims, **new_kwargs)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.all.default,
+        torch.ops.aten.any.default,
+        torch.ops.aten.max.default,
+        torch.ops.aten.min.default,
+    ],
+    "self: jt_all",
+)
+def all_any_max_min_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values, **new_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.min.dim, "self: jt_all, dim: any, keepdim: any?")
+def min_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype_max = torch.finfo(new_kwargs["input"].dtype).max
+    return _apply_reduction(func, "min", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(torch.ops.aten.max.dim, "self: jt_all, dim: any, keepdim: any?")
+def max_dim(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype_min = torch.finfo(new_kwargs["input"].dtype).min
+    return _apply_reduction(func, "max", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.amin.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def amin_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype_max = torch.finfo(new_kwargs["input"].dtype).max
+    return _apply_reduction(func, "amin", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.amax.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def amax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype_min = torch.finfo(new_kwargs["input"].dtype).min
+    return _apply_reduction(func, "amax", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.argmin.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def argmin_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype_max = torch.finfo(new_kwargs["input"].dtype).max
+    return _apply_reduction(func, "argmin", dtype_max, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.argmax.default, "self: jt_all, dim: any?, keepdim: any?"
+)
+def argmax_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    dtype_min = torch.finfo(new_kwargs["input"].dtype).min
+    return _apply_reduction(func, "argmax", dtype_min, *args, **kwargs)
+
+
+@register_jagged_func(
+    torch.ops.aten.value_selecting_reduction_backward.default,
+    "grad: jt_all, dim: any, indices: jt_all, sizes: any, keepdim: any",
+)
+def value_selecting_reduction_backward_default(func, *args, **kwargs):
+    from torch.fx.experimental.symbolic_shapes import is_nested_int
+
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    grad = new_kwargs.pop("grad")
+    new_kwargs["grad"] = grad._values
+    indices = new_kwargs.pop("indices")
+    new_kwargs["indices"] = indices._values
+    # should always succeed; sizes should contain a nested int
+    ragged_idx = next(i for i, s in enumerate(new_kwargs["sizes"]) if is_nested_int(s))
+    # convert dim -> values-space dim
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        len(new_kwargs["sizes"]),
+        new_kwargs["dim"],
+        ragged_idx,
+        "value_selecting_reduction_backward",
+    )
+    # convert saved NJT sizes -> values-space sizes
+    sizes = new_kwargs.pop("sizes")
+    sizes[ragged_idx] = indices._values.size(indices._ragged_idx - 1)
+    sizes = sizes[1:]
+    new_kwargs["sizes"] = sizes
+
+    output_kwargs = extract_kwargs(indices)
+    output_kwargs["_ragged_idx"] = ragged_idx
+
+    return NestedTensor(func(**new_kwargs), **output_kwargs)
+
+
+@register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any")
+def stack_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # guaranteed this is non-empty if we got here
+    tensors = new_kwargs.pop("tensors")
+    for t in tensors:
+        if not isinstance(t, NestedTensor):
+            raise RuntimeError("stack(): expected all nested tensors inputs")
+
+        if t.dim() != tensors[0].dim():
+            raise RuntimeError(
+                "stack(): expected all nested tensors to have the same dim"
+            )
+
+        if not raggedness_matches(t, tensors[0].shape):
+            raise RuntimeError(
+                "stack(): expected all nested tensors to have the same nested structure"
+            )
+
+    new_kwargs["dim"] = _wrap_jagged_dim(
+        tensors[0].dim() + 1, new_kwargs["dim"], tensors[0]._ragged_idx, "stack"
+    )
+
+    return NestedTensor(
+        func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.embedding.default,
+    "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?",
+)
+def embedding_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    # guaranteed this is non-empty if we got here
+    indices = new_kwargs.pop("indices")
+    weight = new_kwargs.pop("weight")
+
+    return NestedTensor(
+        func(weight, indices._values, **new_kwargs), **extract_kwargs(indices)
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.embedding_dense_backward.default,
+    "grad_output: jt, indices: jt, num_weights: any, padding_idx: any, scale_grad_by_freq: any",
+)
+def embedding_dense_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    indices = new_kwargs.pop("indices")
+    grad_output = new_kwargs.pop("grad_output")
+    return func(grad_output._values, indices._values, **new_kwargs)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.values.default,
+        torch.ops.aten._nested_get_values.default,
+    ],
+    "self: jt_all",
+)
+def values_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    # TODO: Handle inference mode properly.
+    # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292
+    return inp._values.detach()
+
+
+@register_jagged_func(torch.ops.aten.all.default, "self: jt_all")
+def all_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return func(inp._values)
+
+
+@register_jagged_func(
+    torch.ops.aten.to_padded_tensor.default,
+    "self: jt_all, padding: any, output_size: any?",
+)
+def to_padded_tensor_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if inp._lengths is not None:
+        raise RuntimeError(
+            "to_padded_tensor(): not supported for nested tensors with holes"
+        )
+
+    # TODO: Handle the rest of output_size
+    output_size = new_kwargs["output_size"]
+    if output_size is not None:
+        max_seq_len = output_size[inp._ragged_idx]
+    else:
+        max_seq_len = (
+            inp._max_seqlen
+            if inp._max_seqlen_tensor is not None
+            else inp._values.size(0)
+        )
+
+    # only 2D values with ragged packed dim=0 is supported by the underlying FBGEMM
+    # kernel so do shape gymnastics if needed
+    values = inp.values()
+    if inp._ragged_idx > 1:
+        values = values.transpose(inp._ragged_idx - 1, 0)
+    values_shape = values.shape
+    if values.dim() > 2:
+        values = values.flatten(start_dim=1)
+    elif values.dim() == 1:
+        values = values.unsqueeze(-1)
+
+    # NB: The CUDA kernel for jagged -> padded dense conversion does not support
+    # integer / bool types; work around this by casting to half.
+    is_bool = values.dtype is torch.bool
+    if is_bool and values.is_cuda:
+        values = values.to(torch.half)
+    padded_out = torch.ops.aten._jagged_to_padded_dense_forward(
+        values,
+        [inp._offsets],
+        [max_seq_len],
+        new_kwargs["padding"],
+    )
+    if is_bool and padded_out.is_cuda:
+        padded_out = padded_out.to(torch.bool)
+
+    # shape gymnastics part 2
+    if len(values_shape) > 2:
+        padded_out = padded_out.unflatten(-1, values_shape[1:])
+    elif len(values_shape) == 1:
+        padded_out = padded_out.squeeze(-1)
+    if inp._ragged_idx > 1:
+        padded_out = padded_out.transpose(inp._ragged_idx, 1)
+
+    return padded_out
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_from_padded_tensor.default,
+    "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?",
+)
+def _nested_from_padded_tensor_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    padded, offsets = new_kwargs["padded"], new_kwargs["offsets"]
+    ragged_idx = new_kwargs.get("ragged_idx", 1)
+
+    # only 3D padded with ragged packed dim=0 is supported by the underlying FBGEMM
+    # kernel so do shape gymnastics
+    if ragged_idx > 1:
+        padded = padded.transpose(ragged_idx, 1)
+    padded_ragged_dim1_shape = padded.shape
+    if padded.dim() > 3:
+        padded = padded.flatten(start_dim=2)
+    elif padded.dim() < 3:
+        padded = padded.unsqueeze(-1)
+
+    # NB: The CUDA kernel for padded dense -> jagged conversion does not support
+    # integer / bool types; work around this by casting to half.
+    is_bool = padded.dtype is torch.bool
+    if is_bool and padded.is_cuda:
+        padded = padded.to(torch.half)
+    values = torch.ops.aten._padded_dense_to_jagged_forward(
+        padded, [offsets], new_kwargs["sum_S"]
+    )
+    if is_bool and values.is_cuda:
+        values = values.to(torch.bool)
+
+    # shape gymnastics part 2
+    if len(padded_ragged_dim1_shape) > 3:
+        values = values.unflatten(-1, padded_ragged_dim1_shape[2:])
+    elif len(padded_ragged_dim1_shape) < 3:
+        values = values.squeeze(-1)
+    if ragged_idx > 1:
+        values = values.transpose(ragged_idx - 1, 0)
+
+    min_seqlen = new_kwargs["min_seqlen"]
+    max_seqlen = new_kwargs["max_seqlen"]
+    metadata_cache = {}
+    if min_seqlen is not None:
+        metadata_cache["min_seqlen"] = min_seqlen
+    if max_seqlen is not None:
+        metadata_cache["max_seqlen"] = max_seqlen
+
+    return NestedTensor(
+        values,
+        offsets,
+        _ragged_idx=ragged_idx,
+        _metadata_cache=metadata_cache,
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_view_from_jagged.default,
+    "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?",
+)
+def _nested_view_from_jagged_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    values, offsets, lengths = (
+        new_kwargs["input"],
+        new_kwargs["offsets"],
+        new_kwargs["lengths"],
+    )
+    ragged_idx = new_kwargs["ragged_idx"]
+    min_seqlen = new_kwargs["min_seqlen"]
+    max_seqlen = new_kwargs["max_seqlen"]
+    metadata_cache = {}
+    if min_seqlen is not None:
+        metadata_cache["min_seqlen"] = min_seqlen
+    if max_seqlen is not None:
+        metadata_cache["max_seqlen"] = max_seqlen
+
+    return NestedTensor(
+        values,
+        offsets,
+        lengths=lengths,
+        _ragged_idx=ragged_idx,
+        _metadata_cache=metadata_cache,
+    )
+
+
+@register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all")
+def _nested_get_offsets(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._offsets
+
+
+@register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all")
+def _nested_get_lengths(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._lengths
+
+
+@register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all")
+def _nested_get_ragged_idx(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._ragged_idx
+
+
+@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all")
+def _nested_get_min_seqlen(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._metadata_cache.get("min_seqlen", None)
+
+
+@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all")
+def _nested_get_max_seqlen(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    return inp._metadata_cache.get("max_seqlen", None)
+
+
+# If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0
+@register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any")
+def masked_select_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    mask = new_kwargs.pop("mask")
+
+    if inp.ndim > 2:
+        raise RuntimeError("masked_select only support 2-D selections currently")
+    elif inp.shape != mask.shape:
+        raise RuntimeError(
+            f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}"
+        )
+    res_values = inp._values.masked_select(mask.values())
+    mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0))  # type: ignore[arg-type]
+
+    args = extract_kwargs(inp)
+    args["offsets"] = mask_cumsum[inp._offsets]
+    return NestedTensor(
+        values=res_values,
+        **args,
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten._nested_select_backward.default,
+    "grad_output: t, self: jt_all, dim: any, index: any",
+)
+def _nested_select_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    grad_output = new_kwargs.pop("grad_output")
+
+    grad_input = torch.zeros_like(inp, dtype=grad_output.dtype)
+    grad_input.select(new_kwargs["dim"], new_kwargs["index"]).copy_(grad_output)
+
+    return grad_input
+
+
+@register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any")
+def record_stream_default(func, *args, **kwargs):
+    inp = args[0]
+    stream = args[1]
+    # ensure all components live until stream computation completes
+    func(inp._values, stream)
+    func(inp._offsets, stream)
+    if inp._lengths is not None:
+        func(inp._lengths, stream)
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.new_empty.default,
+        torch.ops.aten.new_zeros.default,
+        torch.ops.aten.new_ones.default,
+    ],
+    "self: jt_all, size: any, dtype: any?, layout: any?, device: any?, pin_memory: any?",
+)
+def new_empty_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    if len(new_kwargs["size"]) == 0:
+        return func(inp._values, **new_kwargs)
+
+    raise RuntimeError("new_empty() not supported for NJT with shape != ()")
+
+
+@register_jagged_func(
+    [
+        torch.ops.aten.elu_backward.default,
+        torch.ops.aten.hardshrink_backward.default,
+        torch.ops.aten.hardsigmoid_backward.default,
+        torch.ops.aten.hardtanh_backward.default,
+        torch.ops.aten.softplus_backward.default,
+        torch.ops.aten.softshrink_backward.default,
+    ],
+    "self: jt_all, ...",
+)
+def activation_backward(func, *args, **kwargs):
+    # first NJT arg is expected to be grad_output
+    grad_output = next(arg for arg in args if isinstance(arg, NestedTensor))
+    return NestedTensor(
+        func(
+            *(arg._values if isinstance(arg, NestedTensor) else arg for arg in args),
+            **kwargs,
+        ),
+        **extract_kwargs(grad_output),
+    )
+
+
+@register_jagged_func(torch.ops.aten.fill.Scalar, "self: jt_all, value: any")
+def fill_Scalar(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+
+
+@register_jagged_func(torch.ops.aten.fill_.Scalar, "self: jt_all, value: any")
+def fill__Scalar(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+
+    func(inp._values, **new_kwargs)
+    return inp
+
+
+@register_jagged_func(torch.ops.aten.frexp.Tensor, "self: jt_all")
+def frexp_Tensor(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    output_kwargs = extract_kwargs(inp)
+
+    mantissa, exponent = func(inp._values)
+    return NestedTensor(mantissa, **output_kwargs), NestedTensor(
+        exponent, **output_kwargs
+    )
+
+
+@register_jagged_func(
+    torch.ops.aten.matmul_backward.default,
+    "grad: any, self: any, other: any, mask: any",
+)
+def matmul_backward_default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(  # type: ignore[misc]
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    grad = new_kwargs.pop("grad")
+    inp = new_kwargs.pop("input")
+    other = new_kwargs.pop("other")
+    grad_input_mask = new_kwargs.pop("mask")
+
+    if grad is None:
+        return (None, None)
+
+    grad_self = None
+    if grad_input_mask[0]:
+        grad_self = torch.matmul(grad, other.transpose(-1, -2))
+
+    grad_other = None
+    if grad_input_mask[1]:
+        grad_other = torch.matmul(inp.transpose(-1, -2), grad)
+
+    return (grad_self, grad_other)
+
+
+from torch._higher_order_ops.flex_attention import (
+    flex_attention as flex_attention_hop,
+    flex_attention_backward as flex_attention_backward_hop,
+)
+from torch.fx.graph_module import GraphModule
+
+
+@flex_attention_hop.py_impl(NestedTensor)  # type: ignore[misc]
+def flex_njt(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    score_mod: Callable,
+    block_mask: Tuple,
+    scale: float,
+    kernel_options: Dict[str, Any],
+    score_mod_other_buffers: Tuple = (),
+    mask_mod_other_buffers: Tuple = (),
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4
+
+    # TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense.
+    if any(
+        isinstance(buf, torch.Tensor) and buf.is_nested
+        for buf in score_mod_other_buffers + mask_mod_other_buffers
+    ):
+        raise RuntimeError(
+            "flex_attention(): Nested tensor score_mod / mask_mod buffers are not "
+            "currently supported. Please file an issue if this is important to you."
+        )
+
+    # need to pass dense tensor of shape (B, n_heads, sum(seq_len), D)
+    output = flex_attention_hop(
+        query.values().unsqueeze(0),
+        key.values().unsqueeze(0),
+        value.values().unsqueeze(0),
+        score_mod=score_mod,
+        block_mask=block_mask,
+        scale=scale,
+        kernel_options=kernel_options,
+        score_mod_other_buffers=score_mod_other_buffers,
+        mask_mod_other_buffers=mask_mod_other_buffers,
+    )
+
+    # wrap outputs as NJT
+    output_njt = torch.nested.nested_tensor_from_jagged(
+        output[0].transpose(1, 2).squeeze(0),
+        query._offsets,  # type: ignore[attr-defined]
+        query._lengths,  # type: ignore[attr-defined]
+        min_seqlen=query._maybe_min_seqlen,  # type: ignore[attr-defined]
+        max_seqlen=query._maybe_max_seqlen,  # type: ignore[attr-defined]
+    ).transpose(1, 2)
+
+    logsumexp_njt = torch.nested.nested_tensor_from_jagged(
+        output[1].transpose(1, 2).squeeze(0),
+        query._offsets,  # type: ignore[attr-defined]
+        query._lengths,  # type: ignore[attr-defined]
+        min_seqlen=query._maybe_min_seqlen,  # type: ignore[attr-defined]
+        max_seqlen=query._maybe_max_seqlen,  # type: ignore[attr-defined]
+    ).transpose(1, 2)
+
+    return (output_njt, logsumexp_njt)
+
+
+@flex_attention_backward_hop.py_impl(NestedTensor)  # type: ignore[misc]
+def flex_njt_backward(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    out: torch.Tensor,
+    logsumexp: torch.Tensor,
+    grad_out: torch.Tensor,
+    grad_logsumexp: torch.Tensor,
+    fw_graph: Union[Callable, GraphModule],
+    joint_graph: GraphModule,
+    block_mask: Tuple,
+    scale: float,
+    kernel_options: Dict[str, Any],
+    score_mod_other_buffers: Tuple = (),
+    mask_mod_other_buffers: Tuple = (),
+) -> Tuple[
+    torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
+]:
+    output = flex_attention_backward_hop(
+        query.values().unsqueeze(0),
+        key.values().unsqueeze(0),
+        value.values().unsqueeze(0),
+        out=out.values().unsqueeze(0),
+        logsumexp=logsumexp.values().unsqueeze(0),
+        grad_out=grad_out.values().unsqueeze(0),
+        grad_logsumexp=grad_logsumexp.values().unsqueeze(0),
+        fw_graph=fw_graph,
+        joint_graph=joint_graph,
+        block_mask=block_mask,
+        scale=scale,
+        kernel_options=kernel_options,
+        score_mod_other_buffers=score_mod_other_buffers,
+        mask_mod_other_buffers=mask_mod_other_buffers,
+    )
+
+    # wrap grads as NJTs
+    dense_q_grad, dense_k_grad, dense_v_grad, score_mod_other_buffer_grads = output
+    njt_q_grad = torch.nested.nested_tensor_from_jagged(
+        dense_q_grad.transpose(1, 2).squeeze(0),
+        query._offsets,  # type: ignore[attr-defined]
+        query._lengths,  # type: ignore[attr-defined]
+        min_seqlen=query._maybe_min_seqlen,  # type: ignore[attr-defined]
+        max_seqlen=query._maybe_max_seqlen,  # type: ignore[attr-defined]
+    ).transpose(1, 2)
+    njt_k_grad = torch.nested.nested_tensor_from_jagged(
+        dense_k_grad.transpose(1, 2).squeeze(0),
+        key._offsets,  # type: ignore[attr-defined]
+        key._lengths,  # type: ignore[attr-defined]
+        min_seqlen=key._maybe_min_seqlen,  # type: ignore[attr-defined]
+        max_seqlen=key._maybe_max_seqlen,  # type: ignore[attr-defined]
+    ).transpose(1, 2)
+    njt_v_grad = torch.nested.nested_tensor_from_jagged(
+        dense_v_grad.transpose(1, 2).squeeze(0),
+        value._offsets,  # type: ignore[attr-defined]
+        value._lengths,  # type: ignore[attr-defined]
+        min_seqlen=value._maybe_min_seqlen,  # type: ignore[attr-defined]
+        max_seqlen=value._maybe_max_seqlen,  # type: ignore[attr-defined]
+    ).transpose(1, 2)
+
+    return (njt_q_grad, njt_k_grad, njt_v_grad, score_mod_other_buffer_grads)
+
+
+# Make the dummy available on the C++ side.
+@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
+def _nested_get_jagged_dummy(func, *args, **kwargs):
+    from torch.nested._internal.nested_tensor import _nt_view_dummy
+
+    return _nt_view_dummy()
+
+
+with torch.library._scoped_library("aten", "IMPL") as aten:
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU")
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA")
+    aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta")
diff --git a/.venv/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py b/.venv/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ac4cc86a58cc3d7bd385a9290e1ec07ab6f8fc2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py
@@ -0,0 +1,934 @@
+# mypy: allow-untyped-defs
+import logging
+from typing import Optional
+
+import torch
+import torch.nn
+import torch.nn.functional as F
+from torch.backends.cuda import (
+    can_use_cudnn_attention,
+    can_use_efficient_attention,
+    can_use_flash_attention,
+    cudnn_sdp_enabled,
+    flash_sdp_enabled,
+    math_sdp_enabled,
+    mem_efficient_sdp_enabled,
+    SDPAParams,
+)
+from torch.nn.attention import SDPBackend
+
+from .nested_tensor import NestedTensor
+
+
+log = logging.getLogger(__name__)
+
+
+def _validate_sdpa_input(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: Optional[torch.Tensor] = None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+):
+    if (
+        not isinstance(query, NestedTensor)
+        or not isinstance(key, NestedTensor)
+        or not isinstance(value, NestedTensor)
+    ):
+        raise ValueError(
+            f"Expected query, key, and value to be nested tensors, "
+            f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
+            f"and value.is_nested: {value.is_nested} instead."
+        )
+    if query.dtype != key.dtype or query.dtype != value.dtype:
+        raise ValueError(
+            f"Expected query, key, and value to have the same dtype, "
+            f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
+            f"and value.dtype: {value.dtype} instead."
+        )
+    if query.device != key.device or query.device != value.device:
+        raise ValueError(
+            f"Expected query, key, and value to have the same device type, "
+            f"but got query.device: {query.device}, key.device: {key.device}, "
+            f"and value.device: {value.device} instead."
+        )
+    if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
+        raise ValueError(
+            f"Expected query, key, and value to all be  at least 3 dimensional, but got query.dim: "
+            f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
+        )
+    if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
+        raise ValueError(
+            f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
+            f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
+        )
+    if attn_mask is not None:
+        # TODO: Figure out whether masks are actually supported for this layout or not
+        raise ValueError("Masks are not yet supported!")
+        if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
+            raise ValueError(
+                f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
+                f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
+            )
+
+
+def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
+    # This is expected to be called after check_tensor_shapes ensuring that the
+    # size() calls won't error since the inputs are all 4 dimensional
+    q_batch_size = params.query.size(0)
+    k_batch_size = params.key.size(0)
+    v_batch_size = params.value.size(0)
+
+    # num_heads logic for nested input is checked in
+    # check_for_seq_len_0_nested_tensor as there is handling there to make sure
+    # num_heads is not ragged
+    return q_batch_size == k_batch_size and q_batch_size == v_batch_size
+
+
+def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
+    max_size = 256
+    query_size_last = params.query.size(-1)
+    key_size_last = params.key.size(-1)
+    value_size_last = params.value.size(-1)
+    same_head_dim_size = (
+        query_size_last == key_size_last and query_size_last == value_size_last
+    )
+    if not (
+        same_head_dim_size
+        and (query_size_last % 8 == 0)
+        and (query_size_last <= max_size)
+    ):
+        if debug:
+            log.warning(
+                "For NestedTensor inputs, Flash attention requires q,k,v to have the same "
+                "last dimension and to be a multiple of 8 and less than or equal to 256. "
+                "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
+                query_size_last,
+                key_size_last,
+                value_size_last,
+            )
+        return False
+    return True
+
+
+def _check_head_dim_size_cudnn_nested(params: SDPAParams, debug=False) -> bool:
+    max_size = 128
+    query_size_last = params.query.size(-1)
+    key_size_last = params.key.size(-1)
+    value_size_last = params.value.size(-1)
+    same_head_dim_size = (
+        query_size_last == key_size_last and query_size_last == value_size_last
+    )
+    if not (
+        same_head_dim_size
+        and (query_size_last % 8 == 0)
+        and (query_size_last <= max_size)
+    ):
+        if debug:
+            log.warning(
+                "For NestedTensor inputs, cuDNN attention requires q,k,v to have the same "
+                "last dimension and to be a multiple of 8 and less than or equal to 128. "
+                "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
+                query_size_last,
+                key_size_last,
+                value_size_last,
+            )
+        return False
+    return True
+
+
+def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+    param: torch.Tensor, param_name: str, debug=False
+) -> bool:
+    assert isinstance(param, NestedTensor), "param should be a jagged NT"
+
+    if param._ragged_idx == 1:
+        # num_head_dims is ragged
+        if debug:
+            log.warning(
+                "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
+                param_name,
+            )
+        return False
+
+    # This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
+    if param._get_min_seqlen() == 0:
+        if debug:
+            log.warning(
+                "Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
+                param_name,
+            )
+        return False
+
+    return True
+
+
+def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
+    max_size = max(q_size, k_size, v_size)
+    if (
+        (q_size != max_size and q_size != 1)
+        or (k_size != max_size and k_size != 1)
+        or (v_size != max_size and v_size != 1)
+    ):
+        if debug:
+            log.warning(
+                "Both fused kernels require query, key and value to have broadcastable %s, "
+                "got Query %s %d, Key %s %d, Value %s %d instead.",
+                param_name,
+                param_name,
+                q_size,
+                param_name,
+                k_size,
+                param_name,
+                v_size,
+            )
+        return False
+    return True
+
+
+def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
+    # When this function is called we are assured that the nt is dim==4
+    q_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.query, "query", debug
+        )
+        if params.query.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not q_is_safe:
+        return False
+
+    k_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.key, "key", debug
+        )
+        if params.key.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not k_is_safe:
+        return False
+
+    v_is_safe = (
+        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
+            params.value, "value", debug
+        )
+        if params.value.is_nested
+        else True
+    )
+    # short circuit if any is unsafe
+    if not v_is_safe:
+        return False
+
+    # We now know none of the inputs have ragged num_heads, so we can safely
+    # access .size(1)
+    q_num_heads = params.query.size(1)
+    k_num_heads = params.key.size(1)
+    v_num_heads = params.value.size(1)
+    same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
+
+    if not same_num_heads:
+        if (
+            params.query.requires_grad
+            or params.key.requires_grad
+            or params.value.requires_grad
+        ):
+            if debug:
+                log.warning(
+                    "Both fused kernels do not support training with broadcasted NT inputs."
+                )
+            return False
+        return _try_broadcast_param_size(
+            q_num_heads, k_num_heads, v_num_heads, "num heads", debug
+        )
+    return True
+
+
+def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    constraints = (
+        _check_batch_size_nested,
+        _check_head_dim_size_flash_nested,
+        _check_for_seq_len_0_nested,
+    )
+    for constraint in constraints:
+        if not constraint(params, debug):
+            return False
+    return True
+
+
+def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    constraints = (
+        _check_batch_size_nested,
+        _check_for_seq_len_0_nested,
+    )
+    for constraint in constraints:
+        if not constraint(params, debug):
+            return False
+    return True
+
+
+def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
+    if (
+        not params.query.transpose(1, 2).is_contiguous()
+        or not params.key.transpose(1, 2).is_contiguous()
+        or not params.value.transpose(1, 2).is_contiguous()
+    ):
+        if debug:
+            log.warning(
+                "If inputs are nested tensors they must be contiguous after transposing."
+            )
+        return False
+    if params.is_causal:
+        if debug:
+            log.warning(
+                "Nested tensors for query / key are not supported when is_causal=True."
+            )
+        return False
+    return True
+
+
+def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
+    if (
+        not flash_sdp_enabled()
+        and not mem_efficient_sdp_enabled()
+        and not math_sdp_enabled()
+        and not cudnn_sdp_enabled()
+    ):
+        return SDPBackend.ERROR
+
+    ordering = (
+        SDPBackend.FLASH_ATTENTION,
+        SDPBackend.EFFICIENT_ATTENTION,
+        SDPBackend.MATH,
+        SDPBackend.CUDNN_ATTENTION,
+    )
+
+    params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
+
+    for backend in ordering:
+        if backend == SDPBackend.CUDNN_ATTENTION:
+            if can_use_cudnn_attention(params):
+                return SDPBackend.CUDNN_ATTENTION
+        if backend == SDPBackend.FLASH_ATTENTION:
+            if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
+                return SDPBackend.FLASH_ATTENTION
+        if backend == SDPBackend.EFFICIENT_ATTENTION:
+            if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
+                params
+            ):
+                return SDPBackend.EFFICIENT_ATTENTION
+        if backend == SDPBackend.MATH:
+            if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
+                return SDPBackend.MATH
+
+    log.warning("Memory efficient kernel not used because:")
+    can_use_efficient_attention(params, debug=True)
+    _can_use_efficient_sdpa_jagged(params, debug=True)
+    log.warning("Flash attention kernel not used because:")
+    can_use_flash_attention(params, debug=True)
+    _can_use_flash_sdpa_jagged(params, debug=True)
+    log.warning("Math attention kernel not used because:")
+    _can_use_math_sdpa_jagged(params, debug=True)
+    log.warning("cuDNN attention kernel not used because:")
+    can_use_cudnn_attention(params, debug=True)
+    return SDPBackend.ERROR
+
+
+def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, int, int]:
+    # This function is used to calculate two pieces of metadata that are needed
+    # for use with flash-attention and efficient_attention kernels. They are the
+    # cumulative sequence_length over a batch of sequences and the maximum
+    # sequence length.
+
+    # It returns a tuple of cumulative sequence lengths and the maximum sequence
+    # length, and the last element in the cumulative_sequence_lengths
+    if not isinstance(qkv, NestedTensor):
+        raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
+
+    if qkv.lengths() is None:
+        # TODO: Explore performance impact of copying
+        cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
+        max_seqlen = qkv._get_max_seqlen()
+        n_elem = qkv.values().shape[0]
+    else:
+        # TODO: Explore performance impact of copying
+        cumulative_seqlen = (
+            qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
+        )
+        max_seqlen = qkv._get_max_seqlen()
+        # TODO: Explore performance impact when compiling
+        n_elem = int(cumulative_seqlen[-1].item())
+    return cumulative_seqlen, max_seqlen, n_elem
+
+
+def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
+    # This function checks if a nested tensor is valid for
+    # use with the flash-attention and efficient_attention kernels without
+    # needing to call contiguous on the nested tensor input.
+    # It checks that the storage offsets' adjacent_differences are a constant
+    # mutiple of the previous tensor in the nested tensor and that the strides
+    # are monitonically decreasing. This check is done after calling transpose on
+    # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
+
+    # Returns a boolean indicating if contiguous needs to be called for input
+    assert isinstance(tensor, NestedTensor)
+    offsets = tensor.offsets()
+    strides = tensor._strides
+
+    n_tensors = offsets.size(0) - 1
+    if n_tensors <= 1:
+        return True
+
+    # Check initially that the tensor strides are in strictly descending order
+    prev_stride = strides[1]
+    for stride in strides[2:]:
+        if prev_stride <= stride:
+            # This would mean that the last stride is greater than the seq_len
+            # stride
+            return False
+        prev_stride = stride
+
+    # Congrats you made it!
+    return True
+
+
+def _view_as_dense(
+    tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
+) -> torch.Tensor:
+    if tensor.is_nested:
+        return tensor.values()
+    return tensor.view(Nnz, num_heads, head_dim)
+
+
+# TODO: Next iteration should add test cases and check it works
+# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
+#     # Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
+#     # Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+#     # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+#     q_batch_size = query.size(0)
+#     k_batch_size = key.size(0)
+#     v_batch_size = value.size(0)
+
+#     output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
+
+#     q_num_heads = query.size(1)
+#     k_num_heads = key.size(1)
+#     v_num_heads = value.size(1)
+
+#     output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
+
+#     head_dim_qk = query.size(3)
+#     head_dim_v = value.size(3)
+
+#     q_t = query.transpose(1, 2)
+#     k_t = key.transpose(1, 2)
+#     v_t = value.transpose(1, 2)
+
+#     # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
+#     # output_batch_size/num_heads then they are 1
+#     q_batch_size_needs_broadcast = q_batch_size != output_batch_size
+#     k_batch_size_needs_broadcast = k_batch_size != output_batch_size
+#     v_batch_size_needs_broadcast = v_batch_size != output_batch_size
+
+#     # If {*}_batch_size_needs_broadcast, then
+#     # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
+#     #     this is because needs_broadcast indicates that the batch_size is 1
+#     #     and hence there is only 1 value for seq_len
+#     # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
+#     # ..., outut_batch_size * {*}_t.size(1)]
+#     # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
+
+#     if q_batch_size_needs_broadcast or not q_t.is_nested:
+#         max_seqlen_batch_q = q_t.size(1)
+#         cumulative_sequence_length_q = torch.arange(
+#             0,
+#             (output_batch_size + 1) * max_seqlen_batch_q,
+#             max_seqlen_batch_q,
+#             device=q_t.device,
+#             dtype=torch.int32,
+#         )
+#         Nnz_q = output_batch_size * max_seqlen_batch_q
+#     else:
+#         (
+#             cumulative_sequence_length_q,
+#             max_seqlen_batch_q,
+#             Nnz_q,
+#         ) = _cumulative_and_max_seq_len_nnz(q_t)
+
+#     if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
+#         assert k_t.size(1) == v_t.size(1)
+#         max_seqlen_batch_kv = k_t.size(1)
+#         cumulative_sequence_length_kv = torch.arange(
+#             0,
+#             (output_batch_size + 1) * max_seqlen_batch_kv,
+#             max_seqlen_batch_kv,
+#             device=k_t.device,
+#             dtype=torch.int32,
+#         )
+#         Nnz_kv = output_batch_size * max_seqlen_batch_kv
+#     else:
+#         cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
+#             _cumulative_and_max_seq_len_nnz(v_t)
+#             if k_batch_size_needs_broadcast
+#             else _cumulative_and_max_seq_len_nnz(k_t)
+#         )
+
+#     q_num_heads_needs_broadcast = q_num_heads != output_num_heads
+#     k_num_heads_needs_broadcast = k_num_heads != output_num_heads
+#     v_num_heads_needs_broadcast = v_num_heads != output_num_heads
+
+#     if not q_t.is_nested:
+#         query_buffer_reshaped = q_t.expand(
+#             output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
+#         )
+#         query_buffer_reshaped = query_buffer_reshaped.reshape(
+#             Nnz_q, output_num_heads, head_dim_qk
+#         )
+#     else:
+#         if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
+#             q_t = q_t.contiguous()
+#         # If we are broadcasting then Nnz_q will be the output_batch_size since
+#         # seq_len is 1
+#         effective_batch_size_q = (
+#             output_batch_size if q_batch_size_needs_broadcast else Nnz_q
+#         )
+#         query_buffer_reshaped = _view_as_dense(
+#             q_t, effective_batch_size_q, output_num_heads, head_dim_qk
+#         )
+
+#     # If the physical layout of the NestedTensor's storage
+#     # is not: batch, {seq_len}, num_heads, head_dim then we need
+#     # to call contiguous
+#     if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
+#         k_t = k_t.contiguous()
+#     if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
+#         v_t = v_t.contiguous()
+
+#     effective_batch_size_k = (
+#         output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
+#     )
+#     key_buffer_reshaped = _view_as_dense(
+#         k_t, effective_batch_size_k, output_num_heads, head_dim_qk
+#     )
+
+#     effective_batch_size_v = (
+#         output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
+#     )
+#     value_buffer_reshaped = _view_as_dense(
+#         v_t, effective_batch_size_v, output_num_heads, head_dim_v
+#     )
+
+#     if not q_batch_size_needs_broadcast:
+#         output_shape = q_t._size
+#         if head_dim_v != head_dim_qk:
+#             output_shape[-1] = head_dim_v
+#         if q_num_heads_needs_broadcast:
+#             output_shape[1] = output_num_heads
+#     else:
+#         output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
+#         output_shape[0] = q_t.size(1)
+#         output_shape[1] = output_num_heads
+#         output_shape[2] = head_dim_v
+
+#     return (
+#         query_buffer_reshaped,
+#         key_buffer_reshaped,
+#         value_buffer_reshaped,
+#         cumulative_sequence_length_q,
+#         cumulative_sequence_length_kv,
+#         max_seqlen_batch_q,
+#         max_seqlen_batch_kv,
+#         output_shape,
+#     )
+
+
+def _sdpa_nested_preprocessing(query, key, value):
+    # Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
+    # Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+    # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
+    q_batch_size = query.size(0)
+    k_batch_size = key.size(0)
+    v_batch_size = value.size(0)
+
+    q_num_heads = query.size(1)
+    k_num_heads = key.size(1)
+    v_num_heads = value.size(1)
+
+    if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
+        q_num_heads == k_num_heads and k_num_heads == v_num_heads
+    ):
+        raise RuntimeError(
+            "This path is currently not implemented for jagged layout NT."
+        )
+        # return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
+
+    num_heads = query.size(1)
+    head_dim_qk = query.size(3)
+    head_dim_v = value.size(3)
+    q_t = query.transpose(1, 2)
+    k_t = key.transpose(1, 2)
+    v_t = value.transpose(1, 2)
+
+    (
+        cumulative_sequence_length_q,
+        max_seqlen_batch_q,
+        Nnz_q,
+    ) = _cumulative_and_max_seq_len_nnz(q_t)
+    (
+        cumulative_sequence_length_kv,
+        max_seqlen_batch_kv,
+        Nnz_kv,
+    ) = _cumulative_and_max_seq_len_nnz(k_t)
+
+    # [TODO] K and V have to have the same Nnz, should probably torch_check
+    # assume in order to not iterate over v
+
+    # If the physical layout of the NestedTensor's storage
+    # is not: batch, {seq_len}, num_heads, head_dim then we need
+    # to call contiguous
+    if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
+        q_t = q_t.contiguous()
+    if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
+        k_t = k_t.contiguous()
+    if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
+        v_t = v_t.contiguous()
+
+    query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
+    key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
+    value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
+
+    output_nt_info = {
+        "offsets": q_t.offsets(),
+        "lengths": q_t.lengths(),
+        "max_seqlen": q_t._get_max_seqlen(),
+        "min_seqlen": q_t._get_min_seqlen(),
+    }
+
+    return (
+        query_buffer_reshaped,
+        key_buffer_reshaped,
+        value_buffer_reshaped,
+        cumulative_sequence_length_q,
+        cumulative_sequence_length_kv,
+        max_seqlen_batch_q,
+        max_seqlen_batch_kv,
+        output_nt_info,
+    )
+
+
+def _pad_last_dim(
+    tensor: torch.Tensor, alignment_size: int, slice: bool
+) -> torch.Tensor:
+    # FlashAttentionV2 requires that head dimension be a multiple of 8
+    # This was previously done within the kernel, however
+    # This causes the kernel to maybe alias query, key, value
+    # So instead we pad the head_dimensions to be a multiple of 8
+    # in the composite region
+    last_dim_size = tensor.size(-1)
+    if last_dim_size % alignment_size == 0:
+        return tensor
+    pad_count = alignment_size - (last_dim_size % alignment_size)
+    tensor = torch.nn.functional.pad(tensor, [0, pad_count])
+    if slice:
+        return tensor[..., 0:last_dim_size]
+    return tensor
+
+
+# TODO: coalesce with torch/nn/utils/attention.py
+def _calculate_scale(query, scale):
+    # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
+    softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
+    return softmax_scale
+
+
+def _post_process_flash_output(out: torch.Tensor, og_size):
+    if not out.is_nested and out.size(-1) != og_size:
+        out = out[..., 0:og_size]
+    return out
+
+
+def _is_computing_meta_flops(x):
+    # Note: there's a use case of using meta tensors & the dispatch-based flop counter.
+    # We can use this function to check for this scenario in order to handle it specially.
+    if not torch.jit.is_scripting() and x.device.type == "meta":
+        torch_dispatch_mode_stack = (
+            torch.utils._python_dispatch._get_current_dispatch_mode_stack()
+        )
+        return any(
+            type(x) == torch.utils.flop_counter._FlopCounterMode
+            for x in torch_dispatch_mode_stack
+        )
+    return False
+
+
+def _autocast(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: Optional[torch.Tensor],
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+    """
+    [Autocasting SDPA for NJT]
+
+    Normal autocasting doesn't work for NJT+SDPA right now:
+    * NJT intercepts the __torch_function__ call for scaled_dot_product_attention, which happens
+      before we get to any aten ops or dispatcher logic; then the torch_function logic calls into
+      efficient attention or flash attention. So, autocasting on the scaled_dot_product_attention
+      op won't work because we never see that aten op.
+    * If we put autocasting on `_flash_attention_forward`, then we'll get autocasting to run, but
+      the kernel selection logic in torch_function handling (ie. jagged_scaled_dot_product_attention)
+      won't work correctly: the kernel selection logic will run before autocasting, and choose
+      a kernel based on the un-autocasted dtypes; but then autocasting will run and the actual
+      attention computation will happen in a different dtype.
+
+    An alternative is to just change the backend selection logic for SDPA+NJT to be autocast-aware
+    and rely on autocasting to do the actual conversions for flash attention / efficient attention.
+    However, by manually doing the actual autocast before the backend selection, we ensure that the
+    autocast handling for backend selection doesn't diverge from the autocast handling for the
+    actual dtype conversions.
+    """
+    device_type = query.device.type
+    # meta device is not supported by autocast, so break early for it
+    if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
+        return query, key, value, attn_mask
+
+    def cvt(x):
+        if x is None:
+            return x
+        target_dtype = torch.get_autocast_dtype(device_type)
+        if (
+            (not x.dtype.is_floating_point)
+            or x.dtype == target_dtype
+            or x.dtype == torch.float64
+        ):
+            return x
+        return x.to(target_dtype)
+
+    return cvt(query), cvt(key), cvt(value), cvt(attn_mask)
+
+
+def jagged_scaled_dot_product_attention(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_mask: Optional[torch.Tensor] = None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+    enable_gqa=False,
+):
+    query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
+    _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
+    # for mypy, ugh
+    assert (
+        isinstance(query, NestedTensor)
+        and isinstance(key, NestedTensor)
+        and isinstance(value, NestedTensor)
+    )
+    from torch.nested._internal.nested_tensor import (
+        nested_view_from_values_offsets_lengths,
+    )
+
+    # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
+    # second batch dim instead). For this case, we can just send the dense buffers through
+    # vanilla SDPA.
+    if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
+        output = F.scaled_dot_product_attention(
+            query.values(),
+            key.values(),
+            value.values(),
+            attn_mask=(
+                attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
+            ),
+            dropout_p=dropout_p,
+            is_causal=is_causal,
+            scale=scale,
+        )
+        return nested_view_from_values_offsets_lengths(
+            output,
+            query.offsets(),
+            query.lengths(),
+            min_seqlen=query._maybe_min_seqlen,  # type: ignore[attr-defined]
+            max_seqlen=query._maybe_max_seqlen,  # type: ignore[attr-defined]
+        )
+
+    compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
+
+    backend_choice = _select_sdp_backend(
+        query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
+    )
+
+    if _is_computing_meta_flops(query):
+        # Backend choice will probably not be correct if we have a meta device,
+        # because backend choice is device-aware. In this case, we mostly just
+        # want to avoid using math backend (which does a .item() call).
+        # Arbitrarily choose flash attention.
+        backend_choice = SDPBackend.FLASH_ATTENTION
+
+    if backend_choice == SDPBackend.FLASH_ATTENTION:
+        og_size = query.size(-1)
+        query_padded = _pad_last_dim(query, 8, False)
+        key_padded = _pad_last_dim(key, 8, False)
+        value_padded = _pad_last_dim(value, 8, False)
+        # We need to calculate the scale based off the OG head dim size
+        og_scale = _calculate_scale(query, scale)
+        (
+            query_buffer_reshaped,
+            key_buffer_reshaped,
+            value_buffer_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
+        (
+            attention,
+            _logsumexp,
+            _philox_seed,
+            _philox_offset,
+            _debug_attn_mask,
+        ) = torch.ops.aten._flash_attention_forward(
+            query_buffer_reshaped,
+            key_buffer_reshaped,
+            value_buffer_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            dropout_p,
+            is_causal,
+            False,
+            scale=og_scale,
+        )
+        # Reshape output to convert nnz to batch_size and seq_len
+        attention = nested_view_from_values_offsets_lengths(
+            attention,  # output from flash_attn is [total_q, num_heads, head_size_og]
+            **output_nt_info,
+        ).transpose(1, 2)
+        return _post_process_flash_output(attention, og_size)
+    elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
+        (
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query, key, value)
+        (
+            attention,
+            log_sumexp,
+            seed,
+            offset,
+            max_seqlen_q,
+            max_seqlen_batch_kv,
+        ) = torch.ops.aten._efficient_attention_forward(
+            query_reshaped.unsqueeze(0),
+            key_reshaped.unsqueeze(0),
+            value_reshaped.unsqueeze(0),
+            None,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            dropout_p,
+            int(is_causal),
+            compute_logsumexp,
+            scale=scale,
+        )
+        # Reshape output to convert nnz to batch_size and seq_len
+        return nested_view_from_values_offsets_lengths(
+            attention.squeeze(0),
+            **output_nt_info,
+        ).transpose(1, 2)
+    elif backend_choice == SDPBackend.CUDNN_ATTENTION:
+        (
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            output_nt_info,
+        ) = _sdpa_nested_preprocessing(query, key, value)
+        (
+            attention,
+            logsumexp,
+            cum_seqlen_q,
+            cum_seqlen_kv,
+            max_seqlen_q,
+            max_seqlen_kv,
+            seed,
+            offset,
+            _,
+        ) = torch.ops.aten._cudnn_attention_forward(
+            query_reshaped,
+            key_reshaped,
+            value_reshaped,
+            attn_mask,
+            cumulative_sequence_length_q,
+            cumulative_sequence_length_kv,
+            max_seqlen_batch_q,
+            max_seqlen_batch_kv,
+            compute_logsumexp,
+            dropout_p,
+            is_causal,
+            False,
+            scale=scale,
+        )
+        return nested_view_from_values_offsets_lengths(
+            attention,
+            **output_nt_info,
+        ).transpose(1, 2)
+    elif backend_choice == SDPBackend.MATH:
+        # save the offsets and shape of the inputs, so we can reshape the final output
+        # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
+        # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
+        offsets = query.offsets()
+        q_lengths = query.lengths()
+        min_seqlen = query._maybe_min_seqlen
+        max_seqlen = query._maybe_max_seqlen
+        d1 = query._size[1]
+        d2 = value._size[-1]
+
+        # convert jagged layout Nested Tensor to strided layout Nested Tensor
+        # which support the math implementation of SDPA
+        def get_strided_layout_nested_tensor(jagged_layout_nt):
+            lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
+            transpose = torch.transpose(jagged_layout_nt, 1, 2)
+            tensor_list = transpose.values().split(list(lengths), dim=0)
+            strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
+            strided_nt = strided_nt.transpose(1, 2).contiguous()
+            return strided_nt
+
+        query = get_strided_layout_nested_tensor(query)
+        key = get_strided_layout_nested_tensor(key)
+        value = get_strided_layout_nested_tensor(value)
+
+        attn_out = torch._scaled_dot_product_attention_math(
+            query, key, value, attn_mask, dropout_p, is_causal, scale=scale
+        )[0]
+
+        # convert strided layout Nested Tensor back to jagged layout Nested Tensor
+        attn_out = attn_out.transpose(1, 2).contiguous().values()
+        attn_out = attn_out.view(-1, d1, d2)
+        attn_out = nested_view_from_values_offsets_lengths(
+            attn_out,
+            offsets,
+            lengths=q_lengths,
+            min_seqlen=min_seqlen,
+            max_seqlen=max_seqlen,
+        ).transpose(1, 2)
+
+        return attn_out
+    else:
+        raise RuntimeError(
+            "No viable backend for scaled_dot_product_attention was found."
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45f26c1b7d8085a62a796aa0cd36f204bda544dd
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/_adafactor.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/_adafactor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0fded01900bd922f4c0a8459834c6989c6fb22b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/_adafactor.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/_functional.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/_functional.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..899241db2cb1f6b61aa0074183f56c89f1e84dc8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/_functional.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adadelta.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adadelta.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d36f46bd0d6e49ff330d631ddb01493bb54a866c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adadelta.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adagrad.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adagrad.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7bb822dc8684d68e91cc88d871effeafb2e8fd6
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adagrad.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adam.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adam.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7da49c3fb711f120f37e5f0514069dce04bf1597
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adam.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adamax.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adamax.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc5980c75a8c9c1fcaa222c2bb6c8b073d667586
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adamax.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adamw.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adamw.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d3f6b19d81dbad707b4cafd2a37aa5c15fa6b3f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/adamw.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/asgd.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/asgd.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83ae307b1586a3c679bb484c616e28f691758576
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/asgd.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/lbfgs.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/lbfgs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2aa774cff64c15d234ca6df18e0e50be3579e74
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/lbfgs.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/lr_scheduler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/lr_scheduler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f44b7205b5f1c360cef894df97c9597dabb9f13b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/lr_scheduler.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/nadam.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/nadam.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48bc1cfcadb4a0c16545a54f2c975e81ee749fb1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/nadam.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/optimizer.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/optimizer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..976a235f4986a898dd5fc5d728e117f5442f24f8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/optimizer.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/radam.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/radam.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0182f668e30e3dce42759da21a658b51bd96878b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/radam.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/rmsprop.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/rmsprop.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67159aba99a04db976fa160fd8ec8626ae1f7084
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/rmsprop.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/rprop.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/rprop.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da0f28736efc3e79977696bcf25bcd050e3f211e
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/rprop.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/sgd.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/sgd.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb693260d07df0e2695cd7ceb11e2e45cb6068ec
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/sgd.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/sparse_adam.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/sparse_adam.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c9c8ef9f095311563a5a6bb147de887179ac5aa
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/sparse_adam.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/swa_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/swa_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..639cfad912fcb1d342e7bed219b5f96880962c8f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/optim/__pycache__/swa_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/_multi_tensor/__init__.py b/.venv/lib/python3.12/site-packages/torch/optim/_multi_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6818e5a50f3b6792e249dc6e45b64bd29d0a067
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/optim/_multi_tensor/__init__.py
@@ -0,0 +1,31 @@
+"""
+:mod:`torch.optim._multi_tensor` is a package implementing various optimization algorithms.
+
+Most commonly used methods are already supported, and the interface is general
+enough, so that more sophisticated ones can be also easily integrated in the
+future.
+"""
+
+from functools import partialmethod
+
+from torch import optim
+
+
+def partialclass(cls, *args, **kwargs):  # noqa: D103
+    class NewCls(cls):
+        __init__ = partialmethod(cls.__init__, *args, **kwargs)
+
+    return NewCls
+
+
+Adam = partialclass(optim.Adam, foreach=True)
+AdamW = partialclass(optim.AdamW, foreach=True)
+NAdam = partialclass(optim.NAdam, foreach=True)
+SGD = partialclass(optim.SGD, foreach=True)
+RAdam = partialclass(optim.RAdam, foreach=True)
+RMSprop = partialclass(optim.RMSprop, foreach=True)
+Rprop = partialclass(optim.Rprop, foreach=True)
+ASGD = partialclass(optim.ASGD, foreach=True)
+Adamax = partialclass(optim.Adamax, foreach=True)
+Adadelta = partialclass(optim.Adadelta, foreach=True)
+Adagrad = partialclass(optim.Adagrad, foreach=True)
diff --git a/.venv/lib/python3.12/site-packages/torch/optim/_multi_tensor/__init__.pyi b/.venv/lib/python3.12/site-packages/torch/optim/_multi_tensor/__init__.pyi
new file mode 100644
index 0000000000000000000000000000000000000000..97c3e2df989303c0f4a1cf76977cc47e25dfaaf8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/optim/_multi_tensor/__init__.pyi
@@ -0,0 +1,15 @@
+from functools import partial
+
+from torch import optim
+
+Adam = partial(optim.Adam, foreach=True)
+AdamW = partial(optim.AdamW, foreach=True)
+NAdam = partial(optim.NAdam, foreach=True)
+SGD = partial(optim.SGD, foreach=True)
+RAdam = partial(optim.RAdam, foreach=True)
+RMSprop = partial(optim.RMSprop, foreach=True)
+Rprop = partial(optim.Rprop, foreach=True)
+ASGD = partial(optim.ASGD, foreach=True)
+Adamax = partial(optim.Adamax, foreach=True)
+Adadelta = partial(optim.Adadelta, foreach=True)
+Adagrad = partial(optim.Adagrad, foreach=True)
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5718adb0c1c524a1954bbb453450e99f78a7d326
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_digraph.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_digraph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7235406b71fc1e5998450675f89794d511ade72b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_digraph.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_directory_reader.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_directory_reader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1cb7a9b87d2aceca08d35fca9acf1b8af4db9c3
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_directory_reader.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_importlib.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_importlib.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd1f044ca0491141e0eaa7509c649be7e2fc6d0f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_importlib.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_mangling.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_mangling.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6dc9ed63b474b068cbe48ef6a621ee41b0d70080
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_mangling.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_package_pickler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_package_pickler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15d52ed802f1bfe34c87c37a582b5d9cbc88b00c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_package_pickler.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_package_unpickler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_package_unpickler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..256fd354dd1a4f63f3941674c03183729e8dc6f1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_package_unpickler.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_stdlib.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_stdlib.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26339cbf0ef33663d0711f84cdba54976fde9b83
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/_stdlib.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/file_structure_representation.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/file_structure_representation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f2a4349fb6756906bf900f7e36ca5807c75e643
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/file_structure_representation.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/find_file_dependencies.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/find_file_dependencies.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45d91120f31f9c6ae9fad06be67c948e109540fe
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/find_file_dependencies.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/glob_group.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/glob_group.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0c115b481c664c8b6e3825ae6cd95d29a919b08
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/glob_group.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/importer.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/importer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5fbd524c03a9042bb48c1c81489a44532f4dbec
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/importer.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/package_exporter.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/package_exporter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32c0f77cc042ef04a119f6288983c5e8d5b48a31
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/package_exporter.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/__pycache__/package_importer.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/package_importer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..70b6e90319f5803a1be3ea31be91538cc7372a09
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/__pycache__/package_importer.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/analyze/__init__.py b/.venv/lib/python3.12/site-packages/torch/package/analyze/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ef7a1716af241e21f97f593abde2a2b75960814
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/package/analyze/__init__.py
@@ -0,0 +1,2 @@
+from .find_first_use_of_broken_modules import find_first_use_of_broken_modules
+from .trace_dependencies import trace_dependencies
diff --git a/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0817dadbdf83f7ad66cf4657d64b58ffa530efe
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/find_first_use_of_broken_modules.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/find_first_use_of_broken_modules.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb757ec1b7b1185ec69b524132013330c71105c9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/find_first_use_of_broken_modules.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/is_from_package.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/is_from_package.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ed2d46888e666bfff97e6afd77b2f7c192923a1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/is_from_package.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/trace_dependencies.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/trace_dependencies.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..646b1e5d6ea63030129fb2f4e8f6a15a593b8954
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/package/analyze/__pycache__/trace_dependencies.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/package/analyze/find_first_use_of_broken_modules.py b/.venv/lib/python3.12/site-packages/torch/package/analyze/find_first_use_of_broken_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..728f3289b5cd4fee58bd49346f327419d9d2af25
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/package/analyze/find_first_use_of_broken_modules.py
@@ -0,0 +1,30 @@
+from torch.package.package_exporter import PackagingError
+
+
+__all__ = ["find_first_use_of_broken_modules"]
+
+
+def find_first_use_of_broken_modules(exc: PackagingError) -> dict[str, list[str]]:
+    """
+    Find all broken modules in a PackagingError, and for each one, return the
+    dependency path in which the module was first encountered.
+
+    E.g. broken module m.n.o was added to a dependency graph while processing a.b.c,
+    then re-encountered while processing d.e.f. This method would return
+    {'m.n.o': ['a', 'b', 'c']}
+
+    Args:
+        exc: a PackagingError
+
+    Returns: A dict from broken module names to lists of module names in the path.
+    """
+
+    assert isinstance(exc, PackagingError), "exception must be a PackagingError"
+    uses = {}
+    broken_module_names = [
+        m for m, attr in exc.dependency_graph.nodes.items() if attr.get("error", False)
+    ]
+    for module_name in broken_module_names:
+        path = exc.dependency_graph.first_path(module_name)
+        uses[module_name] = path
+    return uses
diff --git a/.venv/lib/python3.12/site-packages/torch/package/analyze/is_from_package.py b/.venv/lib/python3.12/site-packages/torch/package/analyze/is_from_package.py
new file mode 100644
index 0000000000000000000000000000000000000000..82ff5896b6ffcc2dcb7b15dc169729aceb8b1d75
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/package/analyze/is_from_package.py
@@ -0,0 +1,16 @@
+from types import ModuleType
+from typing import Any
+
+from .._mangling import is_mangled
+
+
+def is_from_package(obj: Any) -> bool:
+    """
+    Return whether an object was loaded from a package.
+
+    Note: packaged objects from externed modules will return ``False``.
+    """
+    if type(obj) == ModuleType:
+        return is_mangled(obj.__name__)
+    else:
+        return is_mangled(type(obj).__module__)
diff --git a/.venv/lib/python3.12/site-packages/torch/package/analyze/trace_dependencies.py b/.venv/lib/python3.12/site-packages/torch/package/analyze/trace_dependencies.py
new file mode 100644
index 0000000000000000000000000000000000000000..e029fe130bdd0fe2e8f0b8a7ac4117cdfa15c317
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/package/analyze/trace_dependencies.py
@@ -0,0 +1,65 @@
+# mypy: allow-untyped-defs
+import sys
+from collections.abc import Iterable
+from typing import Any, Callable
+
+
+__all__ = ["trace_dependencies"]
+
+
+def trace_dependencies(
+    callable: Callable[[Any], Any], inputs: Iterable[tuple[Any, ...]]
+) -> list[str]:
+    """Trace the execution of a callable in order to determine which modules it uses.
+
+    Args:
+        callable: The callable to execute and trace.
+        inputs: The input to use during tracing. The modules used by 'callable' when invoked by each set of inputs
+            are union-ed to determine all modules used by the callable for the purpooses of packaging.
+
+    Returns: A list of the names of all modules used during callable execution.
+    """
+    modules_used = set()
+
+    def record_used_modules(frame, event, arg):
+        # If the event being profiled is not a Python function
+        # call, there is nothing to do.
+        if event != "call":
+            return
+
+        # This is the name of the function that was called.
+        name = frame.f_code.co_name
+        module = None
+
+        # Try to determine the name of the module that the function
+        # is in:
+        #   1) Check the global namespace of the frame.
+        #   2) Check the local namespace of the frame.
+        #   3) To handle class instance method calls, check
+        #       the attribute named 'name' of the object
+        #       in the local namespace corresponding to "self".
+        if name in frame.f_globals:
+            module = frame.f_globals[name].__module__
+        elif name in frame.f_locals:
+            module = frame.f_locals[name].__module__
+        elif "self" in frame.f_locals:
+            method = getattr(frame.f_locals["self"], name, None)
+            module = method.__module__ if method else None
+
+        # If a module was found, add it to the set of used modules.
+        if module:
+            modules_used.add(module)
+
+    try:
+        # Attach record_used_modules as the profiler function.
+        sys.setprofile(record_used_modules)
+
+        # Execute the callable with all inputs.
+        for inp in inputs:
+            callable(*inp)
+
+    finally:
+        # Detach the profiler function.
+        sys.setprofile(None)
+
+    return list(modules_used)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16a2d552151d4aa2278227008f67ab3f6f3a2e1a
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6eb2c02205ed485ea87788a07c499547f5885b31
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_comparison.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66855a79691405409344930824da1ffdbe3d5d30
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_creation.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3210901cf2b09afe1dd507efd9f365e43f88dd51
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/testing/__pycache__/_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d3ae69a1d060ef1ce6bac8848f47fc8d63c240a
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..27919240b12e22caae6384028b6c1322e944268b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py
new file mode 100644
index 0000000000000000000000000000000000000000..a75e9c834b70ad12f90ab9b1ee6a2ea27f6af404
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/autocast_test_lists.py
@@ -0,0 +1,472 @@
+# mypy: ignore-errors
+
+import collections
+
+import torch
+from torch.testing._internal.common_utils import TEST_WITH_ROCM
+from torch.testing._internal.common_utils import TestCase
+
+
+class AutocastTestLists:
+    def _rnn_cell_args(self, n, num_chunks, is_lstm, dev, dtype):
+        input = (torch.randn((n, n), device=dev, dtype=torch.float32),)
+
+        hx = ((torch.randn((n, n), device=dev, dtype=torch.float32),
+               torch.randn((n, n), device=dev, dtype=torch.float32)) if is_lstm else
+              torch.randn((n, n), device=dev, dtype=torch.float32),)
+
+        weights = (torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32),  # weight_ih
+                   torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32),  # weight_hh
+                   torch.randn((num_chunks * n), device=dev, dtype=torch.float32),  # bias_ih
+                   torch.randn((num_chunks * n), device=dev, dtype=torch.float32))  # bias_hh
+
+        # returns args as a tuple
+        return input + hx + weights
+
+    # Supplies ops and arguments for test_autocast_* in test/test_cuda.py
+    def __init__(self, dev):
+        super().__init__()
+        n = 8
+        # Utility arguments, created as one-element tuples
+        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise2_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        mat0_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+        mat1_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+        mat2_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
+
+        dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
+        conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
+                           torch.randn(dimset, dtype=torch.float32, device=dev))
+                          for dimset in dimsets]
+        bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
+        element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
+        pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+
+        # The lists below organize ops that autocast needs to test.
+        # self.list_name corresponds to test_autocast_list_name in test/test_cuda.py.
+        # Each op is associated with a tuple of valid arguments.
+        # In addition, cudnn conv ops are not supported on ROCm and hence will
+        # be skipped by passing TEST_WITH_ROCM flag to those ops in self.torch_fp16 list.
+
+        # Some ops implement built-in type promotion.  These don't need autocasting,
+        # but autocasting relies on their promotion, so we include tests to double-check.
+        self.torch_expect_builtin_promote = [
+            ("eq", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ge", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("gt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("le", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("lt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ne", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("add", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("div", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("mul", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("cat", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
+            ("equal", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("stack", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
+        ]
+        self.methods_expect_builtin_promote = [
+            ("__eq__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ge__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__gt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__le__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__lt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ne__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__add__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__div__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__mul__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+
+        # The remaining lists organize ops that autocast treats explicitly.
+        self.torch_fp16 = [
+            # deprecated _convolution
+            ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
+                                                              (0, 0), 1, False, True, True)),
+            # the current  _convolution
+            ("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
+                                                              (0, 0), 1, False, True, True, True)),
+            ("conv1d", conv_args_fp32[0]),
+            ("conv2d", conv_args_fp32[1]),
+            ("conv3d", conv_args_fp32[2]),
+            ("conv_tbc", conv_args_fp32[0] + bias_fp32),
+            ("conv_transpose1d", conv_args_fp32[0]),
+            ("conv_transpose2d", conv_args_fp32[1]),
+            ("conv_transpose3d", conv_args_fp32[2]),
+            ("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)),
+            ("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM),
+            ("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1),
+                                                                 (1, 1), 1, False, True, True), TEST_WITH_ROCM),
+            ("prelu", pointwise0_fp32 + element0_fp32),
+            ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
+            ("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32),
+            ("addr", mat0_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            ("matmul", mat0_fp32 + mat1_fp32),
+            ("einsum", "bkhd,bqhd->bqkh", mat0_fp32 + mat1_fp32),
+            ("mm", mat0_fp32 + mat1_fp32),
+            ("mv", mat0_fp32 + pointwise0_fp32),
+            ("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32),
+            ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                    torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            # _thnn_fused_lstm_cell and _thnn_fused_gru_cell are not Python-exposed as far as I can tell.
+            # ("_thnn_fused_lstm_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            # ("_thnn_fused_gru_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
+            ("lstm_cell", self._rnn_cell_args(n, num_chunks=4, is_lstm=True, dev=dev, dtype=torch.float32)),
+            ("gru_cell", self._rnn_cell_args(n, num_chunks=3, is_lstm=False, dev=dev, dtype=torch.float32)),
+            ("rnn_tanh_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
+            ("rnn_relu_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
+        ]
+        self.torch_fp32 = [
+            ("acos", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
+            ("asin", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
+            ("cosh", pointwise0_fp16),
+            ("erfinv", (pointwise0_fp16[0].clamp(-.9, .9),)),
+            ("exp", pointwise0_fp16),
+            ("expm1", pointwise0_fp16),
+            ("log", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log10", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log2", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
+            ("log1p", (pointwise0_fp16[0].clamp(-0.9, 100.0),)),
+            ("reciprocal", pointwise0_fp16),
+            ("rsqrt", (pointwise0_fp16[0].clamp(0.0, 100.0),)),
+            ("sinh", pointwise0_fp16),
+            ("tan", (pointwise0_fp16[0].clamp(-3.1 / 2, 3.1 / 2),)),
+            ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_fp16),
+            ("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)),
+            # ("pow", (1.7,) + pointwise0_fp16), # This variant has a backend, but is not documented in the API.
+            ("softmax", pointwise0_fp16 + (0,)),
+            ("log_softmax", pointwise0_fp16 + (0,)),
+            ("layer_norm", pointwise0_fp16 + ((pointwise0_fp16[0].numel(),),)),
+            ("group_norm", mat0_fp16 + (1,)),
+            ("norm", pointwise0_fp16),
+            ("norm", pointwise0_fp16, {"dim": 0}),
+            # these need magma
+            # ("norm", mat0_fp16, {"p": "nuc"}),
+            # ("norm", mat0_fp16, {"p": "nuc", "dim": 0}),
+            ("norm", pointwise0_fp16, {"p": 1}),
+            ("norm", pointwise0_fp16, {"p": 1, "dim": 0}),
+            ("cosine_similarity", mat0_fp16 + mat1_fp16),
+            ("poisson_nll_loss", mat0_fp16 + mat1_fp16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
+            ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.float16),
+                                       torch.tensor([[1, 3, 4]], device=dev, dtype=torch.float16),
+                                       torch.tensor([1], device=dev, dtype=torch.int))),
+            ("hinge_embedding_loss", mat0_fp16 + (torch.ones(n, device=dev, dtype=torch.int),)),
+            ("kl_div", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
+            ("margin_ranking_loss", mat0_fp16 + mat1_fp16 + (torch.ones((n,), device=dev, dtype=torch.float16),)),
+            ("triplet_margin_loss", mat0_fp16 + mat1_fp16 + mat2_fp16),
+            ("binary_cross_entropy_with_logits", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
+            ("cumprod", pointwise0_fp16 + (0,)),
+            ("cumsum", pointwise0_fp16 + (0,)),
+            ("dist", pointwise0_fp16 + pointwise1_fp16),
+            ("pdist", mat0_fp16),
+            ("cdist", mat0_fp16 + mat1_fp16),
+            ("prod", pointwise0_fp16),
+            ("prod", pointwise0_fp16 + (0,)),
+            ("renorm", mat0_fp16 + (2, 0, 1.0)),
+            ("sum", pointwise0_fp16),
+            ("sum", mat0_fp16 + (1,)),
+            ("logsumexp", mat0_fp16 + (1,)),
+        ]
+        self.torch_need_autocast_promote = [
+            ("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)),
+            ("addcmul", pointwise0_fp32 + pointwise1_fp16 + pointwise2_fp16),
+            ("atan2", pointwise0_fp32 + (pointwise1_fp16[0].clamp(0.1, 100),)),
+            ("bilinear", (torch.randn((1, 2), dtype=torch.float16, device=dev),
+                          torch.randn((1, 2), dtype=torch.float32, device=dev),
+                          torch.randn((1, 2, 2), dtype=torch.float16, device=dev),
+                          torch.randn((1,), dtype=torch.float32, device=dev))),
+            ("cross", (torch.randn(3, dtype=torch.float32, device=dev),
+                       torch.randn(3, dtype=torch.float16, device=dev))),
+            ("dot", pointwise0_fp16 + pointwise1_fp32),
+            ("vdot", pointwise0_fp16 + pointwise1_fp32),
+            ("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev),
+                              torch.randn((2, 22, 11, 2), dtype=torch.float32, device=dev),
+                              0, 0, False)),
+            ("index_put", pointwise0_fp32 + ((torch.tensor([1], device=dev, dtype=torch.long),),
+                                             torch.randn(1, device=dev, dtype=torch.float16))),
+            ("index_put", pointwise0_fp16 + ((torch.tensor([1], device=dev, dtype=torch.long),),
+                                             torch.randn(1, device=dev, dtype=torch.float32))),
+            ("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev),
+                           torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
+            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev),
+                             0,
+                             torch.randint(0, 2, (2, 2, 2), device=dev),
+                             torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
+            ("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev),
+                             0,
+                             torch.randint(0, 2, (2, 2, 2), device=dev),
+                             torch.randn((2, 2, 2), dtype=torch.float32, device=dev))),
+        ]
+        self.nn_fp16 = [
+            ("linear", mat0_fp32 + mat1_fp32 + mat2_fp32),
+        ]
+        self.nn_fp32 = [
+            ("softplus", pointwise0_fp16),
+            ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float),
+                          torch.zeros((n,), device=dev, dtype=torch.long))),
+            ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.half),
+                            torch.zeros((n, n, n), device=dev, dtype=torch.long))),
+            ("l1_loss", mat0_fp16 + mat1_fp16),
+            ("smooth_l1_loss", mat0_fp16 + mat1_fp16),
+            ("mse_loss", mat0_fp16 + mat1_fp16),
+            ("multilabel_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
+        ]
+        self.linalg_fp16 = [
+            ("linalg_vecdot", mat0_fp32 + mat0_fp32),
+            ("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
+        ]
+        self.methods_fp16 = [
+            ("__matmul__", mat0_fp32 + mat1_fp32)
+        ]
+        self.methods_fp32 = [
+            ("__pow__", (torch.rand(n, device=dev, dtype=torch.float16), 1.5)),
+        ]
+        self.banned = [
+            ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32),
+                                      torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
+        ]
+
+
+class AutocastCPUTestLists:
+    # Supplies ops and arguments for test_autocast_* in test/test_cpu.py
+    def __init__(self, dev):
+        super().__init__()
+        n = 8
+        # Utility arguments, created as one-element tuples
+        pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
+        pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
+        mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+        mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+        mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
+
+        pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+        pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
+
+        dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
+
+        dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
+                      for dimset in dummy_dimsets]
+
+        dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
+        conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
+                           torch.randn(dimset, dtype=torch.float32, device=dev))
+                          for dimset in dimsets]
+
+        element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
+        pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
+        mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+        mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
+
+        dummy_fp32 = [  # noqa: F841
+            (torch.randn(dimset, dtype=torch.float32, device=dev),)
+            for dimset in dummy_dimsets
+        ]
+        # The lists below organize ops that autocast needs to test.
+        # self.list_name corresponds to test_autocast_list_name in test/test_cpu.py.
+        # Each op is associated with a tuple of valid arguments.
+
+        # Some ops implement built-in type promotion.  These don't need autocasting,
+        # but autocasting relies on their promotion, so we include tests to double-check.
+        self.torch_expect_builtin_promote = [
+            ("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+
+        self.methods_expect_builtin_promote = [
+            ("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
+            ("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+            ("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
+        ]
+        # The remaining lists organize ops that autocast treats explicitly.
+        self.torch_16 = [
+            ("conv1d", conv_args_fp32[0]),
+            ("conv2d", conv_args_fp32[1]),
+            ("conv3d", conv_args_fp32[2]),
+            ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                     torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("mm", mat0_fp32 + mat1_fp32),
+            ("matmul", mat0_fp32 + mat1_fp32),
+            ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                         torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
+            ("_addmm_activation", mat1_fp32 + mat2_fp32 + mat3_fp32, {"beta": 1, "alpha": 1, "use_gelu": True}),
+            ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                    torch.randn((n, n, n), device=dev, dtype=torch.float32))),
+            ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32),
+                          torch.randn((5, 3, 5), device=dev, dtype=torch.float32),
+                          torch.randn(5, device=dev, dtype=torch.float32),
+                          0)),
+            ("conv_transpose1d", conv_args_fp32[0]),
+            ("conv_transpose2d", conv_args_fp32[1]),
+            ("conv_transpose3d", conv_args_fp32[2]),
+            ("prelu", pointwise0_fp32 + element0_fp32),
+            ("_native_multi_head_attention", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n, n), device=dev, dtype=torch.float32),
+                                              n, 4, torch.randn((3 * n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((3 * n), device=dev, dtype=torch.float32),
+                                              torch.randn((n, n), device=dev, dtype=torch.float32),
+                                              torch.randn((n), device=dev, dtype=torch.float32))),
+        ]
+        self.torch_fp32 = [
+            ("poisson_nll_loss", mat0_bf16 + mat1_bf16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
+            ("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.bfloat16),
+                                       torch.tensor([[1, 3, 4]], device=dev, dtype=torch.bfloat16),
+                                       torch.tensor([1], device=dev, dtype=torch.int))),
+            ("hinge_embedding_loss", mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)),
+            ("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones((n,), device=dev, dtype=torch.bfloat16),)),
+            ("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
+            ("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
+        ]
+        self.nn_16 = [
+            ("linear", mat0_fp32 + mat1_fp32, {}),
+        ]
+        self.nn_fp32 = [
+            ("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
+            ("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
+                                     (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
+            ("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
+            ("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),
+                          torch.zeros((n,), device=dev, dtype=torch.long))),
+            ("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.bfloat16),
+                            torch.zeros((n, n, n), device=dev, dtype=torch.long))),
+            ("l1_loss", mat0_bf16 + mat1_bf16),
+            ("smooth_l1_loss", mat0_bf16 + mat1_bf16),
+            ("mse_loss", mat0_bf16 + mat1_bf16),
+            ("multilabel_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("soft_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
+            ("multi_margin_loss", mat0_bf16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
+            ("huber_loss", mat0_bf16 + mat1_bf16),
+        ]
+        self.torch_need_autocast_promote = [
+            ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
+            ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
+        ]
+
+
+class TestAutocast(TestCase):
+    def args_maybe_kwargs(self, op_with_args):
+        if len(op_with_args) == 2:
+            return op_with_args[0], op_with_args[1], {}
+        else:
+            return op_with_args[0], op_with_args[1], op_with_args[2]
+
+    def _run_autocast_outofplace(
+        self,
+        op,
+        args,
+        run_as_type,
+        device,
+        out_type=None,
+        module=torch,
+        add_kwargs=None,
+        amp_dtype=torch.bfloat16,
+    ):
+        # helper to cast args
+        def cast(val, to_type):
+            if isinstance(val, torch.Tensor):
+                return val.to(to_type) if val.is_floating_point() else val
+            elif isinstance(val, collections.abc.Iterable):
+                return type(val)(cast(v, to_type) for v in val)
+            else:
+                return val
+
+        if add_kwargs is None:
+            add_kwargs = {}
+
+        self.assertFalse(torch.is_autocast_enabled(device_type=device))
+        with torch.amp.autocast(device_type=device, dtype=amp_dtype):
+            self.assertTrue(torch.is_autocast_enabled(device_type=device))
+
+            out_type = out_type if out_type is not None else run_as_type
+            output = output_method = None
+
+            # Try module.* variant, if requested:
+            if module is not None and hasattr(module, op):
+                output = getattr(module, op)(*args, **add_kwargs)
+                if isinstance(output, torch.Tensor):
+                    self.assertTrue(
+                        out_type == output.dtype,
+                        f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
+                    )
+            # Try Tensor.* variant:
+            if hasattr(torch.Tensor, op):
+                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
+                if isinstance(output_method, torch.Tensor):
+                    self.assertTrue(
+                        out_type == output_method.dtype,
+                        f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
+                    )
+
+            self.assertTrue(
+                (output is not None) or (output_method is not None),
+                f"{op} not found as an attribute on either Tensor or the requested module {module}",
+            )
+
+            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
+            # For example, lstm_cell returns a tuple and equal returns bool.
+            def compare(first, second):
+                if isinstance(first, torch.Tensor):
+                    return torch.equal(first, second)
+                elif isinstance(first, collections.abc.Iterable):
+                    return all(compare(f, s) for f, s in zip(first, second))
+                else:
+                    return first == second
+
+            # If both torch.* and Tensor.* variants were found, check outputs are identical
+            if (output is not None) and (output_method is not None):
+                self.assertTrue(type(output) == type(output_method))
+                comparison = compare(output, output_method)
+                self.assertTrue(
+                    comparison, f"torch.{op} result did not match Tensor.{op} result"
+                )
+
+            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
+            # as the C++-side autocasting, and should be bitwise accurate.
+            output_to_compare = output if output is not None else output_method
+            with torch.amp.autocast(device_type=device, enabled=False):
+                self.assertFalse(
+                    torch.is_autocast_enabled(device_type=device)
+                )
+
+                if module is not None and hasattr(module, op):
+                    control = getattr(module, op)(
+                        *cast(args, run_as_type), **add_kwargs
+                    )
+                else:
+                    control = getattr(args[0].to(run_as_type), op)(
+                        *cast(args[1:], run_as_type), **add_kwargs
+                    )
+                self.assertTrue(type(output_to_compare) == type(control))
+                comparison = compare(output_to_compare, control)
+                self.assertTrue(comparison, f"torch.{op} result did not match control")
+            self.assertTrue(torch.is_autocast_enabled(device_type=device))
+        self.assertFalse(torch.is_autocast_enabled(device_type=device))
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..46abb4bb758dde5752d974f5459ccd77ac9c0f74
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/autograd_function_db.py
@@ -0,0 +1,633 @@
+# mypy: ignore-errors
+
+import torch
+from functools import partial
+from torch.testing import make_tensor
+from torch.testing._internal.opinfo.core import (
+    OpInfo,
+    SampleInput,
+)
+from torch.testing._internal.common_dtype import all_types_and
+import numpy as np
+
+# Note: [autograd.Function db]
+#
+# This is a collection of autograd.Function test cases written as OpInfos
+# so they can easily be consumed by OpInfo-based tests to check if a subsystem
+# supports autograd.Function.
+#
+# Axes:
+# - saves {output, input, intermediate, non-tensor}
+# - {inputs, output} x {single tensor, tensors, arbitrary objects}
+# - Uses {mark_dirty, mark_non_differentiable, once_differentiable}
+
+
+def to_numpy(tensor):
+    return tensor.cpu().numpy()
+
+
+class NumpyCube(torch.autograd.Function):
+    @staticmethod
+    def forward(input):
+        input_np = to_numpy(input)
+        dinput = torch.tensor(3 * input_np ** 2, device=input.device)
+        return torch.tensor(input_np ** 3, device=input.device), dinput
+
+    @staticmethod
+    def setup_context(ctx, inputs, output):
+        ctx.save_for_backward(inputs[0], output[1])
+        ctx.save_for_forward(inputs[0], output[1])
+
+    @staticmethod
+    def backward(ctx, grad_output, grad_saved):
+        input, dinput = ctx.saved_tensors
+        return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input)
+
+    @staticmethod
+    def vmap(info, in_dims, input):
+        result = NumpyCube.apply(input)
+        return result, (in_dims[0], in_dims[0])
+
+    @staticmethod
+    def jvp(ctx, input_tangent):
+        input, dinput = ctx.saved_tensors
+        return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
+
+
+class CubeGenVmap(torch.autograd.Function):
+    generate_vmap_rule = True
+
+    @staticmethod
+    def forward(x):
+        return x ** 3, 3 * x ** 2
+
+    @staticmethod
+    def setup_context(ctx, inputs, outputs):
+        ctx.save_for_backward(inputs[0], outputs[1])
+        ctx.save_for_forward(inputs[0], outputs[1])
+
+    @staticmethod
+    def backward(ctx, grad_output, grad_saved):
+        _input, dinput = ctx.saved_tensors
+        result = grad_output * dinput + 6 * dinput
+        return result
+
+    @staticmethod
+    def jvp(ctx, input_tangent):
+        input, dinput = ctx.saved_tensors
+        return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
+
+
+def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg(1, low=0.8, high=2), args=())
+
+
+class NumpyCubeNotComposable(torch.autograd.Function):
+    @staticmethod
+    def forward(input):
+        input_np = to_numpy(input)
+        return torch.tensor(input_np ** 3, device=input.device), input_np
+
+    @staticmethod
+    def setup_context(ctx, inputs, output):
+        _, input_np = output
+        ctx.input_np = input_np
+        ctx.device = inputs[0].device
+
+    @staticmethod
+    @torch.autograd.function.once_differentiable
+    def backward(ctx, grad_output, grad_saved):
+        result_np = 3 * (ctx.input_np ** 2)
+        return torch.tensor(result_np, device=ctx.device)
+
+
+class NumpyMul(torch.autograd.Function):
+    @staticmethod
+    def forward(x, y):
+        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
+
+    @staticmethod
+    def setup_context(ctx, inputs, output):
+        ctx.save_for_backward(*inputs)
+        ctx.save_for_forward(*inputs)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        x, y = ctx.saved_tensors
+        gx = None
+        if ctx.needs_input_grad[0]:
+            gx = NumpyMul.apply(grad_output, y)
+        gy = None
+        if ctx.needs_input_grad[1]:
+            gy = NumpyMul.apply(grad_output, x)
+        return gx, gy
+
+    @staticmethod
+    def vmap(info, in_dims, x, y):
+        x_bdim, y_bdim = in_dims
+        x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
+        y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
+        result = NumpyMul.apply(x, y)
+        result = result.movedim(-1, 0)
+        return result, 0
+
+    @staticmethod
+    def jvp(ctx, x_tangent, y_tangent):
+        x, y = ctx.saved_tensors
+        return x_tangent * y + y_tangent * x
+
+def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    # Broadcasting
+    yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),))
+
+def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14})
+
+class MulGenVmap(torch.autograd.Function):
+    generate_vmap_rule = True
+
+    @staticmethod
+    def forward(x, y):
+        return x * y
+
+    @staticmethod
+    def setup_context(ctx, inputs, outputs):
+        ctx.save_for_backward(*inputs)
+        ctx.save_for_forward(*inputs)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        x, y = ctx.saved_tensors
+        gx = None
+        if ctx.needs_input_grad[0]:
+            gx = MulGenVmap.apply(grad_output, y)
+        gy = None
+        if ctx.needs_input_grad[1]:
+            gy = MulGenVmap.apply(grad_output, x)
+        return gx, gy
+
+    @staticmethod
+    def jvp(ctx, x_tangent, y_tangent):
+        x, y = ctx.saved_tensors
+        return x_tangent * y + y_tangent * x
+
+
+class NumpyExp_(torch.autograd.Function):
+    @staticmethod
+    def forward(x):
+        x_np = to_numpy(x)
+        np.exp(x_np, x_np)
+        return x
+
+    @staticmethod
+    def setup_context(ctx, inputs, output):
+        x, = inputs
+        ctx.mark_dirty(x)
+        ctx.save_for_backward(output)
+        ctx.save_for_forward(output)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        output, = ctx.saved_tensors
+        return NumpyMul.apply(grad_output, output)
+
+    @staticmethod
+    def vmap(info, in_dims, x):
+        NumpyExp_.apply(x)
+        return x, in_dims[0]
+
+    @staticmethod
+    def jvp(ctx, x_tangent):
+        # Doesn't call numpy operations because I didn't want to write NumpyMul_
+        output, = ctx.saved_tensors
+        x_tangent.mul_(output)
+        return x_tangent
+
+class NumpySort(torch.autograd.Function):
+    @staticmethod
+    def forward(x, dim):
+        device = x.device
+        x = to_numpy(x)
+        ind = np.argsort(x, axis=dim)
+        ind_inv = np.argsort(ind, axis=dim)
+        return (
+            torch.tensor(x, device=device),
+            torch.tensor(ind, device=device),
+            torch.tensor(ind_inv, device=device),
+        )
+
+    @staticmethod
+    def setup_context(ctx, inputs, output):
+        _x, dim = inputs
+        _, ind, ind_inv = output
+        ctx.mark_non_differentiable(ind, ind_inv)
+        ctx.save_for_backward(ind, ind_inv)
+        ctx.save_for_forward(ind, ind_inv)
+        ctx.dim = dim
+
+    @staticmethod
+    def backward(ctx, grad_output, _0, _1):
+        ind, ind_inv = ctx.saved_tensors
+        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
+
+    @staticmethod
+    def vmap(info, in_dims, x, dim):
+        x_bdim, _ = in_dims
+        x = x.movedim(x_bdim, 0)
+        # wrap dim
+        dim = dim if dim >= 0 else dim + x.dim() - 1
+        return NumpySort.apply(x, dim + 1), (0, 0, 0)
+
+    @staticmethod
+    def jvp(ctx, x_tangent, _):
+        ind, ind_inv = ctx.saved_tensors
+        return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
+
+class SortGenVmap(torch.autograd.Function):
+    generate_vmap_rule = True
+
+    @staticmethod
+    def forward(x, dim):
+        ind = torch.argsort(x, dim=dim)
+        ind_inv = torch.argsort(ind, axis=dim)
+        result = torch.take_along_dim(x, ind, dim=dim)
+        return result, ind, ind_inv
+
+    @staticmethod
+    def setup_context(ctx, inputs, outputs):
+        x, dim = inputs
+        _, ind, ind_inv = outputs
+        ctx.mark_non_differentiable(ind, ind_inv)
+        ctx.save_for_backward(ind, ind_inv)
+        ctx.save_for_forward(ind, ind_inv)
+        ctx.dim = dim
+
+    @staticmethod
+    def backward(ctx, grad_output, _0, _1):
+        ind, ind_inv = ctx.saved_tensors
+        return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None
+
+    @staticmethod
+    def jvp(ctx, x_tangent, _):
+        ind, ind_inv = ctx.saved_tensors
+        return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
+
+
+def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg(3, 5), args=(1,))
+
+
+def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    tensor = make_arg(3, 5)
+    dim = 1
+    _, ind, ind_inv = NumpySort.apply(tensor, 1)
+    yield SampleInput(tensor, args=(ind, ind_inv, dim))
+
+
+class NumpyTake(torch.autograd.Function):
+    @staticmethod
+    def forward(x, ind, ind_inv, dim):
+        device = x.device
+        x = to_numpy(x)
+        ind = to_numpy(ind)
+        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
+
+    @staticmethod
+    def setup_context(ctx, inputs, output):
+        _x, ind, ind_inv, dim = inputs
+        ctx.save_for_backward(ind, ind_inv)
+        ctx.save_for_forward(ind, ind_inv)
+        ctx.dim = dim
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        ind, ind_inv = ctx.saved_tensors
+        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
+        return result, None, None, None
+
+    @staticmethod
+    def vmap(info, in_dims, x, ind, ind_inv, dim):
+        x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
+
+        # wrap dim
+        logical_dim = x.dim() if x_bdim is None else x_bdim - 1
+        dim = dim if dim >= 0 else dim + logical_dim
+
+        def expand_bdim(x, x_bdim):
+            if x_bdim is None:
+                return x.expand(info.batch_size, *x.shape)
+            return x.movedim(x_bdim, 0)
+
+        x = expand_bdim(x, x_bdim)
+        ind = expand_bdim(ind, ind_bdim)
+        ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
+
+        return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
+
+    @staticmethod
+    def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
+        assert ind_tangent is None
+        assert ind_inv_tangent is None
+        ind, ind_inv = ctx.saved_tensors
+        return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim)
+
+class TakeGenVmap(torch.autograd.Function):
+    generate_vmap_rule = True
+
+    @staticmethod
+    def forward(x, ind, ind_inv, dim):
+        return torch.take_along_dim(x, ind, dim)
+
+    @staticmethod
+    def setup_context(ctx, inputs, outputs):
+        _x, ind, ind_inv, dim = inputs
+        ctx.save_for_backward(ind, ind_inv)
+        ctx.save_for_forward(ind, ind_inv)
+        ctx.dim = dim
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        ind, ind_inv = ctx.saved_tensors
+        result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim)
+        return result, None, None, None
+
+    @staticmethod
+    def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
+        ind, ind_inv = ctx.saved_tensors
+        return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim)
+
+class Select(torch.autograd.Function):
+    @staticmethod
+    def forward(x, idx):
+        return x[idx]
+
+    @staticmethod
+    def setup_context(ctx, inputs, output):
+        x, idx = inputs
+        ctx.x_shape = x.shape
+        ctx.idx = idx
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        result = grad_output.new_zeros(ctx.x_shape)
+        result[ctx.idx] = grad_output
+        return result, None
+
+    @staticmethod
+    def vmap(info, in_dims, x, idx):
+        x_bdim, _ = in_dims
+        x = x.movedim(x_bdim, 1)
+        return Select.apply(x, idx), 0
+
+    @staticmethod
+    def jvp(ctx, x_tangent, _):
+        return Select.apply(x_tangent, ctx.idx)
+
+class SelectGenVmap(torch.autograd.Function):
+    generate_vmap_rule = True
+
+    @staticmethod
+    def forward(x, idx):
+        return x[idx]
+
+    @staticmethod
+    def setup_context(ctx, inputs, outputs):
+        x, idx = inputs
+        ctx.x_shape = x.shape
+        ctx.idx = idx
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        result = grad_output.new_zeros(ctx.x_shape)
+        result[ctx.idx] = grad_output
+        return result, None
+
+    @staticmethod
+    def jvp(ctx, x_tangent, _):
+        return SelectGenVmap.apply(x_tangent, ctx.idx)
+
+
+def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg(3, 5), args=(2,))
+
+class ScaleGradGenVmap(torch.autograd.Function):
+    generate_vmap_rule = True
+    scale = 3.14
+
+    @staticmethod
+    def forward(x):
+        return x.clone()
+
+    @staticmethod
+    def setup_context(ctx, inputs, outputs):
+        pass
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output * ScaleGradGenVmap.scale
+
+    @staticmethod
+    def jvp(ctx, x_tangent):
+        return x_tangent * ScaleGradGenVmap.scale
+
+class ZeroGradientsGenVmap(torch.autograd.Function):
+    generate_vmap_rule = True
+
+    @staticmethod
+    def forward(x, y):
+        return x.clone(), y.clone()
+
+    @staticmethod
+    def setup_context(ctx, inputs, outputs):
+        pass
+
+    @staticmethod
+    def backward(ctx, gx, gy):
+        # Intentionally returning torch.zeros instead of zeros_like or new_zeros.
+        # Also intentionally not None.
+        return (
+            # Intentionally too-large gradient
+            torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device),
+            torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
+        )
+
+    @staticmethod
+    def jvp(ctx, gx, gy):
+        # Intentionally returning torch.zeros instead of zeros_like or new_zeros.
+        # Also intentionally not None.
+        return (
+            torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device),
+            torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
+        )
+
+
+def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg(3, 5))
+
+
+class ForwardHasDefaultArgs(torch.autograd.Function):
+    @staticmethod
+    def forward(x, idx=(2,)):
+        return x[idx]
+
+    @staticmethod
+    def setup_context(ctx, inputs, output):
+        x, idx = inputs
+        ctx.x_shape = x.shape
+        ctx.idx = idx
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        result = grad_output.new_zeros(ctx.x_shape)
+        result[ctx.idx] = grad_output
+        return result, None
+
+    @staticmethod
+    def vmap(info, in_dims, x, idx):
+        x_bdim, _ = in_dims
+        x = x.movedim(x_bdim, 1)
+        return ForwardHasDefaultArgs.apply(x, idx), 0
+
+    @staticmethod
+    def jvp(ctx, x_tangent, _):
+        return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx)
+
+
+autograd_function_db = [
+    OpInfo(
+        'NumpyCubeAutogradFunction',
+        op=NumpyCube.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_numpy_cube,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyExpMarkDirtyAutogradFunction',
+        op=lambda x: NumpyExp_.apply(x.clone()),
+        inplace_variant=NumpyExp_.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_numpy_cube,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyMulAutogradFunction',
+        op=NumpyMul.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_numpy_mul,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyCubeNotComposableAutogradFunction',
+        op=lambda x: NumpyCubeNotComposable.apply(x)[0],
+        supports_forward_ad=False,
+        supports_fwgrad_bwgrad=False,
+        sample_inputs_func=sample_inputs_numpy_cube,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpySortAutogradFunction',
+        op=NumpySort.apply,
+        supports_forward_ad=False,
+        supports_fwgrad_bwgrad=False,
+        sample_inputs_func=sample_inputs_numpy_sort,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        gradcheck_wrapper=lambda y, ind: y,
+    ),
+    OpInfo(
+        'NumpyTakeAutogradFunction',
+        op=NumpyTake.apply,
+        supports_forward_ad=False,
+        supports_fwgrad_bwgrad=False,
+        sample_inputs_func=sample_inputs_numpy_take,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'SelectAutogradFunction',
+        op=Select.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_select,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'CubeGenVmapAutogradFunction',
+        op=CubeGenVmap.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_numpy_cube,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'MulGenVmapAutogradFunction',
+        op=MulGenVmap.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_numpy_mul,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'SortGenVmapAutogradFunction',
+        op=SortGenVmap.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_numpy_sort,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        gradcheck_wrapper=lambda y, ind: y,
+    ),
+    OpInfo(
+        'SelectGenVmapAutogradFunction',
+        op=SelectGenVmap.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_select,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'ScaleGradGenVmapAutogradFunction',
+        op=ScaleGradGenVmap.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_numpy_cube,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'ZeroGradientsGenVmapAutogradFunction',
+        op=ZeroGradientsGenVmap.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_numpy_mul,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'ForwardHasDefaultArgsAutogradFunction',
+        op=ForwardHasDefaultArgs.apply,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_forward_default_args,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2219ef4ea56aa306dfdd3af18b7403af8384c78
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/check_kernel_launches.py
@@ -0,0 +1,164 @@
+# mypy: ignore-errors
+
+import os
+import re
+import sys
+
+__all__ = [
+    "check_code_for_cuda_kernel_launches",
+    "check_cuda_kernel_launches",
+]
+
+# FILES TO EXCLUDE (match is done with suffix using `endswith`)
+# You wouldn't drive without a seatbelt, though, so why would you
+# launch a kernel without some safety? Use this as a quick workaround
+# for a problem with the checker, fix the checker, then de-exclude
+# the files in question.
+exclude_files: list[str] = []
+
+# Without using a C++ AST we can't 100% detect kernel launches, so we
+# model them as having the pattern "<<>>(arguments);"
+# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
+# the next statement.
+#
+# We model the next statement as ending at the next `}` or `;`.
+# If we see `}` then a clause ended (bad) if we see a semi-colon then
+# we expect the launch check just before it.
+#
+# Since the kernel launch can include lambda statements, it's important
+# to find the correct end-paren of the kernel launch. Doing this with
+# pure regex requires recursive regex, which aren't part of the Python
+# standard library. To avoid an additional dependency, we build a prefix
+# regex that finds the start of a kernel launch, use a paren-matching
+# algorithm to find the end of the launch, and then another regex to
+# determine if a launch check is present.
+
+# Finds potential starts of kernel launches
+kernel_launch_start = re.compile(
+    r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
+)
+
+# This pattern should start at the character after the final paren of the
+# kernel launch. It returns a match if the launch check is not the next statement
+has_check = re.compile(
+    r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
+)
+
+def find_matching_paren(s: str, startpos: int) -> int:
+    """Given a string "prefix (unknown number of characters) suffix"
+    and the position of the first `(` returns the index of the character
+    1 past the `)`, accounting for paren nesting
+    """
+    opening = 0
+    for i, c in enumerate(s[startpos:]):
+        if c == '(':
+            opening += 1
+        elif c == ')':
+            opening -= 1
+            if opening == 0:
+                return startpos + i + 1
+
+    raise IndexError("Closing parens not found!")
+
+
+def should_exclude_file(filename) -> bool:
+    for exclude_suffix in exclude_files:
+        if filename.endswith(exclude_suffix):
+            return True
+    return False
+
+
+def check_code_for_cuda_kernel_launches(code, filename=None):
+    """Checks code for CUDA kernel launches without cuda error checks.
+
+    Args:
+        filename - Filename of file containing the code. Used only for display
+                   purposes, so you can put anything here.
+        code     - The code to check
+
+    Returns:
+        The number of unsafe kernel launches in the code
+    """
+    if filename is None:
+        filename = "##Python Function Call##"
+
+    # We break the code apart and put it back together to add
+    # helpful line numberings for identifying problem areas
+    code = enumerate(code.split("\n"))                             # Split by line breaks
+    code = [f"{lineno}: {linecode}" for lineno, linecode in code]  # Number the lines
+    code = '\n'.join(code)                                         # Put it back together
+
+    num_launches_without_checks = 0
+    for m in kernel_launch_start.finditer(code):
+        end_paren = find_matching_paren(code, m.end() - 1)
+        if has_check.match(code, end_paren):
+            num_launches_without_checks += 1
+            context = code[m.start():end_paren + 1]
+            print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
+
+    return num_launches_without_checks
+
+
+def check_file(filename):
+    """Checks a file for CUDA kernel launches without cuda error checks
+
+    Args:
+        filename - File to check
+
+    Returns:
+        The number of unsafe kernel launches in the file
+    """
+    if not (filename.endswith((".cu", ".cuh"))):
+        return 0
+    if should_exclude_file(filename):
+        return 0
+    with open(filename) as f:
+        contents = f.read()
+        unsafeCount = check_code_for_cuda_kernel_launches(contents, filename)
+    return unsafeCount
+
+
+def check_cuda_kernel_launches():
+    """Checks all pytorch code for CUDA kernel launches without cuda error checks
+
+    Returns:
+        The number of unsafe kernel launches in the codebase
+    """
+    torch_dir = os.path.dirname(os.path.realpath(__file__))
+    torch_dir = os.path.dirname(torch_dir)  # Go up to parent torch
+    torch_dir = os.path.dirname(torch_dir)  # Go up to parent caffe2
+
+    kernels_without_checks = 0
+    files_without_checks = []
+    for root, dirnames, filenames in os.walk(torch_dir):
+        # `$BASE/build` and `$BASE/torch/include` are generated
+        # so we don't want to flag their contents
+        if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
+            # Curtail search by modifying dirnames and filenames in place
+            # Yes, this is the way to do this, see `help(os.walk)`
+            dirnames[:] = []
+            continue
+
+        for x in filenames:
+            filename = os.path.join(root, x)
+            file_result = check_file(filename)
+            if file_result > 0:
+                kernels_without_checks += file_result
+                files_without_checks.append(filename)
+
+    if kernels_without_checks > 0:
+        count_str = f"Found {kernels_without_checks} instances in " \
+                    f"{len(files_without_checks)} files where kernel " \
+                    "launches didn't have checks."
+        print(count_str, file=sys.stderr)
+        print("Files without checks:", file=sys.stderr)
+        for x in files_without_checks:
+            print(f"\t{x}", file=sys.stderr)
+        print(count_str, file=sys.stderr)
+
+    return kernels_without_checks
+
+
+if __name__ == "__main__":
+    unsafe_launches = check_cuda_kernel_launches()
+    sys.exit(0 if unsafe_launches == 0 else 1)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e3572cfc4c6a0ddc3d8fa2e1b056415204acdfa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/codegen/__init__.py
@@ -0,0 +1 @@
+# mypy: ignore-errors
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py
new file mode 100644
index 0000000000000000000000000000000000000000..a211851d671fa254caab98cf28867721960f5f5c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_cuda.py
@@ -0,0 +1,347 @@
+# mypy: ignore-errors
+
+r"""This file is allowed to initialize CUDA context when imported."""
+
+import functools
+import torch
+import torch.cuda
+from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS, IS_MACOS
+import inspect
+import contextlib
+import os
+import unittest
+
+
+CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
+
+
+TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
+CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
+# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
+if TEST_WITH_ROCM:
+    TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
+else:
+    TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
+
+TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
+
+SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
+SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
+SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
+SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
+SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
+SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
+SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
+SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0))
+SM120OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (12, 0))
+
+IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
+                  and torch.cuda.get_device_capability()[1] > 0)
+IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR))
+IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
+
+def evaluate_gfx_arch_within(arch_list):
+    if not torch.cuda.is_available():
+        return False
+    gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
+    effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
+    # gcnArchName can be complicated strings like gfx90a:sramecc+:xnack-
+    # Hence the matching should be done reversely
+    return any(arch in effective_arch for arch in arch_list)
+
+def CDNA3OrLater():
+    return evaluate_gfx_arch_within(["gfx940", "gfx941", "gfx942", "gfx950"])
+
+def CDNA2OrLater():
+    return evaluate_gfx_arch_within(["gfx90a", "gfx942"])
+
+def evaluate_platform_supports_flash_attention():
+    if TEST_WITH_ROCM:
+        arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
+        if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
+            arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
+        return evaluate_gfx_arch_within(arch_list)
+    if TEST_CUDA:
+        return not IS_WINDOWS and SM80OrLater
+    return False
+
+def evaluate_platform_supports_efficient_attention():
+    if TEST_WITH_ROCM:
+        arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
+        if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
+            arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
+        return evaluate_gfx_arch_within(arch_list)
+    if TEST_CUDA:
+        return True
+    return False
+
+def evaluate_platform_supports_cudnn_attention():
+    return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
+
+PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
+PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
+PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
+# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
+PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or
+                                                  PLATFORM_SUPPORTS_CUDNN_ATTENTION or
+                                                  PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
+
+PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
+
+PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
+
+def evaluate_platform_supports_fp8():
+    if torch.cuda.is_available():
+        if torch.version.hip:
+            ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
+            archs = ['gfx94']
+            if ROCM_VERSION >= (6, 3):
+                archs.extend(['gfx120'])
+            if ROCM_VERSION >= (6, 5):
+                archs.append('gfx95')
+            for arch in archs:
+                if arch in torch.cuda.get_device_properties(0).gcnArchName:
+                    return True
+        else:
+            return SM90OrLater or torch.cuda.get_device_capability() == (8, 9)
+    return False
+
+PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
+
+PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
+
+if TEST_NUMBA:
+    try:
+        import numba.cuda
+        TEST_NUMBA_CUDA = numba.cuda.is_available()
+    except Exception:
+        TEST_NUMBA_CUDA = False
+        TEST_NUMBA = False
+else:
+    TEST_NUMBA_CUDA = False
+
+# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
+# RNG have been initialized.
+__cuda_ctx_rng_initialized = False
+
+
+# after this call, CUDA context and RNG must have been initialized on each GPU
+def initialize_cuda_context_rng():
+    global __cuda_ctx_rng_initialized
+    assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
+    if not __cuda_ctx_rng_initialized:
+        # initialize cuda context and rng for memory tests
+        for i in range(torch.cuda.device_count()):
+            torch.randn(1, device=f"cuda:{i}")
+        __cuda_ctx_rng_initialized = True
+
+
+@contextlib.contextmanager
+def tf32_off():
+    old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
+    try:
+        torch.backends.cuda.matmul.allow_tf32 = False
+        with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
+            yield
+    finally:
+        torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
+
+
+@contextlib.contextmanager
+def tf32_on(self, tf32_precision=1e-5):
+    if torch.version.hip:
+        hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
+        os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
+    old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
+    old_precision = self.precision
+    try:
+        torch.backends.cuda.matmul.allow_tf32 = True
+        self.precision = tf32_precision
+        with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
+            yield
+    finally:
+        if torch.version.hip:
+            if hip_allow_tf32 is not None:
+                os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
+            else:
+                del os.environ["HIPBLASLT_ALLOW_TF32"]
+        torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
+        self.precision = old_precision
+
+
+@contextlib.contextmanager
+def tf32_enabled():
+    """
+    Context manager to temporarily enable TF32 for CUDA operations.
+    Restores the previous TF32 state after exiting the context.
+    """
+    old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
+    try:
+        torch.backends.cuda.matmul.allow_tf32 = True
+        with torch.backends.cudnn.flags(
+            enabled=None, benchmark=None, deterministic=None, allow_tf32=True
+        ):
+            yield
+    finally:
+        torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
+
+
+# This is a wrapper that wraps a test to run this test twice, one with
+# allow_tf32=True, another with allow_tf32=False. When running with
+# allow_tf32=True, it will use reduced precision as specified by the
+# argument. For example:
+#    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
+#    @tf32_on_and_off(0.005)
+#    def test_matmul(self, device, dtype):
+#        a = ...; b = ...;
+#        c = torch.matmul(a, b)
+#        self.assertEqual(c, expected)
+# In the above example, when testing torch.float32 and torch.complex64 on CUDA
+# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
+# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
+# precision to check values.
+#
+# This decorator can be used for function with or without device/dtype, such as
+# @tf32_on_and_off(0.005)
+# def test_my_op(self)
+# @tf32_on_and_off(0.005)
+# def test_my_op(self, device)
+# @tf32_on_and_off(0.005)
+# def test_my_op(self, device, dtype)
+# @tf32_on_and_off(0.005)
+# def test_my_op(self, dtype)
+# if neither device nor dtype is specified, it will check if the system has ampere device
+# if device is specified, it will check if device is cuda
+# if dtype is specified, it will check if dtype is float32 or complex64
+# tf32 and fp32 are different only when all the three checks pass
+def tf32_on_and_off(tf32_precision=1e-5):
+    def with_tf32_disabled(self, function_call):
+        with tf32_off():
+            function_call()
+
+    def with_tf32_enabled(self, function_call):
+        with tf32_on(self, tf32_precision):
+            function_call()
+
+    def wrapper(f):
+        params = inspect.signature(f).parameters
+        arg_names = tuple(params.keys())
+
+        @functools.wraps(f)
+        def wrapped(*args, **kwargs):
+            kwargs.update(zip(arg_names, args))
+            cond = torch.cuda.is_tf32_supported()
+            if 'device' in kwargs:
+                cond = cond and (torch.device(kwargs['device']).type == 'cuda')
+            if 'dtype' in kwargs:
+                cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
+            if cond:
+                with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
+                with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
+            else:
+                f(**kwargs)
+
+        return wrapped
+    return wrapper
+
+
+# This is a wrapper that wraps a test to run it with TF32 turned off.
+# This wrapper is designed to be used when a test uses matmul or convolutions
+# but the purpose of that test is not testing matmul or convolutions.
+# Disabling TF32 will enforce torch.float tensors to be always computed
+# at full precision.
+def with_tf32_off(f):
+    @functools.wraps(f)
+    def wrapped(*args, **kwargs):
+        with tf32_off():
+            return f(*args, **kwargs)
+
+    return wrapped
+
+def _get_magma_version():
+    if 'Magma' not in torch.__config__.show():
+        return (0, 0)
+    position = torch.__config__.show().find('Magma ')
+    version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
+    return tuple(int(x) for x in version_str.split("."))
+
+def _get_torch_cuda_version():
+    if torch.version.cuda is None:
+        return (0, 0)
+    cuda_version = str(torch.version.cuda)
+    return tuple(int(x) for x in cuda_version.split("."))
+
+def _get_torch_rocm_version():
+    if not TEST_WITH_ROCM or torch.version.hip is None:
+        return (0, 0)
+    rocm_version = str(torch.version.hip)
+    rocm_version = rocm_version.split("-")[0]    # ignore git sha
+    return tuple(int(x) for x in rocm_version.split("."))
+
+def _check_cusparse_generic_available():
+    return not TEST_WITH_ROCM
+
+def _check_hipsparse_generic_available():
+    if not TEST_WITH_ROCM:
+        return False
+    if not torch.version.hip:
+        return False
+
+    rocm_version = str(torch.version.hip)
+    rocm_version = rocm_version.split("-")[0]    # ignore git sha
+    rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
+    return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
+
+
+TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
+TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
+
+# Shared by test_torch.py and test_multigpu.py
+def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
+    # Create a module+optimizer that will use scaling, and a control module+optimizer
+    # that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
+    mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
+    mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
+    with torch.no_grad():
+        for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
+            s.copy_(c)
+
+    kwargs = {"lr": 1.0}
+    if optimizer_kwargs is not None:
+        kwargs.update(optimizer_kwargs)
+    opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
+    opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
+
+    return mod_control, mod_scaling, opt_control, opt_scaling
+
+# Shared by test_torch.py, test_cuda.py and test_multigpu.py
+def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
+    data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
+            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
+            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
+            (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
+
+    loss_fn = torch.nn.MSELoss().to(device)
+
+    skip_iter = 2
+
+    return _create_scaling_models_optimizers(
+        device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
+    ) + (data, loss_fn, skip_iter)
+
+
+def xfailIfSM89(func):
+    return func if not IS_SM89 else unittest.expectedFailure(func)
+
+def xfailIfSM100OrLater(func):
+    return func if not SM100OrLater else unittest.expectedFailure(func)
+
+def xfailIfSM120OrLater(func):
+    return func if not SM120OrLater else unittest.expectedFailure(func)
+
+def xfailIfDistributedNotSupported(func):
+    return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func)
+
+# Importing this module should NOT eagerly initialize CUDA
+if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
+    assert not torch.cuda.is_initialized()
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..01499280da8f5dfcd8ad6f9e15c54093dfe17a52
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_device_type.py
@@ -0,0 +1,1981 @@
+# mypy: ignore-errors
+
+import copy
+import gc
+import inspect
+import os
+import runpy
+import sys
+import threading
+import unittest
+from collections import namedtuple
+from collections.abc import Iterable, Sequence
+from enum import Enum
+from functools import partial, wraps
+from typing import Any, Callable, ClassVar, Optional, TypeVar, Union
+from typing_extensions import ParamSpec
+
+import torch
+from torch._inductor.utils import GPU_TYPES
+from torch.testing._internal.common_cuda import (
+    _get_torch_cuda_version,
+    _get_torch_rocm_version,
+    TEST_CUSPARSE_GENERIC,
+    TEST_HIPSPARSE_GENERIC,
+)
+from torch.testing._internal.common_dtype import get_all_dtypes
+from torch.testing._internal.common_utils import (
+    _TestParametrizer,
+    clear_tracked_input,
+    compose_parametrize_fns,
+    dtype_name,
+    get_tracked_input,
+    IS_FBCODE,
+    IS_MACOS,
+    is_privateuse1_backend_available,
+    IS_REMOTE_GPU,
+    IS_SANDCASTLE,
+    IS_WINDOWS,
+    NATIVE_DEVICES,
+    PRINT_REPRO_ON_FAILURE,
+    skipCUDANonDefaultStreamIf,
+    skipIfTorchDynamo,
+    TEST_HPU,
+    TEST_MKL,
+    TEST_MPS,
+    TEST_WITH_ASAN,
+    TEST_WITH_MIOPEN_SUGGEST_NHWC,
+    TEST_WITH_ROCM,
+    TEST_WITH_TORCHINDUCTOR,
+    TEST_WITH_TSAN,
+    TEST_WITH_UBSAN,
+    TEST_XPU,
+    TestCase,
+)
+
+
+_T = TypeVar("_T")
+_P = ParamSpec("_P")
+
+try:
+    import psutil  # type: ignore[import]
+
+    HAS_PSUTIL = True
+except ModuleNotFoundError:
+    HAS_PSUTIL = False
+    psutil = None
+
+# Note [Writing Test Templates]
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# This note was written shortly after the PyTorch 1.9 release.
+# If you notice it's out-of-date or think it could be improved then please
+# file an issue.
+#
+# PyTorch has its own framework for instantiating test templates. That is, for
+#   taking test classes that look similar to unittest or pytest
+#   compatible test classes and optionally doing the following:
+#
+#     - instantiating a version of the test class for each available device type
+#         (often the CPU, CUDA, and META device types)
+#     - further instantiating a version of each test that's always specialized
+#         on the test class's device type, and optionally specialized further
+#         on datatypes or operators
+#
+# This functionality is similar to pytest's parametrize functionality
+#   (see https://docs.pytest.org/en/6.2.x/parametrize.html), but with considerable
+#   additional logic that specializes the instantiated test classes for their
+#   device types (see CPUTestBase and CUDATestBase below), supports a variety
+#   of composable decorators that allow for test filtering and setting
+#   tolerances, and allows tests parametrized by operators to instantiate
+#   only the subset of device type x dtype that operator supports.
+#
+# This framework was built to make it easier to write tests that run on
+#   multiple device types, multiple datatypes (dtypes), and for multiple
+#   operators. It's also useful for controlling which tests are run. For example,
+#   only tests that use a CUDA device can be run on platforms with CUDA.
+#   Let's dive in with an example to get an idea for how it works:
+#
+# --------------------------------------------------------
+# A template class (looks like a regular unittest TestCase)
+# class TestClassFoo(TestCase):
+#
+#   # A template test that can be specialized with a device
+#   # NOTE: this test case is not runnable by unittest or pytest because it
+#   #   accepts an extra positional argument, "device", that they do not understand
+#   def test_bar(self, device):
+#     pass
+#
+# # Function that instantiates a template class and its tests
+# instantiate_device_type_tests(TestCommon, globals())
+# --------------------------------------------------------
+#
+# In the above code example we see a template class and a single test template
+#   that can be instantiated with a device. The function
+#   instantiate_device_type_tests(), called at file scope, instantiates
+#   new test classes, one per available device type, and new tests in those
+#   classes from these templates. It actually does this by removing
+#   the class TestClassFoo and replacing it with classes like TestClassFooCPU
+#   and TestClassFooCUDA, instantiated test classes that inherit from CPUTestBase
+#   and CUDATestBase respectively. Additional device types, like XLA,
+#   (see https://github.com/pytorch/xla) can further extend the set of
+#   instantiated test classes to create classes like TestClassFooXLA.
+#
+# The test template, test_bar(), is also instantiated. In this case the template
+#   is only specialized on a device, so (depending on the available device
+#   types) it might become test_bar_cpu() in TestClassFooCPU and test_bar_cuda()
+#   in TestClassFooCUDA. We can think of the instantiated test classes as
+#   looking like this:
+#
+# --------------------------------------------------------
+# # An instantiated test class for the CPU device type
+# class TestClassFooCPU(CPUTestBase):
+#
+#   # An instantiated test that calls the template with the string representation
+#   #   of a device from the test class's device type
+#   def test_bar_cpu(self):
+#     test_bar(self, 'cpu')
+#
+# # An instantiated test class for the CUDA device type
+# class TestClassFooCUDA(CUDATestBase):
+#
+#   # An instantiated test that calls the template with the string representation
+#   #   of a device from the test class's device type
+#   def test_bar_cuda(self):
+#     test_bar(self, 'cuda:0')
+# --------------------------------------------------------
+#
+# These instantiated test classes ARE discoverable and runnable by both
+#   unittest and pytest. One thing that may be confusing, however, is that
+#   attempting to run "test_bar" will not work, despite it appearing in the
+#   original template code. This is because "test_bar" is no longer discoverable
+#   after instantiate_device_type_tests() runs, as the above snippet shows.
+#   Instead "test_bar_cpu" and "test_bar_cuda" may be run directly, or both
+#   can be run with the option "-k test_bar".
+#
+# Removing the template class and adding the instantiated classes requires
+#   passing "globals()" to instantiate_device_type_tests(), because it
+#   edits the file's Python objects.
+#
+# As mentioned, tests can be additionally parametrized on dtypes or
+#   operators. Datatype parametrization uses the @dtypes decorator and
+#   require a test template like this:
+#
+# --------------------------------------------------------
+# # A template test that can be specialized with a device and a datatype (dtype)
+# @dtypes(torch.float32, torch.int64)
+# def test_car(self, device, dtype)
+#   pass
+# --------------------------------------------------------
+#
+# If the CPU and CUDA device types are available this test would be
+#   instantiated as 4 tests that cover the cross-product of the two dtypes
+#   and two device types:
+#
+#     - test_car_cpu_float32
+#     - test_car_cpu_int64
+#     - test_car_cuda_float32
+#     - test_car_cuda_int64
+#
+# The dtype is passed as a torch.dtype object.
+#
+# Tests parametrized on operators (actually on OpInfos, more on that in a
+#   moment...) use the @ops decorator and require a test template like this:
+# --------------------------------------------------------
+# # A template test that can be specialized with a device, dtype, and OpInfo
+# @ops(op_db)
+# def test_car(self, device, dtype, op)
+#   pass
+# --------------------------------------------------------
+#
+# See the documentation for the @ops decorator below for additional details
+#   on how to use it and see the note [OpInfos] in
+#   common_methods_invocations.py for more details on OpInfos.
+#
+# A test parametrized over the entire "op_db", which contains hundreds of
+#   OpInfos, will likely have hundreds or thousands of instantiations. The
+#   test will be instantiated on the cross-product of device types, operators,
+#   and the dtypes the operator supports on that device type. The instantiated
+#   tests will have names like:
+#
+#     - test_car_add_cpu_float32
+#     - test_car_sub_cuda_int64
+#
+# The first instantiated test calls the original test_car() with the OpInfo
+#   for torch.add as its "op" argument, the string 'cpu' for its "device" argument,
+#   and the dtype torch.float32 for is "dtype" argument. The second instantiated
+#   test calls the test_car() with the OpInfo for torch.sub, a CUDA device string
+#   like 'cuda:0' or 'cuda:1' for its "device" argument, and the dtype
+#   torch.int64 for its "dtype argument."
+#
+# In addition to parametrizing over device, dtype, and ops via OpInfos, the
+#   @parametrize decorator is supported for arbitrary parametrizations:
+# --------------------------------------------------------
+# # A template test that can be specialized with a device, dtype, and value for x
+# @parametrize("x", range(5))
+# def test_car(self, device, dtype, x)
+#   pass
+# --------------------------------------------------------
+#
+# See the documentation for @parametrize in common_utils.py for additional details
+#   on this. Note that the instantiate_device_type_tests() function will handle
+#   such parametrizations; there is no need to additionally call
+#   instantiate_parametrized_tests().
+#
+# Clever test filtering can be very useful when working with parametrized
+#   tests. "-k test_car" would run every instantiated variant of the test_car()
+#   test template, and "-k test_car_add" runs every variant instantiated with
+#   torch.add.
+#
+# It is important to use the passed device and dtype as appropriate. Use
+#   helper functions like make_tensor() that require explicitly specifying
+#   the device and dtype so they're not forgotten.
+#
+# Test templates can use a variety of composable decorators to specify
+#   additional options and requirements, some are listed here:
+#
+#     - @deviceCountAtLeast()
+#         Passes a list of strings representing all available devices of
+#         the test class's device type as the test template's "device" argument.
+#         If there are fewer devices than the value passed to the decorator
+#         the test is skipped.
+#     - @dtypes()
+#         In addition to accepting multiple dtypes, the @dtypes decorator
+#         can accept a sequence of tuple pairs of dtypes. The test template
+#         will be called with each tuple for its "dtype" argument.
+#     - @onlyNativeDeviceTypes
+#         Skips the test if the device is not a native device type (currently CPU, CUDA, Meta)
+#     - @onlyCPU
+#         Skips the test if the device is not a CPU device
+#     - @onlyCUDA
+#         Skips the test if the device is not a CUDA device
+#     - @onlyMPS
+#         Skips the test if the device is not a MPS device
+#     - @skipCPUIfNoLapack
+#         Skips the test if the device is a CPU device and LAPACK is not installed
+#     - @skipCPUIfNoMkl
+#         Skips the test if the device is a CPU device and MKL is not installed
+#     - @skipCUDAIfNoMagma
+#         Skips the test if the device is a CUDA device and MAGMA is not installed
+#     - @skipCUDAIfRocm
+#         Skips the test if the device is a CUDA device and ROCm is being used
+
+
+# Note [Adding a Device Type]
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# To add a device type:
+#
+#   (1) Create a new "TestBase" extending DeviceTypeTestBase.
+#       See CPUTestBase and CUDATestBase below.
+#   (2) Define the "device_type" attribute of the base to be the
+#       appropriate string.
+#   (3) Add logic to this file that appends your base class to
+#       device_type_test_bases when your device type is available.
+#   (4) (Optional) Write setUpClass/tearDownClass class methods that
+#       instantiate dependencies (see MAGMA in CUDATestBase).
+#   (5) (Optional) Override the "instantiate_test" method for total
+#       control over how your class creates tests.
+#
+# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF
+# they are run. This makes it useful for initializing devices and dependencies.
+
+
+def _dtype_test_suffix(dtypes):
+    """Returns the test suffix for a dtype, sequence of dtypes, or None."""
+    if isinstance(dtypes, (list, tuple)):
+        if len(dtypes) == 0:
+            return ""
+        return "_" + "_".join(dtype_name(d) for d in dtypes)
+    elif dtypes:
+        return f"_{dtype_name(dtypes)}"
+    else:
+        return ""
+
+
+def _update_param_kwargs(param_kwargs, name, value):
+    """Adds a kwarg with the specified name and value to the param_kwargs dict."""
+    # Make name plural (e.g. devices / dtypes) if the value is composite.
+    plural_name = f"{name}s"
+
+    # Clear out old entries of the arg if any.
+    if name in param_kwargs:
+        del param_kwargs[name]
+    if plural_name in param_kwargs:
+        del param_kwargs[plural_name]
+
+    if isinstance(value, (list, tuple)):
+        param_kwargs[plural_name] = value
+    elif value is not None:
+        param_kwargs[name] = value
+
+    # Leave param_kwargs as-is when value is None.
+
+
+class DeviceTypeTestBase(TestCase):
+    device_type: str = "generic_device_type"
+
+    # Flag to disable test suite early due to unrecoverable error such as CUDA error.
+    _stop_test_suite = False
+
+    # Precision is a thread-local setting since it may be overridden per test
+    _tls = threading.local()
+    _tls.precision = TestCase._precision
+    _tls.rel_tol = TestCase._rel_tol
+
+    @property
+    def precision(self):
+        return self._tls.precision
+
+    @precision.setter
+    def precision(self, prec):
+        self._tls.precision = prec
+
+    @property
+    def rel_tol(self):
+        return self._tls.rel_tol
+
+    @rel_tol.setter
+    def rel_tol(self, prec):
+        self._tls.rel_tol = prec
+
+    # Returns a string representing the device that single device tests should use.
+    # Note: single device tests use this device exclusively.
+    @classmethod
+    def get_primary_device(cls):
+        return cls.device_type
+
+    @classmethod
+    def _init_and_get_primary_device(cls):
+        try:
+            return cls.get_primary_device()
+        except Exception:
+            # For CUDATestBase, XPUTestBase, XLATestBase, and possibly others, the primary device won't be available
+            # until setUpClass() sets it. Call that manually here if needed.
+            if hasattr(cls, "setUpClass"):
+                cls.setUpClass()
+            return cls.get_primary_device()
+
+    # Returns a list of strings representing all available devices of this
+    # device type. The primary device must be the first string in the list
+    # and the list must contain no duplicates.
+    # Note: UNSTABLE API. Will be replaced once PyTorch has a device generic
+    #   mechanism of acquiring all available devices.
+    @classmethod
+    def get_all_devices(cls):
+        return [cls.get_primary_device()]
+
+    # Returns the dtypes the test has requested.
+    # Prefers device-specific dtype specifications over generic ones.
+    @classmethod
+    def _get_dtypes(cls, test):
+        if not hasattr(test, "dtypes"):
+            return None
+
+        default_dtypes = test.dtypes.get("all")
+        msg = f"@dtypes is mandatory when using @dtypesIf however '{test.__name__}' didn't specify it"
+        assert default_dtypes is not None, msg
+
+        return test.dtypes.get(cls.device_type, default_dtypes)
+
+    def _get_precision_override(self, test, dtype):
+        if not hasattr(test, "precision_overrides"):
+            return self.precision
+        return test.precision_overrides.get(dtype, self.precision)
+
+    def _get_tolerance_override(self, test, dtype):
+        if not hasattr(test, "tolerance_overrides"):
+            return self.precision, self.rel_tol
+        return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol))
+
+    def _apply_precision_override_for_test(self, test, param_kwargs):
+        dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None
+        dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype
+        if dtype:
+            self.precision = self._get_precision_override(test, dtype)
+            self.precision, self.rel_tol = self._get_tolerance_override(test, dtype)
+
+    # Creates device-specific tests.
+    @classmethod
+    def instantiate_test(cls, name, test, *, generic_cls=None):
+        def instantiate_test_helper(
+            cls, name, *, test, param_kwargs=None, decorator_fn=lambda _: []
+        ):
+            # Add the device param kwarg if the test needs device or devices.
+            param_kwargs = {} if param_kwargs is None else param_kwargs
+            test_sig_params = inspect.signature(test).parameters
+            if "device" in test_sig_params or "devices" in test_sig_params:
+                device_arg: str = cls._init_and_get_primary_device()
+                if hasattr(test, "num_required_devices"):
+                    device_arg = cls.get_all_devices()
+                _update_param_kwargs(param_kwargs, "device", device_arg)
+
+            # Apply decorators based on param kwargs.
+            for decorator in decorator_fn(param_kwargs):
+                test = decorator(test)
+
+            # Constructs the test
+            @wraps(test)
+            def instantiated_test(self, param_kwargs=param_kwargs):
+                # Sets precision and runs test
+                # Note: precision is reset after the test is run
+                guard_precision = self.precision
+                guard_rel_tol = self.rel_tol
+                try:
+                    self._apply_precision_override_for_test(test, param_kwargs)
+                    result = test(self, **param_kwargs)
+                except RuntimeError as rte:
+                    # check if rte should stop entire test suite.
+                    self._stop_test_suite = self._should_stop_test_suite()
+                    # Check if test has been decorated with `@expectedFailure`
+                    # Using `__unittest_expecting_failure__` attribute, see
+                    # https://github.com/python/cpython/blob/ffa505b580464/Lib/unittest/case.py#L164
+                    # In that case, make it fail with "unexpected success" by suppressing exception
+                    if (
+                        getattr(test, "__unittest_expecting_failure__", False)
+                        and self._stop_test_suite
+                    ):
+                        import sys
+
+                        print(
+                            "Suppressing fatal exception to trigger unexpected success",
+                            file=sys.stderr,
+                        )
+                        return
+                    # raise the runtime error as is for the test suite to record.
+                    raise rte
+                finally:
+                    self.precision = guard_precision
+                    self.rel_tol = guard_rel_tol
+
+                return result
+
+            assert not hasattr(cls, name), f"Redefinition of test {name}"
+            setattr(cls, name, instantiated_test)
+
+        def default_parametrize_fn(test, generic_cls, device_cls):
+            # By default, no parametrization is needed.
+            yield (test, "", {}, lambda _: [])
+
+        # Parametrization decorators set the parametrize_fn attribute on the test.
+        parametrize_fn = getattr(test, "parametrize_fn", default_parametrize_fn)
+
+        # If one of the @dtypes* decorators is present, also parametrize over the dtypes set by it.
+        dtypes = cls._get_dtypes(test)
+        if dtypes is not None:
+
+            def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes):
+                for dtype in dtypes:
+                    param_kwargs: dict[str, Any] = {}
+                    _update_param_kwargs(param_kwargs, "dtype", dtype)
+
+                    # Note that an empty test suffix is set here so that the dtype can be appended
+                    # later after the device.
+                    yield (test, "", param_kwargs, lambda _: [])
+
+            parametrize_fn = compose_parametrize_fns(
+                dtype_parametrize_fn, parametrize_fn
+            )
+
+        # Instantiate the parametrized tests.
+        for (
+            test,  # noqa: B020
+            test_suffix,
+            param_kwargs,
+            decorator_fn,
+        ) in parametrize_fn(test, generic_cls, cls):
+            test_suffix = "" if test_suffix == "" else "_" + test_suffix
+            cls_device_type = (
+                cls.device_type
+                if cls.device_type != "privateuse1"
+                else torch._C._get_privateuse1_backend_name()
+            )
+            device_suffix = "_" + cls_device_type
+
+            # Note: device and dtype suffix placement
+            # Special handling here to place dtype(s) after device according to test name convention.
+            dtype_kwarg = None
+            if "dtype" in param_kwargs or "dtypes" in param_kwargs:
+                dtype_kwarg = (
+                    param_kwargs["dtypes"]
+                    if "dtypes" in param_kwargs
+                    else param_kwargs["dtype"]
+                )
+            test_name = (
+                f"{name}{test_suffix}{device_suffix}{_dtype_test_suffix(dtype_kwarg)}"
+            )
+
+            instantiate_test_helper(
+                cls=cls,
+                name=test_name,
+                test=test,
+                param_kwargs=param_kwargs,
+                decorator_fn=decorator_fn,
+            )
+
+    def run(self, result=None):
+        super().run(result=result)
+        # Early terminate test if _stop_test_suite is set.
+        if self._stop_test_suite:
+            result.stop()
+
+
+class CPUTestBase(DeviceTypeTestBase):
+    device_type = "cpu"
+
+    # No critical error should stop CPU test suite
+    def _should_stop_test_suite(self):
+        return False
+
+
+class CUDATestBase(DeviceTypeTestBase):
+    device_type = "cuda"
+    _do_cuda_memory_leak_check = True
+    _do_cuda_non_default_stream = True
+    primary_device: ClassVar[str]
+    cudnn_version: ClassVar[Any]
+    no_magma: ClassVar[bool]
+    no_cudnn: ClassVar[bool]
+
+    def has_cudnn(self):
+        return not self.no_cudnn
+
+    @classmethod
+    def get_primary_device(cls):
+        return cls.primary_device
+
+    @classmethod
+    def get_all_devices(cls):
+        primary_device_idx = int(cls.get_primary_device().split(":")[1])
+        num_devices = torch.cuda.device_count()
+
+        prim_device = cls.get_primary_device()
+        cuda_str = "cuda:{0}"
+        non_primary_devices = [
+            cuda_str.format(idx)
+            for idx in range(num_devices)
+            if idx != primary_device_idx
+        ]
+        return [prim_device] + non_primary_devices
+
+    @classmethod
+    def setUpClass(cls):
+        # has_magma shows up after cuda is initialized
+        t = torch.ones(1).cuda()
+        cls.no_magma = not torch.cuda.has_magma
+
+        # Determines if cuDNN is available and its version
+        cls.no_cudnn = not torch.backends.cudnn.is_acceptable(t)
+        cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version()
+
+        # Acquires the current device as the primary (test) device
+        cls.primary_device = f"cuda:{torch.cuda.current_device()}"
+
+
+# See Note [Lazy Tensor tests in device agnostic testing]
+lazy_ts_backend_init = False
+
+
+class LazyTestBase(DeviceTypeTestBase):
+    device_type = "lazy"
+
+    def _should_stop_test_suite(self):
+        return False
+
+    @classmethod
+    def setUpClass(cls):
+        import torch._lazy
+        import torch._lazy.metrics
+        import torch._lazy.ts_backend
+
+        global lazy_ts_backend_init
+        if not lazy_ts_backend_init:
+            # Need to connect the TS backend to lazy key before running tests
+            torch._lazy.ts_backend.init()
+            lazy_ts_backend_init = True
+
+
+class MPSTestBase(DeviceTypeTestBase):
+    device_type = "mps"
+    primary_device: ClassVar[str]
+
+    @classmethod
+    def get_primary_device(cls):
+        return cls.primary_device
+
+    @classmethod
+    def get_all_devices(cls):
+        # currently only one device is supported on MPS backend
+        prim_device = cls.get_primary_device()
+        return [prim_device]
+
+    @classmethod
+    def setUpClass(cls):
+        cls.primary_device = "mps:0"
+
+    def _should_stop_test_suite(self):
+        return False
+
+
+class XPUTestBase(DeviceTypeTestBase):
+    device_type = "xpu"
+    primary_device: ClassVar[str]
+
+    @classmethod
+    def get_primary_device(cls):
+        return cls.primary_device
+
+    @classmethod
+    def get_all_devices(cls):
+        # currently only one device is supported on MPS backend
+        prim_device = cls.get_primary_device()
+        return [prim_device]
+
+    @classmethod
+    def setUpClass(cls):
+        cls.primary_device = f"xpu:{torch.xpu.current_device()}"
+
+    def _should_stop_test_suite(self):
+        return False
+
+
+class HPUTestBase(DeviceTypeTestBase):
+    device_type = "hpu"
+    primary_device: ClassVar[str]
+
+    @classmethod
+    def get_primary_device(cls):
+        return cls.primary_device
+
+    @classmethod
+    def setUpClass(cls):
+        cls.primary_device = "hpu:0"
+
+
+class PrivateUse1TestBase(DeviceTypeTestBase):
+    primary_device: ClassVar[str]
+    device_mod = None
+    device_type = "privateuse1"
+
+    @classmethod
+    def get_primary_device(cls):
+        return cls.primary_device
+
+    @classmethod
+    def get_all_devices(cls):
+        primary_device_idx = int(cls.get_primary_device().split(":")[1])
+        num_devices = cls.device_mod.device_count()
+        prim_device = cls.get_primary_device()
+        device_str = f"{cls.device_type}:{{0}}"
+        non_primary_devices = [
+            device_str.format(idx)
+            for idx in range(num_devices)
+            if idx != primary_device_idx
+        ]
+        return [prim_device] + non_primary_devices
+
+    @classmethod
+    def setUpClass(cls):
+        cls.device_type = torch._C._get_privateuse1_backend_name()
+        cls.device_mod = getattr(torch, cls.device_type, None)
+        assert (
+            cls.device_mod is not None
+        ), f"""torch has no module of `{cls.device_type}`, you should register
+                                            a module by `torch._register_device_module`."""
+        cls.primary_device = f"{cls.device_type}:{cls.device_mod.current_device()}"
+
+
+# Adds available device-type-specific test base classes
+def get_device_type_test_bases():
+    # set type to List[Any] due to mypy list-of-union issue:
+    # https://github.com/python/mypy/issues/3351
+    test_bases: list[Any] = []
+
+    if IS_SANDCASTLE or IS_FBCODE:
+        if IS_REMOTE_GPU:
+            # Skip if sanitizer is enabled
+            if not TEST_WITH_ASAN and not TEST_WITH_TSAN and not TEST_WITH_UBSAN:
+                test_bases.append(CUDATestBase)
+        else:
+            test_bases.append(CPUTestBase)
+    else:
+        test_bases.append(CPUTestBase)
+        if torch.cuda.is_available():
+            test_bases.append(CUDATestBase)
+
+        if is_privateuse1_backend_available():
+            test_bases.append(PrivateUse1TestBase)
+        # Disable MPS testing in generic device testing temporarily while we're
+        # ramping up support.
+        # elif torch.backends.mps.is_available():
+        #   test_bases.append(MPSTestBase)
+
+    return test_bases
+
+
+device_type_test_bases = get_device_type_test_bases()
+
+
+def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None):
+    # device type cannot appear in both except_for and only_for
+    intersect = set(except_for if except_for else []) & set(
+        only_for if only_for else []
+    )
+    assert (
+        not intersect
+    ), f"device ({intersect}) appeared in both except_for and only_for"
+
+    # Replace your privateuse1 backend name with 'privateuse1'
+    if is_privateuse1_backend_available():
+        privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
+
+        def func_replace(x: str):
+            return x.replace(privateuse1_backend_name, "privateuse1")
+
+        except_for = (
+            ([func_replace(x) for x in except_for] if except_for is not None else None)
+            if not isinstance(except_for, str)
+            else func_replace(except_for)
+        )
+        only_for = (
+            ([func_replace(x) for x in only_for] if only_for is not None else None)
+            if not isinstance(only_for, str)
+            else func_replace(only_for)
+        )
+
+    if except_for:
+        device_type_test_bases = filter(
+            lambda x: x.device_type not in except_for, device_type_test_bases
+        )
+    if only_for:
+        device_type_test_bases = filter(
+            lambda x: x.device_type in only_for, device_type_test_bases
+        )
+
+    return list(device_type_test_bases)
+
+
+# Note [How to extend DeviceTypeTestBase to add new test device]
+# The following logic optionally allows downstream projects like pytorch/xla to
+# add more test devices.
+# Instructions:
+#  - Add a python file (e.g. pytorch/xla/test/pytorch_test_base.py) in downstream project.
+#    - Inside the file, one should inherit from `DeviceTypeTestBase` class and define
+#      a new DeviceTypeTest class (e.g. `XLATestBase`) with proper implementation of
+#      `instantiate_test` method.
+#    - DO NOT import common_device_type inside the file.
+#      `runpy.run_path` with `globals()` already properly setup the context so that
+#      `DeviceTypeTestBase` is already available.
+#    - Set a top-level variable `TEST_CLASS` equal to your new class.
+#      E.g. TEST_CLASS = XLATensorBase
+#  - To run tests with new device type, set `TORCH_TEST_DEVICE` env variable to path
+#    to this file. Multiple paths can be separated by `:`.
+# See pytorch/xla/test/pytorch_test_base.py for a more detailed example.
+_TORCH_TEST_DEVICES = os.environ.get("TORCH_TEST_DEVICES", None)
+if _TORCH_TEST_DEVICES:
+    for path in _TORCH_TEST_DEVICES.split(":"):
+        # runpy (a stdlib module) lacks annotations
+        mod = runpy.run_path(path, init_globals=globals())  # type: ignore[func-returns-value]
+        device_type_test_bases.append(mod["TEST_CLASS"])
+
+
+PYTORCH_CUDA_MEMCHECK = os.getenv("PYTORCH_CUDA_MEMCHECK", "0") == "1"
+
+PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY = "PYTORCH_TESTING_DEVICE_ONLY_FOR"
+PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = "PYTORCH_TESTING_DEVICE_EXCEPT_FOR"
+PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY = "PYTORCH_TESTING_DEVICE_FOR_CUSTOM"
+
+
+def get_desired_device_type_test_bases(
+    except_for=None, only_for=None, include_lazy=False, allow_mps=False, allow_xpu=False
+):
+    # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy`
+    test_bases = device_type_test_bases.copy()
+    if allow_mps and TEST_MPS and MPSTestBase not in test_bases:
+        test_bases.append(MPSTestBase)
+    if allow_xpu and TEST_XPU and XPUTestBase not in test_bases:
+        test_bases.append(XPUTestBase)
+    if TEST_HPU and HPUTestBase not in test_bases:
+        test_bases.append(HPUTestBase)
+    # Filter out the device types based on user inputs
+    desired_device_type_test_bases = filter_desired_device_types(
+        test_bases, except_for, only_for
+    )
+    if include_lazy:
+        # Note [Lazy Tensor tests in device agnostic testing]
+        # Right now, test_view_ops.py runs with LazyTensor.
+        # We don't want to opt every device-agnostic test into using the lazy device,
+        # because many of them will fail.
+        # So instead, the only way to opt a specific device-agnostic test file into
+        # lazy tensor testing is with include_lazy=True
+        if IS_FBCODE:
+            print(
+                "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds",
+                file=sys.stderr,
+            )
+        else:
+            desired_device_type_test_bases.append(LazyTestBase)
+
+    def split_if_not_empty(x: str):
+        return x.split(",") if x else []
+
+    # run some cuda testcases on other devices if available
+    # Usage:
+    # export PYTORCH_TESTING_DEVICE_FOR_CUSTOM=privateuse1
+    env_custom_only_for = split_if_not_empty(
+        os.getenv(PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY, "")
+    )
+    if env_custom_only_for:
+        desired_device_type_test_bases += filter(
+            lambda x: x.device_type in env_custom_only_for, test_bases
+        )
+        desired_device_type_test_bases = list(set(desired_device_type_test_bases))
+
+    # Filter out the device types based on environment variables if available
+    # Usage:
+    # export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu
+    # export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla
+    env_only_for = split_if_not_empty(
+        os.getenv(PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, "")
+    )
+    env_except_for = split_if_not_empty(
+        os.getenv(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, "")
+    )
+
+    return filter_desired_device_types(
+        desired_device_type_test_bases, env_except_for, env_only_for
+    )
+
+
+# Adds 'instantiated' device-specific test cases to the given scope.
+# The tests in these test cases are derived from the generic tests in
+# generic_test_class. This function should be used instead of
+# instantiate_parametrized_tests() if the test class contains
+# device-specific tests (NB: this supports additional @parametrize usage).
+#
+# See note "Writing Test Templates"
+# TODO: remove "allow_xpu" option after Interl GPU support all test case instantiate by this function.
+def instantiate_device_type_tests(
+    generic_test_class,
+    scope,
+    except_for=None,
+    only_for=None,
+    include_lazy=False,
+    allow_mps=False,
+    allow_xpu=False,
+):
+    # Removes the generic test class from its enclosing scope so its tests
+    # are not discoverable.
+    del scope[generic_test_class.__name__]
+
+    generic_members = set(generic_test_class.__dict__.keys())
+    generic_tests = [x for x in generic_members if x.startswith("test")]
+
+    # Creates device-specific test cases
+    for base in get_desired_device_type_test_bases(
+        except_for, only_for, include_lazy, allow_mps, allow_xpu
+    ):
+        class_name = generic_test_class.__name__ + base.device_type.upper()
+
+        # type set to Any and suppressed due to unsupported runtime class:
+        # https://github.com/python/mypy/wiki/Unsupported-Python-Features
+        device_type_test_class: Any = type(class_name, (base, generic_test_class), {})
+
+        # Arrange for setUpClass and tearDownClass methods defined both in the test template
+        # class and in the generic base to be called. This allows device-parameterized test
+        # classes to support setup and teardown.
+        # NB: This should be done before instantiate_test() is called as that invokes setup.
+        @classmethod
+        def _setUpClass(cls):
+            # This should always be called, whether or not the test class invokes
+            # super().setUpClass(), to set the primary device.
+            base.setUpClass()
+            # We want to call the @classmethod defined in the generic base, but pass
+            # it the device-specific class object (cls), hence the __func__ call.
+            generic_test_class.setUpClass.__func__(cls)
+
+        @classmethod
+        def _tearDownClass(cls):
+            # We want to call the @classmethod defined in the generic base, but pass
+            # it the device-specific class object (cls), hence the __func__ call.
+            generic_test_class.tearDownClass.__func__(cls)
+            base.tearDownClass()
+
+        device_type_test_class.setUpClass = _setUpClass
+        device_type_test_class.tearDownClass = _tearDownClass
+
+        for name in generic_members:
+            if name in generic_tests:  # Instantiates test member
+                test = getattr(generic_test_class, name)
+                # XLA-compat shim (XLA's instantiate_test takes doesn't take generic_cls)
+                sig = inspect.signature(device_type_test_class.instantiate_test)
+                if len(sig.parameters) == 3:
+                    # Instantiates the device-specific tests
+                    device_type_test_class.instantiate_test(
+                        name, copy.deepcopy(test), generic_cls=generic_test_class
+                    )
+                else:
+                    device_type_test_class.instantiate_test(name, copy.deepcopy(test))
+            # Ports non-test member. Setup / teardown have already been handled above
+            elif name not in device_type_test_class.__dict__:
+                nontest = getattr(generic_test_class, name)
+                setattr(device_type_test_class, name, nontest)
+
+        # Mimics defining the instantiated class in the caller's file
+        # by setting its module to the given class's and adding
+        # the module to the given scope.
+        # This lets the instantiated class be discovered by unittest.
+        device_type_test_class.__module__ = generic_test_class.__module__
+        scope[class_name] = device_type_test_class
+
+    # Delete the generic form of the test functions (e.g. TestFoo.test_bar()) so they're
+    # not discoverable. This mutates the original class (TestFoo), which was removed from
+    # scope above. At this point, device-specific tests (e.g. TestFooCUDA.test_bar_cuda)
+    # have already been created and the generic forms are no longer needed.
+    for name in generic_tests:
+        delattr(generic_test_class, name)
+
+
+# Category of dtypes to run an OpInfo-based test for
+# Example use: @ops(dtype=OpDTypes.supported)
+#
+# There are 7 categories:
+# - supported: Every dtype supported by the operator. Use for exhaustive
+#              testing of all dtypes.
+# - unsupported: Run tests on dtypes not supported by the operator. e.g. for
+#                testing the operator raises an error and doesn't crash.
+# - supported_backward: Every dtype supported by the operator's backward pass.
+# - unsupported_backward: Run tests on dtypes not supported by the operator's backward pass.
+# - any_one: Runs a test for one dtype the operator supports. Prioritizes dtypes the
+#     operator supports in both forward and backward.
+# - none: Useful for tests that are not dtype-specific. No dtype will be passed to the test
+#         when this is selected.
+# - any_common_cpu_cuda_one: Pick a dtype that supports both CPU and CUDA.
+class OpDTypes(Enum):
+    supported = 0  # Test all supported dtypes (default)
+    unsupported = 1  # Test only unsupported dtypes
+    supported_backward = 2  # Test all supported backward dtypes
+    unsupported_backward = 3  # Test only unsupported backward dtypes
+    any_one = 4  # Test precisely one supported dtype
+    none = 5  # Instantiate no dtype variants (no dtype kwarg needed)
+    any_common_cpu_cuda_one = (
+        6  # Test precisely one supported dtype that is common to both cuda and cpu
+    )
+
+
+# Arbitrary order
+ANY_DTYPE_ORDER = (
+    torch.float32,
+    torch.float64,
+    torch.complex64,
+    torch.complex128,
+    torch.float16,
+    torch.bfloat16,
+    torch.long,
+    torch.int32,
+    torch.int16,
+    torch.int8,
+    torch.uint8,
+    torch.bool,
+    torch.float8_e4m3fn,
+    torch.float8_e5m2,
+)
+
+
+def _serialize_sample(sample_input):
+    # NB: For OpInfos, SampleInput.summary() prints in a cleaner way.
+    if getattr(sample_input, "summary", None) is not None:
+        return sample_input.summary()
+    return str(sample_input)
+
+
+# Decorator that defines the OpInfos a test template should be instantiated for.
+#
+# Example usage:
+#
+# @ops(unary_ufuncs)
+# def test_numerics(self, device, dtype, op):
+#   
+#
+# This will instantiate variants of test_numerics for each given OpInfo,
+# on each device the OpInfo's operator supports, and for every dtype supported by
+# that operator. There are a few caveats to the dtype rule, explained below.
+#
+# The @ops decorator can accept two
+# additional arguments, "dtypes" and "allowed_dtypes". If "dtypes" is specified
+# then the test variants are instantiated for those dtypes, regardless of
+# what the operator supports. If given "allowed_dtypes" then test variants
+# are instantiated only for the intersection of allowed_dtypes and the dtypes
+# they would otherwise be instantiated with. That is, allowed_dtypes composes
+# with the options listed above and below.
+#
+# The "dtypes" argument can also accept additional values (see OpDTypes above):
+#   OpDTypes.supported - the test is instantiated for all dtypes the operator
+#     supports
+#   OpDTypes.unsupported - the test is instantiated for all dtypes the operator
+#     doesn't support
+#   OpDTypes.supported_backward - the test is instantiated for all dtypes the
+#     operator's gradient formula supports
+#   OpDTypes.unsupported_backward - the test is instantiated for all dtypes the
+#     operator's gradient formula doesn't support
+#   OpDTypes.any_one - the test is instantiated for one dtype the
+#     operator supports. The dtype supports forward and backward if possible.
+#   OpDTypes.none - the test is instantiated without any dtype. The test signature
+#     should not include a dtype kwarg in this case.
+#   OpDTypes.any_common_cpu_cuda_one - the test is instantiated for a dtype
+#     that supports both CPU and CUDA.
+#
+# These options allow tests to have considerable control over the dtypes
+#   they're instantiated for.
+
+
+class ops(_TestParametrizer):
+    def __init__(
+        self,
+        op_list,
+        *,
+        dtypes: Union[OpDTypes, Sequence[torch.dtype]] = OpDTypes.supported,
+        allowed_dtypes: Optional[Sequence[torch.dtype]] = None,
+        skip_if_dynamo=True,
+    ):
+        self.op_list = list(op_list)
+        self.opinfo_dtypes = dtypes
+        self.allowed_dtypes = (
+            set(allowed_dtypes) if allowed_dtypes is not None else None
+        )
+        self.skip_if_dynamo = skip_if_dynamo
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        """Parameterizes the given test function across each op and its associated dtypes."""
+        if device_cls is None:
+            raise RuntimeError(
+                "The @ops decorator is only intended to be used in a device-specific "
+                "context; use it with instantiate_device_type_tests() instead of "
+                "instantiate_parametrized_tests()"
+            )
+
+        op = check_exhausted_iterator = object()
+        for op in self.op_list:
+            # Determine the set of dtypes to use.
+            dtypes: Union[set[torch.dtype], set[None]]
+            if isinstance(self.opinfo_dtypes, Sequence):
+                dtypes = set(self.opinfo_dtypes)
+            elif self.opinfo_dtypes == OpDTypes.unsupported_backward:
+                dtypes = set(get_all_dtypes()).difference(
+                    op.supported_backward_dtypes(device_cls.device_type)
+                )
+            elif self.opinfo_dtypes == OpDTypes.supported_backward:
+                dtypes = op.supported_backward_dtypes(device_cls.device_type)
+            elif self.opinfo_dtypes == OpDTypes.unsupported:
+                dtypes = set(get_all_dtypes()).difference(
+                    op.supported_dtypes(device_cls.device_type)
+                )
+            elif self.opinfo_dtypes == OpDTypes.supported:
+                dtypes = set(op.supported_dtypes(device_cls.device_type))
+            elif self.opinfo_dtypes == OpDTypes.any_one:
+                # Tries to pick a dtype that supports both forward or backward
+                supported = op.supported_dtypes(device_cls.device_type)
+                supported_backward = op.supported_backward_dtypes(
+                    device_cls.device_type
+                )
+                supported_both = supported.intersection(supported_backward)
+                dtype_set = supported_both if len(supported_both) > 0 else supported
+                for dtype in ANY_DTYPE_ORDER:
+                    if dtype in dtype_set:
+                        dtypes = {dtype}
+                        break
+                else:
+                    dtypes = {}
+            elif self.opinfo_dtypes == OpDTypes.any_common_cpu_cuda_one:
+                # Tries to pick a dtype that supports both CPU and CUDA
+                supported = set(op.dtypes).intersection(op.dtypesIfCUDA)
+                if supported:
+                    dtypes = {
+                        next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported)
+                    }
+                else:
+                    dtypes = {}
+
+            elif self.opinfo_dtypes == OpDTypes.none:
+                dtypes = {None}
+            else:
+                raise RuntimeError(f"Unknown OpDType: {self.opinfo_dtypes}")
+
+            if self.allowed_dtypes is not None:
+                dtypes = dtypes.intersection(self.allowed_dtypes)
+
+            # Construct the test name; device / dtype parts are handled outside.
+            # See [Note: device and dtype suffix placement]
+            test_name = op.formatted_name
+
+            # Filter sample skips / xfails to only those that apply to the OpInfo.
+            # These are defined on the test function via decorators.
+            sample_skips_and_xfails = getattr(test, "sample_skips_and_xfails", None)
+            if sample_skips_and_xfails is not None:
+                sample_skips_and_xfails = [
+                    rule
+                    for rule in sample_skips_and_xfails
+                    if rule.op_match_fn(device_cls.device_type, op)
+                ]
+
+            for dtype in dtypes:
+                # Construct parameter kwargs to pass to the test.
+                param_kwargs = {"op": op}
+                _update_param_kwargs(param_kwargs, "dtype", dtype)
+
+                # NOTE: test_wrapper exists because we don't want to apply
+                #   op-specific decorators to the original test.
+                #   Test-specific decorators are applied to the original test,
+                #   however.
+                try:
+
+                    @wraps(test)
+                    def test_wrapper(*args, **kwargs):
+                        try:
+                            return test(*args, **kwargs)
+                        except unittest.SkipTest as e:
+                            raise e
+                        except Exception as e:
+                            tracked_input = get_tracked_input()
+                            if PRINT_REPRO_ON_FAILURE and tracked_input is not None:
+                                e_tracked = Exception(  # noqa: TRY002
+                                    f"Caused by {tracked_input.type_desc} "
+                                    f"at index {tracked_input.index}: "
+                                    f"{_serialize_sample(tracked_input.val)}"
+                                )
+                                e_tracked._tracked_input = tracked_input  # type: ignore[attr]
+                                raise e_tracked from e
+                            raise e
+                        finally:
+                            clear_tracked_input()
+
+                    if self.skip_if_dynamo and not TEST_WITH_TORCHINDUCTOR:
+                        test_wrapper = skipIfTorchDynamo(
+                            "Policy: we don't run OpInfo tests w/ Dynamo"
+                        )(test_wrapper)
+
+                    # Initialize info for the last input seen. This is useful for tracking
+                    # down which inputs caused a test failure. Note that TrackedInputIter is
+                    # responsible for managing this.
+                    test.tracked_input = None
+
+                    decorator_fn = partial(
+                        op.get_decorators,
+                        generic_cls.__name__,
+                        test.__name__,
+                        device_cls.device_type,
+                        dtype,
+                    )
+
+                    if sample_skips_and_xfails is not None:
+                        test_wrapper.sample_skips_and_xfails = sample_skips_and_xfails
+
+                    yield (test_wrapper, test_name, param_kwargs, decorator_fn)
+                except Exception as ex:
+                    # Provides an error message for debugging before rethrowing the exception
+                    print(f"Failed to instantiate {test_name} for op {op.name}!")
+                    raise ex
+        if op is check_exhausted_iterator:
+            raise ValueError(
+                "An empty op_list was passed to @ops. "
+                "Note that this may result from reuse of a generator."
+            )
+
+
+# Decorator that skips a test if the given condition is true.
+# Notes:
+#   (1) Skip conditions stack.
+#   (2) Skip conditions can be bools or strings. If a string the
+#       test base must have defined the corresponding attribute to be False
+#       for the test to run. If you want to use a string argument you should
+#       probably define a new decorator instead (see below).
+#   (3) Prefer the existing decorators to defining the 'device_type' kwarg.
+class skipIf:
+    def __init__(self, dep, reason, device_type=None):
+        self.dep = dep
+        self.reason = reason
+        self.device_type = device_type
+
+    def __call__(self, fn):
+        @wraps(fn)
+        def dep_fn(slf, *args, **kwargs):
+            if (
+                self.device_type is None
+                or self.device_type == slf.device_type
+                or (
+                    isinstance(self.device_type, Iterable)
+                    and slf.device_type in self.device_type
+                )
+            ):
+                if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or (
+                    isinstance(self.dep, bool) and self.dep
+                ):
+                    raise unittest.SkipTest(self.reason)
+
+            return fn(slf, *args, **kwargs)
+
+        return dep_fn
+
+
+# Skips a test on CPU if the condition is true.
+class skipCPUIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type="cpu")
+
+
+# Skips a test on CUDA if the condition is true.
+class skipCUDAIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type="cuda")
+
+
+# Skips a test on XPU if the condition is true.
+class skipXPUIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type="xpu")
+
+
+# Skips a test on XPU or CUDA if the condition is true.
+class skipGPUIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type=GPU_TYPES)
+
+
+# Skips a test on Lazy if the condition is true.
+class skipLazyIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type="lazy")
+
+
+# Skips a test on Meta if the condition is true.
+class skipMetaIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type="meta")
+
+
+# Skips a test on MPS if the condition is true.
+class skipMPSIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type="mps")
+
+
+class skipHPUIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type="hpu")
+
+
+# Skips a test on XLA if the condition is true.
+class skipXLAIf(skipIf):
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type="xla")
+
+
+class skipPRIVATEUSE1If(skipIf):
+    def __init__(self, dep, reason):
+        device_type = torch._C._get_privateuse1_backend_name()
+        super().__init__(dep, reason, device_type=device_type)
+
+
+def _has_sufficient_memory(device, size):
+    if torch.device(device).type == "cuda":
+        if not torch.cuda.is_available():
+            return False
+        gc.collect()
+        torch.cuda.empty_cache()
+        # torch.cuda.mem_get_info, aka cudaMemGetInfo, returns a tuple of (free memory, total memory) of a GPU
+        if device == "cuda":
+            device = "cuda:0"
+        return (
+            torch.cuda.memory.mem_get_info(device)[0]
+            * torch.cuda.memory.get_per_process_memory_fraction(device)
+        ) >= size
+
+    if device == "xla":
+        raise unittest.SkipTest("TODO: Memory availability checks for XLA?")
+
+    if device == "xpu":
+        raise unittest.SkipTest("TODO: Memory availability checks for Intel GPU?")
+
+    if device != "cpu":
+        raise unittest.SkipTest("Unknown device type")
+
+    # CPU
+    if not HAS_PSUTIL:
+        raise unittest.SkipTest("Need psutil to determine if memory is sufficient")
+
+    # The sanitizers have significant memory overheads
+    if TEST_WITH_ASAN or TEST_WITH_TSAN or TEST_WITH_UBSAN:
+        effective_size = size * 10
+    else:
+        effective_size = size
+
+    if psutil.virtual_memory().available < effective_size:
+        gc.collect()
+    return psutil.virtual_memory().available >= effective_size
+
+
+def largeTensorTest(size, device=None, inductor=TEST_WITH_TORCHINDUCTOR):
+    """Skip test if the device has insufficient memory to run the test
+
+    size may be a number of bytes, a string of the form "N GB", or a callable
+
+    If the test is a device generic test, available memory on the primary device will be checked.
+    It can also be overridden by the optional `device=` argument.
+    In other tests, the `device=` argument needs to be specified.
+    """
+    if isinstance(size, str):
+        assert size.endswith(("GB", "gb")), "only bytes or GB supported"
+        size = 1024**3 * int(size[:-2])
+
+    def inner(fn):
+        @wraps(fn)
+        def dep_fn(self, *args, **kwargs):
+            size_bytes: int = size(self, *args, **kwargs) if callable(size) else size
+            _device = device
+            if _device is None:
+                if hasattr(self, "get_primary_device"):
+                    _device = self.get_primary_device()
+                else:
+                    _device = self.device
+
+            # If this is running with GPU cpp_wrapper, the autotuning step will generate
+            # an additional array of the same size as the input.
+            if inductor and torch._inductor.config.cpp_wrapper and _device != "cpu":
+                size_bytes *= 2
+
+            if not _has_sufficient_memory(_device, size_bytes):
+                raise unittest.SkipTest(f"Insufficient {_device} memory")
+
+            return fn(self, *args, **kwargs)
+
+        return dep_fn
+
+    return inner
+
+
+class expectedFailure:
+    def __init__(self, device_type):
+        self.device_type = device_type
+
+    def __call__(self, fn):
+        @wraps(fn)
+        def efail_fn(slf, *args, **kwargs):
+            if (
+                not hasattr(slf, "device_type")
+                and hasattr(slf, "device")
+                and isinstance(slf.device, str)
+            ):
+                target_device_type = slf.device
+            else:
+                target_device_type = slf.device_type
+
+            if self.device_type is None or self.device_type == target_device_type:
+                try:
+                    fn(slf, *args, **kwargs)
+                except Exception:
+                    return
+                else:
+                    slf.fail("expected test to fail, but it passed")
+
+            return fn(slf, *args, **kwargs)
+
+        return efail_fn
+
+
+class onlyOn:
+    def __init__(self, device_type):
+        self.device_type = device_type
+
+    def __call__(self, fn):
+        @wraps(fn)
+        def only_fn(slf, *args, **kwargs):
+            if self.device_type != slf.device_type:
+                reason = f"Only runs on {self.device_type}"
+                raise unittest.SkipTest(reason)
+
+            return fn(slf, *args, **kwargs)
+
+        return only_fn
+
+
+# Decorator that provides all available devices of the device type to the test
+# as a list of strings instead of providing a single device string.
+# Skips the test if the number of available devices of the variant's device
+# type is less than the 'num_required_devices' arg.
+class deviceCountAtLeast:
+    def __init__(self, num_required_devices):
+        self.num_required_devices = num_required_devices
+
+    def __call__(self, fn):
+        assert not hasattr(
+            fn, "num_required_devices"
+        ), f"deviceCountAtLeast redefinition for {fn.__name__}"
+        fn.num_required_devices = self.num_required_devices
+
+        @wraps(fn)
+        def multi_fn(slf, devices, *args, **kwargs):
+            if len(devices) < self.num_required_devices:
+                reason = f"fewer than {self.num_required_devices} devices detected"
+                raise unittest.SkipTest(reason)
+
+            return fn(slf, devices, *args, **kwargs)
+
+        return multi_fn
+
+
+# Only runs the test on the native device type (currently CPU, CUDA, Meta and PRIVATEUSE1)
+def onlyNativeDeviceTypes(fn: Callable[_P, _T]) -> Callable[_P, _T]:
+    @wraps(fn)
+    def only_fn(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
+        if self.device_type not in NATIVE_DEVICES:
+            reason = f"onlyNativeDeviceTypes: doesn't run on {self.device_type}"
+            raise unittest.SkipTest(reason)
+
+        return fn(self, *args, **kwargs)
+
+    return only_fn
+
+
+# Only runs the test on the native device types and devices specified in the devices list
+def onlyNativeDeviceTypesAnd(devices=None):
+    def decorator(fn):
+        @wraps(fn)
+        def only_fn(self, *args, **kwargs):
+            if (
+                self.device_type not in NATIVE_DEVICES
+                and self.device_type not in devices
+            ):
+                reason = f"onlyNativeDeviceTypesAnd {devices} : doesn't run on {self.device_type}"
+                raise unittest.SkipTest(reason)
+
+            return fn(self, *args, **kwargs)
+
+        return only_fn
+
+    return decorator
+
+
+# Specifies per-dtype precision overrides.
+# Ex.
+#
+# @precisionOverride({torch.half : 1e-2, torch.float : 1e-4})
+# @dtypes(torch.half, torch.float, torch.double)
+# def test_X(self, device, dtype):
+#   ...
+#
+# When the test is instantiated its class's precision will be set to the
+# corresponding override, if it exists.
+# self.precision can be accessed directly, and it also controls the behavior of
+# functions like self.assertEqual().
+#
+# Note that self.precision is a scalar value, so if you require multiple
+# precisions (or are working with multiple dtypes) they should be specified
+# explicitly and computed using self.precision (e.g.
+# self.precision *2, max(1, self.precision)).
+class precisionOverride:
+    def __init__(self, d):
+        assert isinstance(
+            d, dict
+        ), "precisionOverride not given a dtype : precision dict!"
+        for dtype in d.keys():
+            assert isinstance(
+                dtype, torch.dtype
+            ), f"precisionOverride given unknown dtype {dtype}"
+
+        self.d = d
+
+    def __call__(self, fn):
+        fn.precision_overrides = self.d
+        return fn
+
+
+# Specifies per-dtype tolerance overrides tol(atol, rtol). It has priority over
+# precisionOverride.
+# Ex.
+#
+# @toleranceOverride({torch.float : tol(atol=1e-2, rtol=1e-3},
+#                     torch.double : tol{atol=1e-4, rtol = 0})
+# @dtypes(torch.half, torch.float, torch.double)
+# def test_X(self, device, dtype):
+#   ...
+#
+# When the test is instantiated its class's tolerance will be set to the
+# corresponding override, if it exists.
+# self.rtol and self.precision can be accessed directly, and they also control
+# the behavior of functions like self.assertEqual().
+#
+# The above example sets atol = 1e-2 and rtol = 1e-3 for torch.float and
+# atol = 1e-4 and rtol = 0 for torch.double.
+tol = namedtuple("tol", ["atol", "rtol"])
+
+
+class toleranceOverride:
+    def __init__(self, d):
+        assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!"
+        for dtype, prec in d.items():
+            assert isinstance(
+                dtype, torch.dtype
+            ), f"toleranceOverride given unknown dtype {dtype}"
+            assert isinstance(
+                prec, tol
+            ), "toleranceOverride not given a dtype : tol dict!"
+
+        self.d = d
+
+    def __call__(self, fn):
+        fn.tolerance_overrides = self.d
+        return fn
+
+
+# Decorator that instantiates a variant of the test for each given dtype.
+# Notes:
+#   (1) Tests that accept the dtype argument MUST use this decorator.
+#   (2) Can be overridden for CPU or CUDA, respectively, using dtypesIfCPU
+#       or dtypesIfCUDA.
+#   (3) Can accept an iterable of dtypes or an iterable of tuples
+#       of dtypes.
+# Examples:
+# @dtypes(torch.float32, torch.float64)
+# @dtypes((torch.long, torch.float32), (torch.int, torch.float64))
+class dtypes:
+    def __init__(self, *args, device_type="all"):
+        if len(args) > 0 and isinstance(args[0], (list, tuple)):
+            for arg in args:
+                assert isinstance(arg, (list, tuple)), (
+                    "When one dtype variant is a tuple or list, "
+                    "all dtype variants must be. "
+                    f"Received non-list non-tuple dtype {str(arg)}"
+                )
+                assert all(
+                    isinstance(dtype, torch.dtype) for dtype in arg
+                ), f"Unknown dtype in {str(arg)}"
+        else:
+            assert all(
+                isinstance(arg, torch.dtype) for arg in args
+            ), f"Unknown dtype in {str(args)}"
+
+        self.args = args
+        self.device_type = device_type
+
+    def __call__(self, fn):
+        d = getattr(fn, "dtypes", {})
+        assert self.device_type not in d, f"dtypes redefinition for {self.device_type}"
+        d[self.device_type] = self.args
+        fn.dtypes = d
+        return fn
+
+
+# Overrides specified dtypes on the CPU.
+class dtypesIfCPU(dtypes):
+    def __init__(self, *args):
+        super().__init__(*args, device_type="cpu")
+
+
+# Overrides specified dtypes on CUDA.
+class dtypesIfCUDA(dtypes):
+    def __init__(self, *args):
+        super().__init__(*args, device_type="cuda")
+
+
+class dtypesIfMPS(dtypes):
+    def __init__(self, *args):
+        super().__init__(*args, device_type="mps")
+
+
+class dtypesIfHPU(dtypes):
+    def __init__(self, *args):
+        super().__init__(*args, device_type="hpu")
+
+
+class dtypesIfPRIVATEUSE1(dtypes):
+    def __init__(self, *args):
+        super().__init__(*args, device_type=torch._C._get_privateuse1_backend_name())
+
+
+def onlyCPU(fn):
+    return onlyOn("cpu")(fn)
+
+
+def onlyCUDA(fn):
+    return onlyOn("cuda")(fn)
+
+
+def onlyMPS(fn):
+    return onlyOn("mps")(fn)
+
+
+def onlyXPU(fn):
+    return onlyOn("xpu")(fn)
+
+
+def onlyHPU(fn):
+    return onlyOn("hpu")(fn)
+
+
+def onlyPRIVATEUSE1(fn):
+    device_type = torch._C._get_privateuse1_backend_name()
+    device_mod = getattr(torch, device_type, None)
+    if device_mod is None:
+        reason = f"Skip as torch has no module of {device_type}"
+        return unittest.skip(reason)(fn)
+    return onlyOn(device_type)(fn)
+
+
+def onlyCUDAAndPRIVATEUSE1(fn):
+    @wraps(fn)
+    def only_fn(self, *args, **kwargs):
+        if self.device_type not in ("cuda", torch._C._get_privateuse1_backend_name()):
+            reason = f"onlyCUDAAndPRIVATEUSE1: doesn't run on {self.device_type}"
+            raise unittest.SkipTest(reason)
+
+        return fn(self, *args, **kwargs)
+
+    return only_fn
+
+
+def disablecuDNN(fn):
+    @wraps(fn)
+    def disable_cudnn(self, *args, **kwargs):
+        if self.device_type == "cuda" and self.has_cudnn():
+            with torch.backends.cudnn.flags(enabled=False):
+                return fn(self, *args, **kwargs)
+        return fn(self, *args, **kwargs)
+
+    return disable_cudnn
+
+
+def disableMkldnn(fn):
+    @wraps(fn)
+    def disable_mkldnn(self, *args, **kwargs):
+        if torch.backends.mkldnn.is_available():
+            with torch.backends.mkldnn.flags(enabled=False):
+                return fn(self, *args, **kwargs)
+        return fn(self, *args, **kwargs)
+
+    return disable_mkldnn
+
+
+def expectedFailureCPU(fn):
+    return expectedFailure("cpu")(fn)
+
+
+def expectedFailureCUDA(fn):
+    return expectedFailure("cuda")(fn)
+
+
+def expectedFailureXPU(fn):
+    return expectedFailure("xpu")(fn)
+
+
+def expectedFailureMeta(fn):
+    return skipIfTorchDynamo()(expectedFailure("meta")(fn))
+
+
+def expectedFailureXLA(fn):
+    return expectedFailure("xla")(fn)
+
+
+def expectedFailureHPU(fn):
+    return expectedFailure("hpu")(fn)
+
+
+def expectedFailureMPS(fn):
+    return expectedFailure("mps")(fn)
+
+
+def expectedFailureMPSPre15(fn):
+    import platform
+
+    version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1)
+    if not version or version < 1.0:  # cpu or other unsupported device
+        return fn
+    if version < 15.0:
+        return expectedFailure("mps")(fn)
+    return fn
+
+
+def expectedFailureMPSPre14(fn):
+    import platform
+
+    version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1)
+    if not version or version < 1.0:  # cpu or other unsupported device
+        return fn
+    if version < 14.0:
+        return expectedFailure("mps")(fn)
+    return fn
+
+
+# Skips a test on CPU if LAPACK is not available.
+def skipCPUIfNoLapack(fn):
+    return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
+
+
+# Skips a test on CPU if FFT is not available.
+def skipCPUIfNoFFT(fn):
+    return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")(
+        fn
+    )
+
+
+# Skips a test on CPU if MKL is not available.
+def skipCPUIfNoMkl(fn):
+    return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn)
+
+
+# Skips a test on CPU if MKL Sparse is not available (it's not linked on Windows).
+def skipCPUIfNoMklSparse(fn):
+    return skipCPUIf(
+        IS_WINDOWS or not TEST_MKL, "PyTorch is built without MKL support"
+    )(fn)
+
+
+# Skips a test on CPU if mkldnn is not available.
+def skipCPUIfNoMkldnn(fn):
+    return skipCPUIf(
+        not torch.backends.mkldnn.is_available(),
+        "PyTorch is built without mkldnn support",
+    )(fn)
+
+
+# Skips a test on CUDA if MAGMA is not available.
+def skipCUDAIfNoMagma(fn):
+    return skipCUDAIf("no_magma", "no MAGMA library detected")(
+        skipCUDANonDefaultStreamIf(True)(fn)
+    )
+
+
+def has_cusolver():
+    return not TEST_WITH_ROCM
+
+
+def has_hipsolver():
+    rocm_version = _get_torch_rocm_version()
+    # hipSOLVER is disabled on ROCM < 5.3
+    return rocm_version >= (5, 3)
+
+
+# Skips a test on CUDA/ROCM if cuSOLVER/hipSOLVER is not available
+def skipCUDAIfNoCusolver(fn):
+    return skipCUDAIf(
+        not has_cusolver() and not has_hipsolver(), "cuSOLVER not available"
+    )(fn)
+
+
+# Skips a test if both cuSOLVER and MAGMA are not available
+def skipCUDAIfNoMagmaAndNoCusolver(fn):
+    if has_cusolver():
+        return fn
+    else:
+        # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA
+        return skipCUDAIfNoMagma(fn)
+
+
+# Skips a test if both cuSOLVER/hipSOLVER and MAGMA are not available
+def skipCUDAIfNoMagmaAndNoLinalgsolver(fn):
+    if has_cusolver() or has_hipsolver():
+        return fn
+    else:
+        # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA
+        return skipCUDAIfNoMagma(fn)
+
+
+# Skips a test on CUDA when using ROCm.
+def skipCUDAIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):
+    def dec_fn(fn):
+        reason = f"skipCUDAIfRocm: {msg}"
+        return skipCUDAIf(TEST_WITH_ROCM, reason=reason)(fn)
+
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+
+# Skips a test on CUDA when not using ROCm.
+def skipCUDAIfNotRocm(fn):
+    return skipCUDAIf(
+        not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack"
+    )(fn)
+
+
+# Skips a test on CUDA if ROCm is unavailable or its version is lower than requested.
+def skipCUDAIfRocmVersionLessThan(version=None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if self.device_type == "cuda":
+                if not TEST_WITH_ROCM:
+                    reason = "ROCm not available"
+                    raise unittest.SkipTest(reason)
+                rocm_version_tuple = _get_torch_rocm_version()
+                if (
+                    rocm_version_tuple is None
+                    or version is None
+                    or rocm_version_tuple < tuple(version)
+                ):
+                    reason = (
+                        f"ROCm {rocm_version_tuple} is available but {version} required"
+                    )
+                    raise unittest.SkipTest(reason)
+
+            return fn(self, *args, **kwargs)
+
+        return wrap_fn
+
+    return dec_fn
+
+
+# Skips a test on CUDA when using ROCm.
+def skipCUDAIfNotMiopenSuggestNHWC(fn):
+    return skipCUDAIf(
+        not TEST_WITH_MIOPEN_SUGGEST_NHWC,
+        "test doesn't currently work without MIOpen NHWC activation",
+    )(fn)
+
+
+# Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s.
+def skipCUDAVersionIn(versions: Optional[list[tuple[int, int]]] = None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            version = _get_torch_cuda_version()
+            if version == (0, 0):  # cpu or rocm
+                return fn(self, *args, **kwargs)
+            if version in (versions or []):
+                reason = f"test skipped for CUDA version {version}"
+                raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+
+        return wrap_fn
+
+    return dec_fn
+
+
+# Skips a test for CUDA versions less than specified, given in the form of [major, minor].
+def skipCUDAIfVersionLessThan(versions: Optional[tuple[int, int]] = None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            version = _get_torch_cuda_version()
+            if version == (0, 0):  # cpu or rocm
+                return fn(self, *args, **kwargs)
+            if version < versions:
+                reason = f"test skipped for CUDA versions < {version}"
+                raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+
+        return wrap_fn
+
+    return dec_fn
+
+
+# Skips a test on CUDA if cuDNN is unavailable or its version is lower than requested.
+def skipCUDAIfCudnnVersionLessThan(version=0):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if self.device_type == "cuda":
+                if self.no_cudnn:
+                    reason = "cuDNN not available"
+                    raise unittest.SkipTest(reason)
+                if self.cudnn_version is None or self.cudnn_version < version:
+                    reason = f"cuDNN version {self.cudnn_version} is available but {version} required"
+                    raise unittest.SkipTest(reason)
+
+            return fn(self, *args, **kwargs)
+
+        return wrap_fn
+
+    return dec_fn
+
+
+# Skips a test on CUDA if cuSparse generic API is not available
+def skipCUDAIfNoCusparseGeneric(fn):
+    return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")(
+        fn
+    )
+
+
+def skipCUDAIfNoHipsparseGeneric(fn):
+    return skipCUDAIf(
+        not TEST_HIPSPARSE_GENERIC, "hipSparse Generic API not available"
+    )(fn)
+
+
+def skipCUDAIfNoSparseGeneric(fn):
+    return skipCUDAIf(
+        not (TEST_CUSPARSE_GENERIC or TEST_HIPSPARSE_GENERIC),
+        "Sparse Generic API not available",
+    )(fn)
+
+
+def skipCUDAIfNoCudnn(fn):
+    return skipCUDAIfCudnnVersionLessThan(0)(fn)
+
+
+def skipCUDAIfMiopen(fn):
+    return skipCUDAIf(torch.version.hip is not None, "Marked as skipped for MIOpen")(fn)
+
+
+def skipCUDAIfNoMiopen(fn):
+    return skipCUDAIf(torch.version.hip is None, "MIOpen is not available")(
+        skipCUDAIfNoCudnn(fn)
+    )
+
+
+def skipLazy(fn):
+    return skipLazyIf(True, "test doesn't work with lazy tensors")(fn)
+
+
+def skipMeta(fn):
+    return skipMetaIf(True, "test doesn't work with meta tensors")(fn)
+
+
+def skipXLA(fn):
+    return skipXLAIf(True, "Marked as skipped for XLA")(fn)
+
+
+def skipMPS(fn):
+    return skipMPSIf(True, "test doesn't work on MPS backend")(fn)
+
+
+def skipHPU(fn):
+    return skipHPUIf(True, "test doesn't work on HPU backend")(fn)
+
+
+def skipPRIVATEUSE1(fn):
+    return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)
+
+
+# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now.
+#  This should probably enumerate all available device type test base classes.
+def get_all_device_types() -> list[str]:
+    return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
+
+
+# skip since currently flex attention requires at least `avx2` support on CPU.
+IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED = (
+    not torch.xpu.is_available()
+    and not torch.cuda.is_available()
+    and not IS_MACOS
+    and torch.cpu._is_avx2_supported()
+    and os.getenv("ATEN_CPU_CAPABILITY") != "default"
+)
+flex_attention_supported_platform = unittest.skipUnless(
+    IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED
+    or (
+        torch.cuda.is_available()
+        and torch.utils._triton.has_triton()
+        and torch.cuda.get_device_capability() >= (8, 0)
+    ),
+    "Requires CUDA and Triton, or CPU with avx2 and later",
+)
+if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
+    e4m3_type = torch.float8_e4m3fnuz
+    e5m2_type = torch.float8_e5m2fnuz
+    E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
+    E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
+else:
+    e4m3_type = torch.float8_e4m3fn
+    e5m2_type = torch.float8_e5m2
+    E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
+    E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd14b85a21915ddf8ab415f3bf5dc6e79db14dfc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dist_composable.py
@@ -0,0 +1,112 @@
+# mypy: ignore-errors
+
+# Owner(s): ["oncall: distributed"]
+
+
+import torch
+import torch.nn as nn
+
+
+class UnitModule(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l1 = nn.Linear(100, 100, device=device)
+        self.seq = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(100, 100, device=device),
+            nn.ReLU(),
+        )
+        self.l2 = nn.Linear(100, 100, device=device)
+
+    def forward(self, x):
+        return self.l2(self.seq(self.l1(x)))
+
+
+class CompositeModel(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l1 = nn.Linear(100, 100, device=device)
+        self.u1 = UnitModule(device)
+        self.u2 = UnitModule(device)
+        self.l2 = nn.Linear(100, 100, device=device)
+
+    def forward(self, x):
+        return self.l2(self.u2(self.u1(self.l1(x))))
+
+
+class UnitParamModule(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l = nn.Linear(100, 100, device=device)
+        self.seq = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(100, 100, device=device),
+            nn.ReLU(),
+        )
+        self.p = nn.Parameter(torch.randn((100, 100), device=device))
+
+    def forward(self, x):
+        return torch.mm(self.seq(self.l(x)), self.p)
+
+
+class CompositeParamModel(nn.Module):
+    def __init__(self, device: torch.device):
+        super().__init__()
+        self.l = nn.Linear(100, 100, device=device)
+        self.u1 = UnitModule(device)
+        self.u2 = UnitModule(device)
+        self.p = nn.Parameter(torch.randn((100, 100), device=device))
+        self.register_buffer(
+            "buffer", torch.randn((100, 100), device=device), persistent=True
+        )
+
+    def forward(self, x):
+        a = self.u2(self.u1(self.l(x)))
+        b = self.p
+        return torch.mm(a, b)
+
+
+class FakeSequential(nn.Module):
+    # Define this class to achieve a desired nested wrapping using the module
+    # wrap policy with `nn.Sequential`
+    def __init__(self, *modules: tuple[nn.Module, ...]) -> None:
+        super().__init__()
+        self._module_sequence = list(modules)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        for module in self._module_sequence:
+            x = module(x)
+        return x
+
+
+class NestedSequentialModel(nn.Module):
+    def __init__(self, device: torch.device) -> None:
+        super().__init__()
+        # This nested structure exercises traversal order to catch differences
+        # between valid traversals (e.g. BFS and DFS variations).
+        self.seq1 = nn.Sequential(
+            nn.Linear(1, 1, device=device),
+            FakeSequential(
+                nn.Linear(1, 1, device=device),
+                nn.ReLU(),
+                FakeSequential(
+                    nn.Linear(1, 1, device=device),
+                ),
+                nn.ReLU(),
+            ),
+            nn.Linear(1, 2, device=device),
+        )
+        self.lin = nn.Linear(2, 2, device=device)
+        self.seq2 = nn.Sequential(
+            nn.ReLU(),
+            nn.Linear(2, 3, device=device),
+            FakeSequential(
+                nn.Linear(3, 2, bias=False, device=device),
+                nn.Linear(2, 4, bias=False, device=device),
+            ),
+        )
+
+        # FIXME(rec): forward() is not a method, it's a local function inside __init__
+        # that is never used. It should probabkly be outdented by four spaces, or removed.
+        def forward(self, x: torch.Tensor) -> torch.Tensor:
+            return self.seq2(self.lin(self.seq1(x)))
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9d1c1332724e198a4caec935053b1df9adc9841
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_distributed.py
@@ -0,0 +1,1799 @@
+# mypy: ignore-errors
+
+import faulthandler
+import itertools
+import logging
+import multiprocessing
+import operator
+import os
+import queue
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import traceback
+import types
+import unittest
+from contextlib import contextmanager
+from dataclasses import dataclass
+from datetime import timedelta
+from enum import Enum
+from functools import partial, reduce, wraps
+from io import StringIO
+from typing import Any, Callable, NamedTuple, Optional, Union
+from unittest.mock import patch
+
+import torch
+import torch._dynamo.test_case
+import torch.cuda.nccl
+import torch.distributed as c10d
+import torch.nn as nn
+from torch._C._autograd import DeviceType
+from torch._C._distributed_c10d import _SymmetricMemory
+from torch._logging._internal import trace_log
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    find_free_port,
+    IS_SANDCASTLE,
+    retry_on_connect_failures,
+    skip_but_pass_in_sandcastle,
+    skip_but_pass_in_sandcastle_if,
+    TEST_CUDA,
+    TEST_HPU,
+    TEST_WITH_ROCM,
+    TEST_WITH_TSAN,
+    TEST_XPU,
+    TestCase,
+)
+from torch.testing._internal.distributed.multi_threaded_pg import (
+    _install_threaded_pg,
+    _uninstall_threaded_pg,
+    ProcessLocalGroup,
+)
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+ACCELERATOR_DIST_BACKENDS = ["nccl", "xccl", "hccl"]
+DDP_RANK_DEVICES = ["cuda", "xpu"]
+HAS_ACCELERATOR = TEST_CUDA or TEST_HPU or TEST_XPU
+
+
+class TestSkip(NamedTuple):
+    exit_code: int
+    message: str
+
+
+TEST_SKIPS = {
+    "backend_unavailable": TestSkip(
+        72, "Skipped because distributed backend is not available."
+    ),
+    "small_worldsize": TestSkip(73, "Skipped due to small world size."),
+    "odd_worldsize": TestSkip(87, "Skipped due to odd world size."),
+    "no_cuda": TestSkip(74, "CUDA is not available."),
+    "multi-gpu-1": TestSkip(75, "Need at least 1 CUDA device"),
+    "multi-gpu-2": TestSkip(77, "Need at least 2 CUDA devices"),
+    "multi-gpu-3": TestSkip(80, "Need at least 3 CUDA devices"),
+    "multi-gpu-4": TestSkip(81, "Need at least 4 CUDA devices"),
+    "multi-gpu-5": TestSkip(82, "Need at least 5 CUDA devices"),
+    "multi-gpu-6": TestSkip(83, "Need at least 6 CUDA devices"),
+    "multi-gpu-7": TestSkip(84, "Need at least 7 CUDA devices"),
+    "multi-gpu-8": TestSkip(85, "Need at least 8 CUDA devices"),
+    "nccl": TestSkip(76, "c10d not compiled with NCCL support"),
+    "skipIfRocm": TestSkip(78, "Test skipped for ROCm"),
+    "no_peer_access": TestSkip(79, "Test skipped because no GPU peer access"),
+    "generic": TestSkip(
+        86, "Test skipped at subprocess level, look at subprocess log for skip reason"
+    ),
+    "importerror": TestSkip(88, "Test skipped due to missing import"),
+    "no_accelerator": TestSkip(89, "accelerator is not available."),
+}
+
+
+@dataclass
+class DistTestCases:
+    # Backends that do not support a specific collective
+    skip_collective = {}
+    skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"}
+    skip_collective["reduce"] = set()
+    skip_collective["sendrecv anysource"] = {"nccl", "ucc"}
+    skip_collective["cpu barrier"] = {"nccl", "ucc"}
+
+    # Sets showing that something is implemented
+    backend_feature = {}
+    backend_feature["gpu"] = {"nccl", "gloo", "ucc"}
+    backend_feature["cuda"] = {"nccl", "gloo", "ucc"}
+    backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
+    backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
+    backend_feature["plugin"] = set()
+    if TEST_HPU:
+        backend_feature["hpu"] = {"hccl"}
+    if TEST_XPU:
+        backend_feature["xpu"] = {"xccl"}
+
+
+def requires_ddp_rank(device):
+    return device in DDP_RANK_DEVICES
+
+
+def skip_if_no_gpu(func):
+    """Skips if the world size exceeds the number of GPUs, ensuring that if the
+    test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if not (TEST_CUDA or TEST_HPU or TEST_XPU):
+            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
+        world_size = int(os.environ["WORLD_SIZE"])
+        if TEST_CUDA and torch.cuda.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+        if TEST_HPU and torch.hpu.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+        if TEST_XPU and torch.xpu.device_count() < world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+# TODO (kwen2501): what is the purpose of this decorator?  Tests with this
+# decorator were always skipped. So they may be outdated already.
+# Oct 2024: bumping the small-world criteria to < 8, as we are increasing the
+# number of GPUs in CI from 2 to 4, and we need to continue skipping those tests
+# to keep CI green. But this is just a temporary solution. We should clean up
+# those tests somehow.
+def skip_if_small_worldsize(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) < 8:
+            sys.exit(TEST_SKIPS["small_worldsize"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+def skip_if_odd_worldsize(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) % 2 == 1:
+            sys.exit(TEST_SKIPS["odd_worldsize"].exit_code)
+
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+def require_n_gpus_for_nccl_backend(n, backend):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if backend == "nccl" and torch.cuda.device_count() < n:
+                sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code)
+            else:
+                return func(*args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+def import_transformers_or_skip():
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            try:
+                from transformers import AutoModelForMaskedLM, BertConfig  # noqa: F401
+
+                return func(*args, **kwargs)
+            except ImportError:
+                sys.exit(TEST_SKIPS["importerror"].exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+def at_least_x_gpu(x):
+    if TEST_CUDA and torch.cuda.device_count() >= x:
+        return True
+    if TEST_HPU and torch.hpu.device_count() >= x:
+        return True
+    if TEST_XPU and torch.xpu.device_count() >= x:
+        return True
+    return False
+
+
+def skip_if_lt_x_gpu(x):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
+                return func(*args, **kwargs)
+            if TEST_HPU and torch.hpu.device_count() >= x:
+                return func(*args, **kwargs)
+            if TEST_XPU and torch.xpu.device_count() >= x:
+                return func(*args, **kwargs)
+            sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+# This decorator helps avoiding initializing cuda while testing other backends
+def nccl_skip_if_lt_x_gpu(backend, x):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if backend != "nccl":
+                return func(*args, **kwargs)
+            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
+                return func(*args, **kwargs)
+            sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
+
+        return wrapper
+
+    return decorator
+
+
+def verify_ddp_error_logged(model_DDP, err_substr):
+    # Verify error was logged in ddp_logging_data.
+    ddp_logging_data = model_DDP._get_ddp_logging_data()
+    assert "iteration" in ddp_logging_data
+    assert "has_error" in ddp_logging_data
+    assert "error" in ddp_logging_data
+    logging_err = ddp_logging_data["error"]
+    # Remove C++ stacktrace if needed.
+    actual = (
+        err_substr
+        if err_substr.find("\nException raised from ") == -1
+        else err_substr.split("\nException raised from ")[0]
+    )
+    assert (
+        actual in logging_err
+    ), f"Did not find expected {actual} in ddp logging data error: {logging_err}"
+
+
+def with_nccl_blocking_wait(func):
+    """
+    Convenience decorator to set/unset TORCH_NCCL_BLOCKING_WAIT flag. Note that use of
+    this decorator will override the setting of TORCH_NCCL_ASYNC_ERROR_HANDLING for
+    the particular test. After the test, both TORCH_NCCL_BLOCKING_WAIT and
+    TORCH_NCCL_ASYNC_ERROR_HANDLING will be restored to their original values.
+    """
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        # Save and unset TORCH_NCCL_ASYNC_ERROR_HANDLING
+        try:
+            cached_nccl_async_error_handling: Union[str, None] = os.environ[
+                "TORCH_NCCL_ASYNC_ERROR_HANDLING"
+            ]
+            del os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]
+        except KeyError:
+            # TORCH_NCCL_ASYNC_ERROR_HANDLING was unset
+            cached_nccl_async_error_handling = None
+
+        # Save val of TORCH_NCCL_BLOCKING_WAIT and set it.
+        try:
+            cached_nccl_blocking_wait: Union[str, None] = os.environ[
+                "TORCH_NCCL_BLOCKING_WAIT"
+            ]
+        except KeyError:
+            cached_nccl_blocking_wait = None
+        finally:
+            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
+
+        try:
+            ret = func(*args, **kwargs)
+            return ret
+        finally:
+            # restore old values.
+            if cached_nccl_async_error_handling is not None:
+                os.environ[
+                    "TORCH_NCCL_ASYNC_ERROR_HANDLING"
+                ] = cached_nccl_async_error_handling
+
+            if cached_nccl_blocking_wait is not None:
+                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
+
+    return wrapper
+
+
+def with_dist_debug_levels(levels):
+    """
+    Runs a test for each distributed debug level specified in levels.
+    """
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            old_level = os.environ.get("TORCH_DISTRIBUTED_DEBUG", None)
+            for level in levels:
+                os.environ["TORCH_DISTRIBUTED_DEBUG"] = level
+                c10d.set_debug_level_from_env()
+                ret = func(*args, **kwargs)
+                c10d.barrier()
+                if old_level is not None:
+                    os.environ["TORCH_DISTRIBUTED_DEBUG"] = old_level
+            # Only returns test return for last test, but since these are
+            # unittests the return value is not really used and earlier tests
+            # would've raised had they failed.
+            return ret
+
+        return wrapper
+
+    return decorator
+
+
+def requires_gloo():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_gloo_available(),
+        "c10d was not compiled with the Gloo backend",
+    )
+
+
+def requires_nccl_version(version, msg):
+    if not c10d.is_nccl_available():
+        return skip_but_pass_in_sandcastle(
+            "c10d was not compiled with the NCCL backend",
+        )
+    else:
+        return skip_but_pass_in_sandcastle_if(
+            torch.cuda.nccl.version() < version,
+            f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}",
+        )
+
+
+def requires_nccl():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_nccl_available(),
+        "c10d was not compiled with the NCCL backend",
+    )
+
+
+def requires_ucc():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_ucc_available(),
+        "c10d was not compiled with the UCC backend",
+    )
+
+
+def requires_mpi():
+    return skip_but_pass_in_sandcastle_if(
+        not c10d.is_mpi_available(),
+        "c10d was not compiled with the MPI backend",
+    )
+
+
+def requires_accelerator_dist_backend(backends=None):
+    """
+    Decorator to skip tests if no accelerator communication backend (NCCL, XCCL, HCCL) is available.
+
+    Args:
+        backends (Optional[List[str]]): Specific accelerator backends to check (e.g., ["nccl", "xccl", "hccl"]).
+                                       If None, checks all supported accelerator backends (NCCL, XCCL, HCCL).
+
+    Returns:
+        callable: A decorator that skips the test if no specified accelerator backend is available.
+    """
+    if backends is None:
+        backends = ACCELERATOR_DIST_BACKENDS
+
+    backend_available = any(
+        {
+            "nccl": c10d.is_nccl_available,
+            "xccl": c10d.is_xccl_available,
+            "hccl": lambda: TEST_HPU,
+        }.get(backend, lambda: False)()
+        for backend in backends
+    )
+
+    return skip_but_pass_in_sandcastle_if(
+        not backend_available,
+        f"No accelerator communication backend available among {backends}",
+    )
+
+
+def requires_multicast_support():
+    has_multicast_support = (
+        torch.cuda.is_available()
+        and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
+    )
+    return skip_but_pass_in_sandcastle_if(
+        not has_multicast_support,
+        "multicast support is not available",
+    )
+
+
+def skip_if_rocm_multiprocess(func):
+    """Skips a test for ROCm"""
+    func.skip_if_rocm_multiprocess = True
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        if not TEST_WITH_ROCM:
+            return func(*args, **kwargs)
+        sys.exit(TEST_SKIPS["skipIfRocm"].exit_code)
+
+    return wrapper
+
+
+def skip_if_win32():
+    return skip_but_pass_in_sandcastle_if(
+        sys.platform == "win32",
+        "This unit test case is not supported on Windows platform",
+    )
+
+
+def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
+    """
+    Returns True if the device's compute capability is (major, minor) or higher.
+    Error out if the device is not a CUDA device.
+    Returns False if device is a RoCM device.
+    """
+    if device.type != "cuda":
+        raise ValueError("sm_is_or_later() is only supported for CUDA devices")
+
+    if torch.version.hip is not None:
+        # ROCm devices may have different compute capability codes
+        return False
+
+    return torch.cuda.get_device_capability(device) >= (major, minor)
+
+
+@retry_on_connect_failures
+def create_tcp_store(
+    addr="localhost",
+    world_size=1,
+    is_master=True,
+    timeout=timedelta(minutes=5),
+    wait_for_workers=True,
+    jit_class=False,
+    use_libuv=True,
+):
+    """
+    Creates a TCP store. Retries if the chosen port is already in use.
+    """
+    port = find_free_port()
+    if jit_class:
+        timeout_millisecond = int(timeout / timedelta(milliseconds=1))
+        return torch.classes.dist_c10d.TCPStore(
+            addr, port, world_size, is_master, timeout_millisecond
+        )
+    else:
+        return c10d.TCPStore(
+            addr,
+            port,
+            world_size,
+            is_master,
+            wait_for_workers=wait_for_workers,
+            use_libuv=use_libuv,
+        )
+
+
+if TEST_WITH_TSAN:
+    # TSAN runs much slower.
+    TIMEOUT_DEFAULT = 500
+else:
+    TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300"))
+TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400}
+
+
+# https://github.com/pytorch/pytorch/issues/75665
+if TEST_WITH_ROCM:
+    TIMEOUT_OVERRIDE["test_join_kwargs"] = 200
+
+
+def create_device(interface=None, lazy_init: bool = False):
+    if sys.platform == "win32" or interface is None:
+        return c10d.ProcessGroupGloo.create_device(
+            hostname="127.0.0.1", lazy_init=lazy_init
+        )
+    else:
+        return c10d.ProcessGroupGloo.create_device(
+            interface=interface, lazy_init=lazy_init
+        )
+
+
+def get_timeout(test_id) -> int:
+    return TIMEOUT_OVERRIDE.get(test_id.split(".")[-1], TIMEOUT_DEFAULT)
+
+
+@contextmanager
+def captured_output():
+    new_out, new_err = StringIO(), StringIO()
+    old_out, old_err = sys.stdout, sys.stderr
+    try:
+        sys.stdout, sys.stderr = new_out, new_err
+        yield sys.stdout, sys.stderr
+    finally:
+        sys.stdout, sys.stderr = old_out, old_err
+
+
+def simple_sparse_reduce_tests(rank: int, world_size: int, num_inputs: int = 1):
+    """
+    Generate a number of basic test cases for sparse reduction.
+    These cover tensors with a varying number of sparse dimensions and a varying
+    number of dense dimensions. The only reduction operation we support is sum.
+    """
+
+    def generate(rank: int, world_size: int, sparse_dims: int = 1, dense_dims: int = 0):
+        # First sparse dimension is [0..rank].
+        # Subsequent dimensions are always 0, so we know there is
+        # a non-empty intersection between any two sparse tensors.
+        indices = torch.reshape(torch.arange(rank + 1), (1, rank + 1))
+        shape = [world_size] + [2 for _ in range(dense_dims)]
+        for _ in range(sparse_dims - 1):
+            indices = torch.cat((indices, torch.zeros(1, rank + 1)))
+            shape.append(world_size)
+        values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
+        return torch.sparse_coo_tensor(indices, values, shape)
+
+    def compute_sum(fn, world_size: int):
+        return reduce(
+            operator.add, [fn(rank, world_size) for rank in range(world_size)]
+        )
+
+    return [
+        (
+            [
+                fn(num_inputs * rank + i, num_inputs * world_size)
+                for i in range(num_inputs)
+            ],
+            [compute_sum(fn, num_inputs * world_size) for i in range(num_inputs)],
+        )
+        for fn in [
+            partial(generate, sparse_dims=1),
+            partial(generate, sparse_dims=2),
+            partial(generate, sparse_dims=3),
+            partial(generate, dense_dims=1),
+            partial(generate, dense_dims=2),
+            partial(generate, dense_dims=3),
+        ]
+    ]
+
+
+# HELPER FOR MULTIGPU TESTS
+def init_multigpu_helper(world_size: int, backend: str):
+    """Multigpu tests are designed to simulate the multi nodes with multi
+    GPUs on each node. Nccl backend requires equal #GPUs in each process.
+    On a single node, all visible GPUs are evenly
+    divided to subsets, each process only uses a subset.
+    """
+    nGPUs = torch.cuda.device_count()
+    if TEST_HPU:
+        nGPUs = torch.hpu.device_count()
+    if TEST_XPU:
+        nGPUs = torch.xpu.device_count()
+    visible_devices = range(nGPUs)
+
+    # If rank is less than or equal to number of available GPU's
+    # then each rank can be mapped to corresponding GPU.
+    nGPUs_per_process = 1
+    if world_size > nGPUs:
+        nGPUs_per_process = nGPUs // world_size
+    rank_to_GPU = {
+        i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process])
+        for i in range(world_size)
+    }
+    return rank_to_GPU
+
+
+tmp_dir: Optional[tempfile.TemporaryDirectory] = None
+
+
+def initialize_temp_directories(init_method: Optional[str] = None) -> None:
+    global tmp_dir
+    tmp_dir = tempfile.TemporaryDirectory()
+    os.environ["TEMP_DIR"] = tmp_dir.name
+    os.mkdir(os.path.join(tmp_dir.name, "barrier"))
+    os.mkdir(os.path.join(tmp_dir.name, "test_dir"))
+    init_dir_path = os.path.join(tmp_dir.name, "init_dir")
+    os.mkdir(init_dir_path)
+    # Set init method if specified.
+    if init_method is not None:
+        os.environ["INIT_METHOD"] = init_method
+    else:
+        os.environ["INIT_METHOD"] = FILE_SCHEMA + os.path.join(
+            init_dir_path, "shared_init_file"
+        )
+
+
+def cleanup_temp_dir() -> None:
+    if tmp_dir is not None:
+        tmp_dir.cleanup()
+
+
+# Most tests operate with this worldsize
+DEFAULT_WORLD_SIZE = 4
+
+# [How does MultiProcessTestCase work?]
+# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
+# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an
+# example which inherits from this class. Its `Setup()` methods calls into
+# `MultiProcessTestCase._spawn_processes()` which spawns `world_size()`
+# subprocesses. During the spawn, the main process passes the test name to
+# subprocesses, and the name is acquired from self.id(). The subprocesses
+# then use the provided test function name to retrieve the function attribute
+# from the test instance and run it. The main process simply waits for all
+# subprocesses to join.
+
+
+class MultiProcessTestCase(TestCase):
+    MAIN_PROCESS_RANK = -1
+    # This exit code is used to indicate that the test code had an error and
+    # exited abnormally. There are certain tests that might use sys.exit() to
+    # simulate failures and in those cases, we can't have an exit code of 0,
+    # but we still want to ensure we didn't run into any other errors.
+    TEST_ERROR_EXIT_CODE = 10
+
+    # do not early terminate for distributed tests.
+    def _should_stop_test_suite(self) -> bool:
+        return False
+
+    # Many test cases init a process group but do not destroy it.  This property
+    # determines whether this base test class should call
+    # `destroy_process_group` on behalf of the test. Its value is customizable
+    # by derived TestCase's but it is a pan-TestCase value (cannot be customized
+    # for each test).
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        return True
+
+    @property
+    def world_size(self) -> int:
+        return DEFAULT_WORLD_SIZE
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_PROCESS_RANK:
+                self._join_processes(fn)
+            else:
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    # The main process spawns N subprocesses that run the test.
+    # Constructor patches current instance test method to
+    # assume the role of the main process and join its subprocesses,
+    # or run the underlying test function.
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self.join_or_run(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
+
+    def setUp(self) -> None:
+        super().setUp()
+
+        # Used for tests that are expected to return a non-0 exit code, such as
+        # SIGABRT thrown by watchdog.
+        self.special_return_code_checks: dict = {}
+
+        # Used for tests that may return any exit code, which makes it hard to
+        # check. This is rare, use with caution.
+        self.skip_return_code_checks: list = []
+
+        self.processes = []  # type: ignore[var-annotated]
+        self.rank = self.MAIN_PROCESS_RANK
+        self.file_name = tempfile.NamedTemporaryFile(delete=False).name
+        # pid to pipe consisting of error message from process.
+        self.pid_to_pipe = {}  # type: ignore[var-annotated]
+
+    def tearDown(self) -> None:
+        super().tearDown()
+        for p in self.processes:
+            p.terminate()
+        # Each Process instance holds a few open file descriptors. The unittest
+        # runner creates a new TestCase instance for each test method and keeps
+        # it alive until the end of the entire suite. We must thus reset the
+        # processes to prevent an effective file descriptor leak.
+        self.processes = []
+
+    def _current_test_name(self) -> str:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        return self.id().split(".")[-1]
+
+    def _start_processes(self, proc) -> None:
+        self.processes = []
+        for rank in range(int(self.world_size)):
+            parent_conn, child_conn = torch.multiprocessing.Pipe()
+            process = proc(
+                target=self.__class__._run,
+                name="process " + str(rank),
+                args=(rank, self._current_test_name(), self.file_name, child_conn),
+                kwargs={
+                    "fake_pg": getattr(self, "fake_pg", False),
+                },
+            )
+            process.start()
+            logger.info("Started process %s with pid %s", rank, process.pid)
+            self.pid_to_pipe[process.pid] = parent_conn
+            self.processes.append(process)
+
+    def _spawn_processes(self) -> None:
+        try:
+            torch.multiprocessing.set_start_method("spawn")
+        except RuntimeError:
+            pass
+
+        proc = torch.multiprocessing.get_context("spawn").Process
+        self._start_processes(proc)
+
+    class Event(Enum):
+        GET_TRACEBACK = 1
+
+    @staticmethod
+    def _event_listener(parent_pipe, signal_pipe, rank: int):
+        logger.debug("Starting event listener thread for rank %s", rank)
+        while True:
+            ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe])
+
+            if parent_pipe in ready_pipes:
+                if parent_pipe.closed:
+                    logger.debug(
+                        "Pipe closed for process %s, stopping event listener thread",
+                        rank,
+                    )
+                    return
+
+                event = parent_pipe.recv()
+                logger.info("Received event %s on process %s", event, rank)
+
+                if event == MultiProcessTestCase.Event.GET_TRACEBACK:
+                    # Return traceback to the parent process.
+                    with tempfile.NamedTemporaryFile(mode="r+") as tmp_file:
+                        faulthandler.dump_traceback(tmp_file)
+                        # Flush buffers and seek to read from the beginning
+                        tmp_file.flush()
+                        tmp_file.seek(0)
+                        parent_pipe.send(tmp_file.read())
+
+                        logger.info("Process %s sent traceback", rank)
+
+            if signal_pipe in ready_pipes:
+                return
+
+    @classmethod
+    def _run(
+        cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
+    ) -> None:
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        self.run_test(test_name, parent_pipe)
+
+    def run_test(self, test_name: str, parent_pipe) -> None:
+        # Start event listener thread.
+        signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(duplex=False)
+        event_listener_thread = threading.Thread(
+            target=MultiProcessTestCase._event_listener,
+            args=(parent_pipe, signal_recv_pipe, self.rank),
+            daemon=True,
+        )
+        event_listener_thread.start()
+        if sys.platform != "win32" and sys.platform != "darwin":
+            # Register signal handler to dump stack traces on FATALs.
+            # Windows and MacOS do not support the signal handlers.
+            torch._C._set_print_stack_traces_on_fatal_signal(True)
+        # Show full C++ stacktraces when a Python error originating from C++ is raised.
+        os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
+
+        # self.id() == e.g. '__main__.TestDistributed.test_get_rank'
+        # We're retrieving a corresponding test and executing it.
+        try:
+            getattr(self, test_name)()
+        except unittest.SkipTest as se:
+            logger.info(
+                "Process %s skipping test %s for following reason: %s",
+                self.rank,
+                test_name,
+                str(se),
+            )
+            sys.exit(TEST_SKIPS["generic"].exit_code)
+        except Exception:
+            logger.error(
+                "Caught exception: \n%s exiting " "process %s with exit code: %s",
+                traceback.format_exc(),
+                self.rank,
+                MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
+            )
+            # Send error to parent process.
+            parent_pipe.send(traceback.format_exc())
+            sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
+        finally:
+            if signal_send_pipe is not None:
+                signal_send_pipe.send(None)
+
+            assert event_listener_thread is not None
+            event_listener_thread.join()
+            # Close pipe after done with test.
+            parent_pipe.close()
+
+        if self.destroy_pg_upon_exit:
+            try:
+                # Some tests do destroy the pgs, and destroy can't be called twice.
+                # This avoids spewing warnings about improperly shutting down.
+                c10d.destroy_process_group()
+            except (AssertionError, ValueError):
+                pass
+
+    def _get_timedout_process_traceback(self) -> None:
+        pipes = []
+        for i, process in enumerate(self.processes):
+            if process.exitcode is None:
+                pipe = self.pid_to_pipe[process.pid]
+                try:
+                    pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK)
+                    pipes.append((i, pipe))
+                except ConnectionError as e:
+                    logger.error(
+                        "Encountered error while trying to get traceback for process %s: %s",
+                        i,
+                        e,
+                    )
+
+        # Wait for results.
+        for rank, pipe in pipes:
+            try:
+                # Wait for traceback
+                if pipe.poll(5):
+                    if pipe.closed:
+                        logger.info(
+                            "Pipe closed for process %s, cannot retrieve traceback",
+                            rank,
+                        )
+                        continue
+
+                    traceback = pipe.recv()
+                    logger.error(
+                        "Process %s timed out with traceback: \n\n%s", rank, traceback
+                    )
+                else:
+                    logger.error(
+                        "Could not retrieve traceback for timed out process: %s", rank
+                    )
+            except ConnectionError as e:
+                logger.error(
+                    "Encountered error while trying to get traceback for process %s: %s",
+                    rank,
+                    e,
+                )
+
+    def _join_processes(self, fn) -> None:
+        timeout = get_timeout(self.id())
+        start_time = time.time()
+        subprocess_error = False
+        try:
+            while True:
+                # check to see if any subprocess exited with an error early.
+                for i, p in enumerate(self.processes):
+                    # This is the exit code processes exit with if they
+                    # encountered an exception.
+                    if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
+                        print(
+                            f"Process {i} terminated with exit code {p.exitcode}, terminating remaining processes."
+                        )
+                        active_children = torch.multiprocessing.active_children()
+                        for ac in active_children:
+                            ac.terminate()
+                        subprocess_error = True
+                        break
+                if subprocess_error:
+                    break
+                # All processes have joined cleanly if they all a valid exitcode
+                if all(p.exitcode is not None for p in self.processes):
+                    break
+                # Check if we should time out the test. If so, we terminate each process.
+                elapsed = time.time() - start_time
+                if elapsed > timeout:
+                    self._get_timedout_process_traceback()
+                    print(
+                        f"Timing out after {timeout} seconds and killing subprocesses."
+                    )
+                    for p in self.processes:
+                        p.terminate()
+                    break
+                # Sleep to avoid excessive busy polling.
+                time.sleep(0.1)
+
+            elapsed_time = time.time() - start_time
+            self._check_return_codes(fn, elapsed_time)
+        finally:
+            # Close all pipes
+            for pipe in self.pid_to_pipe.values():
+                pipe.close()
+
+    def _check_return_codes(self, fn, elapsed_time) -> None:
+        """
+        Checks that the return codes of all spawned processes match, and skips
+        tests if they returned a return code indicating a skipping condition.
+        """
+        # If no processes are spawned, there is nothing to check.
+        if not self.processes:
+            logger.warning(
+                "Note: no subprocesses were spawned, test was likely skipped."
+            )
+            return
+
+        first_process = self.processes[0]
+        # first, we check if there are errors in actual processes
+        # (via TEST_ERROR_EXIT CODE), and raise an exception for those.
+        # the reason we do this is to attempt to raise a more helpful error
+        # message than "Process x terminated/timed out"
+        # TODO: we should pipe the exception of the failed subprocess here.
+        # Currently, the actual exception is displayed as a logging output.
+        errored_processes = [
+            (i, p)
+            for i, p in enumerate(self.processes)
+            if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE
+        ]
+        if errored_processes:
+            error = ""
+            for i, process in errored_processes:
+                # Get error from pipe.
+                error_message = self.pid_to_pipe[process.pid].recv()
+                error += (
+                    f"Process {i} exited with error code {MultiProcessTestCase.TEST_ERROR_EXIT_CODE} "
+                    f"and exception:\n{error_message}\n"
+                )
+
+            raise RuntimeError(error)
+        # If no process exited uncleanly, we check for timeouts, and then ensure
+        # each process exited cleanly.
+        for i, p in enumerate(self.processes):
+            if p.exitcode is None:
+                raise RuntimeError(
+                    f"Process {i} terminated or timed out after {elapsed_time} seconds"
+                )
+
+        # Skip the test return code check
+        if fn in self.skip_return_code_checks:
+            return
+
+        for skip in TEST_SKIPS.values():
+            if first_process.exitcode == skip.exit_code:
+                if IS_SANDCASTLE:
+                    # Don't use unittest.skip to skip the test on sandcastle
+                    # since it creates tasks for skipped tests assuming there
+                    # is some follow-up needed. Instead just "pass" the test
+                    # with an appropriate message.
+                    logger.info(
+                        "Skipping %s on sandcastle for the following reason: %s",
+                        self.id(),
+                        skip.message,
+                    )
+                    return
+                else:
+                    raise unittest.SkipTest(skip.message)
+
+        # In most cases, we expect test to return exit code 0, standing for success.
+        expected_return_code = 0
+        # In some negative tests, we expect test to return non-zero exit code,
+        # such as watchdog throwing SIGABRT.
+        if fn in self.special_return_code_checks:
+            expected_return_code = self.special_return_code_checks[fn]
+
+        self.assertEqual(
+            first_process.exitcode,
+            expected_return_code,
+            msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
+        )
+
+    @property
+    def is_master(self) -> bool:
+        return self.rank == 0
+
+
+# Utility base class for distributed Multi Process Test cases
+# This abstracts the PG creation and deletion, the backends are selected based
+# on device type. The tests functions can be instantiated per device type using
+# common_device_type.instantiate_device_type_tests
+# other backends can add entry in backend() function
+class DistributedTestBase(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        os.environ["WORLD_SIZE"] = str(self.world_size)
+        self._spawn_processes()
+
+    def tearDown(self):
+        try:
+            torch.distributed.destroy_process_group()
+        except AssertionError:
+            pass
+        try:
+            os.remove(self.file_name)
+        except OSError:
+            pass
+
+    def backend(self, device) -> str:
+        if "cuda" in device:
+            return "nccl"
+        elif "hpu" in device:  # intel gaudi
+            return "hccl"
+        elif "xpu" in device:
+            return "xccl"
+        else:
+            return "gloo"
+
+    def create_pg(self, device, world_size=None):
+        if world_size is None:
+            world_size = self.world_size
+        num_visible_devices = torch.get_device_module(device).device_count()
+        store = torch.distributed.FileStore(self.file_name, num_visible_devices)
+        torch.distributed.init_process_group(
+            backend=self.backend(device),
+            world_size=world_size,
+            rank=self.rank,
+            store=store,
+        )
+        if "nccl" in self.backend(device) or "xccl" in self.backend(device):
+            torch.accelerator.set_device_index(self.rank)
+        return torch.distributed.distributed_c10d._get_default_group()
+
+    def rank_to_device(self, device):
+        num_visible_devices = torch.get_device_module(device).device_count()
+        return {i: [i % num_visible_devices] for i in range(self.world_size)}
+
+
+def run_subtests(
+    cls_inst,
+    subtest_config: dict[str, list[Any]],
+    test_fn: Callable,
+    *test_args,
+    **test_kwargs: Any,
+):
+    """
+    Runs a test function given by ``test_fn`` as a subtest according to the
+    configurations specified by ``subtest_config``. This amortizes the
+    costly setup overhead (including process spawn and initializing the
+    process group) over the subtests.
+
+    Args:
+        subtest_config (Dict[str, List[Any]]): A mapping from subtest
+            keyword argument name to a list of its possible values.
+        test_fn (Callable): A callable that runs the actual test.
+        test_args: Positional arguments to pass to ``test_fn``.
+        test_kwargs: Keyword arguments to pass to ``test_fn``.
+    """
+    # Convert the config mapping to a list to have a fixed order
+    subtest_config_items: list[tuple[str, list[Any]]] = list(subtest_config.items())
+    subtest_config_keys: list[str] = [item[0] for item in subtest_config_items]
+    subtest_config_values: list[list[Any]] = [item[1] for item in subtest_config_items]
+    for values in itertools.product(*subtest_config_values):
+        # Map keyword to chosen value
+        subtest_kwargs = dict(zip(subtest_config_keys, values))
+        with cls_inst.subTest(**subtest_kwargs):
+            torch._dynamo.reset()
+            test_fn(*test_args, **test_kwargs, **subtest_kwargs)
+            torch._dynamo.reset()
+        c10d.barrier()
+
+
+# Cannot use functools.cache as it requires python 3.9
+EFA_PROBE_RESULT = None
+
+
+def has_efa() -> bool:
+    """
+    If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has
+    Libfabric EFA interfaces and EFA software components installed,
+    see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-start.html.
+    """
+    global EFA_PROBE_RESULT
+    if EFA_PROBE_RESULT is not None:
+        return EFA_PROBE_RESULT
+
+    try:
+        EFA_PROBE_RESULT = (
+            subprocess.run(
+                ["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False
+            ).returncode
+            == 0
+        )
+    except FileNotFoundError:
+        EFA_PROBE_RESULT = False
+    return EFA_PROBE_RESULT
+
+
+def tp_transports():
+    """
+    If the machine has Libfabric EFA interfaces and EFA software components installed it may cause
+    'RuntimeError: In operator() at tensorpipe/common/ibv.h:172 "": Operation not supported' if tensorpipe
+    uses InfiniBand transport, so we exclude it from tensorpipe transports,
+    see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
+    """
+    return ["shm", "uv"] if has_efa() else None
+
+
+def spawn_threads_and_init_comms(
+    func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE
+):
+    """
+    Wrapper to use with a test method
+    """
+    if func is None:
+        return partial(
+            spawn_threads_and_init_comms, timeout=timeout, world_size=world_size
+        )
+
+    def _run_test_method_with_multi_threads(world_size, callback):
+        world = _install_threaded_pg()
+        global_store = c10d.HashStore()
+
+        def world_is_valid():
+            return world == c10d.distributed_c10d._world
+
+        def worker(rank, world_pg, store):
+            c10d.init_process_group(
+                backend="threaded", rank=rank, world_size=world_size, store=store
+            )
+            try:
+                callback()
+            except BaseException as ex:
+                # Exceptions are handled in MultiThreadedTestCase
+                MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
+                ProcessLocalGroup.exception_handle(
+                    ex
+                )  # trigger _terminate event and awaken worker threads
+            finally:
+                if world_is_valid():
+                    c10d.destroy_process_group()
+
+        threads = []
+        for rank in range(world_size):
+            t = threading.Thread(target=worker, args=(rank, world, global_store))
+            t.start()
+            threads.append(t)
+
+        return threads
+
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        # TODO: get test name from kwargs
+        torch._C._distributed_c10d._set_thread_isolation_mode(True)
+        try:
+            threads = _run_test_method_with_multi_threads(
+                world_size, lambda: func(self, *args, **kwargs)
+            )
+            # join and error handling
+            MultiThreadedTestCase._join_threads(threads, func)
+        finally:
+            torch._C._distributed_c10d._set_thread_isolation_mode(False)
+
+    return wrapper
+
+
+class MultiThreadedTestCase(TestCase):
+    """
+    Test runner that runs all tests with the in-proc process group using
+    multiple threads with the threaded process group.
+
+    Each test spawns world_size threads and run the test method in each thread.
+
+    Difference from regular MultiProcess test runner:
+    Must explicitly defines SetUp and call self._spawn_threads() to run the tests.
+    Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
+        to set up / tear down each thread when running each test.
+    No global state possible
+        How bad of a limitation is this?
+    """
+
+    exception_queue = queue.Queue()
+
+    MAIN_THREAD_RANK = -1
+
+    def join_or_run(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_THREAD_RANK:
+                self._join_threads(self.threads, fn)
+            else:
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self.join_or_run(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
+
+    def perThreadSetUp(self):
+        # super().setUp()  # TestCase.setUp() calls torch.manual_seed()
+        pass
+
+    def perThreadTearDown(self):
+        pass
+
+    def setUp(self) -> None:
+        """
+        setUp only set up things in the main thread, if you want to configure things
+        in the spawned threads, use perThreadSetUp
+        """
+        super().setUp()
+        self.rank = self.MAIN_THREAD_RANK
+        self.threads = []
+        # Show full C++ stacktraces when a Python error originating from C++ is raised.
+        os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
+
+    def tearDown(self):
+        """
+        tearDown only set up things in the main thread, if you want to configure things
+        in the spawned threads, use perThreadTearDown
+        """
+        super().tearDown()
+        self.threads = []
+
+    def _spawn_threads(self):
+        """
+        class method to spawn threads and run test, use this method in the SetUp of your TestCase
+        """
+        torch._C._distributed_c10d._set_thread_isolation_mode(True)
+        test_name = self._current_test_name
+        # for each test case, we need to create thread local world, and a global store
+        world = _install_threaded_pg()
+        self.__class__.global_store = c10d.HashStore()
+
+        def world_is_valid():
+            return world == c10d.distributed_c10d._world
+
+        if not world_is_valid():
+            raise RuntimeError("Invalid world")
+
+        for rank in range(self.world_size):
+            t = threading.Thread(
+                target=self.__class__._run, args=(test_name, rank, self.world_size)
+            )
+            t.start()
+            self.threads.append(t)
+
+    @classmethod
+    def _run(cls, test_name, rank, world_size, **kwargs):
+        self = cls(test_name)
+        self.rank = rank
+
+        # precision/rel_tol is a thread-local setting since it may be overridden per test, need to make
+        # every thread have the same value. This would be relevant when we use op db tests, where it
+        # needs those states to be set i.e. using instantiate_device_type_tests()
+        # TODO: figure out a better way to do this
+        if hasattr(self, "_tls"):
+            self._tls = threading.local()
+            self._tls.precision = TestCase._precision
+            self._tls.rel_tol = TestCase._rel_tol
+
+        self.run_test_with_threaded_pg(test_name, rank, world_size)
+
+    def run_test_with_threaded_pg(self, test_name, rank, world_size):
+        """
+        Run the current test associated with `test_name` using the threaded process group.
+        """
+        c10d.init_process_group(
+            backend="threaded",
+            rank=rank,
+            world_size=world_size,
+            store=self.__class__.global_store,
+        )
+        self.perThreadSetUp()
+
+        try:
+            getattr(self, test_name)()
+        except BaseException as ex:
+            self.exception_queue.put((rank, sys.exc_info()))
+            ProcessLocalGroup.exception_handle(
+                ex
+            )  # trigger _terminate event and awaken worker threads
+        finally:
+            c10d.destroy_process_group()
+            self.perThreadTearDown()
+
+    @classmethod
+    def _join_threads(cls, threads, fn):
+        timeout = TIMEOUT_DEFAULT
+        try:
+            for idx, thread in enumerate(threads):
+                thread.join(max(0, timeout))
+                if thread.is_alive():
+                    MultiThreadedTestCase.exception_queue.put(
+                        (
+                            idx,
+                            (
+                                TimeoutError,
+                                TimeoutError(
+                                    f"Rank failed to join in under {timeout} seconds"
+                                ),
+                                None,
+                            ),
+                        )
+                    )
+            ProcessLocalGroup.reset()
+            failed_ranks = []
+            while not cls.exception_queue.empty():
+                failure = cls.exception_queue.get()
+                failed_ranks.append(failure)
+        finally:
+            _uninstall_threaded_pg()
+            torch._C._distributed_c10d._set_thread_isolation_mode(False)
+
+        cls._check_return_codes(failed_ranks, timeout, fn)
+
+    @classmethod
+    def _check_return_codes(cls, failed_ranks, timeout, fn):
+        # Print based on exceptions raised from threads
+        #   SkipTest: print info for each thread
+        #   TimeoutError: raise RuntimeError for any timed out thread
+        #   Normal Exception: print error for each thread that raises exception
+        #   and raise a RuntimeError
+        error_msg = ""
+        skip_code = -1
+        for rank, exc_info in failed_ranks:
+            exc = exc_info[1]
+            if isinstance(exc, unittest.SkipTest):
+                logger.info(
+                    "Thread %s skipping test %s for following reason: %s",
+                    rank,
+                    fn,
+                    str(exc),
+                )
+                if skip_code < 0:
+                    skip_code = TEST_SKIPS["generic"].exit_code
+            elif isinstance(exc, TimeoutError):
+                msg = f"Thread {rank} terminated or timed out after {timeout} seconds\n"
+                logger.error(msg)
+                raise RuntimeError(msg)
+            elif isinstance(exc, Exception):
+                msg = "".join(traceback.format_exception(*exc_info))
+                logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
+                error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
+            elif isinstance(exc, SystemExit):
+                if type(exc.code) == int and skip_code < 0:
+                    skip_code = exc.code
+
+        # check exceptions
+        if len(error_msg) > 0:
+            raise RuntimeError(error_msg)
+        # check skip
+        if skip_code > 0:
+            for skip in TEST_SKIPS.values():
+                if skip_code == skip.exit_code:
+                    if IS_SANDCASTLE:
+                        # "pass" the test with an appropriate message.
+                        logger.info(
+                            "Skipping %s on sandcastle for the following reason: %s",
+                            fn,
+                            skip.message,
+                        )
+                        return
+                    else:
+                        raise unittest.SkipTest(skip.message)
+
+    @property
+    def world_size(self) -> int:
+        return DEFAULT_WORLD_SIZE
+
+    @property
+    def _current_test_name(self) -> str:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        return self.id().split(".")[-1]
+
+    def assertEqualOnRank(self, x, y, msg=None, *, rank=0):
+        """
+        The reason why we have this util function instead of
+        self.assertEqual is all threads are sharing one CPU RNG
+        so the assertion result is only reliable on rank 0
+        """
+        if self.rank == rank:
+            self.assertEqual(x, y, msg)
+
+    def assertNotEqualOnRank(self, x, y, msg=None, *, rank=0):
+        if self.rank == rank:
+            self.assertNotEqual(x, y)
+
+
+class SaveForwardInputsModule(nn.Module):
+    def __init__(
+        self,
+        forward_inputs: dict[nn.Module, torch.Tensor],
+        cast_forward_inputs: bool,
+    ) -> None:
+        super().__init__()
+        self.l = nn.Linear(100, 100)
+        self.forward_inputs = forward_inputs
+        self.cast_forward_inputs = cast_forward_inputs
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        self.forward_inputs[self] = x
+        return self.l(x.to(self.l.weight.dtype) if self.cast_forward_inputs else x)
+
+
+class SaveForwardInputsModel(nn.Module):
+    def __init__(
+        self,
+        forward_inputs: dict[nn.Module, torch.Tensor],
+        cast_forward_inputs: bool,
+    ) -> None:
+        super().__init__()
+        self.c1 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs)
+        self.c2 = SaveForwardInputsModule(forward_inputs, cast_forward_inputs)
+        self.forward_inputs = forward_inputs
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        self.forward_inputs[self] = x
+        return self.c2(self.c1(x))
+
+
+@contextmanager
+def _dynamo_dist_per_rank_init(
+    rank, world_size, backend="nccl", init_pg=True, fake_pg=False
+):
+    # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
+    # Just manually implement the most important part of the dynamo behavior to reset/clear.
+    if not fake_pg:
+        torch.accelerator.set_device_index(rank)
+    os.environ["MASTER_ADDR"] = "localhost"
+    os.environ["MASTER_PORT"] = "6789"
+    if init_pg:
+        if fake_pg:
+            store = torch.testing._internal.distributed.fake_pg.FakeStore()
+            c10d.init_process_group(
+                backend="fake",
+                world_size=world_size,
+                rank=rank,
+                store=store,
+            )
+        else:
+            c10d.init_process_group(backend=backend, rank=rank, world_size=world_size)
+    torch._dynamo.reset()
+    torch._dynamo.utils.counters.clear()
+    try:
+        yield
+    finally:
+        torch._dynamo.reset()
+        torch._dynamo.utils.counters.clear()
+        if init_pg:
+            c10d.destroy_process_group()
+
+
+class DynamoDistributedSingleProcTestCase(torch._dynamo.test_case.TestCase):
+    """
+    Test harness for single-process dynamo distributed tests,
+    initializes dist process group.
+
+    Prefer this for simple tests, as it's easier to debug.
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        # _exit_stack is set up in TestCase
+        cls._exit_stack.enter_context(
+            patch.dict(
+                os.environ,
+                {
+                    "MASTER_ADDR": "localhost",
+                    "MASTER_PORT": "12355",
+                },
+            )
+        )
+        cls.rank = 0
+        cls.device = f"cuda:{cls.rank}"
+        cls.device_ids = None if "cuda" in cls.device else [cls.rank]
+        c10d.init_process_group("nccl", rank=cls.rank, world_size=1)
+
+    @classmethod
+    def tearDownClass(cls):
+        c10d.destroy_process_group()
+        super().tearDownClass()
+
+
+class DynamoDistributedMultiProcTestCase(DistributedTestBase):
+    """
+    Use this for tests that actually run on multiple GPUs.
+
+    Decorate tests with @skip_if_lt_x_gpu(ngpu)
+
+    Note: MultiProcTestCase spawns processes per test and is slow.
+    Prefer MultiThreadedTestCase for most tests. Perhaps use this one
+    sparingly for integration tests.
+    """
+
+    @property
+    def world_size(self) -> int:
+        return torch.accelerator.device_count()
+
+    @classmethod
+    def _run(
+        cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
+    ) -> None:
+        trace_log.addHandler(logging.NullHandler())
+
+        # The rest is copypasta from MultiProcessTestCase._run
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        self.run_test(test_name, parent_pipe)
+
+
+class MultiProcContinousTest(TestCase):
+    # Class variables:
+    MAIN_PROCESS_RANK = -1
+    # number of test processes
+    world_size: int = -2  # unset state
+    # rank of the current process
+    rank: int = -2  # unset state
+    # Rendezvous file
+    rdvz_file: Optional[str] = None
+    # timeout configured per class
+    timeout: timedelta = timedelta(seconds=120)
+    # Poison pill for rest of tests if one of them fails
+    poison_pill: bool = False
+
+    @classmethod
+    def backend_str(cls) -> Optional[str]:
+        """
+        ProcessGroup backend str.
+        To be customized by sub test classes, e.g. "nccl".
+        Otherwise we return None -- lazily decided by tensor.
+        """
+        return None
+
+    # Please override if you intend to test on specific device type
+    @classmethod
+    def device_type(cls) -> str:
+        curr_device = torch.accelerator.current_accelerator()
+        if curr_device is None:
+            return "cpu"
+        return curr_device.type
+
+    @classmethod
+    def opts(cls, high_priority_stream=False):
+        """
+        ProcessGroup init options.
+        To be customized by sub test classes, e.g. ProcessGroupNCCLOpTest
+        Here we return None.
+        """
+        return None
+
+    @classmethod
+    def _init_pg(cls, rank, world_size, rdvz_file):
+        assert rdvz_file is not None
+        store = c10d.FileStore(rdvz_file, world_size)
+
+        # create nccl processgroup with opts
+        c10d.init_process_group(
+            backend=cls.backend_str(),
+            world_size=world_size,
+            rank=rank,
+            store=store,
+            pg_options=cls.opts(),
+            timeout=cls.timeout,
+        )
+        cls.pg = c10d.distributed_c10d._get_default_group()
+
+    @classmethod
+    def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
+        # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
+        test_name = test_id.split(".")[-1]
+        # Get the test function from the test class
+        self = cls(test_name)
+        self.rank = cls.rank
+        self.world_size = cls.world_size
+        test_fn = getattr(self, test_name)
+        # Run the test function
+        test_fn(**kwargs)
+
+    @classmethod
+    def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue):
+        # Sub tests are going to access these values, check first
+        assert 0 <= rank < world_size
+        # set class variables for the test class
+        cls.rank = rank
+        cls.world_size = world_size
+
+        # Initialize the process group
+        cls._init_pg(rank, world_size, rdvz_file)
+
+        # End of bootstrap
+        logger.info("Setup complete")
+
+        # Loop forever, waiting for a test name to run
+        while True:
+            test_id = task_queue.get()
+            logger.debug(f"Got test {test_id}")  # noqa: G004
+            # None means exit
+            if test_id is None:
+                break
+
+            # Run the test
+            try:
+                cls._run_test_given_id(test_id)
+                completion_queue.put(test_id)
+            except BaseException as ex:
+                # Send the exception back to the dispatcher
+                completion_queue.put(ex)
+
+        # Termination
+        logger.info("Terminating ...")
+        c10d.destroy_process_group()
+
+    @classmethod
+    def _spawn_processes(cls, world_size) -> None:
+        cls.processes = []
+        cls.task_queues = []
+        cls.completion_queues = []
+        # Need a rendezvous file for `init_process_group` purpose.
+        cls.rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
+
+        # CUDA multiprocessing requires spawn instead of fork, to make sure
+        # child processes have their own memory space.
+        try:
+            torch.multiprocessing.set_start_method("spawn")
+        except RuntimeError:
+            # The start method has already been set
+            pass
+
+        for rank in range(int(world_size)):
+            task_queue = torch.multiprocessing.Queue()
+            completion_queue = torch.multiprocessing.Queue()
+            process = torch.multiprocessing.Process(
+                target=cls._worker_loop,
+                name="process " + str(rank),
+                daemon=True,  # so that child processes will exit if parent decides to terminate
+                args=(rank, world_size, cls.rdvz_file, task_queue, completion_queue),
+            )
+            process.start()
+            cls.processes.append(process)
+            cls.task_queues.append(task_queue)
+            cls.completion_queues.append(completion_queue)
+            logger.info(
+                "Started process %s with pid %s", rank, process.pid
+            )  # noqa: UP031
+
+    @classmethod
+    def setUpClass(cls):
+        """
+        Class-scope test fixture. Run once for entire test class, before any test starts.
+        Set up the process group.
+        """
+        super().setUpClass()
+
+        # Use device count as world size
+        device_type = cls.device_type()
+        # If world_size is not set, use device count
+        if cls.world_size == -2:
+            cls.world_size = torch.get_device_module(device_type).device_count()
+            if cls.world_size == 0:
+                raise unittest.SkipTest(f"No {device_type} devices available")
+
+        logger.info(
+            f"Testing class {cls.__name__} on {cls.world_size} {device_type}"  # noqa: G004
+        )
+
+        cls._spawn_processes(cls.world_size)
+
+    @classmethod
+    def tearDownClass(cls):
+        """
+        Class-scope test fixture. Run once for entire test class, after all tests finish.
+        Tear down the process group.
+        """
+        logger.debug(f"Joining {cls.world_size} workers")  # noqa: G004
+        # Enqueue "None" to all workers to tell them to exit
+        for task_queue in cls.task_queues:
+            task_queue.put(None)
+
+        # Wait for all workers to exit
+        for process in cls.processes:
+            process.join()
+
+        # Clear up the rendezvous file
+        try:
+            os.remove(cls.rdvz_file)
+        except OSError:
+            pass
+
+        logger.info(f"Class {cls.__name__} finished")  # noqa: G004
+        super().tearDownClass()
+
+    def setUp(self) -> None:
+        """
+        Test fixture. Run before each test.
+        """
+        super().setUp()
+
+        # I am the dispatcher
+        self.rank = self.MAIN_PROCESS_RANK
+
+        # If this test class hits an exception in one test, skip the rest of tests
+        if self.__class__.poison_pill:
+            raise unittest.SkipTest(f"Previous test failed, skipping {self.id()}")
+
+        # Enqueue "current test" to all workers
+        for i, task_queue in enumerate(self.task_queues):
+            logger.debug(f"Sending Rank {i}: {self.id()}")  # noqa: G004
+            task_queue.put(self.id())
+
+    def _worker_run_main_wait(self, fn):
+        @wraps(fn)
+        def wrapper(self):
+            if self.rank == self.MAIN_PROCESS_RANK:
+                logger.debug(f"Waiting for workers to finish {self.id()}")  # noqa: G004
+                # Wait for the workers to finish the test
+                for i, completion_queue in enumerate(self.completion_queues):
+                    rv = completion_queue.get()
+                    if isinstance(rv, BaseException):
+                        # Hit an exception, re-raise it in the main process.
+                        logger.warning(
+                            f"Detected failure from Rank {i} in: {self.id()}, "  # noqa: G004
+                            f"skipping rest of tests in Test class: {self.__class__.__name__}"  # noqa: G004
+                        )
+                        # Poison rest of tests (because ProcessGroup may be not
+                        # reusable now)
+                        self.__class__.poison_pill = True
+                        raise rv
+
+                    # Success
+                    assert rv == self.id()
+                    logger.debug(
+                        f"Main proc detected rank {i} finished {self.id()}"  # noqa: G004
+                    )
+            else:
+                # Worker just runs the test
+                fn()
+
+        return types.MethodType(wrapper, self)
+
+    # The main process spawns N subprocesses that run the test.
+    # Constructor patches current instance test method to
+    # assume the role of the main process and join its subprocesses,
+    # or run the underlying test function.
+    def __init__(
+        self, method_name: str = "runTest", methodName: str = "runTest"
+    ) -> None:
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+        try:
+            fn = getattr(self, method_name)
+            setattr(self, method_name, self._worker_run_main_wait(fn))
+        except AttributeError as e:
+            if methodName != "runTest":
+                # we allow instantiation with no explicit method name
+                # but not an *incorrect* or missing method name
+                raise ValueError(
+                    f"no such test method in {self.__class__}: {methodName}"
+                ) from e
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..774ce179f33e0f386d66ee90bd22576fd55da994
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_dtype.py
@@ -0,0 +1,214 @@
+# mypy: ignore-errors
+
+
+import torch
+
+
+# Functions and classes for describing the dtypes a function supports
+# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
+
+
+# Verifies each given dtype is a torch.dtype
+def _validate_dtypes(*dtypes):
+    for dtype in dtypes:
+        assert isinstance(dtype, torch.dtype)
+    return dtypes
+
+
+# class for tuples corresponding to a PyTorch dispatch macro
+class _dispatch_dtypes(tuple):
+    __slots__ = ()
+
+    def __add__(self, other):
+        assert isinstance(other, tuple)
+        return _dispatch_dtypes(tuple.__add__(self, other))
+
+
+_empty_types = _dispatch_dtypes(())
+
+
+def empty_types():
+    return _empty_types
+
+
+_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
+
+
+def floating_types():
+    return _floating_types
+
+
+_floating_types_and_half = _floating_types + (torch.half,)
+
+
+def floating_types_and_half():
+    return _floating_types_and_half
+
+
+def floating_types_and(*dtypes):
+    return _floating_types + _validate_dtypes(*dtypes)
+
+
+_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
+
+
+def floating_and_complex_types():
+    return _floating_and_complex_types
+
+
+def floating_and_complex_types_and(*dtypes):
+    return _floating_and_complex_types + _validate_dtypes(*dtypes)
+
+
+_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
+
+
+def double_types():
+    return _double_types
+
+
+# NB: Does not contain uint16/uint32/uint64 for BC reasons
+_integral_types = _dispatch_dtypes(
+    (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
+)
+
+
+def integral_types():
+    return _integral_types
+
+
+def integral_types_and(*dtypes):
+    return _integral_types + _validate_dtypes(*dtypes)
+
+
+_all_types = _floating_types + _integral_types
+
+
+def all_types():
+    return _all_types
+
+
+def all_types_and(*dtypes):
+    return _all_types + _validate_dtypes(*dtypes)
+
+
+_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
+
+
+def complex_types():
+    return _complex_types
+
+
+def complex_types_and(*dtypes):
+    return _complex_types + _validate_dtypes(*dtypes)
+
+
+_all_types_and_complex = _all_types + _complex_types
+
+
+def all_types_and_complex():
+    return _all_types_and_complex
+
+
+def all_types_and_complex_and(*dtypes):
+    return _all_types_and_complex + _validate_dtypes(*dtypes)
+
+
+_all_types_and_half = _all_types + (torch.half,)
+
+
+def all_types_and_half():
+    return _all_types_and_half
+
+
+_float8_types = _dispatch_dtypes(
+    (
+        torch.float8_e4m3fn,
+        torch.float8_e4m3fnuz,
+        torch.float8_e5m2,
+        torch.float8_e5m2fnuz,
+    )
+)
+
+
+def float8_types():
+    return _float8_types
+
+
+def float8_types_and(*dtypes):
+    return _float8_types + _validate_dtypes(*dtypes)
+
+
+def all_types_complex_float8_and(*dtypes):
+    return _all_types + _complex_types + _float8_types + _validate_dtypes(*dtypes)
+
+
+def custom_types(*dtypes):
+    """Create a list of arbitrary dtypes"""
+    return _empty_types + _validate_dtypes(*dtypes)
+
+
+# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
+
+
+# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
+def get_all_dtypes(
+    include_half=True,
+    include_bfloat16=True,
+    include_bool=True,
+    include_complex=True,
+    include_complex32=False,
+    include_qint=False,
+) -> list[torch.dtype]:
+    dtypes = get_all_int_dtypes() + get_all_fp_dtypes(
+        include_half=include_half, include_bfloat16=include_bfloat16
+    )
+    if include_bool:
+        dtypes.append(torch.bool)
+    if include_complex:
+        dtypes += get_all_complex_dtypes(include_complex32)
+    if include_qint:
+        dtypes += get_all_qint_dtypes()
+    return dtypes
+
+
+def get_all_math_dtypes(device) -> list[torch.dtype]:
+    return (
+        get_all_int_dtypes()
+        + get_all_fp_dtypes(
+            include_half=device.startswith("cuda"), include_bfloat16=False
+        )
+        + get_all_complex_dtypes()
+    )
+
+
+def get_all_complex_dtypes(include_complex32=False) -> list[torch.dtype]:
+    return (
+        [torch.complex32, torch.complex64, torch.complex128]
+        if include_complex32
+        else [torch.complex64, torch.complex128]
+    )
+
+
+def get_all_int_dtypes() -> list[torch.dtype]:
+    return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
+
+
+def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> list[torch.dtype]:
+    dtypes = [torch.float32, torch.float64]
+    if include_half:
+        dtypes.append(torch.float16)
+    if include_bfloat16:
+        dtypes.append(torch.bfloat16)
+    return dtypes
+
+
+def get_all_qint_dtypes() -> list[torch.dtype]:
+    return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
+
+
+float_to_corresponding_complex_type_map = {
+    torch.float16: torch.complex32,
+    torch.float32: torch.complex64,
+    torch.float64: torch.complex128,
+}
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd8ccdcd2ff36480618685dc40f666828c3684a6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_fsdp.py
@@ -0,0 +1,1576 @@
+# mypy: allow-untyped-defs
+# Owner(s): ["oncall: distributed"]
+
+import contextlib
+import os
+import re
+import sys
+import time
+import warnings
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+from copy import deepcopy
+from enum import auto, Enum
+from functools import wraps
+from typing import Any, Callable, cast, no_type_check, Optional, Union
+from unittest import mock
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed._composable import checkpoint
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.fsdp import (
+    CPUOffload,
+    fully_shard,
+    FullyShardedDataParallel as FSDP,
+)
+from torch.distributed.fsdp._common_utils import TrainingState
+from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
+    FSDPParamGroup,
+    RegisterPostBackwardFunction,
+)
+from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
+from torch.distributed.fsdp.fully_sharded_data_parallel import (
+    BackwardPrefetch,
+    MixedPrecision,
+    ShardingStrategy,
+)
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
+from torch.distributed.tensor import distribute_tensor, DTensor, Shard
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    parallelize_module,
+    RowwiseParallel,
+    SequenceParallel,
+)
+from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    MultiThreadedTestCase,
+    run_subtests,
+    TEST_SKIPS,
+)
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    get_cycles_per_ms,
+    TEST_CUDA,
+    TEST_HPU,
+    TEST_XPU,
+)
+from torch.utils._triton import has_triton
+
+
+DEVICE_COUNT = 4  # default
+
+if TEST_CUDA:
+    DEVICE_TYPE = "cuda"
+    DISTRIBUTED_BACKEND = "nccl"
+    DEVICE_COUNT = torch.cuda.device_count()
+elif TEST_HPU:
+    DEVICE_TYPE = "hpu:0"
+    DISTRIBUTED_BACKEND = "hccl"
+elif TEST_XPU:
+    DEVICE_TYPE = "xpu"
+    DISTRIBUTED_BACKEND = "xccl"
+    DEVICE_COUNT = torch.xpu.device_count()
+else:
+    DEVICE_TYPE = "cpu"
+    DISTRIBUTED_BACKEND = "gloo"
+    DEVICE_COUNT = 1
+
+
+class FSDPInitMode(Enum):
+    # No FSDP wrapping
+    NO_FSDP = auto()
+    # FSDP recursive wrapping
+    RECURSIVE = auto()
+    # TODO: FSDP non-recursive wrapping
+    # NONRECURSIVE = auto()
+
+
+class DEVICEInitMode(Enum):
+    # Move model to DEVICE before passing to the FSDP constructor
+    DEVICE_BEFORE = auto()
+    # Move model to DEVICE after passing to the FSDP constructor
+    DEVICE_AFTER = auto()
+    # Keep on CPU
+    DEVICE_NEVER = auto()
+
+
+class FSDPTestModel(nn.Module, ABC):
+    """This defines the interface expected from all models used commonly for
+    FSDP unit tests."""
+
+    @abstractmethod
+    def get_input(self, device) -> tuple[torch.Tensor, ...]:
+        """Returns an input for the model as as tuple."""
+        ...
+
+    @abstractmethod
+    def get_loss(self, input, output) -> torch.Tensor:
+        """Returns the loss given the input and output."""
+        ...
+
+    @abstractmethod
+    def run_backward(self, loss) -> None:
+        """Runs the backward pass (e.g. including ``loss.backward()``)."""
+        ...
+
+    @staticmethod
+    @abstractmethod
+    def init(*args: Any, **kwargs: Any) -> nn.Module:
+        """Initializes an instance of this model."""
+        ...
+
+
+def _assert_module_states(
+    model: nn.Module,
+    process_group: dist.ProcessGroup,
+    assert_fn: Callable,
+):
+    """
+    All-gathers module states across ranks and calls ``assert_fn`` on each pair
+    of corresponding states from rank 0 and a nonzero rank. For example, if
+    ``assert_fn`` is ``self.assertEqual()``, then this checks that all module
+    states are equal across ranks.
+    """
+    # Include names for debugging convenience
+    named_module_states = [
+        (param_name, param.detach().cpu())
+        for param_name, param in model.named_parameters()
+    ]
+    named_module_states += [
+        (buffer_name, buffer.detach().cpu())
+        for buffer_name, buffer in model.named_buffers()
+    ]
+    world_size = dist.get_world_size(process_group)
+    olist = [None for _ in range(world_size)]
+    dist.all_gather_object(olist, named_module_states, group=process_group)
+    rank0_states = olist[0]
+    assert rank0_states is not None  # mypy
+    for state in olist[1:]:
+        assert state is not None  # mypy
+        for (_, p1), (_, p2) in zip(rank0_states, state):
+            assert_fn(p1, p2)
+
+
+def get_devtype():
+    return torch.device(DEVICE_TYPE)
+
+
+def _zero_model(
+    model: nn.Module,
+    zero_buffers: bool = False,
+    summon_full=True,
+):
+    """Zeros the parameters and optionally buffers of ``model`` in place."""
+    ctx = FSDP.summon_full_params(model) if summon_full else nullcontext()
+    with ctx:
+        for param in model.parameters():
+            with torch.no_grad():
+                param.zero_()
+        if zero_buffers:
+            for buffer in model.buffers():
+                with torch.no_grad():
+                    buffer.zero_()
+
+
+def _get_state_dict(model, cpu_offload=False, half=False):
+    if not cpu_offload:
+        model = model.to(DEVICE_TYPE)
+    if half:
+        model.half()
+
+    return model.state_dict()
+
+
+def subtest_name(test_name_mapping, *args):
+    return "_".join(
+        [test_name_mapping[str(s)] if s is not None else "none" for s in args]
+    )
+
+
+def _broadcast_state_dict(rank, state_dict):
+    # For non-FSDP roots, some parts of the model state on rank 0 may
+    # not be on CPU, so we move everything to CPU to avoid issues like:
+    # https://github.com/pytorch/pytorch/issues/77113.
+    for param_name, param in state_dict.items():
+        if param.device != torch.device("cpu"):
+            state_dict[param_name] = param.cpu()
+
+    olist = [state_dict if rank == 0 else None]
+    dist.broadcast_object_list(olist)
+    state_dict = cast(dict[str, torch.Tensor], olist[0])
+    # Ensure that the state is on DEVICE
+    for param_name in state_dict.keys():
+        state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE)
+    return state_dict
+
+
+def get_full_params(model: nn.Module, recurse: bool = True):
+    """
+    Returns the full unsharded parameters of ``model``. Any FSDP-managed
+    parameters offloaded to CPU are moved to GPU in the returned list.
+
+    Args:
+        recurse (bool): If ``False``, only unshards the parameters immediate to
+            ``model``; if ``True``, recurses through the module hierarchy
+            rooted at ``model``.
+    """
+    with FSDP.summon_full_params(model, recurse=recurse):
+        return deepcopy(list(model.parameters()))
+
+
+def _move_to_device(model: nn.Module, move_to_device: bool):
+    return model.to(DEVICE_TYPE) if move_to_device else model
+
+
+def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs):
+    return model if not wrap_fsdp else FSDP(model, *args, **kwargs)
+
+
+class DummyProcessGroup:
+    def __init__(self, rank: int, size: int):
+        self._rank = rank
+        self._size = size
+
+    def rank(self) -> int:
+        return self._rank
+
+    def size(self) -> int:
+        return self._size
+
+    def allreduce(self, *args, **kwargs):
+        dist_wait = mock.Mock()
+
+        def get_future():
+            future: torch.futures.Future = torch.futures.Future()
+            future.set_result(1)
+            return future
+
+        dist_wait.get_future = get_future
+        return dist_wait
+
+
+class TransformerWithSharedParams(FSDPTestModel):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        device_init_mode: DEVICEInitMode,
+        add_bn: bool,
+        deterministic: bool,
+    ):
+        super().__init__()
+        self.rank = group.rank()
+        self.world_size = group.size()
+        if deterministic:
+            torch.manual_seed(0)
+        d_vocab = 23
+        d_model = 16
+
+        self.embed_tokens = nn.Embedding(d_vocab, d_model)
+        self.transformer = nn.Transformer(
+            d_model=d_model,
+            num_encoder_layers=2,
+            num_decoder_layers=2,
+            dim_feedforward=8,
+            dropout=0.1,
+        )
+        self.output_proj = nn.Linear(d_model, d_vocab)
+
+        # share the embedding and output projection weights
+        self.output_proj.weight = self.embed_tokens.weight
+        self.register_buffer(
+            "vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
+        )
+        self.register_buffer(
+            "long_buffer",
+            torch.zeros_like(self.vocab_bias, dtype=torch.long),  # type: ignore[arg-type]
+        )  # type: ignore[arg-type]
+
+        self.bs = 2
+        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
+        if device_init_mode == DEVICEInitMode.DEVICE_BEFORE:
+            self = self.to(DEVICE_TYPE)
+        if deterministic:
+            self.eval()
+
+    def get_input(self, device):
+        torch.manual_seed(1 + self.rank)  # keep everything deterministic
+        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
+        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
+        return (src, tgt)
+
+    def forward(self, src_ids, tgt_ids):
+        src = self.embed_tokens(src_ids)
+        src = src + self.vocab_bias + self.long_buffer.type_as(src)  # type: ignore[operator]
+        tgt = self.embed_tokens(tgt_ids)
+        tgt = self.bn(tgt)
+        x = self.transformer(src, tgt)
+        return self.output_proj(x)
+
+    def get_loss(self, input, output):
+        _, tgt = input
+        return nn.functional.cross_entropy(
+            output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
+        )
+
+    def run_backward(self, loss):
+        loss.backward()
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        add_bn: bool = True,
+    ) -> Union[nn.Module, FSDP]:
+        """
+        Initializes a :class:`TransformerWithSharedParams` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps with
+                top-level FSDP. By default, the top-level FSDP uses the
+                ``ModuleWrapPolicy`` for encoder and decoder layers, but a
+                different auto wrap policy may be specified via
+                ``fsdp_kwargs``.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+            add_bn (bool): Whether to include batch norm in the model.
+        """
+
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            if isinstance(group, tuple):
+                pg = group[0]
+            else:
+                pg = group
+            return TransformerWithSharedParams(
+                pg, device_init_mode, add_bn, deterministic
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Default to the `ModuleWrapPolicy`
+            if "auto_wrap_policy" not in fsdp_kwargs:
+                auto_wrap_policy = ModuleWrapPolicy(
+                    {
+                        TransformerEncoderLayer,
+                        TransformerDecoderLayer,
+                    }
+                )
+            else:
+                auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy")
+
+            if (
+                "sharding_strategy" in fsdp_kwargs
+                and fsdp_kwargs["sharding_strategy"]
+                in {ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2}
+                and not isinstance(group, tuple)
+            ):
+                fsdp_pg = None
+            else:
+                fsdp_pg = group
+
+            if isinstance(group, tuple):
+                tformer_pg = group[0]
+            else:
+                tformer_pg = group
+
+            m = TransformerWithSharedParams(
+                tformer_pg, device_init_mode, add_bn, deterministic
+            )
+            fsdp_model = FSDP(
+                m,
+                fsdp_pg,
+                auto_wrap_policy=auto_wrap_policy,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+    def get_ignored_modules(self):
+        return [self.transformer]
+
+
+class NestedWrappedModule(FSDPTestModel):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super().__init__()
+        self.rank = group.rank()
+        self.world_size = group.size()
+        move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+
+        def _maybe_wrap(layer):
+            if wrap_fsdp:
+                return FSDP(layer, group, **fsdp_kwargs)
+            return layer
+
+        if deterministic:
+            torch.manual_seed(0)
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(8, 4), move_to_device),
+            _maybe_wrap(
+                nn.Sequential(
+                    _maybe_wrap(_move_to_device(nn.Linear(4, 16), move_to_device)),
+                    _move_to_device(nn.Linear(16, 16), move_to_device),
+                ),
+            ),
+            _maybe_wrap(_move_to_device(nn.Linear(16, 4), move_to_device)),
+            _move_to_device(nn.Linear(4, 8), move_to_device),
+        )
+
+    def get_input(self, device):
+        torch.manual_seed(1 + self.rank)  # keep everything deterministic
+        return (torch.rand(4, 8, device=device),)
+
+    def forward(self, x):
+        return self.module(x)
+
+    def get_loss(self, input, output):
+        loss = output.sum()
+        return loss
+
+    def run_backward(self, loss):
+        loss.backward()
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ) -> nn.Module:
+        """
+        Initializes a :class:`NestedWrappedModule` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
+                modules with FSDP but not the top-level module. The model may
+                later be wrapped with a top-level FSDP external to this method
+                if desired.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+        """
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return NestedWrappedModule(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Does not wrap with top-level FSDP
+            fsdp_model = NestedWrappedModule(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ):
+        """
+        Initializes a :class:`NestedWrappedModule` instance, but unlike
+        :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
+        wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
+        policy.
+        """
+        model = super(
+            AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule
+        ).init(
+            group=group,
+            fsdp_init_mode=FSDPInitMode.NO_FSDP,
+            device_init_mode=device_init_mode,
+            fsdp_kwargs=fsdp_kwargs,
+            deterministic=deterministic,
+        )
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return model
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            fsdp_kwargs = fsdp_kwargs or {}
+            fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs)
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+
+
+class NonUniformReqGradNWM(NestedWrappedModule):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super(NestedWrappedModule, self).__init__()
+        # This `__init__` only differs from `NestedWrappedModule.__init__` in that
+        # the last two `nn.Linear` layers are FSDP wrapped in a `nn.Sequential`
+        # container. This arrangement results in all elements of the last two parameters
+        # residing on a single rank. Freezing all parameters except those two allows us
+        # to verify that `ShardedGradScaler` accommodates situations where some ranks
+        # have no (non-zero sized) parameter shards.
+        self.rank = group.rank()
+        self.world_size = group.size()
+        move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+
+        def _maybe_wrap(layer):
+            if wrap_fsdp:
+                return FSDP(layer, group, **fsdp_kwargs)
+            return layer
+
+        if deterministic:
+            torch.manual_seed(0)
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(8, 4), move_to_device),
+            _maybe_wrap(
+                nn.Sequential(
+                    _maybe_wrap(_move_to_device(nn.Linear(4, 16), move_to_device)),
+                    _move_to_device(nn.Linear(16, 16), move_to_device),
+                ),
+            ),
+            _maybe_wrap(
+                nn.Sequential(
+                    _move_to_device(nn.Linear(16, 4), move_to_device),
+                    _move_to_device(nn.Linear(4, 8), move_to_device),
+                ),
+            ),
+        )
+
+    @staticmethod
+    def _set_nonuniform_req_grad(model, req_grad_mask) -> None:
+        for n, p in model.named_parameters():
+            if not re.match(req_grad_mask, n):
+                p.requires_grad_(False)
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+    ):
+        """
+        Initializes a :class:`NestedWrappedModule` instance, but unlike
+        :meth:`NestedWrappedModule.init`, it wraps a second :class:`torch.nn.Sequential`
+        container to enable the desired non-uniform ``requires_grad``
+        ``use_orig_params=True`` tests. For both ``RECURSIVE`` and ``NO_FSDP``
+        init modes, freezes all parameters except the last two to validate
+        ``ShardedGradScaler`` support for ranks with no (non-zero sized) local shards in
+        FSDP ``use_orig_params=True`` mode.
+        """
+        # The parameters that should remain unfrozen are in `module.2.1`. The regex
+        # pattern below matches the relevant parameter names both with and without
+        # an interstitial FSDP module indicator (`_fsdp_wrapped_module`) present.
+        req_grad_pattern = re.compile(r"module\.2.*\.1.*")
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            ddp_model = NonUniformReqGradNWM(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+            )
+            NonUniformReqGradNWM._set_nonuniform_req_grad(ddp_model, req_grad_pattern)
+            return ddp_model
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            if fsdp_kwargs is None:
+                fsdp_kwargs = {}
+            fsdp_model = NonUniformReqGradNWM(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            NonUniformReqGradNWM._set_nonuniform_req_grad(fsdp_model, req_grad_pattern)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class ModuleWithDelay(FSDPTestModel):
+    """This class wraps a :class:`FSDPTestModel` to optionally add a delay
+    after computing the loss and/or before the gradient reduction."""
+
+    def __init__(
+        self,
+        module: nn.Module,
+        delay_after_loss_ms: int,
+        delay_before_reduction_ms: int,
+    ):
+        super().__init__()
+        self.delay_after_loss_ms = delay_after_loss_ms
+        self.delay_before_reduction_ms = delay_before_reduction_ms
+        self.module = module
+
+    def get_input(self, device):
+        return self.module.get_input(device)  # type: ignore[operator]
+
+    def forward(self, x):
+        return self.module(x)
+
+    def get_loss(self, input, output):
+        loss = self.module.get_loss(input, output)  # type: ignore[operator]
+        if self.delay_after_loss_ms > 0:
+            if TEST_HPU or TEST_XPU:
+                time.sleep(self.delay_after_loss_ms / 1000)
+            elif TEST_CUDA:
+                torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
+
+        return loss
+
+    def run_backward(self, loss):
+        orig_reduce_scatter = torch.distributed.reduce_scatter_tensor
+
+        def _delayed_reduce_scatter(*args, **kwargs):
+            if self.delay_before_reduction_ms > 0:
+                if TEST_CUDA:
+                    torch.cuda._sleep(
+                        int(self.delay_before_reduction_ms * get_cycles_per_ms())
+                    )
+                elif TEST_HPU or TEST_XPU:
+                    time.sleep(self.delay_before_reduction_ms / 1000)
+            return orig_reduce_scatter(*args, **kwargs)
+
+        with mock.patch(
+            "torch.distributed.reduce_scatter_tensor", _delayed_reduce_scatter
+        ):
+            self.module.run_backward(loss)  # type: ignore[operator]
+
+    @staticmethod
+    def init(
+        module_class: type[FSDPTestModel],
+        *model_args: Any,
+        delay_after_loss_ms: int,
+        delay_before_reduction_ms: int,
+        **model_kwargs: Any,
+    ):
+        """
+        Args:
+            module_class (Type[FSDPTestModel]): Wrapped module class to which
+                to add delays.
+            model_args: Positional arguments forwarded to the ``module_class``
+                ``init()``.
+            delay_after_loss_ms (int): Delay after computing the loss/before
+                the optimizer step (in ms).
+            delay_before_reduction_ms (int): Delay before reduce-scattering
+                gradients (in ms).
+            model_kwargs: Keyword arguments forwarded to the ``module_class``
+                ``init()``.
+        """
+        return ModuleWithDelay(
+            module_class.init(*model_args, **model_kwargs),
+            delay_after_loss_ms,
+            delay_before_reduction_ms,
+        )
+
+
+class NestedWrappedModuleWithDelay(ModuleWithDelay):
+    @staticmethod
+    def init(  # type: ignore[override]
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode = DEVICEInitMode.DEVICE_AFTER,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        delay_after_loss_ms: int = 0,
+        delay_before_reduction_ms: int = 0,
+    ):
+        return ModuleWithDelay.init(
+            NestedWrappedModule,
+            group=group,
+            fsdp_init_mode=fsdp_init_mode,
+            device_init_mode=device_init_mode,
+            fsdp_kwargs=fsdp_kwargs,
+            deterministic=deterministic,
+            delay_after_loss_ms=delay_after_loss_ms,
+            delay_before_reduction_ms=delay_before_reduction_ms,
+        )
+
+
+class DummyDDP(nn.Module):
+    def __init__(self, module):
+        super().__init__()
+        self.module = module
+
+    def forward(self, *args, **kwargs):
+        return self.module(*args, **kwargs)
+
+
+class MixtureOfExperts(NestedWrappedModule):
+    def __init__(
+        self,
+        group: dist.ProcessGroup,
+        wrap_fsdp: bool,
+        device_init_mode: DEVICEInitMode,
+        delay_before_free_ms: int,
+        deterministic: bool,
+        **fsdp_kwargs,
+    ):
+        super().__init__(
+            group=group,
+            wrap_fsdp=wrap_fsdp,
+            device_init_mode=device_init_mode,
+            deterministic=deterministic,
+        )
+        self.group = group
+        self.delay_before_free_ms = delay_before_free_ms
+        self.wrap_fsdp = wrap_fsdp
+        self.move_to_device = device_init_mode == DEVICEInitMode.DEVICE_BEFORE
+        if deterministic:
+            # Give each rank different expert parameters
+            torch.manual_seed(42 + self.rank)
+        d_expert = 23
+        d_shared = 12
+        d_input = 8
+        expert = _move_to_device(nn.Linear(d_expert, d_shared), self.move_to_device)
+
+        self.num_expert_params = sum(p.numel() for p in expert.parameters())
+        for p in expert.parameters():
+            p.expert = True  # type: ignore[attr-defined]
+
+        if deterministic:
+            # Keep all other parameters the same across ranks
+            torch.manual_seed(0)
+
+        shared = _move_to_device(nn.Linear(d_shared, d_expert), self.move_to_device)
+
+        if wrap_fsdp:
+            # we create a process group of size 1 for the expert params
+            expert_group = torch.distributed.new_group(
+                [group.rank()]
+            )  # world size 1 means no shard
+            expert = FSDP(expert, expert_group, **fsdp_kwargs)  # type: ignore[assignment]
+            shared = FSDP(shared, group, **fsdp_kwargs)  # type: ignore[assignment]
+
+        self.module = nn.Sequential(
+            _move_to_device(nn.Linear(d_input, d_shared), self.move_to_device),
+            shared,
+            expert,
+            _move_to_device(nn.Linear(d_shared, d_input), self.move_to_device),
+        )
+
+    def forward(self, x):
+        if self.delay_before_free_ms > 0:
+            expert = self.module[2]
+            if isinstance(expert, FSDP):
+                orig_reshard = torch.distributed.fsdp._runtime_utils._reshard
+
+                def _delayed_reshard(*args, **kwargs):
+                    if TEST_CUDA:
+                        torch.cuda._sleep(
+                            int(self.delay_before_free_ms * get_cycles_per_ms())
+                        )
+                    elif TEST_HPU or TEST_XPU:
+                        time.sleep(self.delay_before_free_ms / 1000)
+
+                    return orig_reshard(*args, **kwargs)
+
+                # This patch covers any `import torch..._reshard` uses.
+                with mock.patch(
+                    "torch.distributed.fsdp._runtime_utils._reshard", _delayed_reshard
+                ):
+                    return self.module(x)
+
+        return self.module(x)
+
+    def run_backward(self, loss):
+        loss.backward()
+        # Manually reduce gradients if not wrapped in FullyShardedDataParallel
+        if not self.wrap_fsdp:
+            with torch.no_grad():
+                for p in self.parameters():
+                    if hasattr(p, "expert"):
+                        continue  # these params don't need grad reduction
+                    if p.grad is not None:
+                        p.grad.div_(self.world_size)
+                        torch.distributed.all_reduce(p.grad, group=self.group)
+
+    @staticmethod
+    def init(
+        group: dist.ProcessGroup,
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        fsdp_kwargs: Optional[dict[str, Any]] = None,
+        deterministic: bool = False,
+        delay_before_free_ms: int = 0,
+    ):
+        """
+        Initializes a :class:`MixtureOfExperts` instance.
+
+        Args:
+            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
+                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
+                modules with FSDP, including the expert and shared layers, but
+                not the top-level module. The model may later be wrapped with a
+                top-level FSDP external to this method if desired.
+            device_init_mode (DEVICEInitMode): Determines model movement to DEVICE.
+            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
+                forwarded to the FSDP constructor.
+            deterministic (bool): Whether to make the model deterministic
+                across constructions.
+            delay_before_free_ms (int): Delay before resharding expert
+                parameters in the forward pass (in ms).
+        """
+        if fsdp_kwargs is None:
+            fsdp_kwargs = {}
+        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
+            return MixtureOfExperts(
+                group,
+                wrap_fsdp=False,
+                device_init_mode=device_init_mode,
+                delay_before_free_ms=delay_before_free_ms,
+                deterministic=deterministic,
+            )
+        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
+            # Does not wrap with top-level FSDP
+            fsdp_model = MixtureOfExperts(
+                group,
+                wrap_fsdp=True,
+                device_init_mode=device_init_mode,
+                delay_before_free_ms=delay_before_free_ms,
+                deterministic=deterministic,
+                **fsdp_kwargs,
+            )
+            if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+                fsdp_model = fsdp_model.to(DEVICE_TYPE)
+            return fsdp_model
+        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
+
+
+class MLP(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        device: Optional[torch.device] = None,
+        *,
+        bias: bool = True,
+        with_buffer: bool = False,
+        dim_multiplier: int = 4,
+    ):
+        super().__init__()
+        self.in_proj = nn.Linear(dim, dim_multiplier * dim, device=device, bias=bias)
+        self.out_proj = nn.Linear(dim_multiplier * dim, dim, device=device, bias=bias)
+        if with_buffer:
+            self.register_buffer("buffer", torch.randn((dim,), device=device))
+        else:
+            self.buffer = None
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        z = self.in_proj(x)
+        z = F.relu(z)
+        z = self.out_proj(z)
+        z = F.relu(z)
+        if self.buffer is not None:
+            z = z + self.buffer
+        return z
+
+    def reset_parameters(self):
+        if self.buffer is not None:
+            torch.nn.init.normal_(self.buffer)
+
+
+class MLPStack(nn.Sequential):
+    def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
+        modules: list[nn.Module] = [
+            # Use multiplier of 3 to exercise uneven case
+            MLP(mlp_dim, dim_multiplier=3),
+            MLP(mlp_dim),
+            MLP(mlp_dim, dim_multiplier=3),
+        ]
+        if with_seq_parallel:
+            modules.append(nn.LayerNorm(mlp_dim, bias=False))
+        super().__init__(*modules)
+        self.with_seq_parallel = with_seq_parallel
+
+    def parallelize(
+        self,
+        tp_mesh: DeviceMesh,
+        dp_mesh: DeviceMesh,
+        use_activation_checkpointing: bool,
+        **fsdp_kwargs,
+    ) -> "MLPStack":
+        parallelize_plan = {
+            # Pass `use_local_output=False` to keep as DTensor to preserve
+            # uneven activation dims
+            "0.in_proj": ColwiseParallel(use_local_output=False),
+            "0.out_proj": RowwiseParallel(use_local_output=False),
+            "1.in_proj": ColwiseParallel(use_local_output=False),
+            "1.out_proj": RowwiseParallel(use_local_output=False),
+            "2.in_proj": ColwiseParallel(use_local_output=False),
+            "2.out_proj": RowwiseParallel(output_layouts=Shard(1))
+            if self.with_seq_parallel
+            else RowwiseParallel(),
+        }
+        if self.with_seq_parallel:
+            parallelize_plan["3"] = SequenceParallel(sequence_dim=1)
+        parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan)
+        for module in self:
+            if isinstance(module, nn.LayerNorm):
+                continue
+            if use_activation_checkpointing:
+                checkpoint(module)
+            fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
+        fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
+        return self
+
+
+class DoubleLinear(nn.Module):
+    """
+    This can be used for returning multiple outputs from a module
+    (``use_second_linear=True``) or for having an unused module (``False``).
+    """
+
+    def __init__(self, dim: int, use_second_linear: bool = True):
+        super().__init__()
+        self.lin1 = nn.Linear(dim, dim)
+        self.lin2 = nn.Linear(dim, dim)
+        self.relu = nn.ReLU()
+        self.use_second_linear = use_second_linear
+
+    def forward(
+        self, x: torch.Tensor
+    ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
+        if self.use_second_linear:
+            return self.relu(self.lin1(x)), self.relu(self.lin2(x))
+        return self.relu(self.lin1(x))
+
+
+# NOTE: For these patch methods, if we want safety under multi-threading (e.g.
+# when using multi-threaded process group), then we want:
+# (1) a barrier immediately after reading the original value to ensure that all
+# threads see the same original value
+# (2) a barrier immediately before restoring the original value to ensure that
+# all threads use the patched value inside the context
+@contextlib.contextmanager
+def patch_all_gather(new_all_gather_into_tensor: Callable):
+    orig_all_gather = dist.all_gather_into_tensor
+    dist.barrier()
+    dist.all_gather_into_tensor = new_all_gather_into_tensor
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.all_gather_into_tensor = orig_all_gather
+
+
+@contextlib.contextmanager
+def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
+    orig_reduce_scatter = dist.reduce_scatter_tensor
+    dist.barrier()
+    dist.reduce_scatter_tensor = new_reduce_scatter_tensor
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.reduce_scatter_tensor = orig_reduce_scatter
+
+
+@contextlib.contextmanager
+def patch_all_reduce(new_all_reduce: Callable):
+    orig_all_reduce = dist.all_reduce
+    dist.barrier()
+    dist.all_reduce = new_all_reduce
+    try:
+        yield
+    finally:
+        dist.barrier()
+        dist.all_reduce = orig_all_reduce
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_unshard(new_unshard: Callable):
+    orig_unshard = FSDPParamGroup.unshard
+    dist.barrier()
+    FSDPParamGroup.unshard = new_unshard
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.unshard = orig_unshard
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_reshard(new_reshard: Callable):
+    orig_reshard = FSDPParamGroup.reshard
+    dist.barrier()
+    FSDPParamGroup.reshard = new_reshard
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.reshard = orig_reshard
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_post_backward(new_post_backward: Callable):
+    orig_post_backward = FSDPParamGroup.post_backward
+    dist.barrier()
+    FSDPParamGroup.post_backward = new_post_backward
+    try:
+        yield
+    finally:
+        dist.barrier()
+        FSDPParamGroup.post_backward = orig_post_backward
+
+
+@no_type_check
+@contextlib.contextmanager
+def patch_register_post_backward_hook_backward(new_backward: Callable):
+    orig_backward = RegisterPostBackwardFunction.backward
+    dist.barrier()
+    RegisterPostBackwardFunction.backward = new_backward
+    try:
+        yield
+    finally:
+        dist.barrier()
+        RegisterPostBackwardFunction.backward = orig_backward
+
+
+def reduce_scatter_with_assert(
+    cls,
+    orig_reduce_scatter: Callable,
+    assert_fn: Callable,  # `assert_fn(output: Tensor)`
+    *args: Any,
+    **kwargs: Any,
+):
+    if len(args) > 0:
+        output = args[0]
+    elif "output" in kwargs:
+        output = kwargs["output"]
+    else:
+        raise AssertionError(
+            f"Cannot get reduce-scatter output from\nargs: {args}\nkwargs: {kwargs}"
+        )
+    assert_fn(output)
+    return orig_reduce_scatter(*args, **kwargs)
+
+
+def check_sharded_parity(
+    cls,  # unit test class
+    replicated_module: nn.Module,
+    sharded_module: nn.Module,
+    prefixes_to_ignore: tuple[str, ...] = (),
+):
+    for (replicated_name, replicated_param), (sharded_name, sharded_param) in zip(
+        replicated_module.named_parameters(), sharded_module.named_parameters()
+    ):
+        clean_sharded_name = sharded_name
+        for prefix in prefixes_to_ignore:
+            clean_sharded_name = clean_sharded_name.replace(prefix, "")
+        cls.assertEqual(replicated_name, clean_sharded_name)
+        cls.assertIsInstance(sharded_param, DTensor)
+        assert isinstance(sharded_param, DTensor)  # mypy
+        mesh, placements = sharded_param.device_mesh, sharded_param.placements
+        if tuple(placements) == (Shard(0), Shard(0)):
+            raise AssertionError(
+                "FSDP's (Shard(0), Shard(0)) layout differs from distribute_tensor(), "
+                "so we cannot check for equality using it"
+            )
+        sharded_ref_param = distribute_tensor(replicated_param, mesh, placements)
+        cls.assertEqual(sharded_param.to_local(), sharded_ref_param.to_local())
+        if replicated_param.grad is None:
+            cls.assertIsNone(sharded_param.grad)
+            continue
+        cls.assertIsNotNone(sharded_param.grad)
+        sharded_ref_grad = distribute_tensor(replicated_param.grad, mesh, placements)
+        cls.assertIsInstance(sharded_param.grad, DTensor)
+        assert isinstance(sharded_param.grad, DTensor)  # mypy
+        cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
+
+
+class FSDPTestMultiThread(MultiThreadedTestCase):
+    @property
+    def world_size(self):
+        return DEVICE_COUNT
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_threads()
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+    def perThreadSetUp(self):
+        torch._dynamo.reset()
+
+    def perThreadTearDown(self):
+        torch._dynamo.reset()
+
+
+class FSDPTest(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`,
+        # which can cause unit test flakiness:
+        # https://github.com/pytorch/pytorch/issues/90848
+        os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
+        self._spawn_processes()
+
+    @property
+    def world_size(self):
+        return DEVICE_COUNT
+
+    @property
+    def process_group(self):
+        return dist.distributed_c10d._get_default_group()
+
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        # Overriding base test class: do not auto destroy PG upon exit.
+        return False
+
+    @property
+    def init_method(self):
+        return f"{FILE_SCHEMA}{self.file_name}"
+
+    def _check_cpu_offload(self, fsdp_model, cpu_offload):
+        self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
+
+    def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
+        self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
+
+    def _check_forward_prefetch(self, fsdp_model, forward_prefetch):
+        self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch)
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+    @classmethod
+    def _run(cls, rank, test_name, file_name, pipe, **kwargs):  # type: ignore[override]
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+        fake_pg = kwargs.get("fake_pg", False)
+
+        print(f"dist init r={self.rank}, world={self.world_size}")
+
+        # Specify gloo backend to make 'init_process_group()' succeed,
+        # Actual tests will be skipped if there is no enough GPUs.
+        try:
+            if fake_pg:
+                store = torch.testing._internal.distributed.fake_pg.FakeStore()
+                dist.init_process_group(
+                    backend="fake",
+                    world_size=self.world_size,
+                    rank=rank,
+                    store=store,
+                )
+            else:
+                dist.init_process_group(
+                    init_method=self.init_method,
+                    backend=DISTRIBUTED_BACKEND,
+                    world_size=int(self.world_size),
+                    rank=self.rank,
+                )
+        except RuntimeError as e:
+            if "recompile" in e.args[0]:
+                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
+
+            raise
+
+        device_ids = None
+        device_id = self.rank % DEVICE_COUNT
+        if TEST_CUDA or TEST_XPU:
+            torch.accelerator.set_device_index(device_id)
+        device_ids = [device_id]
+
+        # Execute barrier prior to running test to ensure that every process
+        # has finished initialization and that the following test
+        # immediately exiting due to a skip doesn't cause flakiness.
+        dist.barrier(device_ids=device_ids)
+
+        torch._dynamo.reset()
+        self.run_test(test_name, pipe)
+        torch._dynamo.reset()
+
+        dist.barrier(device_ids=device_ids)
+
+        dist.destroy_process_group()
+
+    def _train_for_several_steps(
+        self,
+        model: nn.Module,
+        num_steps: int,
+        autocast: bool,
+        lr: float = 0.01,
+        fsdp_cpu_offload: Optional[CPUOffload] = None,
+        save_model: bool = False,
+        mixed_precision: Optional[MixedPrecision] = None,
+        enable_sharded_grad_scaler: bool = False,
+        use_pure_fp16: bool = False,
+        sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
+    ):
+        cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
+
+        model_device = next(model.parameters()).device
+        if sharded_grad_scaler_kwargs is None:
+            sharded_grad_scaler_kwargs = {}
+        sharded_grad_scaler = ShardedGradScaler(
+            enabled=enable_sharded_grad_scaler, **sharded_grad_scaler_kwargs
+        )
+        # use SGD with momentum instead of Adam, since Adam is scale invariant
+        # and this makes it bad for tests
+        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
+        for _ in range(num_steps):
+            optim.zero_grad()
+            with torch.amp.autocast(DEVICE_TYPE, enabled=autocast):
+                # Inputs always cuda regardless of cpu offloading, or model.device
+                input = model.module.get_input(torch.device(DEVICE_TYPE))  # type: ignore[operator, union-attr]
+                if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)):
+                    if isinstance(input, torch.Tensor):
+                        input = input.half()
+                    else:
+                        input = tuple(x.half() for x in input)
+                output = model(*input)
+                # Post-forward, if CPU offloading model param should be on CPU.
+                if (
+                    cpu_offload_params
+                    and isinstance(model, FSDP)
+                    # If not resharding after forward, the parameters are still
+                    # exposed as unsharded views into the GPU flat parameter
+                    and model.sharding_strategy
+                    not in NO_RESHARD_AFTER_FORWARD_STRATEGIES
+                ):
+                    for p in model.parameters():
+                        # Params should always be on CPU
+                        self.assertEqual(p.device, torch.device("cpu"))
+
+                loss = model.module.get_loss(input, output).to(model_device)  # type: ignore[operator, union-attr]
+            loss = sharded_grad_scaler.scale(loss)
+
+            if not mixed_precision and not use_pure_fp16:
+                assert (
+                    loss.dtype == torch.float32
+                ), "loss data type should be float32, as the original \
+                    parameter data type is float32."
+            else:
+                if use_pure_fp16:
+                    self.assertEqual(loss.dtype, torch.float16)
+                # FSDP loss is fp16, DDP AMP loss is fp32
+                elif isinstance(model, FSDP):
+                    assert mixed_precision is not None  # mypy
+                    self.assertEqual(loss.dtype, mixed_precision.param_dtype)
+                else:
+                    self.assertEqual(loss.dtype, torch.float32)
+            model.module.run_backward(loss)  # type: ignore[operator, union-attr]
+            # Post-backward, if CPU offloading model params should be on CPU.
+            if cpu_offload_params and isinstance(model, FSDP):
+                for p in model.parameters():
+                    # Params should always be on CPU
+                    self.assertEqual(p.device, torch.device("cpu"))
+            # Unscale the gradients and step
+            sharded_grad_scaler.step(optim)
+            # Update the scale factor
+            sharded_grad_scaler.update()
+            # if save_model, simulate save + load.
+            if save_model:
+                state_dict = {k: v.clone() for k, v in model.state_dict().items()}
+                # Zero params, if save/load state_dict did not work properly, this
+                # would break the parity test with DDP.
+                _zero_model(model)
+                model.load_state_dict(state_dict)
+
+        if isinstance(model, FSDP):
+            model._assert_state(TrainingState.IDLE)
+        return loss.detach()  # type: ignore[possibly-undefined]
+
+    def _test_fsdp_parity(
+        self,
+        model_class: type[FSDPTestModel],
+        fsdp_init_mode: FSDPInitMode,
+        device_init_mode: DEVICEInitMode,
+        ref_init_fn: Optional[Callable] = None,
+        num_iters: int = 2,
+        save_model: bool = True,
+        cpu_offload: CPUOffload = CPUOffload(),
+        backward_prefetch: Optional[BackwardPrefetch] = None,
+        sharding_strategy: Optional[ShardingStrategy] = None,
+        mixed_precision: Optional[MixedPrecision] = None,
+        forward_prefetch: bool = False,
+        use_orig_params: bool = False,
+        enable_sharded_grad_scaler: bool = False,
+        use_pure_fp16: bool = False,
+        init_kwargs: Optional[dict[str, Any]] = None,
+        sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
+        **fsdp_kwargs,
+    ):
+        """
+        Tests FSDP training against a reference, which defaults to DDP but
+        may be customized with ``ref_init_fn``.
+
+        Args:
+            model_class (Type[FSDPTestModel]): A model class that inherits from
+                ``FSDPTestModel``, which defines the expected interface.
+            fsdp_init_mode (FSDPInitMode): The mode to initialize the
+                FSDP-wrapped model. This should not be ``NO_FSDP``.
+            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
+                non-wrapped model to construct the reference model, where this
+                wrapper should provide data parallel semantics. If ``None``,
+                then the callable defaults to the DDP constructor.
+        """
+        assert (
+            fsdp_init_mode != FSDPInitMode.NO_FSDP
+        ), "Expects an FSDP init mode that wraps with FSDP"
+        if init_kwargs is None:
+            init_kwargs = {}
+        lr = 1e-2
+        rank = self.process_group.rank()
+        # Establish reference behavior with DDP
+        model = model_class.init(
+            self.process_group,
+            FSDPInitMode.NO_FSDP,
+            DEVICEInitMode.DEVICE_BEFORE,
+            deterministic=True,
+            **init_kwargs,
+        )
+        if ref_init_fn is None:
+            if TEST_HPU:
+                ref_model = DDP(
+                    model, device_ids=[DEVICE_TYPE], output_device=DEVICE_TYPE
+                )
+            else:
+                ref_model = DDP(model, device_ids=[rank], output_device=rank)
+        else:
+            ref_model = ref_init_fn(model)
+        if use_pure_fp16:
+            ref_model = ref_model.half()
+        ref_loss = self._train_for_several_steps(
+            ref_model,
+            num_iters,
+            autocast=mixed_precision is not None,
+            lr=lr,
+            fsdp_cpu_offload=cpu_offload,
+            mixed_precision=mixed_precision,
+            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
+            use_pure_fp16=use_pure_fp16,
+            sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
+        )
+        ddp_params = list(ref_model.parameters())
+        # Check against FSDP behavior
+        fsdp_kwargs.update(
+            {
+                "cpu_offload": cpu_offload,
+                "backward_prefetch": backward_prefetch,
+                "sharding_strategy": sharding_strategy,
+                "mixed_precision": mixed_precision,
+                "forward_prefetch": forward_prefetch,
+                "use_orig_params": use_orig_params,
+            }
+        )
+        try:
+            fsdp_model = model_class.init(
+                self.process_group,
+                fsdp_init_mode,
+                device_init_mode,
+                fsdp_kwargs,
+                deterministic=True,
+                **init_kwargs,
+            )
+        except Exception as e:
+            raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
+        if not isinstance(fsdp_model, FSDP):
+            # Enforce that we wrap with top-level FSDP since we are comparing
+            # assuming a data parallel reference and some test models may not
+            # do so in their `init()` method
+            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
+        if use_pure_fp16:
+            # Change the model parameter dtype after FSDP initialization
+            fsdp_model = fsdp_model.half()
+        if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
+            fsdp_model = fsdp_model.to(DEVICE_TYPE)
+        offload_params = cpu_offload is not None and cpu_offload.offload_params
+        # Offloading parameters with `DEVICE_AFTER` should raise an error during
+        # lazy initialization due to the parameter devices not being CPU;
+        # otherwise, all parameter devices should be CPU
+        expects_device_error = (
+            offload_params and device_init_mode == DEVICEInitMode.DEVICE_AFTER
+        )
+        expects_cpu_device = (
+            offload_params and device_init_mode != DEVICEInitMode.DEVICE_AFTER
+        )
+        if expects_cpu_device:
+            cpu_device = torch.device("cpu")
+            for param in fsdp_model.parameters():
+                self.assertEqual(param.device, cpu_device)
+        context = (
+            self.assertRaisesRegex(
+                RuntimeError,
+                "An FSDP-managed module with parameter CPU offloading enabled "
+                f"has parameters on {DEVICE_TYPE}",
+            )
+            if expects_device_error
+            else nullcontext()
+        )
+        with context:
+            fsdp_loss = self._train_for_several_steps(
+                fsdp_model,
+                num_iters,
+                autocast=False,
+                lr=lr,
+                fsdp_cpu_offload=cpu_offload,
+                save_model=save_model,
+                mixed_precision=mixed_precision,
+                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
+                use_pure_fp16=use_pure_fp16,
+                sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
+            )
+        # No need to check for parameter and loss parity if expecting an error
+        if expects_device_error:
+            return
+        # Check parameter devices are CPU if offloading to CPU before calling
+        # `get_full_params()`, which will cast the parameters to FP32
+        if offload_params:
+            cpu_device = torch.device("cpu")
+            for param in fsdp_model.parameters():
+                self.assertEqual(param.device, cpu_device)
+            fsdp_loss = fsdp_loss.to(DEVICE_TYPE)
+        fsdp_unsharded_params = get_full_params(fsdp_model)
+        # Do not check dtype since the reference DDP loss may not be the same
+        # dtype as the FSDP loss in the case of mixed precision
+        torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
+        # Do not check for parameter parity if using mixed precision since (1)
+        # the DDP parameters are in FP16 (from `half()`) while the FSDP
+        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
+        # the optimizer in FP16 while FSDP runs it in FP32
+        # TODO: Disable checking the parameters for pure FP16 due to floating
+        # point inaccuracy. Note that this means that the backward pass is not
+        # checked: https://github.com/pytorch/pytorch/issues/90784
+        if mixed_precision is None and not use_pure_fp16:
+            self.assertEqual(
+                ddp_params,
+                fsdp_unsharded_params,
+                exact_device=True,
+                msg="FSDP did not match DDP",
+            )
+
+
+def compiled_fsdp_test(compile_compute_on_module: Optional[type] = None):
+    def fully_shard_with_compiled_compute(*args, **kwargs):
+        torch.distributed.fsdp.fully_shard(*args, **kwargs)  # type: ignore[operator]
+        if compile_compute_on_module is None or isinstance(
+            args[0], compile_compute_on_module
+        ):
+            args[0].compile()
+
+    class FullyShardMode(Enum):
+        EAGER = auto()
+        COMPILED_COMPUTE = auto()
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            original_fully_shard: Any = torch.distributed.fsdp.fully_shard
+            for mode in FullyShardMode:
+                if mode != FullyShardMode.EAGER and not has_triton():
+                    warnings.warn("Inductor on GPU needs Triton and recent GPU arch")
+                    continue
+                # barrier to ensure thread reading the same value
+                original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks
+                original_compile_threads = torch._inductor.config.compile_threads
+                torch.distributed.barrier()
+
+                if mode == FullyShardMode.EAGER:
+                    fully_shard_patch = original_fully_shard
+                elif mode == FullyShardMode.COMPILED_COMPUTE:
+                    torch._dynamo.config.skip_fsdp_hooks = True
+                    torch._inductor.config.compile_threads = 1
+                    fully_shard_patch = fully_shard_with_compiled_compute  # type: ignore[assignment]
+                else:
+                    raise NotImplementedError(
+                        f"Need to implement FullyShardMode={mode}"
+                    )
+
+                # fully_shard is imported as a global
+                # through `from ... import fully_shard`
+                func.__globals__[original_fully_shard.__name__] = fully_shard_patch
+                func(*args, **kwargs)
+                # other threads use patched func before this thread restores
+                torch.distributed.barrier()
+                func.__globals__[original_fully_shard.__name__] = original_fully_shard
+                torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
+                torch._inductor.config.compile_threads = original_compile_threads
+
+        return wrapper
+
+    return decorator
+
+
+class SkipModule(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.lin = nn.Linear(10, 10, bias=False)
+
+    def forward(self, x):
+        return self.lin(x)
+
+
+class NestedLinear(nn.Module):
+    def __init__(self, fsdp_wrap):
+        super().__init__()
+        if fsdp_wrap:
+            self.nested_linear = wrap(nn.Linear(10, 10, bias=False).to(DEVICE_TYPE))
+        else:
+            self.nested_linear = nn.Linear(10, 10, bias=False).to(DEVICE_TYPE)
+
+    def forward(self, x):
+        return self.nested_linear(x)
+
+
+class SkipModel(nn.Module):
+    def __init__(self, double_nest):
+        super().__init__()
+        self.linear = nn.Linear(10, 10, bias=False).to(DEVICE_TYPE)
+        self.linear_skip = SkipModule().to(DEVICE_TYPE)
+        self.nested_linear = wrap(
+            NestedLinear(fsdp_wrap=double_nest), device_id=DEVICE_TYPE
+        )
+
+    def forward(self, x):
+        x = self.linear(x)
+        x = self.linear_skip(x)
+        x = self.nested_linear(x)
+        return x
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ca05c51189b558f101031af89863197a599ff9d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_jit.py
@@ -0,0 +1,323 @@
+# mypy: ignore-errors
+
+# Torch
+import torch
+import torch.cuda
+import torch.jit
+import torch.jit._logging
+import torch.jit.frontend
+import torch.jit.quantized
+
+# Testing utils
+from torch.testing._internal.common_dtype import floating_and_complex_types_and
+from torch.testing._internal.common_utils import TestCase, \
+    freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors
+from torch.testing._internal.common_utils import enable_profiling_mode  # noqa: F401
+
+# Standard library
+from itertools import chain
+from typing import Union
+from torch._C import TensorType
+
+import io
+
+def check_output_types(self, func, ref_outputs, args, kwargs):
+    graph = getattr(func, 'last_graph', None)
+    types = [o.type() for o in graph.outputs()]
+    self.assertTrue(len(types) == 1)
+    t = types[0]
+    torch._C._jit_assert_is_instance(ref_outputs, t)
+
+# Test names in this set are only checked for a single derivative
+nn_functional_single_grad = frozenset('test_nn_' + name for name in [
+    'pdist',
+    'multilabel_margin_loss',
+    'max_unpool3d',
+    'multi_margin_loss',
+    'binary_cross_entropy',
+    'binary_cross_entropy_size_average',
+    'ctc_loss',
+    'grid_sample',
+])
+
+def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
+                            allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
+    """Verifies a function performs identically to some reference implementation.
+
+    Commonly, this is used to verify that a JIT implementation
+    (output_func) matches the behavior of the eager implementation
+    (reference_func).
+    """
+    kwargs = kwargs if kwargs else {}
+
+    def allSum(vs):
+        if isinstance(vs, torch.Tensor):
+            vs = (vs,)
+        return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
+                   for i, v in enumerate(vs)
+                   if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
+
+    def clone_tensor(t, preserve_requires_grad):
+        require_grad = preserve_requires_grad and t.requires_grad
+        return t.detach().clone().requires_grad_(require_grad)
+
+    def clone_inputs(preserve_requires_grad: bool):
+        inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
+
+        for arg in args:
+            if isinstance(arg, torch.Tensor):
+                inputs.append(clone_tensor(arg, preserve_requires_grad))
+            elif is_iterable_of_tensors(arg):
+                inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg])
+            else:
+                inputs.append(arg)
+
+        return inputs
+
+    # Returns tensors in args that requires_grad, including tensors in TensorList args
+    def get_recording_tensors(args):
+        recording_tensors: list[torch.Tensor] = []
+
+        for arg in args:
+            if isinstance(arg, torch.Tensor) and arg.requires_grad:
+                recording_tensors.append(arg)
+            elif is_iterable_of_tensors(arg):
+                recording_tensors.extend(filter(lambda t: t.requires_grad, arg))
+
+        return recording_tensors
+
+    # test no gradients case
+    nograd_inputs = clone_inputs(preserve_requires_grad=False)
+    outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
+    with enable_profiling_mode_for_profiling_tests():
+        outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
+    self.assertEqual(outputs, outputs_test)
+
+    if check_types:
+        check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
+
+    if no_grad:
+        # skip grad tests
+        return
+
+    with enable_profiling_mode_for_profiling_tests():
+        # test single grad case
+        recording_inputs = clone_inputs(preserve_requires_grad=True)
+        recording_tensors = get_recording_tensors(recording_inputs)
+        outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
+        grads = torch.autograd.grad(allSum(outputs), recording_tensors,
+                                    allow_unused=allow_unused)
+        outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
+        grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
+                                         allow_unused=allow_unused)
+        self.assertEqual(outputs, outputs_test)
+        self.assertEqual(grads, grads_test)
+        # test the grad grad case
+        if self._testMethodName in nn_functional_single_grad or no_gradgrad:
+            return
+
+        outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
+        l1 = allSum(outputs)
+        grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
+                                    allow_unused=allow_unused)
+
+        l2 = (allSum(grads) * l1)
+        grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
+        recording_inputs = clone_inputs(preserve_requires_grad=True)
+        recording_tensors = get_recording_tensors(recording_inputs)
+        outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
+        l1_test = allSum(outputs_test)
+        grads_test = torch.autograd.grad(
+            l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
+
+        l2_test = (allSum(grads_test) * l1_test)
+        grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
+
+        self.assertEqual(outputs, outputs_test)
+        self.assertEqual(grads, grads_test)
+        for g2, g2_test in zip(grads2, grads2_test):
+            if g2 is None and g2_test is None:
+                continue
+            self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
+
+class JitCommonTestCase(TestCase):
+    def createFunctionFromGraph(self, trace):
+        graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
+        return torch._C._create_function_from_graph("forward", graph)
+
+    def assertExportImport(self, trace, inputs):
+        m = self.createFunctionFromGraph(trace)
+        self.assertExportImportModule(m, inputs)
+
+    def assertExportImportModule(self, m, inputs):
+        m_import = self.getExportImportCopy(m)
+        a = self.runAndSaveRNG(m, inputs)
+        b = self.runAndSaveRNG(m_import, inputs)
+        self.assertEqual(a, b, "Results of original model and "
+                               "exported/imported version of model differed")
+
+    def runAndSaveRNG(self, func, inputs, kwargs=None):
+        kwargs = kwargs if kwargs else {}
+        with freeze_rng_state():
+            results = func(*inputs, **kwargs)
+        return results
+
+    def getExportImportCopy(self, m, also_test_file=True, map_location=None):
+        buffer = io.BytesIO()
+        torch.jit.save(m, buffer)
+        buffer.seek(0)
+        imported = torch.jit.load(buffer, map_location=map_location)
+
+        if not also_test_file:
+            return imported
+
+        with TemporaryFileName() as fname:
+            torch.jit.save(imported, fname)
+            return torch.jit.load(fname, map_location=map_location)
+
+    def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph,
+                             fusion_nodes_not_found, non_fusible_nodes_being_fused,
+                             fusion_nodes_found, nodes_in_diff_graph):
+        err_msg = "\nFailure in testing nodes' autodifferentiation. "
+        if should_autodiff_node:
+            err_msg += "One or more nodes were expected to be autodiffed, " \
+                "but were not found in specified fusible/nonfusible " \
+                "DifferentiableGraph groups. \nSpecifically:"
+            # The node is intended to appear in a differentiable graph but doesn't
+            diff_nodes_missing = []
+            # The node is intended to appear in a differentiable graph
+            # outside of a fusion group but instead is in a fusion group
+            diff_nodes_in_fusion = []
+            # The node is intended to appear in a fusion group but doesn't
+            fusion_nodes_missing = []
+            # The node is intended to appear in a fusion group but instead
+            # is just in an outer differentiable graph
+            fusion_nodes_in_diff = []
+            for node in nodes_not_in_diff_graph:
+                if node in non_fusible_nodes_being_fused:
+                    diff_nodes_in_fusion.append(node)
+                else:
+                    diff_nodes_missing.append(node)
+            for node in fusion_nodes_not_found:
+                if node in nodes_in_diff_graph:
+                    fusion_nodes_in_diff.append(node)
+                else:
+                    fusion_nodes_missing.append(node)
+            if len(diff_nodes_missing) > 0:
+                err_msg += f"\n  {diff_nodes_missing} were not in one of the " \
+                    "DifferentiableGraphs when they were expected to be. " \
+                    "Did you intend for these nodes to be autodiffed? " \
+                    "If not, remove them from the list of nonfusible nodes."
+            if len(diff_nodes_in_fusion) > 0:
+                err_msg += f"\n  {diff_nodes_in_fusion} were found in one of the FusionGroups " \
+                    "when they were expected to be just in a DifferentiableGraph. If it was " \
+                    "intended for these nodes to be in FusionGroups, reclassify these nodes as " \
+                    "fusible nodes. If these nodes were not intended to be fused, your " \
+                    "autodifferentiation logic might be wrong."
+            if len(fusion_nodes_missing) > 0:
+                err_msg += f"\n  {fusion_nodes_missing} were not in one of the FusionGroups " \
+                    "of the DifferentiableGraphs when they were expected to be. " \
+                    "They were also not found in an outer DifferentiableGraph. Did you " \
+                    "intend for these nodes to be autodifferentiated? If not, you should " \
+                    "remove these nodes from the test's fusible nodes. Otherwise your " \
+                    "autodifferentiation logic might be wrong."
+            if len(fusion_nodes_in_diff) > 0:
+                err_msg += f"\n  {fusion_nodes_in_diff} were not in one of the FusionGroups " \
+                    "of the DifferentiableGraphs when they were expected to be, " \
+                    "instead they were found just in an outer DifferentiableGraph. " \
+                    "Did you intend for these nodes to be fused? If not, you should " \
+                    "move these nodes into the test's nonfusible nodes. Otherwise your " \
+                    "autodifferentiation logic might be wrong."
+        else:
+            err_msg += "One or more nodes were not expected to be autodiffed " \
+                "but were found in a DifferentiableGraph or in a FusionGroup " \
+                "of a DifferentiableGraph. Did you intend for these nodes to be " \
+                "autodiffed? If so, change this test to expect autodifferentiation. " \
+                "\nSpecifically:"
+            if len(fusion_nodes_found) > 0:
+                err_msg += f"\n  {fusion_nodes_found} were not expected to be in " \
+                    "one of the DifferentiableGraphs, but appeared in a FusionGroup " \
+                    "of a DifferentiableGraph. "
+            if len(nodes_in_diff_graph) > 0:
+                err_msg += f"\n  {nodes_in_diff_graph} were not expected to " \
+                    "be in one of the DifferentiableGraphs but were."
+        return err_msg
+
+    def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
+        diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
+        diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
+
+        # Note: currently no tests have fusible_nodes
+        fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
+        fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
+
+        # For any non-fusible node, it must show up in one of the DifferentiableGraphs.
+        nodes_in_diff_graph = []
+        nodes_not_in_diff_graph = []
+        non_fusible_nodes_being_fused = []
+        for node in nonfusible_nodes:
+            if any(g.findNode(node) is not None for g in diff_subgraphs):
+                nodes_in_diff_graph.append(node)
+            else:
+                nodes_not_in_diff_graph.append(node)
+            if any(g.findNode(node) is not None for g in fusion_subgraphs):
+                non_fusible_nodes_being_fused.append(node)
+        found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes)
+
+        # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs.
+        fusion_nodes_found = []
+        fusion_nodes_not_found = []
+        for node in fusible_nodes:
+            if any(g.findNode(node) is not None for g in fusion_subgraphs):
+                fusion_nodes_found.append(node)
+            else:
+                fusion_nodes_not_found.append(node)
+        found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes)
+
+        if should_autodiff_node is not None:
+            err_msg = self.autoDiffErrorMessage(should_autodiff_node,
+                                                nodes_not_in_diff_graph,
+                                                fusion_nodes_not_found,
+                                                non_fusible_nodes_being_fused,
+                                                fusion_nodes_found,
+                                                nodes_in_diff_graph)
+            self.assertEqual(should_autodiff_node,
+                             found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
+
+    def checkShapeAnalysis(self, out_sizes: Union[list[int], list[list[int]]],
+                           traced_graph, assert_propagation, constant_prop=True):
+        # repropagte input shapes provided by tracing,
+        prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
+        for enable_test_mode in [True, False]:
+            # here we are testing allowing/disallowing substituting in complete shapes as constants,
+            # disallowing constants helps stress test partial eval and substitution pipeline
+            torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
+            torch._C._jit_erase_non_input_shape_information(traced_graph)
+            if constant_prop:
+                torch._C._jit_pass_constant_propagation(traced_graph)
+            torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
+            # Add sizes to default tensor type to avoid checking something out of scope
+            # and difficulties with tracer leaving in other parts of tensor type
+            output = next(traced_graph.outputs()).type()
+
+            def test_type(type, actual_size):
+                sizes = type.symbolic_sizes()
+                out_type = TensorType.get().with_sizes(sizes)
+                actual_type = TensorType.get().with_sizes(actual_size)
+
+                # always check actual shape is a subtype of the output
+                self.assertTrue(actual_type.isSubtypeOf(out_type))
+
+                # and then if assertion flag is provided, check shape analysis
+                # is successful
+                if assert_propagation:
+                    self.assertEqual(out_type.sizes(), actual_size)
+
+            if output.isSubtypeOf(torch._C.TensorType.get()):
+                test_type(output, out_sizes)
+            else:
+                tuple_elements = output.elements()
+                for i in range(len(tuple_elements)):
+                    test_type(tuple_elements[i], out_sizes[i])
+
+        torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py
new file mode 100644
index 0000000000000000000000000000000000000000..351dcf159fae7c10830febab43ed77e41149215d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_methods_invocations.py
@@ -0,0 +1,24912 @@
+# mypy: ignore-errors
+
+from functools import wraps, partial
+from itertools import product, chain, islice
+import itertools
+import functools
+import copy
+import operator
+import random
+import unittest
+import math
+import enum
+
+import torch
+import numpy as np
+import numpy.typing as npt
+from torch import inf, nan
+
+from typing import Any, Union
+from collections.abc import Sequence
+from torch.testing import make_tensor
+from torch.testing._internal.common_dtype import (
+    _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
+    floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
+    empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
+)
+from torch.testing._internal.common_device_type import \
+    (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
+     skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
+     skipCPUIfNoMklSparse,
+     toleranceOverride, tol)
+from torch.testing._internal.common_cuda import (
+    PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
+    SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
+    _get_torch_rocm_version,
+)
+from torch.testing._internal.common_utils import (
+    make_fullrank_matrices_with_distinct_singular_values,
+    TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, IS_S390X, TEST_SCIPY,
+    torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
+    GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
+    TEST_WITH_TORCHINDUCTOR, MACOS_VERSION
+)
+from torch.testing._utils import wrapper_set_seed
+
+import torch._refs as refs  # noqa: F401
+import torch._refs.nn.functional
+import torch._refs.special
+import torch._refs.linalg
+import torch._prims as prims  # noqa: F401
+from torch.utils import _pytree as pytree
+
+
+from torch._vendor.packaging import version
+
+from torch.testing._internal.opinfo.core import (  # noqa: F401
+    L,
+    M,
+    S,
+    XS,
+    _NOTHING,
+    _getattr_qual,
+    DecorateInfo,
+    SampleInput,
+    ErrorInput,
+    AliasInfo,
+    NumericsFilter,
+    OpInfo,
+    _generate_reduction_inputs,
+    _generate_reduction_kwargs,
+    sample_inputs_reduction,
+    ReductionOpInfo,
+    reference_inputs_elementwise_binary,
+    make_error_inputs_elementwise_binary,
+    generate_elementwise_binary_tensors,
+    generate_elementwise_binary_arbitrarily_strided_tensors,
+    generate_elementwise_binary_small_value_tensors,
+    generate_elementwise_binary_large_value_tensors,
+    generate_elementwise_binary_extremal_value_tensors,
+    generate_elementwise_binary_broadcasting_tensors,
+    generate_elementwise_binary_with_scalar_samples,
+    generate_elementwise_binary_with_scalar_and_type_promotion_samples,
+    generate_elementwise_binary_noncontiguous_tensors,
+    sample_inputs_elementwise_binary,
+    BinaryUfuncInfo,
+    sample_inputs_elementwise_unary,
+    generate_elementwise_unary_tensors,
+    generate_elementwise_unary_small_value_tensors,
+    generate_elementwise_unary_large_value_tensors,
+    generate_elementwise_unary_extremal_value_tensors,
+    reference_inputs_elementwise_unary,
+    UnaryUfuncInfo,
+    sample_inputs_spectral_ops,
+    SpectralFuncType,
+    SpectralFuncInfo,
+    ShapeFuncInfo,
+    sample_inputs_foreach,
+    ForeachFuncInfo,
+    gradcheck_wrapper_hermitian_input,
+    gradcheck_wrapper_ctc_loss,
+    gradcheck_wrapper_triangular_input,
+    gradcheck_wrapper_triangular_input_real_positive_diagonal,
+    gradcheck_wrapper_masked_operation,
+    gradcheck_wrapper_masked_pointwise_operation,
+    clone_sample,
+)
+from torch.testing._internal.opinfo.refs import (  # NOQA: F401
+    _find_referenced_opinfo,
+    _inherit_constructor_args,
+    PythonRefInfo,
+    ReductionPythonRefInfo,
+    ElementwiseUnaryPythonRefInfo,
+    ElementwiseBinaryPythonRefInfo,
+)
+from torch.testing._internal.opinfo.utils import (
+    np_unary_ufunc_integer_promotion_wrapper,
+    reference_reduction_numpy,
+    prod_numpy
+)
+from torch.testing._internal import opinfo
+from torch.testing._internal.opinfo.definitions.linalg import (
+    sample_inputs_linalg_cholesky,
+    sample_inputs_linalg_cholesky_inverse,
+    sample_inputs_cross,
+    sample_inputs_linalg_qr_geqrf,
+    sample_inputs_linalg_invertible,
+    sample_inputs_lu_solve,
+    sample_inputs_legacy_solve,
+    sample_inputs_svd,
+    sample_inputs_linalg_det_logdet_slogdet,
+    sample_inputs_linalg_lu,
+    sample_inputs_diagonal_diag_embed,
+    error_inputs_diagonal_diag_embed,
+)
+from torch.testing._internal.opinfo.definitions.special import (
+    sample_inputs_i0_i1,
+    sample_inputs_polygamma,
+    reference_polygamma,
+)
+from torch.testing._internal.opinfo.definitions._masked import (
+    sample_inputs_softmax_variant,
+)
+from torch.testing._internal.opinfo.definitions.sparse import (
+    error_inputs_sparse_like_fns,
+    sample_inputs_sparse_like_fns,
+    error_inputs_sparse_mul,
+    sample_inputs_sparse_mul,
+    error_inputs_sparse_reduction_sum,
+    sample_inputs_sparse_reduction_sum
+)
+
+if TEST_SCIPY:
+    from scipy import stats
+    import scipy.spatial
+    import scipy.special
+
+
+# test if a tensor is close to an integer
+def close_to_int(x, eps=0.1):
+    if x.is_complex():
+        y = torch.abs(torch.view_as_complex(torch.frac(torch.view_as_real(x))))
+    else:
+        y = torch.abs(torch.frac(x))
+    return (y < eps) | (y > (1 - eps))
+
+
+def sample_inputs_slice(op_info, device, dtype, requires_grad, **kwargs):
+
+    make_input = partial(make_tensor, device=device, dtype=dtype,
+                         low=None, high=None, requires_grad=requires_grad)
+
+    yield SampleInput(make_input(3), 0)
+
+    yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2)
+
+    yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2, step=3)
+
+    yield SampleInput(make_input(20, 30, 40), dim=0, start=-10, end=-2, step=2)
+
+
+def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype,
+                         low=None, high=None, requires_grad=requires_grad)
+
+    args_cases = (
+        # Cases with tensor indices.
+        (torch.tensor([1, 2, 3]),),
+        (torch.tensor(1),),
+        (torch.tensor([1, 2, 3]), 1),
+        (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1),
+        # Cases with list of indices.
+        ((2, 4),),
+        ((2, 4), 1),
+        ((2, 4), -1),
+        # Cases with integer section.
+        (3,),
+        (3, 1),
+        (3, -1),
+    )
+
+    for args in args_cases:
+        yield SampleInput(make_input((S, S, S)), args=args)
+
+
+def sample_inputs_hsplit(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device,
+                       low=None, high=None, requires_grad=requires_grad)
+    yield SampleInput(make_arg(6), 2)
+    yield SampleInput(make_arg(S, S, S), [1, 2, 3])
+
+def sample_inputs_vsplit(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device,
+                       low=None, high=None, requires_grad=requires_grad)
+    yield SampleInput(make_arg(6, S), 2)
+    yield SampleInput(make_arg(S, S, S), [1, 2, 3])
+
+def sample_inputs_dsplit(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device,
+                       low=None, high=None, requires_grad=requires_grad)
+    yield SampleInput(make_arg(S, S, S), [1, 2, 3])
+    yield SampleInput(make_arg(S, S, 6), 2)
+
+def error_inputs_hsplit(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, dtype=torch.float32, device=device)
+    err_msg1 = ("torch.hsplit requires a tensor with at least 1 dimension, "
+                "but got a tensor with 0 dimensions!")
+    yield ErrorInput(SampleInput(make_arg(()), 0), error_regex=err_msg1)
+
+    err_msg2 = (f"torch.hsplit attempted to split along dimension 1, "
+                f"but the size of the dimension {S} "
+                f"is not divisible by the split_size 0!")
+    yield ErrorInput(SampleInput(make_arg((S, S, S)), 0), error_regex=err_msg2)
+
+    # Incorrect type for indices_or_section argument
+    err_msg3 = ("received an invalid combination of arguments.")
+    yield ErrorInput(
+        SampleInput(make_arg((S, S, S)), "abc"),
+        error_type=TypeError, error_regex=err_msg3)
+
+def error_inputs_vsplit(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, dtype=torch.float32, device=device)
+    err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, "
+                "but got a tensor with 1 dimensions!")
+    yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1)
+
+    err_msg2 = (f"torch.vsplit attempted to split along dimension 0, "
+                f"but the size of the dimension {S} "
+                f"is not divisible by the split_size 0!")
+    yield ErrorInput(SampleInput(make_arg(S, S, S), 0),
+                     error_regex=err_msg2)
+
+    # Incorrect type for indices_or_section argument
+    err_msg3 = ("received an invalid combination of arguments.")
+    yield ErrorInput(SampleInput(make_arg(S, S, S), "abc"),
+                     error_type=TypeError, error_regex=err_msg3)
+
+def error_inputs_dsplit(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, dtype=torch.float32, device=device)
+    err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, "
+                "but got a tensor with 1 dimensions!")
+    yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1)
+
+    err_msg2 = (f"torch.dsplit attempted to split along dimension 2, "
+                f"but the size of the dimension {S} "
+                f"is not divisible by the split_size 0!")
+    yield ErrorInput(SampleInput(make_arg(S, S, S), 0), error_regex=err_msg2)
+
+
+def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # input shape, output shape, output stride, output storage offset
+    test_cases = (
+        ((1,), (1,), (1,), 0),
+        ((3, 3), (2, 2), (1, 2), 0),
+        ((3, 3), (2, 2), (1, 2), 1),
+        ((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0),
+        ((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0),
+    )
+
+    for input_shape, output_shape, stride, storage_offset in test_cases:
+        input_t = make_arg(input_shape)
+        kwargs = dict(storage_offset=storage_offset)
+        yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs)
+
+def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs):
+    def make_arg():
+        base = make_tensor((20,), device=device, dtype=dtype)
+        return base[5:15].requires_grad_(requires_grad)
+
+    # as_strided on offset, partial views
+    yield SampleInput(make_arg(), (2, 2), (1, 2))
+    yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0)
+    yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10)
+
+def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # input shape, output shape, output stride, output storage offset
+    test_cases = [
+        ((1,), (), (), 0),
+        ((1,), (1,), (1,), 0),
+        ((3, 3), (2, 2), (1, 2), 0),
+        ((3, 3), (2, 2), (1, 2), 1),
+        ((3, 3), (2, 2), (2, 1), 0),
+        # Scatter to larger dimensions
+        ((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0),
+        # Scatter to larger dimensions with strides inverted
+        ((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0),
+    ]
+
+    for input_shape, output_shape, stride, storage_offset in test_cases:
+        input_t = make_arg(input_shape)
+        input_src = make_arg(output_shape)
+        yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset)
+
+
+def error_inputs_as_strided_scatter(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
+
+    # Create a small tensor and try to scatter it out of bounds
+    input_t = make_arg([4, 4])
+    input_src = make_arg([2, 2])
+    yield ErrorInput(
+        SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0),
+        error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64"
+    )
+
+
+def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs):
+    inputs = (
+        (0,),
+        (0, 1),
+        (0, 1, 2, 3),
+    )
+
+    rvals = [1, 2, 4]
+
+    products = product(inputs, rvals, [False, True])
+
+    for input_data, r, with_replacement in products:
+        input_t = torch.tensor(input_data, device=device, dtype=dtype, requires_grad=requires_grad)
+        yield SampleInput(input_t, r=r, with_replacement=with_replacement)
+
+def sample_inputs_cartesian_prod(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(torch.tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # constructs 1-D tensors with varying number of elements
+    a = make_arg((0,))
+    b = make_arg((0, 1))
+    c = make_arg((0, 1, 2, 3))
+
+    # sample with only 1 tensor
+    yield SampleInput(a)
+
+    # sample with 2 tensors
+    yield SampleInput(a, b)
+
+    # sample with 3 tensors
+    yield SampleInput(a, b, c)
+
+def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as input_shape, dict of dim and eps
+    cases: tuple[tuple, dict] = (  # type: ignore[assignment]
+        ((S, S), {'dim': 1}),
+        ((S, 2), {'dim': -1}),
+        ((S,), {'dim': 0, 'eps': 0.5}),
+        ((), {'dim': 0}),
+        ((S, S, M), {'dim': 2}),
+        ((S, S), {})
+    )
+
+    for input_shape, kwargs in cases:
+        yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
+    # Test for Broadcasting
+    yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
+    yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
+    yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
+
+
+def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+
+    cases = (
+        (),
+        (()),
+        (1),
+        ((1,)),
+    )
+
+    for shape in cases:
+        yield SampleInput(make_arg(shape))
+
+def error_inputs_item(op, device, **kwargs):
+    make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False)
+
+    cases = (
+        (M),
+        ((S,)),
+        (S, S),
+        (S, M, L),
+    )
+
+    for shape in cases:
+        yield ErrorInput(
+            SampleInput(make_arg(shape)), error_type=RuntimeError,
+            error_regex="elements cannot be converted to Scalar")
+
+
+def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    # Ordered as: input shape, kwargs for training, momentum, eps
+    cases: tuple[tuple[int], dict] = (  # type: ignore[assignment]
+        ((S, S, S), {'training': True, 'momentum': 0.5, 'eps': 0.6}),
+        ((3, 2, 4), {'training': False, 'momentum': -1.2}),
+        ((3, 1), {'training': True, 'momentum': 0.0}),
+        ((0,), {'training': True}),
+        ((0,), {'training': False}),
+        ((3, 2, 3, 4), {'training': True, 'momentum': -1.0, 'eps': 0.5}),
+        ((3, 2, 3, 4), {'training': False, 'momentum': -1.0, 'eps': 0.5}),
+        ((2, 1), {}),
+    )
+
+    for input_shape, kwargs in cases:
+        # args: running mean, running var, weight and bias should necessarily be of shape: (channels,)
+        channels = input_shape[1] if len(input_shape) > 1 else 0
+        weight = make_arg(channels) if channels > 0 else None
+        bias = make_arg(channels) if channels > 0 else None
+        running_mean = make_arg_without_requires_grad(channels, low=0)
+        running_var = make_arg_without_requires_grad(channels, low=0)
+
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(
+                running_mean,
+                running_var,
+                weight,
+                bias
+            ),
+            kwargs=kwargs
+        )
+
+    # Checking for permutations of weights and biases as `None`
+    weights = [channels, None, None]
+    biases = [None, channels, None]
+    is_training = [True, False, False]
+
+    for weight, bias, training in zip(weights, biases, is_training):
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(
+                running_mean,
+                running_var,
+                make_arg(channels),
+                make_arg(channels)
+            ),
+            kwargs={'training': training}
+        )
+
+    # Test case for no optional kwargs
+    # running_mean and running_var are required in evaluation mode (training: False) but not in training mode
+    yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True})
+
+def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    cases = [
+        ((S,), 0),
+        ((S, S), 0),
+        ((S, M, S), -1),
+    ]
+    input_dtypes = [dtype]
+    if dtype == torch.float and device == 'cuda':
+        input_dtypes += [torch.float16]
+
+    for (shape, dim), input_dtype in product(cases, input_dtypes):
+        input = make_arg(shape)
+        output = torch.nn.functional.softmax(input, dim=dim, dtype=input_dtype)
+        yield SampleInput(make_arg(shape), output, dim, input_dtype)
+
+def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
+    samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
+    for sample in samples:
+        # torch.native_batch_norm does not support 0 numel tensors
+        # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
+        if sample.input.numel() == 0:
+            continue
+        args = sample.args
+        training = sample.kwargs.get('training', True)
+        momentum = sample.kwargs.get('momentum', 0.5)
+        eps = sample.kwargs.get('eps', 1e-5)
+        yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
+
+
+def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs):
+    samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
+    for sample in samples:
+        # torch.native_batch_norm does not support 0 numel tensors
+        # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
+        if sample.input.numel() == 0:
+            continue
+        args = sample.args
+        training = sample.kwargs.get('training', True)
+        momentum = sample.kwargs.get('momentum', 0.5)
+        eps = sample.kwargs.get('eps', 1e-5)
+        if args[0] is not None and args[1] is not None:
+            yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
+        else:
+            yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
+
+def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs):
+    samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
+    for sample in samples:
+        # torch.native_batch_norm does not support 0 numel tensors
+        # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
+        if sample.input.numel() == 0:
+            continue
+        args = sample.args
+        momentum = sample.kwargs.get('momentum', 0.5)
+        eps = sample.kwargs.get('eps', 1e-5)
+        if any(args[i] is None for i in range(4)):
+            continue
+        yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps))
+
+def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = (
+        (()),
+        ((S, )),
+        ((S, S)),
+        ((S, M, S))
+    )
+
+    for shape in cases:
+        yield SampleInput(make_arg(shape))
+
+def sample_inputs_prelu(op_info, device, dtype, requires_grad, **kwargs):
+    op_kwargs = op_info.sample_kwargs(device, dtype, None)[0]
+    yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad,
+                                               op_kwargs=op_kwargs)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = (
+        (()),
+        ((S, )),
+        ((S, S)),
+        ((S, M, S))
+    )
+
+    for shape in cases:
+        for weight in [-1., 0., 0.8, 1.]:
+            weight_tensor = torch.tensor(weight, device=device, dtype=dtype, requires_grad=requires_grad)
+            yield SampleInput(make_arg(shape), args=(weight_tensor,))
+
+        channel_size = shape[1] if len(shape) >= 2 else 1
+        yield SampleInput(make_arg(shape), args=(make_arg((channel_size,)),))
+
+    weight_tensor = torch.tensor(1., device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(make_arg((S, S)), kwargs=dict(weight=weight_tensor,))
+    yield SampleInput(make_arg((S, S)), kwargs=dict(weight=make_arg((S,)),))
+
+def reference_inputs_prelu(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_prelu(op, device, dtype, requires_grad, **kwargs)
+    yield from reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs)
+
+def sample_kwargs_prelu_scalar_weight(device, dtype, input):
+    weight = torch.rand((), device=device, dtype=dtype)
+    # NumPy does not support bfloat16, so we default to float32 (only for NumPy) in that case
+    if dtype == torch.bfloat16:
+        weight_cpu = weight.to(dtype=torch.float32, device="cpu")
+    else:
+        weight_cpu = weight.cpu()
+    np_weight = weight_cpu.numpy()
+    return ({'weight': weight}, {'weight': np_weight})
+
+def error_inputs_prelu(op, device):
+    # Weight has numel != 1, but self.ndim is zero-dim tensor
+    inp = make_tensor((), device=device, dtype=torch.float32)
+    weight = make_tensor((2,), device=device, dtype=torch.float32)
+    yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
+                     error_regex="Not allow zero-dim input tensor.")
+
+    # Weight has numel != 1, but numel does not match channel size
+    inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32)
+    weight = make_tensor((9,), device=device, dtype=torch.float32)
+    yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
+                     error_regex="Mismatch of parameter numbers and input channel size.")
+
+    # Weight is neither a scalar nor 1-D tensor
+    inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32)
+    weight = make_tensor((2, 4), device=device, dtype=torch.float32)
+    yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
+                     error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = 2")
+
+    # src and index tensors must have the same # of dimensions
+def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # ord = inf is tested in inputs_norm_inf as it fails on some tests
+    cases = [
+        ((S, S), (2,), '2'),
+        ((S, S), (0,), '0'),
+        ((S, S), (0.5,), '0_5'),
+        ((S, S), (1,), '1'),
+        ((S, S), (3,), '3'),
+        ((S, S), (-1,), 'neg_1'),
+        ((S, S), (-2,), 'neg_2'),
+        ((S, S), (-0.5,), 'neg_0_5'),
+        ((S, S), (-1.5,), 'neg_1_5'),
+    ]
+
+    cases_nonzero_input = (
+        ((S, S, S), (1.5,), '1_5_default'),
+        ((S, S, S), (1.5, 1), '1_5_dim'),
+        ((S, S, S), (1.5, -1), '1_5_neg_dim'),
+        ((S, S, S), (1.5, 1, True), 'keepdim_1_5_dim'),
+        ((S, S, S), (1.5, -1, True), 'keepdim_1_5_neg_dim'),
+    )
+
+    cases_posdim = (
+        ((S, S), (-2, 1,), 'neg_2_dim'),
+        ((S, S), (-1, 1,), 'neg_1_dim'),
+        ((S, S), (0, 1,), '0_dim'),
+        ((S, S), (1, 1,), '1_dim'),
+        ((S, S), (2, 1,), '2_dim'),
+        ((S, S), (3, 1,), '3_dim'),
+        ((S, S, S), (2, 1), '2_dim'),
+        ((S, S, S), (3, 1), '3_dim'),
+        ((S, S, S), (2, 1, True), 'keepdim_2_dim'),
+        ((S, S, S), (3, 1, True), 'keepdim_3_dim'),
+        ((), (2, 0), '2_dim_scalar'),
+        ((), (3, 0), '3_dim_scalar'),
+        ((), (2, 0, True), 'keepdim_2_dim_scalar'),
+        ((), (3, 0, True), 'keepdim_3_dim_scalar'),
+    )
+
+    cases_negdim = ((shape, args[:1] + (-args[1],) + args[2:], name.replace("_dim", "_neg_dim"))
+                    for shape, args, name in cases_posdim)
+
+    for shape, args, name in itertools.chain(cases, cases_posdim, cases_negdim):
+        yield SampleInput(make_arg(shape), args=args, name=name)
+
+    for shape, args, name in cases_nonzero_input:
+        yield SampleInput(make_arg(shape, exclude_zero=True), args=args, name=name)
+
+
+def sample_inputs_norm_fro(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = (
+        ((S, S), (), 'default'),
+        ((S, S), ('fro',), 'fro_default'),
+        ((S, S), ('fro', [0, 1],), 'fro'),
+    )
+
+    for shape, args, name in cases:
+        yield SampleInput(make_arg(shape), args=args, name=name)
+
+
+def sample_inputs_norm_nuc(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = (
+        ((S, S), ('nuc',), 'nuc'),
+        ((S, S, S), ('nuc', [1, 2]), 'nuc_batched'),
+    )
+
+    for shape, args, name in cases:
+        yield SampleInput(make_arg(shape), args=args, name=name)
+
+
+def sample_inputs_norm_inf(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = (
+        ((S, S), (-inf,), '-inf'),
+        ((S, S), (inf,), 'inf'),
+        ((S, S), (inf, 1,), 'inf_2_dim'),
+        ((S, S), (inf, -1,), 'inf_2_neg_dim'),
+    )
+
+    for shape, args, name in cases:
+        yield SampleInput(make_arg(shape), args=args, name=name)
+
+
+def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    shapes = (
+        ((), ()),
+        ((S,), ()),
+        ((), (S,)),
+        ((S, 1), (S,)),
+        ((M, S), ()),
+        ((S, S), (S, S))
+    )
+
+    for shape_lhs, shape_rhs in shapes:
+        lhs = make_arg(shape_lhs)
+        rhs = make_arg(shape_rhs)
+        broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
+
+        yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input)
+        if shape_lhs == shape_rhs:
+            yield SampleInput(lhs, args=(lhs.clone().detach_(),))
+
+
+def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    shapes = (
+        ((), ()),
+        ((S,), ()),
+        ((S, 1), (S,)),
+        ((M, S), ()),
+        ((S, M, S), (M, S)),
+        ((S, M, S), (S, M, S)),
+        ((M, 1, S), (M, S)),
+        ((M, 1, S), (1, M, S)),
+        ((0, 1, 3), (0, 10, 3))
+    )
+
+    num_inputs = kwargs.get('num_inputs')
+    sample_kwargs = kwargs.get('sample_kwargs', {})
+
+    for shape_lhs, shape_rhs in shapes:
+        lhs = make_arg(shape_lhs)
+        args = [make_arg(shape_rhs) for _ in range(num_inputs - 1)]
+        broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs))
+
+        yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input)
+
+def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs):
+    shapes = (
+        ((), ()),
+        ((S,), ()),
+        ((S, 1), (S,)),
+        ((S, 1), S),
+        ((M, S), ()),
+        ((S, M, S), (M, S)),
+        ((S, M, S), (S, M, S)),
+        ((M, 1, S), (M, S)),
+        ((M, 1, S), (1, M, S)),
+        ((0, 1, 3), (0, 10, 3))
+    )
+
+    for shape in shapes:
+        inp, *arg0 = shape
+        yield SampleInput(inp, args=tuple(arg0))
+
+def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
+
+    # Adds alpha kwarg cases
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
+    rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
+    if dtype is not torch.bool:
+        yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': 2})
+    else:
+        yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': True})
+    neg_alpha = -3.125 if (dtype.is_floating_point or dtype.is_complex) else -3
+    lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
+    rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
+    if dtype is not torch.bool:
+        yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': neg_alpha})
+    else:
+        yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False})
+
+def error_inputs_arange(op, device, **kwargs):
+    yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzero')
+    yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError,
+                     error_regex='upper bound and lower bound inconsistent with step sign')
+    yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError,
+                     error_regex='upper bound and lower bound inconsistent with step sign')
+    yield ErrorInput(SampleInput(1549556900, args=(1549556828, 1989724)), error_type=RuntimeError,
+                     error_regex='upper bound and lower bound inconsistent with step sign')
+    yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range')
+    yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range')
+
+def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs):
+    int_samples = (
+        # positive direction
+        (-1, 2, 2),
+        # negative direction
+        (2, -3, -1),
+        # start == end
+        (1, 1, 1),
+        (1, 1, -1),
+        # divides evenly
+        (0, -8, -4),
+        (1, 5, 2),
+        # bool
+        (False, True, True),
+        # default step
+        (0, 1, None),
+        # default start
+        (None, 3, None),
+    )
+
+    def to_float(start, end, step):
+        start = start + 0.1 if start is not None else None
+        end = end + 0.1
+        step = float(step) if step is not None else None
+        return start, end, step
+
+    float_samples = (
+        # includes endpoint
+        (0., -8. - 1e-6, -4.),
+        (1., 5. + 1e-6, 2.),
+        (0., -8., -4.),
+        (1., 5., 2.),
+        *(to_float(start, end, step) for (start, end, step) in int_samples),
+    )
+
+    large_samples = (
+        (0, 10000, None),
+    )
+
+    samples = int_samples + float_samples
+    if dtype not in (torch.int8, torch.uint8):
+        samples += large_samples
+
+    for start, end, step in samples:
+        if start is None:
+            assert step is None
+            # Pass end as positional arg
+            yield SampleInput(end, kwargs={"dtype": dtype, "device": device})
+            # (Similar to) calling torch.arange(end=3)
+            yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device})
+        elif step is None:
+            yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device})
+        else:
+            yield SampleInput(start, args=(end, step), kwargs={"dtype": dtype, "device": device})
+
+    yield SampleInput(2)
+    yield SampleInput(1, args=(3, 1))
+
+def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs):
+    shapes = (
+        (M,),
+        (S, S)
+    )
+
+    for shape in shapes:
+        yield SampleInput(input=shape, kwargs=dict(dtype=dtype, device=device, requires_grad=requires_grad))
+
+def sample_inputs_normal(op, device, dtype, requires_grad, **kwargs):
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+    samples = (
+        ((S, S), 0, 5),
+        ((S, S, S), -2, 0.5),
+    )
+    for shape, mean, std in samples:
+        yield SampleInput(make_arg(shape), args=(mean, std))
+
+def error_inputs_normal(op, device, **kwargs):
+    t = torch.zeros([10], device=device)
+    invalid_std = -1
+    yield ErrorInput(
+        SampleInput(t, args=(0, invalid_std)),
+        error_type=RuntimeError,
+        error_regex=fr"normal expects std >= 0.0, but found std {invalid_std}",
+    )
+
+def sample_inputs_cauchy(op, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+    samples = (
+        ((M,), 0, 0.5),
+        ((S, S), 0, 1),
+        ((S, S, S), -2, 1),
+    )
+    for shape, median, gamma in samples:
+        yield SampleInput(make_arg(shape), args=(median, gamma))
+
+
+def error_inputs_cauchy(op, device, **kwargs):
+    t = torch.zeros([10], device=device)
+    invalid_scale = 0
+    yield ErrorInput(
+        SampleInput(t, args=(0, invalid_scale,)),
+        error_type=RuntimeError,
+        error_regex=fr"cauchy_ expects sigma > 0.0, but found sigma={invalid_scale}",
+    )
+
+
+def sample_inputs_exponential(op, device, dtype, requires_grad, **kwargs):
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+    samples = (
+        ((M,), 0.5),
+        ((S, S), 1),
+        ((S, S, S), 1.5),
+    )
+    for shape, rate in samples:
+        yield SampleInput(make_arg(shape), args=(rate,))
+
+
+def error_inputs_exponential(op, device, **kwargs):
+    t = torch.zeros([10], device=device)
+    invalid_rate = 0
+    yield ErrorInput(
+        SampleInput(t, args=(invalid_rate,)),
+        error_type=RuntimeError,
+        error_regex=fr"exponential_ expects lambda > 0.0, but found lambda={invalid_rate}",
+    )
+
+
+def sample_inputs_geometric(op, device, dtype, requires_grad, **kwargs):
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+    samples = (
+        ((M,), 0.2),
+        ((S, S), 0.5),
+        ((S, S, S), 0.8),
+    )
+    for shape, rate in samples:
+        yield SampleInput(make_arg(shape), args=(rate,))
+
+
+def error_inputs_geometric(op, device, **kwargs):
+    t = torch.zeros([10], device=device)
+    neg_prob = -1
+    yield ErrorInput(
+        SampleInput(t, args=(neg_prob,)),
+        error_type=RuntimeError,
+        error_regex=fr"geometric_ expects p to be in \(0, 1\), but got p={neg_prob}",
+    )
+
+
+def sample_inputs_log_normal(op, device, dtype, requires_grad, **kwargs):
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+    samples = (
+        ((M,), 0, 0.25),
+        ((S, S), 0.5, 1),
+        ((S, S, S), 0, 0.5),
+    )
+    for shape, mean, std in samples:
+        yield SampleInput(make_arg(shape), args=(mean, std))
+
+
+def error_inputs_log_normal(op, device, **kwargs):
+    t = torch.zeros([10], device=device)
+    invalid_std = 0
+    yield ErrorInput(
+        SampleInput(t, args=(0, invalid_std)),
+        error_type=RuntimeError,
+        error_regex=fr"log_normal_ expects std > 0.0, but found std={invalid_std}",
+    )
+
+
+def sample_inputs_uniform(op, device, dtype, requires_grad, **kwargs):
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
+    samples = (
+        ((M,), -100, 100),
+        ((S, S), 0, 1),
+        ((S, S, S), 1, 2),
+    )
+    for shape, hi, lo in samples:
+        yield SampleInput(make_arg(shape), args=(hi, lo))
+
+def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs):
+    # this is a bit messy, as we want the args to be tuples
+    # so if we pass size as a tuple, we have a tuple containing a tuple
+    sizes = (
+        (M,),
+        (S, S),
+    )
+    for size in sizes:
+        yield SampleInput(size, kwargs={'dtype': dtype, 'device': device})
+
+def sample_inputs_full(op, device, dtype, requires_grad, **kwargs):
+    def get_val(dtype):
+        return make_tensor([], dtype=dtype, device="cpu").item()
+
+    sizes = (
+        (M,),
+        (S, S),
+    )
+    fill_values = [get_val(dtype), get_val(torch.int)]
+
+    for size, fill_value in product(sizes, fill_values):
+        yield SampleInput(size, fill_value, dtype=dtype, device=device)
+
+
+def error_inputs_uniform(op, device, **kwargs):
+    t = torch.zeros([10], device=device)
+    yield ErrorInput(
+        SampleInput(t, args=(3, -1)),
+        error_type=RuntimeError,
+        error_regex=r"uniform_ expects to return a \[from, to\) range, but found from=3 > to=-1",
+    )
+
+
+def error_inputs_linspace(op, device, **kwargs):
+    yield ErrorInput(SampleInput(0, args=(3, -1)), error_type=RuntimeError, error_regex='number of steps must be non-negative')
+    yield ErrorInput(
+        SampleInput(0, args=(3, 1.)),
+        error_type=TypeError,
+        error_regex="received an invalid combination of arguments - got \\(int, int, float",
+    )
+    yield ErrorInput(
+        SampleInput(torch.tensor([1, 1], device=device), args=(torch.tensor([3, 3], device=device), 1)),
+        error_type=RuntimeError,
+        error_regex="only supports 0-dimensional start and end tensors"
+    )
+
+
+def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs):
+    ends = (-3, 0, 1, 4, 50)
+    starts = (-2., 0, 4.3, 50)
+    nsteps = (0, 1, 50)
+    # Extra case to replicate off-by-one issue on CUDA
+    cases = list(product(starts, ends, nsteps)) + [(0, 7, 50)]
+    for start, end, nstep in cases:
+        if dtype == torch.uint8 and (end < 0 or start < 0):
+            continue
+        yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device})
+
+    yield SampleInput(1, args=(3, 1))
+
+
+def sample_inputs_linspace_tensor_overload(op, device, dtype, requires_grad, **kwargs):
+    ends = (-3, 0, 1, 4, 50)
+    starts = (-2., 0, 4.3, 50)
+    nsteps = (0, 1, 50)
+    is_start_end_tensors = ((True, True), (True, False), (False, True))
+    make_arg = partial(torch.tensor, device=device, requires_grad=False)
+
+    # Extra case to replicate off-by-one issue on CUDA
+    cases = list(product(starts, ends, nsteps, is_start_end_tensors)) + [(0, 7, 50, (True, True))]
+    for start, end, nstep, (is_start_tensor, is_end_tensor) in cases:
+        if dtype == torch.uint8 and (end < 0 or start < 0):
+            continue
+
+        tensor_options = {"dtype": dtype, "device": device}
+        if is_start_tensor:
+            start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64)
+        if is_end_tensor:
+            end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64)
+
+        yield SampleInput(start, args=(end, nstep), kwargs=tensor_options)
+
+    yield SampleInput(1, args=(3, 1))
+
+
+def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs):
+    ends = (-3, 0, 1.2, 2, 4)
+    starts = (-2., 0, 1, 2, 4.3)
+    nsteps = (0, 1, 2, 4)
+    bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.)
+    for start, end, nstep, base in product(starts, ends, nsteps, bases):
+        if dtype == torch.uint8 and end < 0 or start < 0:
+            continue
+        if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point):
+            # https://github.com/pytorch/pytorch/issues/82242
+            continue
+        if base is None:
+            yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device})
+        else:
+            yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device})
+
+    yield SampleInput(1, args=(3, 1, 2.))
+
+
+def sample_inputs_logspace_tensor_overload(op, device, dtype, requires_grad, **kwargs):
+    ends = (-3, 0, 1.2, 2, 4)
+    starts = (-2., 0, 1, 2, 4.3)
+    nsteps = (0, 1, 2, 4)
+    bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.)
+    is_start_end_tensors = ((True, True), (True, False), (False, True))
+    make_arg = partial(torch.tensor, device=device)
+    for start, end, nstep, base, (is_start_tensor, is_end_tensor) in product(starts, ends, nsteps, bases, is_start_end_tensors):
+        if dtype == torch.uint8 and end < 0 or start < 0:
+            continue
+        if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point):
+            # https://github.com/pytorch/pytorch/issues/82242
+            continue
+
+        tensor_options = {"dtype": dtype, "device": device}
+
+        if (is_start_tensor):
+            start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64)
+        if (is_end_tensor):
+            end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64)
+
+        if base is None:
+            yield SampleInput(start, args=(end, nstep), kwargs=tensor_options)
+        else:
+            yield SampleInput(start, args=(end, nstep, base), kwargs=tensor_options)
+
+    yield SampleInput(1, args=(3, 1, 2.))
+
+
+def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
+
+    # Creates additional inputs to test the rtol, atol, and equal_nan params
+    rtols = [0., 1e-7]
+    atols = [0., 1e-7]
+    equal_nans = [False, True]
+
+    products = product(rtols, atols, equal_nans)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    for rtol, atol, equal_nan in products:
+        lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
+        rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
+
+        yield SampleInput(lhs, args=(rhs,),
+                          kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan))
+
+
+def error_inputs_isclose(op, device, **kwargs):
+    make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
+
+    yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}),
+                     error_type=RuntimeError,
+                     error_regex='rtol must be greater than or equal to zero')
+
+    yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}),
+                     error_type=RuntimeError,
+                     error_regex='atol must be greater than or equal to zero')
+
+
+def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg((1, 2)))
+    yield SampleInput(make_arg((2,)))
+    yield SampleInput(make_arg(()))
+
+
+def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_arg_conj(size):
+        return make_arg(size).conj().requires_grad_(requires_grad)
+
+    first_shape, second_shape = (S, M), (M, S)
+
+    yield SampleInput(make_arg(first_shape), args=(make_arg(second_shape),))
+
+    if dtype.is_complex:
+        yield SampleInput(make_arg(first_shape), args=(make_arg_conj(second_shape),))
+
+    # Matmul of empty matrices
+    yield SampleInput(make_arg((0, S)), args=(make_arg(S, M),))
+    yield SampleInput(make_arg((S, 0)), args=(make_arg(0, M),))
+
+
+def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
+    alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6)
+    beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2)
+    tests_list = [
+        ((2, 3), (2, 2), (2, 3), False),
+        ((3, 3), (3, 3), (3, 3), False),
+    ]
+    tests_with_lhs_broadcasting = [
+        ((1,), (2, 2), (2, 3), True),
+        ((), (2, 2), (2, 3), True),
+    ]
+    test_cases = tests_list + tests_with_lhs_broadcasting  # type: ignore[operator]
+
+    kwargs = dict(alpha=alpha_val, beta=beta_val)
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for shape_a, shape_b, shape_c, broadcasts_input in test_cases:
+        yield SampleInput(
+            make_arg(shape_a),
+            make_arg(shape_b),
+            make_arg(shape_c),
+            **kwargs,
+        ).with_metadata(broadcasts_input=broadcasts_input)
+
+    if dtype.is_complex:
+        shape = (3, 3)
+        yield SampleInput(
+            make_arg(shape),
+            make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad),
+            make_arg(shape),
+            **kwargs,
+        )
+        yield SampleInput(
+            make_arg(shape),
+            make_arg(shape),
+            make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad),
+            **kwargs,
+        )
+    # addmm of empty matrices
+    if dtype.is_floating_point:
+        yield SampleInput(make_arg(S, M), make_arg(S, 0), make_arg(0, M), **kwargs)
+        # empty matmul with broadcastable input
+        yield SampleInput(make_arg(M), make_arg(S, 0), make_arg(0, M), **kwargs).with_metadata(broadcasts_input=True)
+
+def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **kwargs):
+    alpha = 2 + 3j if dtype.is_complex else 0.6
+    beta = 1 + 2j if dtype.is_complex else 0.2
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # sparse.sampled_addmm performs: alpha * (A @ B) * sparse_ones_like(C) + beta * C
+    for m, n, k in itertools.product([0, 5], repeat=3):
+        yield SampleInput(
+            torch.eye(m, n, device=device, dtype=dtype)
+            .to_sparse_csr()
+            .requires_grad_(requires_grad),
+            make_arg((m, k)),
+            make_arg((k, n)),
+            alpha=alpha,
+            beta=beta,
+        )
+
+def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    reductions = ["sum", "mean", "amax", "amin"]
+    for m, k, reduce in product([5, 7], [3, 11], reductions):
+        yield SampleInput(
+            torch.eye(m, m)
+            .to(device=device, dtype=dtype)
+            .to_sparse_csr()
+            .requires_grad_(requires_grad),
+            make_arg((m, k)),
+            reduce,
+        )
+
+
+def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
+    yield SampleInput(make_arg(S, M), make_arg(M))
+
+def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
+    yield SampleInput(make_arg(M, S, M), make_arg(M, M, S))
+
+def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_arg_conj(size):
+        return make_arg(size).conj().requires_grad_(requires_grad)
+
+    yield SampleInput(make_arg((S, )), make_arg((S, )))
+    if dtype.is_complex:
+        # dot/vdot for (conj(input), conj(arg_tensor)) and (conj(input), arg_tensor)
+        # is tested in test_conj_view (which tests operations with only conjugated input tensor
+        # -- not conjugated arg tensors)
+        yield SampleInput(make_arg((S, )), make_arg_conj((S, )))
+
+
+def error_inputs_dot_vdot(op_info, device, is_ref=False, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=torch.float32)
+
+    yield ErrorInput(SampleInput(make_input(1), args=(make_input(3, dtype=torch.float16),)),
+                     error_regex='dot : expected both vectors to have same dtype')
+    yield ErrorInput(SampleInput(make_input(1, 1), args=(make_input(3),)),
+                     error_regex='1D tensors expected')
+    yield ErrorInput(SampleInput(make_input(9), args=(make_input(3),)),
+                     error_regex='inconsistent tensor size')
+    if device != "cpu" and not is_ref:
+        yield ErrorInput(SampleInput(make_input(3), args=(make_input(3, device="cpu"),)),
+                         error_regex='Expected all tensors to be on the same device')
+
+
+def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    test_cases = (((S,), (S, M), (M,), 1, 1, False),
+                  ((S,), (S, M), (M,), 0.2, 0.6, False),
+                  )
+
+    test_cases_with_broadcast = (((1,), (S, M), (M,), 1, 1, True),
+                                 ((1,), (S, M), (M,), 0.2, 0.6, True),
+                                 ((), (S, M), (M,), 1, 1, True),
+                                 ((), (S, M), (M,), 0.2, 0.6, True),
+                                 )
+
+    cases = test_cases + test_cases_with_broadcast
+
+    # addmv performs: beta * M + alpha * (mat @ vec)
+    for size, mat, vec, beta, alpha, broadcasts_input in cases:
+        yield SampleInput(make_arg(size), args=(make_arg(mat), make_arg(vec)),
+                          kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=broadcasts_input)
+
+def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # input_shape, batch1_shape, batch2_shape, beta_val, alpha_val, is_broadcasting
+    test_cases = [((S, M), (S, S, S), (S, S, M), 1, 1, False),
+                  ((1,), (S, S, S), (S, S, M), 1, 1, True),
+                  ((S, M), (S, S, S), (S, S, M), 0.6, 0.2, False),
+                  ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True),
+                  ((), (S, S, S), (S, S, M), 1, 1, True),
+                  ((), (S, S, S), (S, S, M), 0.6, 0.2, True),
+                  ]
+
+    for input_shape, batch1_shape, batch2_shape, beta, alpha, is_broadcasting in test_cases:
+        if dtype.is_complex:
+            beta_complex, alpha_complex = beta * (1 + 2j), alpha * (2 + 3j)
+            yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)),
+                              kwargs=dict(beta=beta_complex, alpha=alpha_complex), broadcasts_input=is_broadcasting)
+        yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)),
+                          kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=is_broadcasting)
+
+def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    test_cases = [(((S, S), (S, S), (S, S)), False),
+                  (((S, S), (S, 1), (1, S)), False),
+                  (((1,), (S, S, 1), (1, S)), True),
+                  (((), (), ()), False),
+                  (((S, S), (), ()), True),
+                  (((), (S, S, 1), (1, S)), True)
+                  ]
+
+    for input_args, broadcasts_input in test_cases:
+        # addcdiv should accept inputs with zero value
+        # Currently, it throws ZeroDivisionError when the denominator is zero
+        # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed
+        args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg
+                     for arg in input_args)
+        yield SampleInput(*args).with_metadata(broadcasts_input=broadcasts_input)
+
+        # addcdiv should accept inputs with zero value
+        # Currently, it throws ZeroDivisionError when the denominator is zero
+        # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed
+        args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg
+                     for arg in input_args)
+        yield SampleInput(
+            *args, value=3.14 if dtype.is_floating_point or dtype.is_complex else 3
+        ).with_metadata(broadcasts_input=broadcasts_input)
+
+def reference_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_addcmul_addcdiv(
+        op_info, device, dtype, requires_grad, **kwargs)
+
+    # type promotion cases
+    supported_dtypes = op_info.supported_dtypes(device)
+    make_arg = partial(make_tensor, device=device, requires_grad=requires_grad)
+
+    types = (
+        (torch.float64, torch.complex128),
+        (torch.bfloat16, torch.float32),
+    )
+
+    values = (
+        None,
+        True, False,
+        3.14, 3,
+        1.0, 1,
+        0.0, 0,
+        -3.14, -3,
+        3.14 + 2.71j,
+    )
+
+    for (type2, type3), value in product(types, values):
+        if (type2 not in supported_dtypes or
+                type3 not in supported_dtypes):
+            continue
+
+        # RuntimeError: value cannot be converted without overflow
+        if (type(value) is complex and
+                type2 is not torch.complex128):
+            continue
+
+        arg1 = make_arg([5, 5], dtype=dtype)
+        arg2 = make_arg([5, 5], dtype=type2)
+        arg3 = make_arg([1, 5], dtype=type3)
+
+        # TypeError: addcdiv(): argument 'value' must be Number, not NoneType
+        if value is not None:
+            yield SampleInput(arg1, args=(arg2, arg3), kwargs=dict(value=value))
+        else:
+            yield SampleInput(arg1, args=(arg2, arg3))
+
+def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs):
+    test_cases = [((S, S, M), (S, S, S), (S, S, M), 1, 1, False),
+                  ((1,), (S, S, S), (S, S, M), 1, 1, True),
+                  ((S, S, M), (S, S, S), (S, S, M), 0.6, 0.2, False),
+                  ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True),
+                  ((), (S, S, S), (S, S, M), 1, 1, True),
+                  ((), (S, S, S), (S, S, M), 0.6, 0.2, True),
+                  ]
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
+    for (input_shape, batch1_shape, batch2_shape, alpha, beta, broadcasts_input) in test_cases:
+        yield SampleInput(
+            make_arg(input_shape),
+            make_arg(batch1_shape),
+            make_arg(batch2_shape),
+            beta=beta,
+            alpha=alpha
+        ).with_metadata(broadcasts_input=broadcasts_input)
+
+        if dtype.is_complex:
+            yield SampleInput(
+                make_arg(input_shape),
+                make_arg(batch1_shape),
+                make_arg(batch2_shape),
+                beta=beta * (1 + 2j),
+                alpha=alpha * (2 + 3j),
+            ).with_metadata(broadcasts_input=broadcasts_input)
+
+    if dtype.is_complex:
+        shapes = [(S, S, S), (S, M, S), (S, S, M)]
+        args = tuple(make_arg(s) for s in shapes)
+        yield SampleInput(
+            args[0].transpose_(-1, 1),
+            args[1].transpose(-1, 1).conj().requires_grad_(requires_grad),
+            args[2].transpose(-1, 1).conj().requires_grad_(requires_grad),
+            beta=beta * (1 + 2j),
+            alpha=alpha * (2 + 3j),
+        )
+
+# TODO: add reduction kwargs
+def sample_inputs_multilabel_soft_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    shapes = (
+        (S,),
+        (S, S),
+    )
+
+    for shape in shapes:
+        # Produce one with weight and one without.
+        yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), kwargs={})
+        yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),),
+                          kwargs={'weight': _make_tensor(shape, requires_grad=False)})
+
+def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None
+    )
+    yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M))
+
+    yield SampleInput(make_arg(), make_arg(S), make_arg(M)).with_metadata(broadcasts_input=True)
+
+    if dtype.is_complex:
+        alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j
+    elif dtype.is_floating_point:
+        alpha, beta = 0.2, 0.6
+    else:
+        alpha, beta = 2, 3
+
+    yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M), beta=beta, alpha=alpha)
+
+    yield SampleInput(
+        make_arg(),
+        make_arg(S),
+        make_arg(M),
+        beta=beta,
+        alpha=alpha,
+    ).with_metadata(broadcasts_input=True)
+
+    # These samples fail gradcheck
+    if dtype.is_floating_point and not requires_grad:
+        tensor_options = dict(device=device, dtype=dtype, requires_grad=requires_grad)
+        yield SampleInput(
+            torch.tensor([[math.nan]], **tensor_options),
+            torch.tensor([0.0], **tensor_options),
+            torch.tensor([0.0], **tensor_options),
+            beta=0.0,
+            alpha=0.0,
+        ).with_metadata(broadcasts_input=True)
+
+        yield SampleInput(
+            torch.tensor([[0.0]], **tensor_options),
+            torch.tensor([math.nan], **tensor_options),
+            torch.tensor([math.nan], **tensor_options),
+            beta=0.0,
+            alpha=0.0,
+        ).with_metadata(broadcasts_input=True)
+
+def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = ((), (S, S, S), (S,))
+
+    for shape in cases:
+        yield SampleInput(make_arg(shape))
+
+def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
+    make_weight = partial(_make_tensor, requires_grad=False)
+
+    inputs = (
+        ((), make_target([], low=0, high=1), {}),
+        ((S,), make_target([], low=0, high=S), {"p": 1}),
+        ((S,), make_target([1], low=0, high=S), {"p": 2}),
+        ((S, M), make_target([S], low=0, high=M), {"margin": 1.0}),
+        ((S, M), make_target([S], low=0, high=M), {"margin": -3.14}),
+        ((M, S), make_target([M], low=0, high=S), {"weight": None}),
+        ((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}),
+        ((M, S), make_target([M], low=0, high=S), {"reduction": "none"}),
+        ((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}),
+        ((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}),
+    )
+
+    for input_shape, target, kwargs in inputs:
+        yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs)
+
+
+def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
+    make_weight = partial(_make_tensor, requires_grad=False)
+
+    inputs = (
+        ((), make_target([], low=0, high=1)),
+        ((S,), make_target([], low=0, high=S)),
+        ((S,), make_target([1], low=0, high=S)),
+        ((M, S), make_target([M], low=0, high=S)),
+    )
+    ps = (1, 2)
+    margins = (0, 7, -3.14)
+    weights = (False, True)
+    reductions = (None, "none", "mean", "sum")
+
+    for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions):
+        input = _make_tensor(input_shape)
+        weight_shape = [input.size(-1)] if input.ndim > 0 else [1]
+        weight = make_weight(weight_shape, low=-10., high=10.) if weight else None
+        kwargs = {"p": p, "margin": margin, "weight": weight}
+        if reduction is not None:
+            kwargs["reduction"] = reduction
+        yield SampleInput(input, args=(target,), kwargs=kwargs)
+
+
+def error_inputs_multi_margin_loss(op, device, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=torch.float32)
+    # invalid reduction
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'reduction': 'abc'}),
+                     error_type=ValueError, error_regex='abc is not a valid value for reduction')
+    # invalid input
+    yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5,),), kwargs={}),
+                     error_type=RuntimeError,
+                     error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]')
+    yield ErrorInput(SampleInput(make_input(0,), args=(make_input(5,),), kwargs={}),
+                     error_type=RuntimeError,
+                     error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]')
+    # invalid target
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={}),
+                     error_type=RuntimeError, error_regex=r'inconsistent target size, expected 5 but got \[5, 4\]')
+    # invalid target dtype
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}),
+                     error_type=RuntimeError, error_regex='expected scalar type Long but found Float')
+    # invalid weight
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}),
+                     error_type=ValueError, error_regex='weight must be one-dimensional')
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}),
+                     error_type=ValueError, error_regex='weight must be one-dimensional')
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}),
+                     error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]')
+    # invalid p
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}),
+                     error_type=ValueError, error_regex='only p == 1 and p == 2 supported')
+
+
+def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
+    inputs = (
+        ((), (0,), True),
+        ((S, S), (1,), True),
+        ((S, S), (1,), False),
+        ((S, S), (-2,), False),
+        ((S, S), (0, 1), False),
+    )
+    # Test large inputs to check numerical stability
+    lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,)
+    for low in lows:
+        high = low * 2 if low is not None else None
+        for shape, dim, keepdim in inputs:
+            t = make_tensor(shape, dtype=dtype, device=device,
+                            low=low, high=high,
+                            requires_grad=requires_grad)
+            yield SampleInput(t, dim, keepdim)
+
+def reference_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs)
+
+    # https://github.com/pytorch/pytorch/issues/91843
+    t = torch.tensor([20, 30, 100], dtype=dtype, device=device, requires_grad=requires_grad)
+    yield SampleInput(t, 0, False)
+
+    t = torch.tensor((), dtype=dtype, device=device, requires_grad=requires_grad)
+    yield SampleInput(t, 0, False)
+
+    # tests masking
+    # https://github.com/pytorch/pytorch/pull/91860#pullrequestreview-1241344073
+    t = torch.tensor(float("inf"))
+    yield SampleInput(t, 0, True)
+
+def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
+    inputs = [
+        ((), {}),
+        ((S, S), {}),
+        ((0, S, 0), {}),
+        ((S,), {'dtype': dtype, 'device': device}),
+        # Hard-code some dtypes/devices. We want to test cases where the
+        # (dtype, device) is different from the input's (dtype, device)
+        ((S,), {'dtype': torch.double if device != 'mps:0' else torch.float}),
+        ((S,), {'device': 'cpu'}),
+        ((S,), {'dtype': torch.double, 'device': 'cpu'}),
+    ]
+    if torch.cuda.is_available():
+        inputs.append(((S,), {'device': 'cuda'}))
+
+    for shape, kwargs in inputs:
+        t = make_tensor(shape, dtype=dtype, device=device,
+                        low=None, high=None,
+                        requires_grad=requires_grad)
+        yield SampleInput(t, **kwargs)
+
+def reference_inputs_like_fns(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_like_fns(op, device, dtype, requires_grad, **kwargs)
+
+    # shape
+    cases = (
+        (), (0,), (1, 0), (1, 1, 4, 5), (5, 3, 0, 1), (1, 4, 3, 1, 1)
+    )
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for shape in cases:
+        yield SampleInput(make_arg(shape))
+        yield SampleInput(make_arg(shape).transpose(0, -1))
+        yield SampleInput(make_arg(shape, noncontiguous=True))
+        yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1))
+
+def sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
+
+    inputs = (
+        ([], make_target([], low=0, high=1), {}),
+        ([S], make_target([S], low=0, high=S), {}),
+        ([M, S], make_target([M, S], low=0, high=S), {}),
+        ([M, S], make_target([M, S], low=0, high=S), {"reduction": "none"}),
+        ([M, S], make_target([M, S], low=0, high=S), {"reduction": "mean"}),
+        ([M, S], make_target([M, S], low=0, high=S), {"reduction": "sum"}),
+    )
+
+    for shape, target, kwargs in inputs:
+        yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs)
+
+
+def reference_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
+    make_target_tensor = partial(torch.tensor, device=device, dtype=torch.long, requires_grad=False)
+
+    inputs = (
+        # random tests including -1 target labels
+        ([], make_target([], low=-1, high=1)),
+        ([S], make_target([S], low=-1, high=S)),
+        ([M, S], make_target([M, S], low=-1, high=S)),
+        # repeated target labels and -1 (labels after the first -1 are ignored)
+        ([], make_target_tensor(-1)),
+        ([7], make_target_tensor([2, 0, 6, -1, 4, -1, 6])),
+        ([4, 5], make_target_tensor([[4, -1, 0, -1, 2], [0, 0, 4, 1, 4], [-1, 3, -1, 1, 0], [4, 3, 2, 1, 0]])),
+    )
+    reductions = (None, "none", "mean", "sum")
+
+    for (shape, target), reduction in product(inputs, reductions):
+        kwargs = {}
+        if reduction is not None:
+            kwargs["reduction"] = reduction
+        yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs)
+
+
+def error_inputs_multilabel_margin_loss(op, device, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=torch.float32)
+    # invalid reduction
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
+                     error_type=ValueError, error_regex='abc is not a valid value for reduction')
+    # invalid input
+    yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5, 4),), kwargs={}),
+                     error_type=RuntimeError,
+                     error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]')
+    yield ErrorInput(SampleInput(make_input(0,), args=(make_input(0,),), kwargs={}),
+                     error_type=RuntimeError,
+                     error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]')
+    # invalid target
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(4,),), kwargs={}),
+                     error_type=RuntimeError,
+                     error_regex=r'inconsistent target size: \[4\] for input of size: \[5, 4\]')
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input((),),), kwargs={}),
+                     error_type=RuntimeError,
+                     error_regex=r'inconsistent target size: \[\] for input of size: \[5, 4\]')
+
+
+def get_independent_tensor(tensor):
+    return tensor.clone().requires_grad_(tensor.requires_grad)
+
+def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs):
+    low = 2
+    high = 10
+
+    for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
+        sample.kwargs.setdefault('device', device)
+        # With high
+        yield SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs)
+        # With low and high
+        yield SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs)
+
+def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs):
+    low = 2
+    high = 10
+
+    for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
+        # With high
+        yield SampleInput(
+            sample.input,
+            high,
+            *sample.args,
+            **sample.kwargs)
+        # With low and high
+        yield SampleInput(
+            get_independent_tensor(sample.input),
+            low,
+            high,
+            *sample.args,
+            **sample.kwargs)
+
+def sample_inputs_margin_ranking_loss(op_info, device, dtype, requires_grad, **kwargs):
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    shapes = (
+        (),
+        (S,),
+        (S, S),
+        (S, S, S),
+    )
+
+    margins = (0., 1.)
+    reductions = ('sum', 'mean', 'none')
+
+    for shape in shapes:
+        for margin, reduction in product(margins, reductions):
+            kwargs = {'margin': margin, 'reduction': reduction}
+            yield SampleInput(_make_tensor(shape),
+                              args=(_make_tensor(shape, requires_grad=False),
+                                    _make_tensor(shape, requires_grad=False)),
+                              kwargs=kwargs)
+
+def reference_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs)
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    for reduction in ('sum', 'mean', 'none'):
+        if dtype.is_floating_point:  # only supports ints and floats
+            # NaN propagation
+            inp1 = make_input((10, ))
+            inp1[2] = float('nan')
+            inp2 = make_input((10, ))
+            inp2[4] = float('nan')
+            target = make_input((10, ))
+            inp2[9] = float('nan')
+            yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
+
+            # Inf handling
+            inp1 = make_input((10, ))
+            inp2[1] = float('inf')
+            inp2 = make_input((10, ))
+            inp2[4] = float('inf')
+            target = make_input((10, ))
+            inp2[7] = float('inf')
+            yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
+
+        # Broadcasting
+        inp1 = make_input((5, 2))
+        inp2 = make_input((5, 1))
+        target = make_input((1, 2))
+        yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
+
+def error_inputs_margin_ranking_loss(op, device, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=torch.float32)
+    # invalid reduction value.
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5, 4),), kwargs={'reduction': 'abc'}),
+                     error_type=ValueError, error_regex='is not a valid value')
+    # invalid input shapes
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5,),)),
+                     error_regex='margin_ranking_loss : All input tensors should')
+
+def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=False, **kwargs):
+    # input_shape, output_shape, strides, kwargs
+    # lengths of output_shape and strides must be equal
+    inputs = [
+        ((), (), (), {}),
+        ((S, S), (2, 0), (3, 4), {}),
+        ((0, S, 0), (3, 2, 2), (1, 2, 3), {}),
+        ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}),
+        # Hard-code some dtypes/devices. We want to test cases where the
+        # (dtype, device) is different from the input's (dtype, device)
+        ((S,), (10,), (S,), {'dtype': torch.double if device != 'mps:0' else torch.float}),
+        ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}),
+        ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}),
+    ]
+    if torch.cuda.is_available():
+        inputs.append(((S,), (7, 2), (3, 4), {'device': 'cuda'}))
+
+    for input_shape, output_shape, strides, kwargs in inputs:
+        t = make_tensor(input_shape, dtype=dtype, device=device,
+                        low=None, high=None,
+                        requires_grad=requires_grad)
+        if is_strided:
+            yield SampleInput(t, output_shape, strides, **kwargs)
+        else:
+            yield SampleInput(t, output_shape, **kwargs)
+
+def sample_inputs_empty_strided(op, device, dtype, requires_grad=False, **kwargs):
+
+    inputs = [
+        ((), (), {'dtype': dtype, 'device': device}),
+        ((S,), (4,), {'dtype': dtype, 'device': device}),
+        ((S, S), (2, 1), {'dtype': dtype, 'device': device}),
+        ((S, S, S), (2, 0, 1), {'dtype': dtype, 'device': device}),
+    ]
+
+    for shape, strides, kwargs in inputs:
+        yield SampleInput(shape, strides, requires_grad=requires_grad, **kwargs)
+
+def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs):
+    # shape
+    cases = (
+        (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1),
+    )
+
+    for case in cases:
+        yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad)
+
+def sample_inputs_empty_permuted(op, device, dtype, requires_grad, **kwargs):
+    # shape
+    cases = (
+        (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1),
+    )
+
+    for case in cases:
+        for layout in itertools.permutations(range(len(case))):
+            yield SampleInput(case, layout, device=device, dtype=dtype, requires_grad=requires_grad)
+
+def error_inputs_empty_permuted(op_info, device, **kwargs):
+    yield ErrorInput(
+        SampleInput((2,), args=((0, 1),)),
+        error_type=RuntimeError,
+        error_regex="Number of dimensions in size does not match the length of the physical_layout"
+    )
+    yield ErrorInput(
+        SampleInput((2,), args=((3,),)),
+        error_type=RuntimeError,
+        error_regex="Dimension out of range"
+    )
+    yield ErrorInput(
+        SampleInput((2, 3), args=((0, 0),)),
+        error_type=RuntimeError,
+        error_regex="Duplicate dim not allowed"
+    )
+
+def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs):
+    # Not including a scalar tensor in vals because meta tests start failing due to
+    # lack of meta support for _local_scalar_dense
+    # torch.tensor(2, device=device)
+    vals = (-5, 0, 1)
+
+    for item in vals:
+        yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad)
+
+def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs):
+    # only ints >= 0 are allowed for both arguments, unless m is omitted
+    sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S)
+
+    for n, m in product(sizes, sizes):
+        if n is None:
+            continue
+
+        # TODO: no layout
+        _kwargs = {'device': device, 'dtype': dtype, 'requires_grad': requires_grad}
+        if m is None:
+            yield SampleInput(n, args=(), kwargs=_kwargs)
+        else:
+            yield SampleInput(n, args=(m,), kwargs=_kwargs)
+
+def error_inputs_eye(op_info, device, **kwargs):
+    # TODO: no layout
+    _kwargs = {'device': device, 'dtype': torch.float32}
+
+    yield ErrorInput(
+        SampleInput(-1, args=(), kwargs=_kwargs),
+        error_regex="n must be greater or equal to 0, got -1"
+    )
+
+    yield ErrorInput(
+        SampleInput(-7, args=(42,), kwargs=_kwargs),
+        error_regex="n must be greater or equal to 0, got -7"
+    )
+
+    yield ErrorInput(
+        SampleInput(0, args=(-3,), kwargs=_kwargs),
+        error_regex="m must be greater or equal to 0, got -3"
+    )
+
+
+def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
+    def get_val(dtype):
+        return make_tensor([], dtype=dtype, device="cpu").item()
+
+    for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
+        # The scalar we are passing to new_full must be the same dtype
+        # as the one of the resulting tensor
+        use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
+        yield SampleInput(
+            sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)
+
+def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
+    def get_val(dtype):
+        return make_tensor([], dtype=dtype, device="cpu").item()
+
+    double_dtype = torch.double if device != "mps:0" else torch.float
+    inputs = [
+        ((), get_val(dtype), {}),
+        ((S, S), get_val(dtype), {}),
+        ((0, S, 0), get_val(dtype), {}),
+        ((S,), get_val(dtype), {'dtype': dtype, 'device': device}),
+        # Hard-code some dtypes/devices. We want to test cases where the
+        # (dtype, device) is different from the input's (dtype, device)
+        ((S,), get_val(double_dtype), {'dtype': double_dtype}),
+        ((S,), get_val(dtype), {'device': 'cpu'}),
+        ((S,), get_val(double_dtype), {'dtype': double_dtype, 'device': 'cpu'}),
+    ]
+    if torch.cuda.is_available():
+        inputs.append(((S,), get_val(dtype), {'device': 'cuda'}))
+
+    if torch.mps.is_available() and dtype not in [torch.float64, torch.complex128, torch.uint32, torch.uint16]:
+        inputs.append(((S,), get_val(dtype), {'device': 'mps'}))
+
+    if not dtype.is_signed:
+        # For unsigned dtypes, negative values are converted.
+        inputs.append(((S,), -get_val(dtype), {}))
+
+    for shape, fill_value, kwargs in inputs:
+        t = make_tensor(shape, dtype=dtype, device=device,
+                        low=None, high=None,
+                        requires_grad=requires_grad)
+        yield SampleInput(t, fill_value, **kwargs)
+
+def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs):
+    cases = [
+        ([3], 3, {}),
+        ([10], 3, {}),
+        ([3, 10], 3, {}),
+        ([3], 3, dict(replacement=False)),
+        ([3], 3, dict(replacement=True)),
+        ([3, 4], 4, dict(replacement=True)),
+        ([3, 4], 4, dict(replacement=False)),
+    ]
+
+    for shape, num_samples, kwargs in cases:
+        t = make_tensor(shape, dtype=dtype, device=device,
+                        low=0, high=None,
+                        requires_grad=requires_grad)
+        yield SampleInput(t, num_samples, **kwargs)
+
+def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs):
+    def get_value_or_make_tensor(value_or_shape):
+        if isinstance(value_or_shape, list):
+            return make_tensor(value_or_shape, dtype=dtype, device=device,
+                               low=0, high=None,
+                               requires_grad=requires_grad)
+        return value_or_shape
+
+    for value_or_mean_shape, value_or_std_shape, kwargs in cases:
+        mean = get_value_or_make_tensor(value_or_mean_shape)
+        std = get_value_or_make_tensor(value_or_std_shape)
+        yield SampleInput(mean, std, **kwargs)
+
+def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs):
+    # value_or_size, value_or_size, kwargs
+    cases = [
+        ([], [], {}),
+        ([3], [3], {}),
+        ([3, 4, 2], [3, 4, 2], {}),
+        ([2, 3], 1.1, {}),
+        ([1, 2, 3], [5, 2, 3], {}),  # broadcasting
+    ]
+
+    return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs)
+
+def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs):
+    yield SampleInput(1.6, 0.3, [2, 3], dtype=dtype, device=device)
+    yield SampleInput(1.6, 0.3, [2, 2, 2], dtype=dtype, layout=torch.strided, device=device)
+    yield SampleInput(2.7, make_tensor([4, 3], dtype=dtype, device=device, low=0, high=None, requires_grad=requires_grad))
+
+def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs):
+    shapes = [
+        [3],
+        [],
+        [0, 3],
+        [2, 3, 4],
+    ]
+
+    for shape in shapes:
+        t = make_tensor(shape, dtype=dtype, device=device,
+                        low=0, high=1,
+                        requires_grad=requires_grad)
+        yield SampleInput(t)
+
+def error_inputs_bernoulli(op_info, device, **kwargs):
+    # more than one element of the written-to tensor refers to a single memory location
+    x = torch.rand((1,), device=device).expand((6,))
+    err_msg = 'unsupported operation'
+    yield ErrorInput(SampleInput(torch.rand_like(x), kwargs={'out': x}),
+                     error_regex=err_msg)
+
+def sample_inputs_logcumsumexp(self, device, dtype, requires_grad, **kwargs):
+    inputs = (
+        ((S, S, S), 0),
+        ((S, S, S), 1),
+        ((), 0),
+    )
+
+    for large_number in (True, False):
+        for shape, dim in inputs:
+            t = make_tensor(shape, dtype=dtype, device=device,
+                            low=None, high=None,
+                            requires_grad=requires_grad)
+
+            if large_number and t.dim() > 0:
+                t[0] = 10000
+            yield SampleInput(t, dim)
+
+def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs):
+    yield SampleInput(
+        make_tensor((S, S), dtype=dtype, device=device,
+                    low=None, high=None,
+                    requires_grad=requires_grad))
+
+
+def error_inputs_trace(op, device):
+    yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix")
+
+
+def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    cases = (((S, S, S), (2, 1, 0.5)),
+             ((S, S, S), (2, -1, 0.5)),
+             ((S, S, S), (1, 2, 3)),
+             ((S, S, S), (float('inf'), 2, 0.5)),
+             )
+
+    for shape, args in cases:
+        yield SampleInput(make_arg(shape), args=args)
+
+
+def sample_inputs_transpose_swapdims(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (((1, 2, 3), (-1, -2)),
+             ((1, 2, 3), (-1, 2)),
+             ((1, 2, 3), (1, -2)),
+             ((1, 2, 3), (1, 2)),
+             ((), (0, 0)),
+             ((1, ), (0, 0)),
+             ((M, M), (0, 1)),
+             ((S, S, S), (2, 0)), )
+
+    for shape, args in cases:
+        yield SampleInput(make_arg(shape), args=args)
+
+def _numpy_ref_transpose(a, dim0, dim1):
+    if a.ndim <= 1:
+        return a
+
+    return np.swapaxes(a, dim0, dim1)
+
+def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S))
+    return (SampleInput(make_arg(shape)) for shape in shapes)
+
+def sample_inputs_T(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    shapes = ((M, M), (M, L))
+    return (SampleInput(make_arg(shape)) for shape in shapes)
+
+def error_inputs_T(self, device, has_ndims_error=False):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # Deprecated behavior in regular PyTorch, but throws an error in primTorch:
+    # https://github.com/pytorch/pytorch/issues/86968
+    if has_ndims_error:
+        # ndims == 1
+        yield ErrorInput(SampleInput(make_arg(M)),
+                         error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 '
+                                      r'to reverse their shape is not supported\.'))
+
+        # ndims > 2
+        yield ErrorInput(SampleInput(make_arg(M, S, L)),
+                         error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 '
+                                      r'to reverse their shape is not supported\.'))
+
+
+def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False):
+    """
+    This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n).
+    Their matrix product could be used to generate tensor of shape (*, m, n) of rank k.
+    """
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    batches = [(), (2,)]
+    size = [3, 4]
+    for batch, m, n in product(batches, size, size):
+        k = 2
+        a = make_arg((*batch, m, k))
+        b = make_arg((*batch, n, k))
+        yield a, b
+
+
+def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs):
+    # Function that's well defined on the outputs for complex inputs
+    def fn(usv):
+        U, S, V = usv
+        return U @ V.mH, S
+
+    for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad):
+        *batch, m, k = a.shape
+        n = b.shape[-2]
+
+        # NOTE: since svd_lowrank relies on non rank-revealing SVD,
+        # it inherits the problem of unstable behavior with repeated
+        # singular values including zeros.
+        # Since we want to avoid (repeated) zeros as singular values,
+        # we can only use k for q.
+        # This issues could be resolved with using a rank-revealing SVD
+        # which does not include "zero" singular values.
+        yield SampleInput(a, b, q=k, M=None).with_metadata(output_process_fn_grad=fn)
+
+    for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad):
+        *batch, m, k = a.shape
+        n = b.shape[-2]
+        M = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad)
+        yield SampleInput(a, b, q=k, M=M).with_metadata(output_process_fn_grad=fn)
+
+def chunk_iter(iterable, size):
+    it = iter(iterable)
+    while True:
+        chunk = tuple(islice(it, size))
+        if not chunk:
+            break
+        yield chunk
+
+def sample_inputs_pca_lowrank(op_info, device, dtype, requires_grad=False, **kwargs):
+    # we reuse samples from svd_lowrank which come in group of two with
+    # kwarg['M'] = None and with kwarg['M'] = 
+    samples = sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad, **kwargs)
+    for s1, s2 in chunk_iter(samples, 2):
+        del s1.kwargs['M']
+        del s2.kwargs['M']
+        s1.kwargs['center'] = False
+        s2.kwargs['center'] = True
+        yield s1
+        yield s2
+
+def np_sinc_with_fp16_as_fp32(x):
+    # Wraps numpy's sinc function so that fp16 values are promoted to fp32
+    # before sinc is invoked. Context: numpy's sinc returns NaN when evaluated
+    # at 0 for fp16.
+    if x.dtype == np.float16:
+        return np.sinc(x.astype(np.float32))
+    else:
+        return np.sinc(x)
+
+def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs):
+    test_cases = (
+        ((S, 1, 1), (S, S, S)),
+        ((S, 1, S), (S, S, S)),
+        ((S, 1), (S, S, S)),
+        ((1,), (S, S, S)),
+        ((1, S), (1, 1, S)),
+        ((), ()),
+        ((), (1, 3, 2)),
+    )
+
+    return (
+        SampleInput(
+            make_tensor(size, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad),
+            shape,
+        ) for size, shape in test_cases)
+
+def sample_inputs_broadcast_tensors(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    test_cases: tuple[tuple] = (((3,), (1, 2, 1), (1, 1), (5, 1, 1),),)
+
+    for shape, *other_shapes in test_cases:
+        yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes))
+
+def reference_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs)
+
+    m = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    n = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True)
+
+    cases = (
+        ((), (1, 1), (1, 1, 7, 1), (3, 1, 1)),
+        ((3, 5, 6), (1, 3, 5, 6), (1, 1, 1, 1, 6), (8, 3, 5, 6))
+    )
+
+    for a, b, c, d in cases:
+        yield SampleInput(m(a), args=(m(b), m(c), m(d)))
+        yield SampleInput(n(a), args=(n(b), n(c), n(d)))
+
+def sample_inputs_block_diag(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    test_cases: tuple[tuple] = (
+        ((1, S), (2, S), (3, S),),
+        ((S, 1), (S, 2), (S, 3),),
+        ((1,), (2,), (3,),),
+        ((2, S), (S,))
+    )
+
+    for shape, *other_shapes in test_cases:
+        yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes))
+        # We also want to test mixed complex-non-complex inputs to block_diag
+        if dtype == torch.complex32 or dtype == torch.complex64:
+            non_complex_dtype = torch.float32 if dtype == torch.complex32 else torch.float64
+            make_arg_non_complex = partial(make_tensor, dtype=non_complex_dtype, device=device, requires_grad=requires_grad)
+            yield SampleInput(make_arg_non_complex(shape), args=tuple(make_arg(s) for s in other_shapes))
+
+def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs):
+    small_S = 2
+    test_cases = (
+        ((S, S, 2), (S, S + 1, 2)),
+        ((S, S), (S, S)),
+        ((S, S, S), (S, S, S)),
+        ((3, 5), (3, 5)),
+        ((2, 3, 5), (2, 3, 5)),
+        ((1, 2, 3), (1, 2, 3)),
+        ((1, 1), (S, 1)),
+        ((0, 5), (4, 5)),
+        ((4, 5), (0, 5)),
+        ((0, 4, 5), (3, 5)),
+        ((4, 5), (0, 3, 5)),
+        ((0, 4, 5), (1, 3, 5)),
+        ((1, 4, 5), (0, 3, 5)),
+        # Using S here would make this one test take 9s
+        ((small_S, small_S, small_S + 1, 2), (small_S, small_S, small_S + 2, 2)),
+        ((small_S, 1, 1, small_S), (1, small_S, small_S)),
+        ((1, 1, small_S), (small_S, 1, small_S, small_S)),
+    )
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+        # FIXME add an override for JIT and revert 0. back to 0
+        # since it's accepted by eager
+        for p in [0., 1., 2., 3., 0.5, 1.5, 2.5, float("inf")]:
+            for t1_size, t2_size in test_cases:
+                # The args should never be non-contiguous as this is not supported in the backward
+                yield SampleInput(make_arg(t1_size), make_arg(t2_size), p, cm)
+
+def _fill_np(a, value):
+    a = a.copy()
+    a.fill(value)
+    return a
+
+def _fill_sample_kwargs(device, dtype, input):
+    if dtype is torch.bool:
+        value = True
+    else:
+        value = 3
+
+    return ({'value': value}, {'value': value})
+
+def sample_inputs_comparison_ops(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
+
+    # Adds a sample input where both tensors have the same values
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    lhs = make_arg((S, S))
+    yield SampleInput(lhs, args=(lhs.clone(),))
+
+def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # shape x number of tensors
+    cases = (
+        ((3, 4), 1),
+        ((1, 2, 1, 4), 3),
+        ((0, 1, 0), 2),)
+
+    for shape, num_tensors in cases:
+        tensors = [make_arg(shape) for _ in range(num_tensors)]
+        for dim in range(-1, len(shape) - 1):
+            yield SampleInput(tensors, args=(dim,))
+
+
+def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs):
+    # 1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors.
+    #    If all input tensors have the same ndims, we support both negative and non-negative dim.
+    # 2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions.
+    #        No requirements for (wrapped_dim, ...)-th dimension.
+    # 3. Expect positive num_chunks
+    # 4. Expect non-empty input tensor list and each input tensor should have at least 1 element
+    # 5. Non-contiguous input tensors are allowed.
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    same_ndim_cases = (
+        (
+            [
+                torch.Size([1, 2, 3]),
+                torch.Size([1, 2, 3]),
+            ], -1, 5
+        ),
+        (
+            [
+                torch.Size([1, 2, 129]),
+                torch.Size([1, 2, 297]),
+            ], -1, 5
+        ),
+        (
+            [
+                torch.Size([1, 2, 3]),
+                torch.Size([1, 2, 3]),
+            ], 1, 5
+        ),
+        (
+            [
+                torch.Size([3, 3, 2, 1]),
+                torch.Size([1, 4, 2, 2]),
+                torch.Size([2, 1, 3, 3]),
+            ], 0, 2
+        ),
+    )
+    for sizes, dim, num_chunks in same_ndim_cases:
+        tensors = [make_arg(size) for size in sizes]
+        yield SampleInput(tensors, args=(dim, num_chunks))
+
+    different_ndim_case = [
+        torch.Size([2, 3, 3]),
+        torch.Size([2, 3, 1, 2]),
+        torch.Size([2, 3]),
+        torch.Size([2, 3, 2]),
+        torch.Size([2, 3, 271]),
+    ]
+    max_dim, num_chunks = 2, 3
+    for dim in range(max_dim):
+        tensors = []
+        for size in different_ndim_case:
+            tensors.append(make_arg(size))
+        yield SampleInput(tensors, args=(dim, num_chunks))
+
+    # non-contiguous
+    for dim in range(max_dim):
+        tensors = []
+        for size in different_ndim_case:
+            # make the last 2 dims column-major (i.e. non-contiguous)
+            t = make_arg(size).transpose(-2, -1).contiguous().transpose(-2, -1)
+            tensors.append(t)
+        yield SampleInput(tensors, args=(dim, num_chunks))
+
+def error_inputs_chunk_cat(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # input tensors have different ndims but dim is negative
+    sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], -1, 3
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects non-negative dim when input tensors have different ndims',
+    )
+
+    # input tensors have different ndims but dim >= ndim of some input tensors
+    sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], 1, 3
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects dim < ndim for all input tensors',
+    )
+
+    # some tensors have different sizes for 0, ..., dim-1 dimensions.
+    sizes, dim, num_chunks = [torch.Size([2, 3, 4]), torch.Size([4, 3])], 1, 3
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors',
+    )
+
+    # negative num_chunks
+    sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, -1
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects positive num_chunks',
+    )
+
+    # zero as num_chunks
+    sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, 0
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects positive num_chunks',
+    )
+
+    # empty input tensor list
+    dim, num_chunks = 0, 1
+    yield ErrorInput(
+        SampleInput([], args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects a non-empty input tensor list',
+    )
+
+    # empty input tensor with 0 elements
+    sizes, dim, num_chunks = [torch.Size([0,]), torch.Size([3,])], 0, 1
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects non-empty tensor',
+    )
+
+
+def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: tuple[tuple, tuple, dict] = (  # type: ignore[assignment]
+        ((S, S), (S, S), {'dim': -1}),
+        ((S, S), (S, S), {'dim': 1}),
+        ((M, S), (S, S), {'dim': 0}),  # different shapes
+        ((1, 2, 3), (1, 2, 3), {'dim': -2}),
+        ((0,), (0,), {'dim': 0}),  # empty tensor
+        ((0,), (S, S), {'dim': 1}),  # empty tensor with unempty and dim=1 (special case for legacy_cat_wrap_dim)
+        ((0, S), (S, S), {'dim': 0}),
+        ((1,), (1,), {})  # dim not passed, fallback to default
+    )
+
+    for input_shape1, input_shape2, kwargs in cases:
+        yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs)
+
+    # from coat_lite_mini
+    yield SampleInput([make_arg((2, 2, 2, 2), memory_format=torch.channels_last)], args=(1,),)
+
+def error_inputs_cat(op_info, device, **kwargs):
+
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # error inputs for more than one element of the written-to tensor refer to a single memory location
+    yield ErrorInput(SampleInput([make_arg((S, S)), make_arg((S, S))],
+                                 kwargs={'out': make_arg((1, S)).expand((2 * S, S))}),
+                     error_regex='unsupported operation')
+
+    # error inputs for empty tensors
+    yield ErrorInput(SampleInput([], kwargs={'dim': 1}),
+                     error_regex='non-empty list of Tensors')
+
+    # error inputs for different sizes
+    yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
+                     error_regex='Sizes of tensors must match except in dimension')
+    yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S, S, L, L))], kwargs={'dim': 1}),
+                     error_regex='Sizes of tensors must match except in dimension')
+
+    # error inputs for different dimensions
+    yield ErrorInput(SampleInput([make_arg((S - 1, 0)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
+                     error_regex='Tensors must have same number of dimensions')
+    yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S - 1, 0))], kwargs={'dim': 1}),
+                     error_regex='Tensors must have same number of dimensions')
+
+    # error inputs for same memory locations
+    x = torch.zeros((0), device=device)
+    y = torch.randn((4, 6), device=device)
+
+    err_msg = "the written-to tensor refer to a single memory location"
+
+    yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': x}),
+                     error_regex=err_msg)
+    yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': y}),
+                     error_regex=err_msg)
+
+    z = torch.zeros((4, 6), device=device)
+    yield ErrorInput(SampleInput((y, z), kwargs={'out': z[:2, :]}),
+                     error_regex=err_msg)
+
+    # error inputs for different devices
+    if torch.device(device).type == 'cuda':
+        x_cuda = make_tensor((3, 3), device=device, dtype=torch.float32)
+        y_cpu = make_tensor((3, 3), device='cpu', dtype=torch.float32)
+        yield ErrorInput(SampleInput((x_cuda, y_cpu)),
+                         error_regex='Expected all tensors to be on the same device')
+
+    # error inputs for different input sizes for more than 2 tensors
+    yield ErrorInput(SampleInput([make_arg((L, 1)), make_arg((L, 1, 1)), make_arg((L, 1, 1))]),
+                     error_regex='Tensors must have same number of dimensions')
+
+    yield ErrorInput(SampleInput([make_arg((S, 1, M)), make_arg((S, 1, 1)), make_arg((S, M, 1))],
+                                 kwargs={'dim': 1}),
+                     error_regex='Sizes of tensors must match')
+
+    # error inputs for None input
+    yield ErrorInput(SampleInput((make_arg((S, 1, 1)), None)), error_type=TypeError,
+                     error_regex='got None')
+
+    # error inputs for zero-dimensional tensors
+    yield ErrorInput(SampleInput([make_arg(()), make_arg(())]),
+                     error_regex='zero-dimensional.*cannot be concatenated')
+
+    # error inputs for different dtype of out tensors
+    d = make_tensor((2, 3), device=device, dtype=torch.double)
+    x = make_tensor((2, 3), device=device, dtype=torch.float32)
+    yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError,
+                     error_regex='invalid combination of arguments')
+
+def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_cat_concat(op, device, dtype, requires_grad, **kwargs)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Noncontiguous type promoting tensors
+    a = make_arg((3, 4, 2))
+    b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double)
+    c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2)
+
+    yield SampleInput((a, b, c), kwargs={'dim': 1})
+
+    # Special 1D tensor with dim length of 0 case
+    a = make_arg((0,))
+    b = make_arg((3, 2, 2))
+
+    yield SampleInput((a, b, a))
+    yield SampleInput((a, a, a))
+
+def _elementwise_type_promo_np(*args, type_promotion_kind):
+    def _maybe_torch(x):
+        if isinstance(x, np.ndarray):
+            return torch.from_numpy(x)
+        return x
+
+    flattened = pytree.arg_tree_leaves(*args)
+    transformed = tuple(_maybe_torch(a) for a in flattened)
+    result_dtype, _ = prims.utils.elementwise_dtypes(
+        *transformed,
+        type_promotion_kind=type_promotion_kind)
+    return torch_to_numpy_dtype_dict[result_dtype]
+
+def _cat_np(input_seq, dim=0):
+    inputs = tuple(a for a in input_seq if not (a.ndim == 1 and a.size == 0))
+
+    if len(inputs) == 0:
+        np_dtype = _elementwise_type_promo_np(
+            input_seq,
+            type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH)
+        return np.empty(0, dtype=np_dtype)
+
+    return np.concatenate(inputs, axis=dim)
+
+def _floor_divide_np(a, b):
+    dtype = _elementwise_type_promo_np(
+        a,
+        b,
+        type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
+    if isinstance(a, np.ndarray):
+        a = a.astype(dtype)
+    if isinstance(b, np.ndarray):
+        b = b.astype(dtype)
+    return np.floor_divide(a, b)
+
+def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    tensor_shapes = (
+        # First Tensor being 1-D is special
+        # case for hstack
+        ((S,), (S,), (S,)),
+        ((S, S), (S, S), (S, S)),
+    )
+    for s1, s2, s3 in tensor_shapes:
+        tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3))
+        yield SampleInput(tensors)
+
+def error_inputs_hstack_dstack_vstack(op, device):
+    make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
+    tensor_shapes = (
+        ((S,), (S, S, S, S), (S,)),
+    )
+    for s1, s2, s3 in tensor_shapes:
+        tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3))
+        # Different dimension tensor
+        yield ErrorInput(SampleInput(tensors), error_regex="Tensors must have same number of dimensions")
+
+    # empty tensor list
+    yield ErrorInput(SampleInput(()), error_regex="expects a non-empty TensorList")
+
+def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs):
+    # Note: we don't do any tests where we unbind along 0-length dims
+    # because in that case unbind returns and empty tuple, and that breaks
+    # some assumptions in some backward tests in test_ops.py
+    shape_dims = (((S,), 0),
+                  ((S, S), 0),
+                  ((S, S), 1),
+                  ((S, S), -1),
+                  ((S, 0, S), 0),
+                  ((S, S, S), 1),
+                  )
+    for shape, dim in shape_dims:
+        yield SampleInput(make_tensor(shape, dtype=dtype, device=device,
+                                      requires_grad=requires_grad),
+                          args=(dim,))
+
+def error_inputs_unbind(op_info, device):
+    make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
+    yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError,
+                     error_regex="Dimension specified as 0 but tensor has no dimensions")
+    yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError,
+                     error_regex="Dimension out of range")
+
+def reference_unbind(t, dim):
+    """A numpy implementation of torch.unbind"""
+    return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim))
+
+def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
+    yield SampleInput(
+        make_arg((M, S)),
+        0,
+        gather_variable((S, S), 1, M, True, device=device))
+    yield SampleInput(
+        make_arg((M, S)),
+        0,
+        gather_variable((S, S), 1, M, True, device=device).to(torch.int32))
+    yield SampleInput(
+        make_arg((M, S)),
+        1,
+        gather_variable((M, S // 2), 0, S, True, device=device))
+    # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006
+    yield SampleInput(
+        make_arg((S,)),
+        0,
+        torch.tensor([], dtype=torch.uint8, device=device))
+    yield SampleInput(
+        make_arg((S,)),
+        0,
+        torch.tensor([[], []], dtype=torch.uint8, device=device))
+    # 0D tensor case
+    yield SampleInput(
+        make_arg(()),
+        0,
+        torch.tensor([0], dtype=torch.int64, device=device))
+    yield SampleInput(
+        make_arg(()),
+        0,
+        torch.tensor(0, dtype=torch.int64, device=device))
+
+def _fill_indices(idx, dim, dim_size, elems_per_row, m, n, o):
+    for i in range(1 if dim == 0 else m):
+        for j in range(1 if dim == 1 else n):
+            for k in range(1 if dim == 2 else o):
+                ii = [i, j, k]
+                ii[dim] = slice(0, idx.size(dim) + 1)
+                idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
+
+def error_inputs_gather(op_info, device, **kwargs):
+    # src is [1, 2]
+    #        [3, 4]
+    src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
+
+    # idx is [0, 0]
+    #        [1, 0]
+    idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
+
+    # Index should be smaller than self except on dimension 1
+    bad_src = make_tensor((1, 1), device=device, dtype=torch.float32)
+    yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
+                     error_regex="Size does not match at dimension 0")
+
+    # TODO: FIXME
+    # out.dtype must match src.dtype
+    # Creates new src & idx since SampleInputs can't share tensors
+    src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
+    idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
+    out = torch.empty((2, 2), device=device, dtype=torch.float64)
+    yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}),
+                     error_regex="Expected out tensor to have dtype")
+
+    # src and index tensors must have the same # of dimensions
+    # idx too few dimensions
+    src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
+    idx = torch.tensor((0, 0), device=device, dtype=torch.long)
+    yield ErrorInput(SampleInput(src, args=(1, idx)),
+                     error_regex="Index tensor must have the same number of dimensions")
+
+    # src too few dimensions
+    src = torch.tensor((1, 2), device=device, dtype=torch.float32)
+    idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
+    yield ErrorInput(SampleInput(src, args=(0, idx)),
+                     error_regex="Index tensor must have the same number of dimensions")
+
+    # index out of bounds
+    # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
+    if torch.device(device).type == 'cpu':
+        src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
+        idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long)
+        yield ErrorInput(SampleInput(src, args=(1, idx,)),
+                         error_regex="index 23 is out of bounds for dimension")
+
+    x = torch.rand((1,), device=device).expand((3,))
+    src = torch.rand((6,), device=device)
+    ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
+
+    yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=x)),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+    yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=src)),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+    yield ErrorInput(SampleInput(ind.clone(), args=(0, ind[1:],), kwargs=dict(out=ind[:1])),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+def error_inputs_take(op_info, device, **kwargs):
+    x = torch.rand((1,), device=device).expand((3,))
+    src = torch.rand((6,), device=device)
+    ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
+
+    yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=x)),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+    yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=src)),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+    yield ErrorInput(SampleInput(ind.clone(), args=(ind[1:],), kwargs=dict(out=ind[:-1])),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+# Error inputs for scatter
+def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
+    # Error when self.dtype != src.dtype (and src is not a scalar)
+    src = make_tensor((2, 5), device=device, dtype=torch.float32)
+    idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
+    dst = torch.zeros((3, 5), device=device, dtype=torch.double)
+    yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
+                     error_regex="Expected self.dtype to be equal to src.dtype")
+
+    # Index and destination must have the same number of dimensions
+    src = make_tensor((2, 5), device=device, dtype=torch.float32)
+    idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
+    dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32)
+    yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
+                     error_regex="Index tensor must have the same number of dimensions as self tensor")
+
+    # Index and src must have the same number of dimensions when src is not a scalar
+    src = make_tensor((2, 5, 2), device=device, dtype=torch.float32)
+    idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
+    dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
+    yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
+                     error_regex="Index tensor must have the same number of dimensions as src tensor")
+
+    # Index out of bounds
+    # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
+    if torch.device(device).type == 'cpu':
+        src = make_tensor((2, 5), device=device, dtype=torch.float32)
+        idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
+        dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
+        yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
+                         error_regex="index 34 is out of bounds for dimension 0 with size 3")
+
+def error_inputs_renorm(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError,
+                     error_regex="needs at least 2 dimensions, got 0 dimensions")
+
+
+def error_inputs_ormqr(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(SampleInput(zero_d, args=(zero_d, zero_d)), error_type=RuntimeError,
+                     error_regex="input must have at least 2 dimensions")
+
+    # https://github.com/pytorch/pytorch/issues/85218
+    tensor_0 = torch.full((5, 0,), 1, device=device)
+    tensor_1 = torch.full((5,), 1, device=device)
+    tensor_2 = torch.full((5, 5,), 1, device=device)
+    bool_3 = True
+    bool_4 = True
+    yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError,
+                     error_regex=r"tau.shape\[-1\] must be equal to min\(other.shape\[-2\], input.shape\[-1\]\)")
+
+
+def error_inputs_diag(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError,
+                     error_regex="1D or 2D")
+    zero_d = torch.randn(1, 1, 1, device=device)
+    yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError,
+                     error_regex="1D or 2D")
+
+def error_inputs_embedding(op_info, device, **kwargs):
+    indices = torch.rand(2, 2, device=device).long()
+    weights = [
+        torch.tensor(1.0, device=device),
+        torch.tensor(1.0, device=device).reshape(1, 1, 1),
+    ]
+
+    for weight in weights:
+        yield ErrorInput(SampleInput(weight, args=(indices,)), error_type=RuntimeError,
+                         error_regex="'weight' must be 2-D")
+
+
+def error_inputs_t(op_info, device, **kwargs):
+    yield ErrorInput(
+        SampleInput(torch.randn(2, 3, 4, 5, device=device)),
+        error_regex="expects a tensor with <= 2",
+    )
+
+
+def error_inputs_multinomial(op_info, device, **kwargs):
+    x = torch.empty(1, 2, 3, dtype=torch.double, device=device)
+    yield ErrorInput(SampleInput(x, args=(2,)),
+                     error_regex="prob_dist must be 1 or 2 dim")
+
+    x = torch.empty(1, 2, dtype=torch.long, device=device)
+    yield ErrorInput(SampleInput(x, args=(2,)),
+                     error_regex="multinomial only supports floating-point dtypes for input")
+
+    x = torch.empty(1, 2, dtype=torch.double, device=device)
+    y = torch.empty(1, 2, dtype=torch.double, device=device)
+    yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)),
+                     error_regex="multinomial expects Long tensor out")
+
+    x = torch.empty(2, dtype=torch.double, device=device)
+    yield ErrorInput(SampleInput(x, args=(0,)),
+                     error_regex="cannot sample n_sample <= 0 samples")
+
+    x = torch.empty(2, dtype=torch.double, device=device)
+    yield ErrorInput(SampleInput(x, args=(-1,)),
+                     error_regex="cannot sample n_sample <= 0 samples")
+
+    x = torch.empty(2, dtype=torch.double, device=device)
+    yield ErrorInput(SampleInput(x, args=(3, False,)),
+                     error_regex="cannot sample n_sample > prob_dist")
+
+    x = torch.empty(16777217, dtype=torch.double, device=device)
+    yield ErrorInput(SampleInput(x, args=(3,)),
+                     error_regex="number of categories cannot exceed")
+
+    inputs = ((1., -1., 1.), (1., inf, 1.), (1., -inf, 1.), (1., 1., nan))
+
+    err_msg1 = "probability tensor contains either `inf`, `nan` or element < 0"
+    err_msg2 = "invalid multinomial distribution"
+
+    rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,)
+
+    if torch.device(device).type == 'cpu':
+        for rep in rep_arg:
+            kwargs = {'num_samples': 2, 'replacement': rep}
+
+            for shape in inputs:
+                # error case when input tensor contains `inf`, `nan` or negative element
+                yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs),
+                                 error_regex=err_msg1 if rep is False else err_msg2)
+
+            # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input
+            x = torch.zeros(3, device=device)
+            yield ErrorInput(SampleInput(x, kwargs=kwargs),
+                             error_regex=err_msg2)
+
+            # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input
+            x = torch.zeros(3, 3, device=device)
+            yield ErrorInput(SampleInput(x, kwargs=kwargs),
+                             error_regex=err_msg2)
+
+            # error case for the invalid multinomial distribution
+            x[1, :] = 1
+            yield ErrorInput(SampleInput(x, kwargs=kwargs),
+                             error_regex=err_msg2)
+
+def error_inputs_gradient(op_info, device, **kwargs):
+    for dtype in [torch.long, torch.float32, torch.complex64]:
+        t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device, dtype=dtype)
+
+        dim = (1, 0)
+        spacing = [0.1]
+        yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)),
+                         error_type=RuntimeError,
+                         error_regex='torch.gradient expected spacing to be unspecified, a scalar ')
+
+        yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=3)),
+                         error_type=RuntimeError,
+                         error_regex='torch.gradient only supports edge_order=1 and edge_order=2.')
+
+        dim = (1, 1)
+        spacing = 0.1
+        yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)),
+                         error_type=RuntimeError,
+                         error_regex='dim 1 appears multiple times in the list of dims')
+
+        dim = (0, 1)
+        coordinates = [torch.tensor([1, 2, 4], device='cpu'), torch.tensor([1, 2, 4], device='meta')]
+        yield ErrorInput(SampleInput(t, kwargs=dict(spacing=coordinates, dim=dim, edge_order=1)),
+                         error_type=RuntimeError,
+                         error_regex='torch.gradient expected each tensor to be on the same device,')
+
+        yield ErrorInput(SampleInput(t, kwargs=dict(dim=3)),
+                         error_type=IndexError, error_regex='')
+
+        t = torch.tensor([[1], [2], [3]])
+        yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=1)),
+                         error_type=RuntimeError,
+                         error_regex='torch.gradient expected each dimension size to be at least')
+
+        t = torch.tensor([[1, 2], [3, 4]])
+        yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=2)),
+                         error_type=RuntimeError,
+                         error_regex='torch.gradient expected each dimension size to be at least')
+
+def sample_inputs_rrelu(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_elementwise_unary(
+        op_info, device, dtype, requires_grad, op_kwargs=dict(lower=0., upper=1., training=True))
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg(S))
+    yield SampleInput(make_arg(S), training=False)
+
+def error_inputs_rrelu(op_info, device, **kwargs):
+    input = make_tensor((S, S), device=device, dtype=torch.float32)
+    yield ErrorInput(SampleInput(input, kwargs={'lower': 0.3, 'upper': 0.1}),
+                     error_regex='Lower bound should be less than or equal to the upper bound')
+
+def error_inputs_masked_select(op_info, device, **kwargs):
+    x = torch.rand((1,), device=device).expand((3,))
+    y = torch.rand((6,), device=device)
+    mask = torch.tensor([True, False, True, True, False, False], device=device)
+
+    yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=x)),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+    yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=y)),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+    yield ErrorInput(SampleInput(mask.clone(), args=(mask,), kwargs=dict(out=mask)),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+def error_inputs_median(op_info, device, **kwargs):
+    x = torch.tensor([[[[[[[[[[[[[[[[[[[[[[[[[nan],
+                               [nan]]]]]]]]]]]]]]]]]]]]]]]]], device=device)
+    if device == 'cuda':
+        yield ErrorInput(SampleInput(x, kwargs=dict(dim=(-1))),
+                         error_type=RuntimeError,
+                         error_regex='CUDA Tensors cannot have more than 25 dimensions')
+    else:
+        return
+
+
+def error_inputs_index_select(op_info, device, **kwargs):
+    x = torch.rand((1, 6), device=device).expand((2, 6))
+    y = torch.rand((3, 6), device=device)
+    ind = torch.tensor([0, 1], dtype=torch.int64, device=device)
+
+    yield ErrorInput(SampleInput(y, args=(1, ind,), kwargs=dict(out=x)),
+                     error_type=RuntimeError,
+                     error_regex='unsupported operation')
+
+def error_inputs_index_add(op_info, device, **kwargs):
+    result = torch.tensor([[1., 2.], [4., 5.], [7., 8.]])
+    source = torch.tensor([2., 4.])
+
+    yield ErrorInput(SampleInput(result, args=(0, torch.tensor([0, 2]), source)),
+                     error_type=RuntimeError,
+                     error_regex=r'source tensor shape must match self tensor shape, '
+                     r'excluding the specified dimension. Got self.shape = \[3, 2\] source.shape = \[2\]')
+
+def error_inputs_logcumsumexp(op_info, device, **kwargs):
+    dim = 3
+    srcs = [torch.randn(5, 2, device=device), torch.randn(0, 2, device=device)]
+    for src in srcs:
+        yield ErrorInput(SampleInput(src, args=(dim,)),
+                         error_type=IndexError,
+                         error_regex='Dimension out of range')
+
+def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
+    yield SampleInput(
+        make_arg((S, S)), gather_variable((S, S), 1, S, True, device=device), 0)
+
+    # `indices` broadcast
+    yield SampleInput(
+        make_arg((S, S)), gather_variable((1, S // 2), 0, S, True, device=device), 1)
+
+    # `self` broadcast
+    yield SampleInput(
+        make_arg((1, S)), gather_variable((S, S // 2), 0, S, True, device=device), 1)
+
+    # without `dim` arg
+    yield SampleInput(
+        make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device))
+
+
+def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
+
+    # Error Inputs for zero-dim tensors, when 'dim' arg is not provided.
+    shape = (S, 0, S)
+    err_msg_amax_amin = "reduction"
+    err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
+    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
+        yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
+    elif op_info.name in ['aminmax']:
+        yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
+
+    # Error Inputs for tensors with more than 64 dimension
+    sizes = [1] * 65
+    err_msg1 = "only tensors with up to 64 dims are supported"
+    yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': -1}),
+                     error_regex=err_msg1)
+    yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': 64}),
+                     error_regex=err_msg1)
+
+    # Error Inputs for repeated 'dim'
+    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
+        dims = [(0, 0), (0, -4)]
+        err_msg2 = "in the list of dims"
+        x = torch.randn(S, S, S, S, device=device)
+        for dim in dims:
+            yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2)
+
+    # Error Input for illegal dtype
+    input5 = torch.randn(L, L, dtype=torch.float32, device=device)
+    max_values = torch.empty(L, dtype=torch.float32, device=device)
+    min_values = torch.empty(L, dtype=torch.double, device=device)
+    illegal_values = torch.empty(L, dtype=torch.int, device=device)
+
+    # Unlike regular PyTorch, amax and amin refs don't require input and out
+    # dtypes to match exactly:
+    # https://github.com/pytorch/pytorch/pull/87765#pullrequestreview-1162023824
+    if is_ref:
+        err_msg_amax_amin2 = ("Attempting to cast from torch.float32 to out tensor with dtype "
+                              "torch.int32, but this can't be cast because it is not safe!")
+    else:
+        err_msg_amax_amin2 = ("Expected the dtype for input and out to match, but got Float "
+                              "for input's dtype and Int for out's dtype.")
+    err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead"
+
+    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
+        yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
+                         error_regex=err_msg_amax_amin2)
+    elif op_info.name in ['aminmax']:
+        yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}),
+                         error_regex=err_msg_aminmax2)
+
+    # Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim
+    err_msg3 = "reduction"
+    # FIXME: eager and ref impl throw different types of errors
+    error_type = IndexError if 'refs' not in op_info.name else RuntimeError
+    yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}),
+                     error_type=error_type, error_regex=err_msg3)
+
+def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs):
+    test_cases: tuple[tuple, dict] = (  # type: ignore[assignment]
+        ((S, S, S), {}),
+        ((S, S, S), {'dim': 1}),
+        ((S, S, S), {'dim': 1, 'keepdim': True}),
+        ((), {'dim': 0}),
+        ((), {}),
+        ((), {'dim': 0, 'keepdim': True}),
+        ((S, 0, S), {'dim': 0}),
+    )
+
+    for shape, kwargs in test_cases:
+        yield SampleInput(
+            make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad),
+            **kwargs)
+
+def error_inputs_diff(op_info, device, **kwargs):
+    t = torch.rand((1, 3), device=device)
+    n = -1
+    yield ErrorInput(SampleInput(t, args=(n, ), kwargs=kwargs),
+                     error_type=RuntimeError,
+                     error_regex=f'order must be non-negative but got {n}')
+
+def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    test_cases = (
+        ((1,), 0, None, None),
+        ((S,), 0, None, None),
+        ((S, 1), 0, None, None),
+        ((S, 1), 1, None, None),
+        ((S, S), 0, None, None),
+        ((S, S), 1, None, None),
+        ((S, S), 0, (1, S), (2, S)),
+        ((S, S), 0, None, (2, S)),
+        ((XS, XS, XS), 1, None, None),
+        ((XS, XS, XS), 2, None, None),
+        ((XS, XS, XS), 1, (XS, 1, XS), (XS, 1, XS)),
+        ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)),
+        ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),)
+
+    for size, dim, size_prepend, size_append in test_cases:
+        prepend_size = 0 if (size_prepend is None) else size_prepend[dim]
+        append_size = 0 if (size_append is None) else size_append[dim]
+        dim_size = size[dim] + prepend_size + append_size
+        for n in range(dim_size):
+            input_tensor = make_arg(size)
+            prepend = make_arg(size_prepend) if size_prepend else None
+            append = make_arg(size_append) if size_append else None
+            yield SampleInput(input_tensor, n, dim, prepend, append)
+
+    # add some samples with n > dim_size
+    yield SampleInput(make_arg((XS, XS, XS)), S + 1, 1)
+    yield SampleInput(make_arg((XS, XS, XS)), S * 3 + 2, 2, make_arg((XS, XS, XS)), make_arg((XS, XS, XS)))
+
+def sample_inputs_histogram(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
+
+    for size, bin_ct, weighted, density in product(sizes, range(1, 5), [False, True], [False, True]):
+        input_tensor = make_arg(size)
+        weight_tensor = make_arg(size) if weighted else None
+
+        yield SampleInput(input_tensor, bin_ct,
+                          weight=weight_tensor, density=density)
+
+        bins_tensor = make_arg((bin_ct + 1,))
+        sorted_bins, _bins_indices = torch.sort(bins_tensor)
+        yield SampleInput(input_tensor, sorted_bins,
+                          weight=weight_tensor, density=density)
+
+def sample_inputs_histogramdd(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    sizes = ((S, S), (S, S, S), (S, 1, S), (S, 0, S))
+    bin_ct_patterns = ((1, 1, 1, 1, 1), (2, 3, 2, 3, 2), (3, 2, 3, 2, 3))
+
+    for size, bin_ct_pattern, weighted, density in product(sizes, bin_ct_patterns, [False, True], [False, True]):
+        input_tensor = make_arg(size)
+        bin_ct = bin_ct_pattern[:size[-1]]
+        weight_tensor = make_arg(size[:-1]) if weighted else None
+
+        yield SampleInput(input_tensor, bin_ct,
+                          weight=weight_tensor, density=density)
+
+        bins_tensor = [make_arg(ct + 1) for ct in bin_ct]
+        yield SampleInput(input_tensor, bins_tensor,
+                          weight=weight_tensor, density=density)
+
+def error_inputs_histogramdd(opinfo, device, **kwargs):
+    invalid_bins = [1, 1, 1, 1, 1]
+    make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False)
+    msg = "histogramdd: The size of bins must be equal to the innermost dimension of the input."
+    yield ErrorInput(SampleInput(make_arg(5, 6), invalid_bins), error_regex=msg)
+
+def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
+
+    for size, min, max in product(sizes, [0, -10], [0, 10]):
+        # construct sample input omitting bins arg
+        yield SampleInput(make_arg(size), min=min, max=max)
+
+        # construct sample inputs with a few different bins values
+        for bins in [1, 3, 10]:
+            yield SampleInput(make_arg(size), bins=bins, min=min, max=max)
+
+def sample_inputs_bincount(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    for size, weighted in product((S, M), [False, True]):
+        input_tensor = torch.randint(0, size, (size,), dtype=dtype, device=device)
+        weight_tensor = make_arg((size,)) if weighted else None
+
+        max_val = int(input_tensor.max().item())
+
+        for minlength in [0, max_val // 2, max_val, 2 * max_val]:
+            yield SampleInput(
+                input_tensor, weights=weight_tensor, minlength=minlength)
+
+def sample_inputs_bucketize(op_info, device, dtype, requires_grad, reference_inputs_mode=False, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    sizes = (((), S), ((S,), S), ((S, S), S), ((S, S, S), S), ((S, 1, S), S), ((S, 0, S), S))
+
+    if reference_inputs_mode:
+        sizes += (((256,), 128), ((128,), 256), ((32, 32), 11), ((32, 4, 32), 33))
+
+    for (input_shape, nb), out_int32, right in product(sizes, [False, True], [False, True]):
+        input_tensor = make_arg(input_shape)
+        boundaries = make_arg(nb).msort()
+
+        yield SampleInput(input_tensor, boundaries,
+                          out_int32=out_int32, right=right)
+
+reference_inputs_bucketize = partial(sample_inputs_bucketize, reference_inputs_mode=True)
+
+def error_inputs_bucketize(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False)
+    yield ErrorInput(SampleInput(make_arg((S, S, S)), make_arg((S, S))),
+                     error_regex="boundaries tensor must be 1 dimension")
+
+def sample_inputs_searchsorted(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    # (unsorted tensor size, (input sizes,), is_scalar)
+    sizes = (
+        ((0,), ((0,),), False),
+        ((M,), ((), (M,), (M, M)), False),
+        ((0, 0), ((0, 0),), False),
+        ((M, M), ((M, M),), False),
+        ((0, 0, 0), ((0, 0, 0),), False),
+        ((M, M, M), ((M, M, M),), False),
+        ((L,), ((),), True),
+    )
+
+    for (size, input_sizes, is_scalar), noncontiguous, out_int32, right in product(
+        sizes, [False, True], [False, True], [False, True]
+    ):
+        unsorted_tensor = make_arg(size, noncontiguous=noncontiguous)
+        for input_size in input_sizes:
+            input = make_arg(input_size, noncontiguous=noncontiguous)
+            if is_scalar:
+                input = input.item()
+            if np.prod(size) == 0:
+                boundary_tensor = unsorted_tensor
+                sorter = make_tensor(size, dtype=torch.int64, device=device, noncontiguous=noncontiguous)
+            else:
+                boundary_tensor, sorter = torch.sort(unsorted_tensor)
+            side = "right" if right else "left"
+
+            yield SampleInput(boundary_tensor, input, out_int32=out_int32, right=right)
+            yield SampleInput(boundary_tensor, input, out_int32=out_int32, side=side)
+
+            yield SampleInput(unsorted_tensor, input, out_int32=out_int32, right=right, sorter=sorter)
+            yield SampleInput(unsorted_tensor, input, out_int32=out_int32, side=side, sorter=sorter)
+
+def sample_inputs_gradient(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
+    test_cases_float = (
+        ((S,), None, None, 1),
+        ((S,), 2., None, 1),
+        ((S, S), None, None, 2),
+        ((S, S), [2.0, 2.1], None, 1),
+        ((S, S), [2.0, 2.1], (0, 1), 1),
+        ((4, 4, 4), [2., 1.], (0, 1), 2),
+    )
+    for size, spacing, dim, edge_order in test_cases_float:
+        t = make_arg(size)
+        yield SampleInput(t, dim=dim, spacing=spacing, edge_order=edge_order)
+
+    test_cases_tensor = (
+        ((3, 3, 3), ((1.1, 2.0, 3.5), (4.0, 2, 6.0)), (0, -1), 1),
+        ((3, 3, 3), ((1.0, 3.0, 2.0), (8.0, 6.0, 1.0)), (0, 1), 2),
+    )
+    for size, coordinates, dim, edge_order in test_cases_tensor:
+        t = make_arg(size)
+        coordinates_tensor_list = []
+        for coords in coordinates:
+            # `coords` will always contain floating point values and Python 3.10 does not support this
+            # implicit conversion to an integer using `__int__`
+            # TODO: this can be simplified after https://github.com/pytorch/pytorch/issues/69316 is fixed
+            a = torch.tensor(coords, device=device)
+            coordinates_tensor_list.append(a.to(dtype))
+        yield SampleInput(t, dim=dim, spacing=coordinates_tensor_list, edge_order=edge_order)
+
+def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    test_args = [
+        ([1, 2],),
+        (slice(0, 3),),
+        ((slice(0, 3), 1),),
+        (([0, 2, 3], [1, 3, 3], [0, 0, 2]),),
+        (([0, 0, 3], [1, 1, 3], [0, 0, 2]),),
+        ((slice(None), slice(None), [0, 3]),),
+        ((slice(None), [0, 3], slice(None)),),
+        (([0, 3], slice(None), slice(None)),),
+        (([0, 3], [1, 2], slice(None)),),
+        (([0, 3], ),),
+        (([0, 3], slice(None)),),
+        (([0, 3], Ellipsis),),
+        (([0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])),),
+        (index_variable(2, S, device=device),),
+        (mask_not_all_zeros((S,)),),
+    ]
+
+    for args in test_args:
+        yield SampleInput(make_arg((S, S, S)), args=args)
+
+    yield SampleInput(make_arg((S, S, S, S)), args=((slice(None), [0, 1], slice(None), [0, 1]),))
+
+def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    for accumulate in [False, True]:
+        # Test with indices arg
+        yield SampleInput(
+            make_arg((S, S,)),
+            # As defined in the docs, if accumulate is false, duplicate indices are not supported
+            (index_variable(2 if accumulate else 1, S, device=device),),
+            make_arg((2 if accumulate else 1, S)),
+            accumulate=accumulate)
+
+        # Test with mask arg
+        mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,))
+        yield SampleInput(
+            make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate)
+
+def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs):
+    def small_3d_unique():
+        res = torch.randperm(S * S * S, dtype=torch.int64, device=device).view(S, S, S)
+        res = res.to(dtype).requires_grad_(requires_grad)
+        return res
+
+    def large_1d_unique():
+        res = torch.randperm(L * L * L, dtype=torch.int64, device=device)
+        res = res.to(dtype).requires_grad_(requires_grad)
+        return res
+
+    # Test case for large tensor.
+    yield SampleInput(large_1d_unique())
+
+    # Test cases for small 3d tensors.
+    # Imitates legacy tests from test/test_torch.py
+    dims = range(-3, 3)
+    flag = [True, False]
+    for dim, descending, stable in product(dims, flag, flag):
+        # default schema without stable sort
+        if not (dtype == torch.bool and torch.device(device).type == 'cuda'):
+            # bool and cuda requires stable sort for stable results, at least
+            # for the return index
+            yield SampleInput(small_3d_unique(), dim, descending)
+        # schema with stable sort, no CUDA support yet
+        if torch.device(device).type == 'cpu':
+            yield SampleInput(
+                small_3d_unique(), dim=dim, descending=descending, stable=stable)
+
+    # Test cases for scalar tensor
+    tensor_opt = dict(dtype=dtype, device=device, requires_grad=requires_grad)
+    yield SampleInput(torch.tensor(1, **tensor_opt))
+    yield SampleInput(torch.tensor(1, **tensor_opt), 0)
+    yield SampleInput(torch.tensor(1, **tensor_opt), 0, True)
+
+    # Test cases for empty tensor
+    yield SampleInput(torch.tensor((), **tensor_opt))
+    yield SampleInput(torch.tensor((), **tensor_opt), 0)
+    yield SampleInput(torch.tensor((), **tensor_opt), 0, True)
+
+    # Test cases for stable sort
+    yield SampleInput(small_3d_unique(), stable=True)
+    yield SampleInput(small_3d_unique(), dim=0, stable=True)
+    yield SampleInput(small_3d_unique(), dim=0, descending=True, stable=True)
+
+def sample_inputs_threshold(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    sizes = ((), (S,), (S, S), (S, S, S))
+    for x_size in sizes:
+        # threshold and values args must be numbers
+        yield SampleInput(make_arg(x_size), make_arg(()).item(), make_arg(()).item())
+
+def sample_inputs_unique(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
+
+    for shape, sorted, return_inverse, return_counts, dim in \
+            product(sizes, [False, True], [False, True], [False, True], [None, -2, -1, 0, 1, 2]):
+        # torch.unique cannot be called if the input tensor has a zero dimension which isn't the selected dim
+        if 0 in shape and shape.index(0) is not dim:
+            continue
+
+        # skip invalid dim args
+        if dim is not None and (dim < -len(shape) or dim >= len(shape)):
+            continue
+
+        kwargs = dict(sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+
+        # construct a test case with only one distinct value
+        input_t = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+        yield SampleInput(input_t, **kwargs)
+
+        # construct a test case with mixed 0s and 1s
+        input_t = make_arg(shape, dtype=torch.bool, requires_grad=False)\
+            .to(dtype).requires_grad_(requires_grad)
+        yield SampleInput(input_t, **kwargs)
+
+        # construct a test case with many different values
+        yield SampleInput(make_arg(shape), **kwargs)
+
+def sample_inputs_unique_consecutive(*args, **kwargs):
+    for sample_input in sample_inputs_unique(*args, **kwargs):
+        if not sample_input.kwargs["sorted"]:
+            sample_input.kwargs.pop("sorted")
+            yield sample_input
+
+def sample_inputs_adaptive_avg_pool1d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as (input shape, output size)
+    cases = (
+        ((0, 8, 8), (5,)),
+        ((3, 8, 8), 5),
+        ((3, 8, 8), 1)
+    )
+
+    for input_shape, output_size in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(output_size,))
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
+
+
+def error_inputs_adaptive_avg_pool1d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # error inputs for empty output
+    yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()),
+                     error_regex="'output_size' should contain one int")
+
+    # error inputs for output_size lesser than 0
+    yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)),
+                     error_regex="elements of output_size must be greater than or equal to 0")
+
+
+def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as (input shape, output size)
+    cases = (
+        ((1, 8, 8, 8), (5, 7)),
+        ((2, 8, 8, 8), (None, 7)),
+        ((1, 8, 4, 3), (5, None)),
+        ((1, 8, 4, 3), (None, None)),
+        ((1, 8, 4, 3), (5)),
+    )
+
+    for input_shape, output_size in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(output_size,))
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
+
+
+def error_inputs_adaptive_avg_pool2d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # error inputs for incorrect input dimension
+    yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)),
+                     error_type=ValueError, error_regex="Input dimension should be at least 3")
+
+    # error inputs for empty output
+    yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
+                     error_regex="output_size must be 2")
+
+    # error inputs for output_size lesser than 0
+    yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)),
+                     error_regex="elements of output_size must be greater than or equal to 0")
+
+
+def sample_inputs_adaptive_avg_pool3d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as (input shape, output size)
+    cases = (
+        ((0, 8, 8, 8, 8), (5, 7, 4)),
+        ((1, 8, 4, 3, 7), (None, None, None)),
+        ((1, 8, 4, 3, 7), (1, 1, 1)),
+        ((3, 3, 8, 8, 6), (5, 7, None)),
+        ((1, 3, 8, 8, 6), (5, None, 2)),
+        ((3, 3, 8, 8, 6), (None, 3, 2)),
+    )
+
+    for input_shape, output_size in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(output_size,))
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
+
+
+def error_inputs_adaptive_avg_pool3d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # error inputs for incorrect input dimension
+    yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)),
+                     error_type=ValueError, error_regex="Input dimension should be at least 4")
+
+    # error inputs for empty output
+    yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
+                     error_regex="output_size must be 3")
+
+    # error inputs for output_size lesser than 0
+    yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)),
+                     error_regex="elements of output_size must be greater than or equal to 0")
+
+
+def sample_inputs_adaptive_max_pool1d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as (input shape, output size)
+    cases = (
+        # ((0, 8, 8), (5,)),
+        # 0 batch size doesn't work,  cannot reshape tensor of 0 elements into shape [0, 8, -1]
+        ((3, 4, 4), 3),
+        ((3, 4, 4), 1)
+    )
+
+    for shapes, return_idx in product(cases, (True, False)):
+        # Batched
+        yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
+        # Unbatched
+        yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
+
+
+def error_inputs_adaptive_max_pool1d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # error inputs for empty output
+    yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()),
+                     error_regex="'output_size' should contain one int")
+
+    # error inputs for output_size lesser than 0
+    yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)),
+                     error_regex="Trying to create tensor with negative dimension")
+
+def sample_inputs_adaptive_max_pool2d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as (input shape, output size)
+    cases = (
+        # ((0, 8, 8, 8), (5, 7)),
+        # 0 batch size doesn't work,  cannot reshape tensor of 0 elements into shape [0, 8, -1]
+        ((1, 4, 4, 4), (2, 3)),
+        ((2, 4, 4, 4), (None, 3)),
+        ((2, 4, 4, 4), (1, 1)),
+        ((1, 4, 4, 3), (3, None)),
+        ((1, 4, 4, 3), (None, None)),
+        ((1, 4, 4, 3), (3)),
+    )
+
+    for shapes, return_idx in product(cases, (True, False)):
+        # Batched
+        yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
+        # Unbatched
+        yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
+
+def error_inputs_adaptive_max_pool2d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # error inputs for incorrect input dimension
+    yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)),
+                     error_type=ValueError, error_regex="Input dimension should be at least 3")
+
+    # error inputs for empty output
+    yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
+                     error_regex="internal error")
+
+    # error inputs for output_size lesser than 0
+    yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)),
+                     error_regex="Trying to create tensor with negative dimension")
+
+
+def sample_inputs_adaptive_max_pool3d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as (input shape, output size)
+    cases = (
+        # ((0, 8, 8, 8, 8), (5, 7, 4)),
+        # 0 batch size doesn't work,  cannot reshape tensor of 0 elements into shape [0, 8, -1]
+        ((1, 4, 4, 3, 5), (None, None, None)),
+        ((1, 4, 4, 3, 5), (1, 1, 1)),
+        ((3, 3, 4, 4, 6), (2, 3, None)),
+        ((1, 3, 4, 4, 6), (3, None, 2)),
+        ((3, 3, 4, 4, 6), (None, 3, 2)),
+    )
+
+    for shapes, return_idx in product(cases, (True, False)):
+        # Batched
+        yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
+        # Unbatched
+        yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
+
+def error_inputs_adaptive_max_pool3d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # error inputs for incorrect input dimension
+    yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)),
+                     error_type=ValueError, error_regex="Input dimension should be at least 4")
+
+    # error inputs for empty output
+    yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
+                     error_regex="internal error")
+
+    # error inputs for output_size lesser than 0
+    yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)),
+                     error_regex="Trying to create tensor with negative dimension")
+
+
+class _TestParamsMaxPoolBase:
+
+    def __init__(self) -> None:
+        self.kwargs = {
+            'kernel_size': [3],
+            'stride': [2, None],
+            'ceil_mode': [True, False],
+            'padding': [0, 1],
+            'dilation': [1],
+            'return_indices': [True, False]
+        }
+
+        self.shapes = [
+            [1, 2, None],  # batch
+            [2],  # channels
+            [3, 6]  # signal
+        ]
+
+    def _gen_shape(self):
+        for shape in product(*self.shapes):
+            # shape[0] is None indicates missing batch dimension
+            if shape[0] is None:
+                shape = shape[1:]
+
+            yield shape, torch.contiguous_format
+            # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format
+            if len(self.shapes) == 4 and len(shape) == 4:
+                yield shape, torch.channels_last
+
+    def _gen_kwargs(self):
+        keys = self.kwargs.keys()
+        for values in product(*self.kwargs.values()):
+            yield dict(zip(keys, values))
+
+    def gen_input_params(self):
+        yield from product(self._gen_shape(), self._gen_kwargs())
+
+class _TestParamsMaxPool1d(_TestParamsMaxPoolBase):
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.kwargs['kernel_size'] += [(3,)]
+        self.kwargs['stride'] += [(2,)]
+        self.kwargs['padding'] += [(1,)]
+        self.kwargs['dilation'] += [(1,)]
+
+class _TestParamsMaxPool2d(_TestParamsMaxPoolBase):
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.kwargs['kernel_size'] += [(3, 2)]
+        self.kwargs['stride'] += [(2, 1)]
+        self.kwargs['padding'] += [(1, 1)]
+        self.kwargs['dilation'] += [(1, 2)]
+
+        self.shapes.append([6])
+
+class _TestParamsMaxPool3d(_TestParamsMaxPoolBase):
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.kwargs['kernel_size'] += [(3, 2, 3)]
+        self.kwargs['stride'] += [(2, 1, 2)]
+        self.kwargs['dilation'] += [(1, 2, 1)]
+
+        self.shapes.append([6])
+        self.shapes.append([5])
+
+def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    params_generator_type_dict = {
+        'nn.functional.max_pool1d': _TestParamsMaxPool1d,
+        'nn.functional.max_pool2d': _TestParamsMaxPool2d,
+        'nn.functional.max_pool3d': _TestParamsMaxPool3d,
+        'max_pool2d_with_indices_backward': _TestParamsMaxPool2d,
+    }
+
+    params_generator = params_generator_type_dict[op_info.name]()
+    for (shape, memory_format), kwargs in params_generator.gen_input_params():
+        arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad)
+        yield SampleInput(arg, kwargs=kwargs)
+
+def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs):
+    out, indices = torch.nn.functional.max_pool2d_with_indices(
+        *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True)
+    grad_out = torch.ones_like(out)
+    if stride is None:
+        stride = kernel_size
+    out_b = torch.ops.aten.max_pool2d_with_indices_backward.default(
+        grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices)
+    return out_b
+
+def error_inputs_max_pool1d(op_info, device, **kwargs):
+    # Toggle requires_grad because `max_pool1d` has different path
+    # based on whether `requires_grad` is set or not.
+    for requires_grad in (True, False):
+        make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=requires_grad)
+        # error inputs when pad is negative
+        x = make_arg((0, 1, 49))
+        yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
+                         error_regex='pad must be non-negative')
+
+        # error inputs when pad > kernel_size / 2
+        yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
+                         error_regex='pad should be at most half of effective kernel size')
+
+        # error inputs when pad > ((kernel_size - 1) * dilation + 1) / 2, when dilation is not default
+        yield ErrorInput(SampleInput(x,
+                         kwargs={'kernel_size': 3, 'dilation': 2, 'stride': 1, 'padding': 3, 'return_indices': True}),
+                         error_regex='pad should be at most half of effective kernel size')
+
+        # error inputs for input tensor
+        error_msg = r'Expected 2D or 3D \(batch mode\) tensor with optional 0 dim batch size for input'
+        yield ErrorInput(SampleInput(make_arg((), requires_grad=requires_grad), kwargs={'kernel_size': 1}),
+                         error_regex=error_msg)
+
+        # error inputs for empty input
+        yield ErrorInput(SampleInput(torch.tensor([], device=device, requires_grad=requires_grad),
+                                     kwargs={'kernel_size': 1}),
+                         error_regex=error_msg)
+
+        # error: unbatched input with 0 sized non-batch dims.
+        yield ErrorInput(SampleInput(make_arg((0, 10), requires_grad=requires_grad),
+                                     kwargs={'kernel_size': 1}),
+                         error_regex=error_msg)
+
+        # error: batched input with 0 sized non-batch dims.
+        yield ErrorInput(SampleInput(make_arg((1, 10, 0), requires_grad=requires_grad),
+                                     kwargs={'kernel_size': 1}),
+                         error_regex=error_msg)
+
+        # error inputs for empty input with stride=0
+        error_msg = 'stride must be greater than zero, but got 0'
+        yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}),
+                         error_regex=error_msg)
+
+        # error inputs for empty input with dilation=0
+        error_msg = 'dilation must be greater than zero, but got 0'
+        yield ErrorInput(SampleInput(make_arg((3, 3, 3)),
+                                     kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}),
+                         error_regex=error_msg)
+
+        # error inputs for invalid output size
+        error_msg = 'Invalid computed output size: -2'
+        yield ErrorInput(SampleInput(make_arg((2, 2, 2)),
+                                     kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}),
+                         error_regex=error_msg)
+
+        # error inputs when kernel_size=0
+        error_msg = 'kernel_size must be greater than zero'
+        yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}),
+                         error_regex=error_msg)
+
+        # error inputs for strides > 0
+        error_msg = 'stride must be greater than zero'
+        yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}),
+                         error_regex=error_msg)
+
+
+def error_inputs_max_pool2d(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
+    # error inputs when pad is negative
+    x = make_arg((0, 1, 49))
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
+                     error_regex='pad must be non-negative')
+    # 2-dimensional kernel
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1, 'return_indices': True}),
+                     error_regex='pad must be non-negative')
+
+    # error inputs when pad > kernel_size / 2 (kernel_size : int)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
+                     error_regex='pad should be at most half of effective kernel size')
+
+    # error inputs when pad > kernel_size / 2 (kernel_size : tuple)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4, 'return_indices': True}),
+                     error_regex='pad should be at most half of effective kernel size')
+
+    # error: unbatched input with 0 sized non-batch dims.
+    err_msg = r'Expected 3D or 4D \(batch mode\) tensor with optional 0 dim batch size for input'
+    yield ErrorInput(SampleInput(make_arg((1, 0, 10)),
+                                 kwargs={'kernel_size': 1}),
+                     error_regex=err_msg)
+
+    # error: batched input with 0 sized non-batch dims.
+    yield ErrorInput(SampleInput(make_arg((2, 1, 10, 0)),
+                                 kwargs={'kernel_size': 1}),
+                     error_regex=err_msg)
+
+
+def error_inputs_max_pool3d(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
+    # error inputs when pad is negative
+    x = make_arg((0, 1, 49, 50))
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
+                     error_regex='pad must be non-negative')
+    # 3-dimensional kernel
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50,
+                                            'padding': -1, 'return_indices': True}),
+                     error_regex='pad must be non-negative')
+
+    # error inputs when pad > kernel_size / 2 (kernel_size: int)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
+                     error_regex='pad should be at most half of effective kernel size')
+
+    # error inputs when pad > kernel_size / 2 (kernel_size: tuple)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50,
+                                            'padding': 4, 'return_indices': True}),
+                     error_regex='pad should be at most half of effective kernel size')
+
+    # error: unbatched input with 0 sized non-batch dims.
+    err_msg = r'Expected input\'s non-batch dimensions to have positive length'
+    yield ErrorInput(SampleInput(make_arg((0, 1, 2, 10)),
+                                 kwargs={'kernel_size': 1}),
+                     error_regex=err_msg)
+
+    # error: batched inputs with 0 sized non-batch dims.
+    yield ErrorInput(SampleInput(make_arg((2, 1, 0, 1, 2)),
+                                 kwargs={'kernel_size': 1}),
+                     error_regex=err_msg)
+
+
+def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: tuple[tuple[int], dict] = (  # type: ignore[assignment]
+                                     ((2, 1, 4, 5), {'p': 1., 'dim': 2}),
+                                     ((2, 3, 4, 5), {'p': 2., 'dim': 1}),
+                                     ((1, 2, 4, 5), {'p': 0.5, 'dim': 0}),
+                                     ((1, 3, 4, 5), {'p': -1., 'dim': 1}),
+                                     ((1, 3, 4, 5), {'p': 0., 'dim': -1}),
+                                     ((), {'p': 1.2, 'dim': 0}),
+                                     ((2, 3, 4, 5), {}),
+                                     ((2, 3, 4, 5), {'eps': 1e-4}))
+
+    for input_shape, kwargs in cases:
+        yield SampleInput(make_arg(input_shape), kwargs=kwargs)
+
+
+def complex_conv(fn, input_size, weight, grad_output, stride, padding, dilation, groups):
+    # conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
+    # a = conv(Wr, xr, br),
+    # b = conv(Wi, xi, 0),
+    # c = conv(Wr + Wi, xr + xi, br + bi)
+    # conv(W, x, b) = a - b + i(c - a - b)
+
+    grad_output_ = torch.view_as_real(grad_output)
+    grad_output_r = grad_output_[..., 0]
+    grad_output_i = grad_output_[..., 1]
+
+    weight_ = torch.view_as_real(weight)
+    weight_r = weight_[..., 0]
+    weight_i = weight_[..., 1]
+
+    a = fn(input_size, weight_r, grad_output_r, stride, padding, dilation, groups)
+    b = fn(input_size, weight_i, grad_output_i, stride, padding, dilation, groups)
+    c = fn(input_size, weight_r + weight_i, grad_output_r + grad_output_i, stride, padding, dilation, groups)
+
+    return (a - b) + 1j * (c - a - b)
+
+
+def conv_transpose_ref(input, weight, bias, stride=1, padding=0,
+                       output_padding=0, dilation=1, groups=1,
+                       fn=None):
+    # Derivative of `conv` is `conv_transpose`.
+    # To verify the correctness of `conv_transpose`,
+    # we rely `torch.nn.grad` implementation (which is tested in test_nn.py)
+    # for floating dtypes.
+
+    assert fn is not None
+
+    grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input,
+                   torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input,
+                   torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input}
+    batched_dim_map = {torch.nn.functional.conv_transpose1d: 3,
+                       torch.nn.functional.conv_transpose2d: 4,
+                       torch.nn.functional.conv_transpose3d: 5}
+
+    # Input for `ref` is ndarray.
+    input, weight = torch.from_numpy(input), torch.from_numpy(weight)
+
+    is_batched = len(input.shape) == batched_dim_map[fn]
+    if not is_batched:
+        input = input.unsqueeze(0)
+
+    if bias is not None:
+        bias = torch.from_numpy(bias)
+        unsqueeze_dims = input.ndim - 2
+        for _ in range(unsqueeze_dims):
+            bias = bias.unsqueeze(1)
+
+    grad_output = input
+    # Get the input shape for grad_fn.
+    conv_transpose_output = fn(grad_output.to('meta'), weight.to('meta'), None,
+                               stride=stride, padding=padding, output_padding=output_padding,
+                               groups=groups, dilation=dilation)
+    input_size = conv_transpose_output.shape
+
+    grad_fn = grad_fn_map[fn]
+    if weight.dtype.is_complex:
+        out = complex_conv(grad_fn, input_size, weight, grad_output, stride, padding, dilation, groups)
+    else:  # Floating
+        out = grad_fn(input_size, weight, grad_output, stride, padding, dilation, groups)
+
+    if bias is not None:
+        out = out + bias
+
+    return out.squeeze(0) if not is_batched else out
+
+
+def sample_inputs_conv_transpose1d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as shapes for input, weight, bias
+    # and a dict of values of (stride, padding, output_padding, groups, dilation)
+    cases: tuple[tuple[int], tuple[int], tuple[int], dict] = (  # type: ignore[assignment]
+        ((1, 3, 4), (3, 3, 3), (3,),
+         {'stride': (2,), 'padding': 2, 'output_padding': (1,), 'groups': 1}),
+        ((2, 2, 4), (2, 2, 4), (4,),
+         {'stride': (3,), 'padding': (1,), 'output_padding': (2,), 'groups': 2, 'dilation': (4,)}),
+        ((1, 1, 4), (1, 1, 4), (1,),
+         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2,)}),
+        ((1, 1, 4), (1, 2, 3), None,
+         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
+        ((1, 4, 5), (4, 8, 3), None,
+         {})
+    )
+
+    for input_shape, weight, bias, kwargs in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+
+
+def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as shapes for input, weight, bias
+    # and a dict of values of (stride, padding, output_padding, groups, dilation)
+    cases: tuple[tuple[int], tuple[int], tuple[int], dict] = (  # type: ignore[assignment]
+        ((1, 3, 4, 4), (3, 3, 3, 3), (3,),
+         {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}),
+        ((2, 2, 4, 4), (2, 2, 4, 5), (4,),
+         {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}),
+        ((1, 1, 4, 5), (1, 1, 4, 3), (1,),
+         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}),
+        ((1, 1, 4, 3), (1, 2, 3, 4), None,
+         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
+        ((2, 4, 4, 4), (4, 1, 3, 3), None, {'groups': 4}),
+        ((1, 2, 5, 5), (2, 4, 3, 3), None, {})
+    )
+
+    for input_shape, weight, bias, kwargs in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+
+def sample_inputs_conv_transpose3d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as shapes for input, weight, bias
+    # and a dict of values of (stride, padding, output_padding, groups, dilation)
+    cases: tuple[tuple[int], tuple[int], tuple[int], dict] = (  # type: ignore[assignment]
+        ((1, 3, 4, 4, 4), (3, 3, 3, 3, 3), (3,),
+         {'stride': (2, 2, 2), 'padding': 2, 'output_padding': (1, 1, 1), 'groups': 1}),
+        ((2, 2, 4, 4, 4), (2, 2, 4, 5, 6), (4,),
+         {'stride': (3, 2, 1), 'padding': (1, 2, 3), 'output_padding': (2, 3, 1), 'groups': 2, 'dilation': (4, 4, 4)}),
+        ((1, 1, 4, 5, 2), (1, 1, 4, 3, 1), (1,),
+         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3, 2)}),
+        ((1, 1, 4, 3, 4), (1, 2, 3, 4, 5), None,
+         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
+        ((1, 4, 5, 5, 5), (4, 8, 3, 3, 3), None,
+         {})
+    )
+
+    for input_shape, weight, bias, kwargs in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+
+
+def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as shapes for input, weight, bias,
+    # and a dict of values of (stride, padding, dilation, groups)
+    cases: tuple = (
+        ((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}),
+        ((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}),
+        ((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}),
+        ((2, 2, 4), (2, 1, 4), (2,), {'stride': (1,), 'padding': 'same', 'groups': 2, 'dilation': (2,)}),
+        # With defaults
+        ((1, 4, 5), (3, 4, 3), None, {}),
+    )
+
+    for input_shape, weight, bias, kwargs in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+
+
+def error_inputs_conv1d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float64)
+    make_int_arg = partial(make_tensor, device=device, dtype=torch.int64)
+    make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128)
+
+    # error inputs for different dtypes of input tensor and bias
+    yield ErrorInput(
+        SampleInput(make_int_arg((1, 1, 4)), args=(make_int_arg((1, 1, 2)), make_arg((1,)))),
+        error_regex="should be the same")
+
+    # error inputs for different dtypes of input tensor and bias
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_complex_arg((1,)))),
+        error_regex="should be the same")
+
+    # error inputs for negative strides
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))),
+                    kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported")
+
+    # error inputs for negative padding
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))),
+                    kwargs={'padding': (-1,)}), error_regex="negative padding is not supported")
+
+    # error inputs for negative dilation
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))),
+                    kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero")
+
+    # FIXME: https://github.com/pytorch/pytorch/issues/85656
+    # error inputs for bias shape not equal to the output channels
+    # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))),
+    #                  error_regex="expected bias to be 1-dimensional with 1 elements")
+
+    # error inputs for input.ndim != weight.ndim
+    yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))),
+                     error_regex="weight should have at least three dimensions")
+
+    # error inputs for the weight[0] are less than the number of groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))),
+                    kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0")
+
+    # error inputs for the weight[0] are less than the number of groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))),
+                    kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0")
+
+    # error inputs for invalid groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))),
+                    kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported")
+
+    # error inputs for invalid groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))),
+                    kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported")
+
+
+def error_inputs_conv2d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float64)
+    make_int_arg = partial(make_tensor, device=device, dtype=torch.int64)
+    make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128)
+
+    # error inputs for different dtypes of input tensor and bias
+    yield ErrorInput(
+        SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))),
+        error_regex="should be the same")
+
+    # error inputs for different dtypes of input tensor and bias
+    yield ErrorInput(
+        SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))),
+        error_regex="should be the same")
+
+    # error inputs for negative strides
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))),
+                    kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported")
+
+    # error inputs for negative padding
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))),
+                    kwargs={'padding': (-1,)}), error_regex="negative padding is not supported")
+
+    # error inputs for negative dilation
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))),
+                    kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero")
+
+    # FIXME: https://github.com/pytorch/pytorch/issues/85656
+    # error inputs for bias shape not equal to the output channels
+    # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))),
+    #                  error_regex="expected bias to be 1-dimensional with 1 elements")
+
+    # error inputs for input.ndim != weight.ndim
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))),
+                    kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight")
+
+    # error inputs for the weight[0] are less than the number of groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))),
+                    kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0")
+
+    # error inputs for groups the weight[0] are less than the number of groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))),
+                    kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0")
+
+    # error inputs for invalid groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))),
+                    kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported")
+
+    # error inputs for invalid groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))),
+                    kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported")
+
+
+def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as shapes for input, weight, bias
+    # and a dict of values of (stride, padding, groups, dilation)
+    cases: tuple = (
+        ((1, 3, 4, 4), (3, 3, 3, 3), (3,),
+            {'stride': (2, 2), 'padding': 2, 'groups': 1}),
+        ((2, 4, 8, 8), (2, 2, 3, 3), (2,),
+            {'stride': (3, 2), 'padding': (2, 1), 'groups': 2, 'dilation': (4, 4)}),
+        ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
+            {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}),
+        ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
+            {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}),
+        ((1, 2, 4, 3), (4, 2, 3, 4), None,
+            {'stride': 2, 'padding': 1, 'groups': 1}),
+        ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
+            {'stride': 2, 'padding': "valid"}),
+        ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
+            {'stride': 1, 'padding': "same", 'dilation': 3}),
+        # Below are the group related samples from common_nn.py
+        ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4}),
+        ((2, 4, 6, 6), (8, 1, 3, 3), (8,), {'groups': 4}),
+        ((2, 4, 6, 6), (8, 1, 3, 3), None, {'groups': 4}),
+        ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'stride': (3, 2)}),
+        ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'padding': (1, 1)}),
+        ((2, 4, 5, 5), (4, 1, 2, 2), (4,), {'groups': 4, 'dilation': (2, 2)}),
+        ((2, 4, 6, 5), (6, 2, 3, 2), (6,), {'groups': 2}),
+        # With defaults
+        ((1, 4, 5, 5), (3, 4, 3, 3), None, {}),
+    )
+
+    for input_shape, weight, bias, kwargs in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+
+
+def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as shapes for input, weight, bias
+    # and dict of values of (stride, padding, dilation, groups)
+    cases: tuple = (
+        ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}),
+        ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}),
+        ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}),
+        ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}),
+        ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}),
+        ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}),
+        ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}),
+        ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}),
+        ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}),
+        ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}),
+    )
+
+    for input_shape, weight, bias, kwargs in cases:
+        # Batched
+        yield SampleInput(make_arg(input_shape), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+        # Unbatched
+        yield SampleInput(make_arg(input_shape[1:]), args=(
+            make_arg(weight),
+            make_arg(bias) if bias is not None else bias
+        ), kwargs=kwargs)
+
+
+def error_inputs_conv3d(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float64)
+    make_int_arg = partial(make_tensor, device=device, dtype=torch.int64)
+    make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128)
+
+    # error inputs for different dtypes of input tensor and bias
+    yield ErrorInput(
+        SampleInput(make_int_arg((1, 1, 4, 4, 4)), args=(make_int_arg((1, 1, 2, 2, 2)), make_arg((1,)))),
+        error_regex="should be the same")
+
+    # error inputs for different dtypes of input tensor and bias
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_complex_arg((1,)))),
+        error_regex="should be the same")
+
+    # error inputs for negative strides
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))),
+                    kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported")
+
+    # error inputs for negative padding
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))),
+                    kwargs={'padding': (-1,)}), error_regex="negative padding is not supported")
+
+    # error inputs for negative dilation
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))),
+                    kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero")
+
+    # FIXME: https://github.com/pytorch/pytorch/issues/85656
+    # error inputs for bias shape not equal to the output channels
+    # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))),
+    #                  error_regex="expected bias to be 1-dimensional with 1 elements")
+
+    # error inputs for input.ndim != weight.ndim
+    yield ErrorInput(
+        SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))),
+                    kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight")
+
+    # error inputs for the weight[0] are less than the number of groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)),
+                    make_arg((2,))), kwargs={'groups': 3}),
+        error_regex="expected weight to be at least 3 at dimension 0")
+
+    # error inputs for the weight[0] are less than the number of groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)),
+                    make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}),
+        error_regex="expected weight to be at least 3 at dimension 0")
+
+    # error inputs for invalid groups
+    yield ErrorInput(
+        SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)),
+                    make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}),
+        error_regex="non-positive groups is not supported")
+
+    # error inputs for padding='same' not supported by strided convolutions
+    yield ErrorInput(
+        SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)),
+                    make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}),
+        error_regex="padding='same' is not supported for strided convolutions")
+
+
+def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as input shape, num groups, and kwargs for eps
+    cases: tuple[tuple[int], int, float] = (  # type: ignore[assignment]
+        ((1, 6, 3), 2, {'eps' : 0.5}),
+        ((2, 6, 3), 2, {'eps' : -0.5}),
+        ((1, 3), 1, {'eps' : 1e-5}),
+        ((0, 2), 1, {'eps' : 1e-5}),
+        ((S, S, S), 1, {'eps' : 0.5}),
+    )
+
+    # num_channels is inferred to be input.shape[1] dimension
+    for input_shape, num_groups, kwargs in cases:
+        # Shape of weight and bias should be the same as num_channels
+        channels = input_shape[1] if len(input_shape) > 1 else 0
+        weight_tensor = make_arg(channels)
+        bias_tensor = make_arg(channels)
+
+        # Checking for permutations of weights and biases as `None`
+        weights = [weight_tensor, None]
+        biases = [bias_tensor, None]
+        for weight, bias in itertools.product(weights, biases):
+            kwargs = {
+                'weight': weight,
+                'bias': bias,
+                **kwargs
+            }
+            yield SampleInput(make_arg(input_shape), num_groups, **kwargs)
+
+    # Without any optional args
+    yield SampleInput(make_arg((1, 2)), args=(1,))
+
+def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_group_norm(
+        op_info, device, dtype, requires_grad, **kwargs)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as input shape, num groups, and kwargs for eps
+    cases: tuple[tuple[int], int, float] = (  # type: ignore[assignment]
+        ((20, 6, 10, 10), 3, {'eps' : 1e-5}),
+        # equivalent with InstanceNorm
+        # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C)
+        ((20, 6, 10, 10), 6, {'eps' : 1e-5}),
+        # equivalent with LayerNorm
+        # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False)
+        ((20, 6, 10, 10), 1, {'eps' : 1e-5}),
+    )
+
+    # num_channels is inferred to be input.shape[1] dimension
+    for input_shape, num_groups, kwargs in cases:
+        # Shape of weight and bias should be the same as num_channels
+        channels = input_shape[1] if len(input_shape) > 1 else 0
+        input_tensor = make_arg(input_shape)
+        weight_tensor = make_arg(channels)
+        bias_tensor = make_arg(channels)
+
+        # Checking for permutations of weights and biases as `None`
+        weights = [weight_tensor, None]
+        biases = [bias_tensor, None]
+        for weight, bias in itertools.product(weights, biases):
+            kwargs = {
+                'weight': weight,
+                'bias': bias,
+                **kwargs
+            }
+            yield SampleInput(input_tensor, num_groups, **kwargs)
+
+
+def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    # Ordered as: input shape, kwargs for momentum, eps
+    cases: tuple[tuple[int], dict] = (  # type: ignore[assignment]
+        ((S, S, S), {'momentum': 0.5, 'eps': 0.6}),
+        ((S, S, S), {'momentum': 0.5, 'eps': 0.6, 'use_input_stats': True}),
+        ((3, 2, 4), {'momentum': -1.2}),
+        ((3, 2, 4), {'momentum': 0.0}),
+        ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}),
+        ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}),
+    )
+
+    for input_shape, kwargs in cases:
+        # args: running mean, running var, weight and bias should necessarily be of shape: (channels,)
+        channels = input_shape[1]
+        weight = make_arg(channels)
+        bias = make_arg(channels)
+        running_mean = make_arg_without_requires_grad(channels, low=0)
+        running_var = make_arg_without_requires_grad(channels, low=0)
+        new_kwargs = {
+            'running_mean': running_mean,
+            'running_var': running_var,
+            'weight': weight,
+            'bias': bias,
+            **kwargs
+        }
+
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(),
+            kwargs=new_kwargs
+        )
+
+    # Checking for permutations of weights and biases as `None`
+    # instance_norm assumes that if there's a bias, there's a weight
+    weights = [channels, None]
+    biases = [None, None]
+
+    for weight_channels, bias_channels in zip(weights, biases):
+        running_mean = make_arg_without_requires_grad(channels, low=0)
+        running_var = make_arg_without_requires_grad(channels, low=0)
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(),
+            kwargs={
+                'running_mean': running_mean,
+                'running_var': running_var,
+                'weight': make_arg(weight_channels) if weight_channels is not None else None,
+                'bias': make_arg(bias_channels) if bias_channels is not None else None
+            }
+        )
+
+    # Test case for no optional kwargs
+    yield SampleInput(make_arg((1, 2, 3)), kwargs={})
+
+def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    def make_bool_mask(*shape):
+        return torch.randint(0, 2, shape, device=device, dtype=torch.bool)
+
+    def mask_two_rows(rows, cols):
+        mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device)
+        mask_two_rows[rows - 1] = False
+        mask_two_rows[rows - 3] = False
+        return mask_two_rows
+
+    def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor:
+        return torch.where(~mask, float('-inf'), 0.0)
+
+    def with_requires_grad(tensor):
+        return tensor.requires_grad_(requires_grad)
+
+    def generate_input_from_mask(mask_shape, dim):
+        mask = make_bool_mask(*mask_shape)
+        input_tensor = make_arg(mask_shape)
+        masked_input = input_tensor + convert_to_float_mask(mask)
+        return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim})
+
+    samples = [
+        # Basic 3D tensor with mask
+        generate_input_from_mask((2, 3, 4), dim=1),
+        # 2D tensor with mask, testing different dim
+        generate_input_from_mask((5, 5), dim=0),
+        # 4D tensor, testing with a different dim
+        generate_input_from_mask((2, 3, 4, 5), dim=2),
+        # Edge case: 1D tensor
+        generate_input_from_mask((10,), dim=0),
+        # Edge case: tensor with one dimension of size 1
+        generate_input_from_mask((1, 5, 5), dim=1),
+        # Testing with all elements masked
+        SampleInput(
+            with_requires_grad(
+                make_arg((3, 3))
+                + convert_to_float_mask(
+                    torch.zeros((3, 3), dtype=torch.bool, device=device)
+                )
+            ),
+            kwargs={"dim": 1},
+        ),
+        # Testing with no elements masked
+        SampleInput(
+            with_requires_grad(
+                make_arg((3, 3))
+                + convert_to_float_mask(
+                    torch.ones((3, 3), dtype=torch.bool, device=device)
+                )
+            ),
+            kwargs={"dim": 1},
+        ),
+        # Testing with two rows masked
+        SampleInput(
+            with_requires_grad(
+                make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3))
+            ),
+            kwargs={"dim": 1},
+        ),
+    ]
+    yield from samples
+
+def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as input shape, normalized_shape and a kwarg dict for eps
+    cases: tuple[tuple[int], tuple[int], dict] = (  # type: ignore[assignment]
+        ((1, 2, 3), (1, 2, 3), {'eps': 0.5}),
+        ((2, 2, 3), (2, 3), {'eps': -0.5}),
+        ((1,), (1,), {}),
+        ((1, 2), (2,), {}),
+        ((0, 1), (1,), {}),
+    )
+
+    for input_shape, normalized_shape, kwargs in cases:
+        # Shape of weight and bias should be the same as normalized_shape
+        weight = make_arg(normalized_shape)
+        bias = make_arg(normalized_shape)
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(normalized_shape, weight, bias),
+            kwargs=kwargs
+        )
+    # Without any optional args
+    yield SampleInput(make_arg((1, 2)), args=((2,),))
+
+    # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs,
+    # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400
+
+    # With weight and a `None` bias
+    # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None))
+
+    # With `None` weight and bias (tests failing for this, see the link above)
+    # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,))))
+
+
+def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as input shape, normalized_shape, eps
+    cases: tuple[tuple[int], tuple[int], float] = (  # type: ignore[assignment]
+        ((1, 2, 3), (1, 2, 3), 0.5),
+        ((2, 2, 3), (2, 3), -0.5),
+        ((1,), (1,), 1e-5),
+        ((1, 2), (2,), 1e-5),
+        ((0, 1), (1,), 1e-5),
+    )
+
+    for input_shape, normalized_shape, eps in cases:
+        # Shape of weight and bias should be the same as normalized_shape
+        weight = make_arg(normalized_shape)
+        bias = make_arg(normalized_shape)
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(normalized_shape, weight, bias, eps),
+        )
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(normalized_shape, None, bias, eps),
+        )
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(normalized_shape, weight, None, eps),
+        )
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(normalized_shape, None, None, eps),
+        )
+
+def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000)
+
+    # Ordered as input shape, normalized_shape and a kwarg dict for eps
+    cases: tuple[tuple[int], tuple[int], dict] = (  # type: ignore[assignment]
+        ((1, 2, 3), (1, 2, 3), {'eps': 0.5}),
+        ((2, 2, 3), (2, 3), {'eps': -0.5}),
+        ((1,), (1,), {}),
+        ((1, 2), (2,), {}),
+        ((0, 1), (1,), {}),
+    )
+
+    for input_shape, normalized_shape, kwargs in cases:
+        # Shape of weight and bias should be the same as normalized_shape
+        weight = make_arg(normalized_shape)
+        yield SampleInput(
+            make_arg(input_shape),
+            args=(normalized_shape, weight),
+            kwargs=kwargs
+        )
+    # Without any optional args
+    yield SampleInput(make_arg((1, 2)), args=((2,),))
+
+def error_inputs_group_norm(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
+
+    # check that input has minimum number of dimensions
+    err_msg1 = "Expected at least 2 dimensions for input tensor but received"
+    s1 = SampleInput(make_arg(1), args=(1,))
+    yield ErrorInput(s1, error_regex=err_msg1)
+
+    # check that the channels dimension is compatible with number of groups
+    err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape"
+    s2 = SampleInput(make_arg((2, 7, 4)), args=(2,))
+    yield ErrorInput(s2, error_regex=err_msg2)
+
+def error_inputs_native_layer_norm(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
+    input_shape = (1, 2, 3)
+
+    err_msg1 = "Expected normalized_shape to be at least 1-dimensional"
+    s1 = SampleInput(
+        make_arg(input_shape), args=((), None, None, 1e-5)
+    )
+    yield ErrorInput(s1, error_regex=err_msg1)
+
+    normalized_shape = (1, 2, 3)
+    weight = make_arg((1, 2))
+    err_msg2 = "Expected weight to be of same shape as normalized_shape"
+    s2 = SampleInput(
+        make_arg(input_shape), args=(normalized_shape, weight, None, 1e-5)
+    )
+    yield ErrorInput(s2, error_regex=err_msg2)
+
+    bias = make_arg((1, 2))
+    err_msg3 = "Expected bias to be of same shape as normalized_shape"
+    s3 = SampleInput(
+        make_arg(input_shape), args=(normalized_shape, None, bias, 1e-5)
+    )
+    yield ErrorInput(s3, error_regex=err_msg3)
+
+    err_msg4 = "Given normalized_shape="
+    s4 = SampleInput(
+        make_arg((2, 2, 3)), args=((2, 2), None, None, 1e-5)
+    )
+    yield ErrorInput(s4, error_regex=err_msg4)
+
+def error_inputs_rms_norm(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
+    input_shape = (1, 2, 3)
+
+    err_msg1 = "Expected normalized_shape to be at least 1-dimensional"
+    s1 = SampleInput(
+        make_arg(input_shape), args=((), None, 1e-5)
+    )
+    yield ErrorInput(s1, error_regex=err_msg1)
+
+    normalized_shape = (1, 2, 3)
+    weight = make_arg((1, 2))
+    err_msg2 = "Expected weight to be of same shape as normalized_shape"
+    s2 = SampleInput(
+        make_arg(input_shape), args=(normalized_shape, weight, 1e-5)
+    )
+    yield ErrorInput(s2, error_regex=err_msg2)
+
+
+    err_msg4 = "Given normalized_shape="
+    s4 = SampleInput(
+        make_arg((2, 2, 3)), args=((2, 2), None, 1e-5)
+    )
+    yield ErrorInput(s4, error_regex=err_msg4)
+
+
+def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Ordered as input shape, size and a kwarg dict for alpha, beta, and k
+    cases: tuple[tuple[int], tuple[int], dict] = (  # type: ignore[assignment]
+        ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
+        ((1, 6, 3), 2, {'beta': 0.5, 'k': 1.25}),
+        ((1, 6, 3), 2, {'alpha': 3e-05, 'k': 1.25}),
+        ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5}),
+        ((1, 6, 3), 2, {'alpha': 3e-05}),
+        ((1, 6, 3), 2, {'beta': 0.5}),
+        ((1, 6, 3), 2, {'k': 1.25}),
+        ((1, 6, 3), 2, {}),
+        ((2, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
+        ((1, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
+        ((0, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
+    )
+
+    for input_shape, size, kwargs in cases:
+        yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs)
+
+def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
+    N = 5
+    # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
+    make_arg = partial(make_tensor, device=device, dtype=dtype,
+                       requires_grad=requires_grad, low=-5, high=5)
+    return (SampleInput(make_arg((N * 2, N * 2))) for _ in range(1, N))
+
+def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs):
+    features_options = [[3, 4], [8, 8]]
+    batch_options: list[list[int]] = [
+        [],  # no batch
+        [0],
+        [8],
+        [2, 3],
+    ]
+    create_tensor = partial(make_tensor, device=device, dtype=dtype,
+                            requires_grad=requires_grad, low=-2, high=2)
+
+    for has_bias, (in_feat, out_feat), batch_shape in \
+            itertools.product([True, False], features_options, batch_options):
+        input_tensor = create_tensor(batch_shape + [in_feat])
+        weight = create_tensor([out_feat, in_feat])
+        if not has_bias:
+            yield SampleInput(input_tensor, weight)
+            continue
+
+        bias = create_tensor([out_feat])
+        yield SampleInput(input_tensor, weight, bias)
+
+    # 5D tensor, used to crash on MPS, see https://github.com/pytorch/pytorch/issues/114942
+    yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2))
+    yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2), create_tensor(4))
+
+def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs):
+    features_options = [[3, 4, 5], [8, 8, 8]]
+    batch_options: list[list[int]] = [
+        [],  # no batch
+        [0],
+        [8],
+        [2, 3],
+    ]
+    create_tensor = partial(make_tensor, device=device, dtype=dtype,
+                            requires_grad=requires_grad, low=-2, high=2)
+
+    for has_bias, (in_feat1, in_feat2, out_feat), batch_shape in \
+            itertools.product([True, False], features_options, batch_options):
+        input_tensor1 = create_tensor(batch_shape + [in_feat1])
+        input_tensor2 = create_tensor(batch_shape + [in_feat2])
+        weight = create_tensor([out_feat, in_feat1, in_feat2])
+        if not has_bias:
+            yield SampleInput(input_tensor1, input_tensor2, weight)
+            continue
+        bias = create_tensor([out_feat])
+        yield SampleInput(input_tensor1, input_tensor2, weight, bias)
+
+def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs):
+    features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]]
+    batch_options: list[list[int]] = [
+        [],  # no batch
+        [0],
+        [8],
+        [2, 3],
+    ]
+    create_tensor = partial(make_tensor, device=device, dtype=dtype,
+                            requires_grad=requires_grad, low=-2, high=2)
+
+    for features, batch_shape in itertools.product(features_options, batch_options):
+        ndim = len(features) + len(batch_shape)
+        for dim in range(ndim):
+            input_tensor = create_tensor(batch_shape + features)
+            dim_size = input_tensor.size(dim)
+            if dim_size > 0 and dim_size % 2 == 0:
+                yield SampleInput(input_tensor, dim)
+
+def sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs):
+    N, C = 2, 3
+    D = 4
+    S = 3
+    L = 5
+
+    align_corners_options: tuple[Any, ...] = (None,)
+    if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'):
+        align_corners_options = (True, False, None)
+    ranks_for_mode = {
+        'nearest': [1, 2, 3],
+        'nearest-exact': [1, 2, 3],
+        'linear': [1],
+        'bilinear': [2],
+        'bicubic': [2],
+        'trilinear': [3],
+        'area': [1, 2, 3]
+    }
+
+    def shape(size, rank, with_batch_channel=True):
+        if with_batch_channel:
+            return tuple([N, C] + ([size] * rank))
+        return tuple([size] * rank)
+
+    def uneven_shape(size, rank, with_batch_channel=True):
+        rc = list(shape(size, rank, with_batch_channel))
+        rc[-1] += 1
+        if rank > 2:
+            rc[-2] -= 1
+        return tuple(rc)
+
+    if mode in ('bilinear', 'bicubic') and dtype == torch.uint8:
+        make_arg = partial(
+            make_tensor,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype
+            high=256 if dtype == torch.uint8 else None,
+        )
+        # provide few samples for a more close to typical image processing usage
+        rank = 2
+        for memory_format in [torch.contiguous_format, torch.channels_last]:
+            yield SampleInput(
+                make_arg(shape(270, rank), memory_format=memory_format),
+                shape(130, rank, False),
+                scale_factor=None,
+                mode=mode,
+                align_corners=False,
+            )
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    for align_corners in align_corners_options:
+        for rank in ranks_for_mode[mode]:
+            yield SampleInput(
+                make_arg(shape(D, rank)),
+                shape(S, rank, False),
+                scale_factor=None,
+                mode=mode,
+                align_corners=align_corners,
+            )
+            yield SampleInput(
+                make_arg(shape(D, rank)),
+                shape(L, rank, False),
+                scale_factor=None,
+                mode=mode,
+                align_corners=align_corners,
+            )
+            if rank > 1 and dtype.is_floating_point:
+                yield SampleInput(
+                    make_arg(uneven_shape(D, rank)),
+                    uneven_shape(S, rank, False),
+                    scale_factor=None,
+                    mode=mode,
+                    align_corners=align_corners,
+                )
+                yield SampleInput(
+                    make_arg(uneven_shape(D, rank)),
+                    uneven_shape(L, rank, False),
+                    scale_factor=None,
+                    mode=mode,
+                    align_corners=align_corners,
+                )
+            for recompute_scale_factor in [False, True]:
+                for scale_factor in [1.7, 0.6]:
+                    yield SampleInput(
+                        make_arg(shape(D, rank)),
+                        size=None,
+                        scale_factor=scale_factor,
+                        mode=mode,
+                        align_corners=align_corners,
+                        recompute_scale_factor=recompute_scale_factor,
+                    )
+
+def reference_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs)
+
+    if mode in ('bilinear', 'bicubic'):
+        make_arg = partial(
+            make_tensor,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype
+            high=256 if dtype == torch.uint8 else None,
+        )
+        # provide few samples for more typical image processing usage
+        for memory_format in [torch.contiguous_format, torch.channels_last]:
+            for aa in [True, False]:
+                yield SampleInput(
+                    make_arg((2, 3, 345, 456), memory_format=memory_format),
+                    (270, 270),
+                    scale_factor=None,
+                    mode=mode,
+                    align_corners=False,
+                    antialias=aa,
+                )
+
+def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
+    N, C = 2, 3
+    D = 4
+    S = 3
+    L = 5
+
+    ranks_for_mode = {
+        'nearest': [1, 2, 3],
+        'bilinear': [2],
+    }
+
+    def shape(size, rank, with_batch_channel=True):
+        if with_batch_channel:
+            return torch.Size([N, C] + ([size] * rank))
+        return torch.Size([size] * rank)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    for rank in ranks_for_mode[mode]:
+        yield SampleInput(make_arg(shape(D, rank)), size=shape(S, rank, False))
+        yield SampleInput(make_arg(shape(D, rank)), size=shape(L, rank, False))
+        yield SampleInput(make_arg(shape(D, rank)), scale_factor=1.7)
+        yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6)
+
+def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs)
+
+    if mode in ('bilinear', ):
+        make_arg = partial(
+            make_tensor,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype
+            high=256 if dtype == torch.uint8 else None,
+        )
+        # provide a single sample for more typical image processing usage
+        for memory_format in [torch.contiguous_format, torch.channels_last]:
+            yield SampleInput(
+                make_arg((2, 3, 345, 456), memory_format=memory_format),
+                (270, 270),
+            )
+
+def sample_inputs_upsample_aa(mode, self, device, dtype, requires_grad, **kwargs):
+    N = 6
+    C = 3
+    H = 10
+    W = 20
+    S = 3
+    L = 5
+
+    input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None)
+    yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None)
+    yield SampleInput(input_tensor, output_size=None, align_corners=False, scale_factors=[1.7, 0.9])
+    yield SampleInput(input_tensor, output_size=None, align_corners=True, scale_factors=[0.8, 1.0])
+
+    yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=None, scales_w=None)
+    yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=1.7, scales_w=0.9)
+    yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=True, scales_h=1.7, scales_w=0.9)
+
+def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs):
+    N = 5
+    for _ in range(1, N):
+        for approximate in ['none', 'tanh']:
+            yield SampleInput(
+                make_tensor((N * 2, N * 2), device=device, dtype=dtype,
+                            requires_grad=requires_grad, low=-3, high=3),
+                approximate=approximate)
+
+
+def error_inputs_gelu(op, device, **kwargs):
+    # Tests that gelu errors out when passed an approximation we don't know.
+    yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device), kwargs={"approximate": "asdf"}),
+                     error_regex="approximate argument must be either")
+
+
+def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs):
+    args_for_reduction_with_dim = (
+        ((S, S, S), (1,),),
+        ((S, S, S), (1, True, ),),
+        ((), (0,),),
+        ((), (0, True,),),
+    )
+    return ((SampleInput(make_tensor(input_tensor, dtype=dtype, device=device,
+                                     low=None, high=None,
+                                     requires_grad=requires_grad),
+                         *args))
+            for input_tensor, args in args_for_reduction_with_dim)
+
+def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
+    yield SampleInput(make_arg((S, S, S)))
+    yield SampleInput(make_arg(()))
+
+def _generate_nan_reduction_inputs(device, dtype, requires_grad, **kwargs):
+    yield from _generate_reduction_inputs(device, dtype, requires_grad)
+    # NaN only exists for floating point numbers
+    if dtype.is_complex or dtype.is_floating_point:
+        yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad)
+        yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad)
+
+def sample_inputs_nan_reduction(supports_multiple_dims):
+    # Generates sample inputs for reduction ops that contain the input tensor
+    # and dim and keepdim kwargs. If a reduction op needs to test additional
+    # args/kwargs then create a separate sample_inputs function
+    def fn(op_info, device, dtype, requires_grad, **kwargs):
+        for t in _generate_nan_reduction_inputs(device, dtype, requires_grad):
+            # Add case without dim and keepdim kwargs
+            yield SampleInput(t.clone().requires_grad_(requires_grad))
+            for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims):
+                yield SampleInput(t.clone().requires_grad_(requires_grad), **kwargs)
+
+    return fn
+
+def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad, **kwargs):
+    test_quantiles = (0.5, make_tensor((2,), dtype=dtype, device=device, low=0, high=1, requires_grad=requires_grad))
+    test_interpolations = ['linear', 'midpoint']
+
+    for quantiles in test_quantiles:
+        for t in _generate_reduction_inputs(device, dtype, requires_grad):
+            # Add case without dim and keepdim kwargs
+            input = t.clone().requires_grad_(requires_grad)
+            yield SampleInput(input, quantiles)
+            for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False):
+                # Interpolation kwarg for now is only supported when providing both dim and keepdim
+                kwargs.setdefault('dim', 0)
+                kwargs.setdefault('keepdim', False)
+                for interpolation in test_interpolations:
+                    kwargs['interpolation'] = interpolation
+                    input = t.clone().requires_grad_(requires_grad)
+                    yield SampleInput(input, quantiles, **kwargs)
+
+def sample_inputs_reduction_count_nonzero(*args, **kwargs):
+    """Sample inputs for count_nonzero"""
+    # count_nonzero does not support keepdim yet
+    for sample in sample_inputs_reduction(*args, **kwargs):
+        sample.kwargs.pop('keepdim', None)
+        yield sample
+
+def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad, **kwargs):
+    N = 10
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    return (SampleInput(make_arg((N, N))) for _ in range(1, N))
+
+def sample_inputs_fractional_max_pool2d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Order: input_shape, kernel_size
+    cases = (((1, 3, 9, 9), 3),
+             ((1, 3, 9, 9), (4, 4)),
+             ((1, 3, 9, 9), (6, 6)),
+             ((2, 3, 9, 9), (3, 3)),
+             ((1, 1, 4, 4), (2, 2)),
+             ((1, 2, 6, 6), (4, 4)))
+
+    for input_shape, kernel_size in cases:
+        for return_indices in [False, True]:
+            # test case passing a single output size
+            yield SampleInput(
+                make_arg(input_shape),
+                kernel_size,
+                output_size=2,
+                return_indices=return_indices,
+            )
+
+            # test case passing a tuple output size
+            yield SampleInput(
+                make_arg(input_shape),
+                kernel_size,
+                output_size=(2, 3),
+                return_indices=return_indices,
+            )
+
+            # test case passing an output ratio
+            yield SampleInput(
+                make_arg(input_shape),
+                kernel_size,
+                output_ratio=(0.5, 0.5),
+                return_indices=return_indices,
+            )
+
+    yield SampleInput(
+        make_arg((1, 1, 16, 16)),
+        (1, 1),
+        output_ratio=(0.5, 0.5),
+        return_indices=True,
+        _random_samples=make_tensor((1, 1, 2), device=device, dtype=dtype, requires_grad=False),
+    )
+
+def sample_inputs_fractional_max_pool3d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Order: input_shape, kernel_size
+    cases = (((2, 3, 5, 5, 5), (2, 2, 2)),
+             ((1, 2, 6, 5, 4), 2),
+             ((1, 2, 5, 6, 5), (2, 3, 2)),
+             ((1, 2, 6, 6, 6), (2, 3, 2)),
+             ((1, 1, 7, 6, 7), (2, 3, 4)),
+             ((1, 1, 4, 5, 4), (2, 2, 1)),
+             ((1, 1, 8, 7, 6), (4, 3, 2)),
+             ((0, 1, 4, 5, 4), (2, 2, 1)))
+
+    for input_shape, kernel_size in cases:
+        for return_indices in [False, True]:
+            # test case passing a single output size
+            yield SampleInput(
+                make_arg(input_shape),
+                kernel_size,
+                output_size=2,
+                return_indices=return_indices,
+            )
+
+            # test case passing a tuple output size
+            yield SampleInput(
+                make_arg(input_shape),
+                kernel_size,
+                output_size=(2, 3, 2),
+                return_indices=return_indices,
+            )
+
+            # test case passing an output ratio
+            yield SampleInput(
+                make_arg(input_shape),
+                kernel_size,
+                output_ratio=(0.5, 0.5, 0.5),
+                return_indices=return_indices,
+            )
+
+def sample_inputs_avgpool2d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
+    cases = (((1, 3, 9, 9), 3, 1, 1, True, False, 2),
+             ((1, 3, 9, 9), (4, 4), (2, 3), 1, True, False, 2),
+             ((1, 3, 9, 9), (6, 6), (3, 3), (2, 3), True, True, 2),
+             ((2, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2),
+             ((1, 1, 4, 4), (2, 2), (), (0, ), False, True, -2),
+             ((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, None))
+
+    for input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override in cases:
+        yield SampleInput(make_arg(input_shape),
+                          args=(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override))
+    # Case with just input_shape and kernel_size
+    yield SampleInput(make_arg((1, 3, 9, 9)), args=((3, 3)))
+
+def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Order: input_shape, kernel_size, kwargs
+    cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [
+        ((2, 3, 9), (3,), {}),
+        ((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)),
+        ((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)),
+        ((2, 3, 9), (3,), dict(stride=(1,), padding=(1,), ceil_mode=False, count_include_pad=True)),
+        ((0, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=False, count_include_pad=True)),
+        ((1, 2, 9), (7,), dict(stride=(3,), padding=(2,), ceil_mode=False)),
+        ((1, 2, 9), (7,), dict(stride=(3,), padding=(3,), ceil_mode=True)),
+        ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=False)),
+        ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=True)),
+    ]
+
+    for input_shape, kernel_size, kwargs in cases:
+        yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs)
+
+def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
+    cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [
+        ((2, 3, 3, 4, 4), (2, 2, 2), {}),
+        ((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True,
+                                  count_include_pad=False, divisor_override=2)),
+        ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=True,
+                                          count_include_pad=True, divisor_override=2)),
+        ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=False)),
+        ((1, 1, 7, 5, 7), (6, 3, 4), dict(stride=(2, 3, 2), padding=(3, 1, 0), ceil_mode=False,
+                                          count_include_pad=False, divisor_override=2)),
+        ((1, 1, 4, 5, 4), (2, 2, 3), dict(stride=(2, 2, 1), padding=0, ceil_mode=False,
+                                          count_include_pad=True, divisor_override=-2)),
+        ((1, 1, 6, 5, 6), (4, 5, 6), dict(stride=(2, 3, 2), padding=2, ceil_mode=True,
+                                          count_include_pad=True, divisor_override=None)),
+        ((0, 1, 4, 5, 4), (2, 3, 1), dict(stride=(2, 1, 2), padding=0, ceil_mode=False,
+                                          count_include_pad=True, divisor_override=None)),
+    ]
+
+    for input_shape, kernel_size, kwargs in cases:
+        yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs)
+
+def error_inputs_avg_pool1d(op_info, device, **kwargs):
+    # error inputs when pad is negative
+    x = torch.rand([0, 1, 49], dtype=torch.float32)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
+                     error_regex='pad must be non-negative')
+
+    # error inputs when pad > kernel_size / 2
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
+                     error_regex='pad should be at most half of effective kernel size')
+
+def error_inputs_avg_pool2d(op_info, device, **kwargs):
+    # error inputs when pad is negative
+    x = torch.rand([0, 1, 49], dtype=torch.float32)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
+                     error_regex='pad must be non-negative')
+    # 2-dimensional kernel
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1}),
+                     error_regex='pad must be non-negative')
+
+    # error inputs when pad > kernel_size / 2
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
+                     error_regex='pad should be at most half of effective kernel size')
+    # 2-dimensional kernel
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4}),
+                     error_regex='pad should be at most half of effective kernel size')
+
+    # error inputs for zero divisor
+    x = torch.zeros(3, 3, 3)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2), 'divisor_override': 0}),
+                     error_regex='divisor must be not zero')
+
+def error_inputs_avg_pool3d(op_info, device, **kwargs):
+    # error inputs when pad is negative
+    x = torch.rand([0, 1, 49, 50], dtype=torch.float32)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
+                     error_regex='pad must be non-negative')
+    # 3-dimensional kernel
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': -1}),
+                     error_regex='pad must be non-negative')
+
+    # error inputs when pad > kernel_size / 2
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
+                     error_regex='pad should be at most half of effective kernel size')
+    # 3-dimensional kernel
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': 4}),
+                     error_regex='pad should be at most half of effective kernel size')
+
+    # error inputs for zero divisor
+    x = torch.zeros(3, 3, 3, 3)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2, 2), 'divisor_override': 0}),
+                     error_regex='divisor must be not zero')
+
+    # error inputs for invalid input dimension
+    x = torch.rand([0, 1, 49], dtype=torch.float32)
+    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 0}),
+                     error_regex='non-empty 4D or 5D')
+
+
+def sample_inputs_to(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    # test_multiple_devices_to_cuda would fail if we use a different device than given
+    devices = [device]
+    if torch.device(device).type == 'cpu':
+        devices = [torch.device('cpu'), torch.device('cuda:0')] if torch.cuda.is_available() else devices
+    memory_formats = [torch.preserve_format, torch.channels_last]
+
+    # TODO: can't switch `to.device` overload to use positional arguments
+    # https://github.com/pytorch/pytorch/issues/84265
+    # to.device overload
+    for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats):
+        kwargs = {
+            "memory_format": mem_f,
+        }
+        yield SampleInput(make_arg((S, S, S, S)), args=(device, torch.float64, nb, cp), kwargs=kwargs)
+
+    # to.dtype overload
+    for nb, cp, mem_f in product([True, False], [True, False], memory_formats):
+        kwargs = {
+            "memory_format": mem_f,
+        }
+        yield SampleInput(make_arg((S, S, S, S)), args=(torch.float64, nb, cp), kwargs=kwargs)
+
+    # to.other overload
+    for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats):
+        kwargs = {
+            "memory_format": mem_f,
+        }
+        other = make_arg((S, S, S, S), dtype=torch.float64, device=device)
+        yield SampleInput(make_arg((S, S, S, S)), args=(other, nb, cp), kwargs=kwargs)
+
+
+def sample_inputs_topk(op_info, device, dtype, requires_grad, **kwargs):
+    def get_tensor_input(size):
+        return make_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    yield SampleInput(get_tensor_input((S, M, S)), 3)
+    yield SampleInput(get_tensor_input((S, M, S)), 3, 1)
+    yield SampleInput(get_tensor_input((S, M, S)), 3, -2)
+    yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True)
+    yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True)
+    yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True, True)
+    yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True, True)
+
+    yield SampleInput(get_tensor_input(()), 1)
+    yield SampleInput(get_tensor_input(()), 1, 0)
+    yield SampleInput(get_tensor_input(()), 1, -1)
+    yield SampleInput(get_tensor_input(()), 1, 0, True)
+    yield SampleInput(get_tensor_input(()), 1, -1, True)
+    yield SampleInput(get_tensor_input(()), 1, 0, True, True)
+    yield SampleInput(get_tensor_input(()), 1, -1, True, True)
+
+def sample_inputs_outer(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg(S), make_arg(M))
+
+def sample_inputs_dist(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    sizes = ((S, S, S), (S,), (S, 1, S), (), (S, S))
+    ps = (2, 4)
+
+    for size_x, size_y, p in product(sizes, sizes, ps):
+        yield SampleInput(make_arg(size_x), args=(make_arg(size_y), p))
+
+# Missing to test the nondeterminism of the operation
+# https://github.com/pytorch/pytorch/issues/53352
+def sample_inputs_index(op_info, device, dtype, requires_grad, reference=False, **kwargs):
+    # target.index_add(dim, idx, source, *, alpha=1)
+    add = "index_add" in op_info.name
+    # target.index_copy(dim, idx, source)
+    copy = "index_copy" in op_info.name
+    # target.index_fill(dim, idx, value)
+    fill = "index_fill" in op_info.name
+
+    # Extended reference inputs. We generate that exercise atomic adds / writing
+    # several times to one location
+    if reference:
+        make_arg = partial(torch.ones, device=device, dtype=dtype, requires_grad=requires_grad)
+        make_idx = partial(torch.zeros, device=device, dtype=torch.int64)
+    else:
+        make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+        # idx They need to be different for copy and add to be deterministic
+        if copy or add:
+            make_idx = partial(torch.randperm, device=device, dtype=torch.int64)
+        else:
+            def make_idx(n):
+                return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=n)
+
+    shapes = [(), (1,), (S, S)]
+    # extra parameter for add
+    if add:
+        if dtype == torch.bool:
+            alphas = (True, False)
+        else:
+            alphas = (-1, 0, 2)
+    else:
+        alphas = (None,)
+
+    if fill:
+        # A weird number to catch errors.
+        # The former one tests `index_fill.int_Scalar`, and the latter one tests `index_fill.int_Tensor`.
+        values = (make_arg((1,)).item(), make_arg(()))
+    else:
+        values = (None,)
+
+    for shape, alpha, value in product(shapes, alphas, values):
+        t = make_arg(shape)
+        args = []
+
+        # dim. We handle the scalar case
+        dim = -1 if t.ndim == 2 else 0
+        args.append(dim)
+
+        idx = make_idx(t.shape[dim] if t.ndim != 0 else 1)
+        args.append(idx)
+
+        # source
+        if copy or add:
+            args.append(make_arg(shape))
+        elif fill:
+            args.append(value)
+
+        args = tuple(args)
+        kwargs = {} if alpha is None else {"alpha": alpha}
+
+        yield SampleInput(t, args=args, kwargs=kwargs)
+
+def sample_inputs_index_reduce(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_idx(n, m):
+        return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m)
+
+    shapes = [((), ()), ((1,), (1,)), ((S, S), (S, M)), ((S, S, S), (S, M, S))]
+    include_selfs = (True, False)
+    reduce = op_info.variant_test_name
+    assert reduce in ('prod', 'mean', 'amin', 'amax')
+
+    for shape, include_self in product(shapes, include_selfs):
+        self_shape, src_shape = shape
+        # dim. We handle the scalar case
+        dim = 1 if len(self_shape) >= 2 else 0
+        idx = make_idx(src_shape[dim] if len(src_shape) != 0 else 1,
+                       self_shape[dim] if len(self_shape) != 0 else 1)
+        args = (dim, idx, make_arg(src_shape), reduce)
+        yield SampleInput(make_arg(self_shape),
+                          args=args,
+                          kwargs={'include_self' : include_self})
+
+    # Sample inputs to test edge cases for backward
+    if requires_grad and reduce == 'prod':
+        # Check that gradients are propagated correctly for prod when zeros in self/src are reduced
+        # This sample tests gradients for the following cases
+        # (a) 1 zero reduced (from source (self[0, 1]), from self (self[0, 0]))
+        # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0], self[1, 1])
+        # (c) no zeros reduced (self[2, 1], self[2, 2])
+        # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py
+        #     test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad
+        input = torch.tensor([[0, 13], [0, 0], [15, 19]], dtype=dtype, device=device, requires_grad=requires_grad)
+        src = torch.tensor([[2, 0], [0, 0], [2, 3], [2, 2]], dtype=dtype, device=device, requires_grad=requires_grad)
+        idx = torch.tensor([0, 1, 2, 0], dtype=torch.long, device=device)
+
+        yield SampleInput(input,
+                          args=(0, idx, src, reduce),
+                          kwargs={'include_self': True})
+
+def sample_inputs__unsafe_masked_index(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_idx(n, m, dim, d):
+        view_shape = [1] * dim
+        view_shape[d] = n
+        return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape)
+
+    cases = [
+        ((S, S), S, M),
+        ((S, S), M, S),
+        ((S, S, S), S, M),
+    ]
+
+    fill_value = make_tensor([], dtype=dtype, device="cpu").item()
+
+    for c in cases:
+        self_shape, high, idx_size = c
+        dim = len(self_shape)
+        indices = [make_idx(idx_size, high, dim, d) for d in range(dim)]
+        masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None]
+        mask = functools.reduce(torch.logical_and, masks)
+        yield SampleInput(make_arg(self_shape), mask, indices, fill_value)
+
+        masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None]
+        mask = functools.reduce(torch.logical_and, masks)
+        yield SampleInput(make_arg(self_shape), mask, indices, fill_value)
+
+def sample_inputs__unsafe_masked_index_put_accumulate(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_idx(n, m, dim, d):
+        view_shape = [1] * dim
+        view_shape[d] = n
+        return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape)
+
+    cases = [
+        ((S, S), S, (M, M)),
+        ((S, S), M, (S, S + 1)),
+        ((S, S, S), S, (M, M - 1, M + 1)),
+    ]
+
+    for c in cases:
+        self_shape, high, idx_sizes = c
+        dim = len(self_shape)
+        indices = [make_idx(idx_sizes[d], high, dim, d) for d in range(dim)]
+        masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None]
+        mask = functools.reduce(torch.logical_and, masks)
+        values = make_arg(idx_sizes)
+        yield SampleInput(make_arg(self_shape), mask, indices, values)
+
+        masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None]
+        mask = functools.reduce(torch.logical_and, masks)
+        yield SampleInput(make_arg(self_shape), mask, indices, values)
+
+
+def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs):
+    args = (
+        ((S, S, S), (),),
+        ((S, S, S), (1, ),),
+        ((S, S, S), (1, True, ),),
+        ((), (),),
+        ((), (0,),),
+        ((), (0, True,),),
+        # Non-fused mode kernel on CUDA
+        ((3000,), ()),
+    )
+    make_arg = partial(make_tensor, dtype=dtype, device=device,
+                       requires_grad=requires_grad, low=None, high=None)
+    return (SampleInput(make_arg(input_tensor), *args)
+            for input_tensor, args in args)
+
+# Missing to test the nondeterminism of the operation
+# https://github.com/pytorch/pytorch/issues/53352
+def sample_inputs_put(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False)
+
+    S = 3
+
+    # Generic inputs
+    idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S]
+    idx_list = [idx, -idx - 1]
+    for idx, acc in product(idx_list, (True, False)):
+        yield SampleInput(input=make_arg((S, S)),
+                          args=(idx.clone(),
+                                make_arg((S,)),
+                                acc))
+
+    # Scalar cases
+    scalar_sizes = [(), (1,)]
+    tgt_gen = (make_arg(size) for size in scalar_sizes)
+    idx_gen = (make_idx(size, high=1) for size in scalar_sizes)
+    src_gen = (make_arg(size) for size in scalar_sizes)
+    for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)):
+        yield SampleInput(input=tgt.clone().requires_grad_(requires_grad),
+                          args=(idx.clone(),
+                                src.clone().requires_grad_(requires_grad),
+                                acc))
+
+    # Empty cases
+    tgt_sizes = [(0,), (), (1,), (3, 2)]
+    tgt_gen = (make_arg(size) for size in tgt_sizes)
+    idx = make_idx((0,), high=1)
+    src = make_arg((0,))
+    for tgt, acc in product(tgt_gen, (True, False)):
+        yield SampleInput(input=tgt.clone().requires_grad_(requires_grad),
+                          args=(idx.clone(),
+                                src.clone().requires_grad_(requires_grad),
+                                acc))
+
+def sample_inputs_take(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False)
+
+    S = 3
+
+    # Generic inputs: take S elements out of S * S
+    index = make_idx((S,), high=(S * S))
+    for idx in (index, -index - 1):
+        yield SampleInput(input=make_arg((S, S)), args=(idx,))
+
+    # Scalar cases
+    scalar_sizes = [(), (1,)]
+    src_gen = (make_arg(size) for size in scalar_sizes)
+    idx_gen = (make_idx(size, high=1) for size in scalar_sizes)
+    for src, idx in product(src_gen, idx_gen):
+        yield SampleInput(input=src.clone().requires_grad_(requires_grad),
+                          args=(idx.clone(),))
+
+    # Empty cases
+    src_sizes = [(0,), (), (1,), (3, 2)]
+    src_gen = (make_arg(size) for size in src_sizes)
+
+    idx = make_idx((0,), high=1)
+    for src in src_gen:
+        yield SampleInput(input=src.clone().requires_grad_(requires_grad),
+                          args=(idx.clone(),))
+
+def sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
+    yield SampleInput(make_arg((4, 3, 2, 1)), [0, 1, 2, 3], [3, 2, 1, 0])
+    yield SampleInput(make_arg((4, 3, 2, 1)), [0, -1, -2, -3], [-3, -2, -1, -0])
+
+def reference_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # shape, source, destination
+    args = (
+        # empty inputs
+        ((), (), ()),
+        # int inputs, negative
+        ((3, 5, 7, 2), -2, 1),
+        # swap bounds
+        ((3, 5, 7, 2), (-1, 0), (0, -1)),
+        # non-sequential, negative
+        ((2, 3, 4, 5, 6), (3, -3, 4), (1, 0, -1)),
+        # idempotence, negative
+        ((2, 3, 4, 5, 6), (-3, 4, 3, 1), (-3, 4, 3, 1)),
+        # reverse, sequential, positive
+        ((6, 2, 3, 5, 4), (4, 3, 2, 1, 0), (0, 1, 2, 3, 4)),
+        # reverse, non-sequential
+        ((6, 2, 3, 5, 4), (-3, -2, -4, -5, -1), (2, 1, 3, 4, 0)),
+        # reverse, sequential, negative
+        ((6, 2, 3, 5, 4), (4, -2, 2, -4, -5), (-5, 1, 2, -2, -1)),
+    )
+
+    for shape, source, destination in args:
+        yield SampleInput(make_arg(shape), args=(source, destination))
+
+def error_movedim_moveaxis(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # source length < destination length
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3), (1, 0, -1))),
+        error_regex=(r"movedim: Invalid source or destination dims: source "
+                     r"\(\[3, -3\] dims\) should contain the same number of "
+                     r"dims as destination \(\[1, 0, -1\] dims\)"),
+    )
+
+    # source length > destination length
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3, 4), (1, 0))),
+        error_regex=(r"movedim: Invalid source or destination dims: source "
+                     r"\(\[3, -3, 4\] dims\) should contain the same number of "
+                     r"dims as destination \(\[1, 0\] dims\)"),
+    )
+
+    # repeated source dim, with negative indices
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 4, -5), (1, 0, 2))),
+        error_regex=r"movedim: repeated dim in `source` \(\[0, 4, -5\]\)",
+    )
+
+    # repeated destination dim, with negative indices
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, 2), (0, 4, -5))),
+        error_regex=r"movedim: repeated dim in `destination` \(\[0, 4, -5\]\)",
+    )
+
+    # repeated dim (both), with negative indices
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, -4), (0, 4, -5))),
+        error_regex=r"movedim: repeated dim in `source` \(\[1, 0, -4\]\)",
+    )
+
+    # out of bounds source inputs, with negative indices
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 1, -6), (1, 4, 2))),
+        error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
+        error_type=IndexError,
+    )
+
+    # out of bounds destination inputs, with negative indices
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 4, 2), (0, 1, -6))),
+        error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
+        error_type=IndexError,
+    )
+
+    # out of bounds source input, int
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=(-6, 1)),
+        error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
+        error_type=IndexError,
+    )
+
+    # out of bounds destination input, int
+    yield ErrorInput(
+        SampleInput(make_arg(2, 3, 4, 5, 6), args=(3, -6)),
+        error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
+        error_type=IndexError,
+    )
+
+def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),)
+    shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1))
+
+    if requires_grad:
+        # Tests for variant_consistency_jit, grad, gradgrad
+        # are slower. Use smaller bags of `rep_dims` and `shapes`
+        # in this case.
+        rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1))  # type: ignore[assignment]
+        shapes = ((), (0,), (2,), (3, 2))  # type: ignore[assignment]
+
+    is_repeat_op = op_info.name in ['repeat', '_refs.repeat']
+    for rep_dim, shape in product(rep_dims, shapes):
+        # `torch.repeat` errors for `len(rep_dims) < t.dim()`,
+        # so we filter such combinations.
+        if is_repeat_op and len(rep_dim) < len(shape):
+            continue
+        yield SampleInput(make_arg(shape), rep_dim)
+
+
+def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs):
+    shapes_and_args = (
+        ((S, S, S), 1, 2, 2),
+        ((S, S, S), -1, 2, 2),
+        ((S, S, S), 1, 0, 0),
+        ((S, S, S), -1, 0, 0),
+        ((S, S, S), 2, 1, 2),
+    )
+
+    for shape, dim, start, length in shapes_and_args:
+        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
+                             requires_grad=requires_grad)
+        yield SampleInput(tensor, dim, start, length)
+        # narrow also accepts the start argument being a Tensor
+        if is_narrow:
+            yield SampleInput(tensor, dim, torch.tensor(start), length)
+
+def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs):
+    yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs)
+
+    shapes_and_args = (
+        # 1-dim
+        ((M,), 0, 0, 0),    # 0 elems from the left
+        ((M,), -1, -1, 0),  # 0 elems from the right
+        ((M,), 0, 5, 3),    # 3 elems from the left
+        ((M,), 0, -5, 2),   # 2 elems from the right
+        ((M,), -1, 0, M),   # M elems from the left
+        ((M,), 0, -M, M),   # M elems from the right
+
+        # 2-dim
+        ((M, S), 1, 0, 0),    # dim 1, 0 elems from the left
+        ((S, M), -2, -1, 0),  # dim 0, 0 elems from the right
+        ((L, S), 1, 2, 3),    # dim 1, 3 elems from the left
+        ((L, S), -1, 3, 2),   # dim 1, 2 elems from the left
+        ((M, L), 0, 0, M),    # dim 0, M elems from the left
+        ((M, L), -1, -L, L),  # dim 1, L elems from the right
+
+        # 3-dim
+        ((L, M, S), 2, 0, 0),    # dim 2, 0 elems from the left
+        ((M, S, L), -1, -1, 0),  # dim 2, 0 elems from the right
+        ((S, L, M), 2, 0, M),    # dim 2, M elems from the left
+        ((L, S, M), -1, -M, M),  # dim 2, M elems from the right
+        ((S, L, M), 1, 0, 0),    # dim 1, 0 elems from the left
+        ((S, L, M), 0, 2, 1),    # dim 0, 1 elem from the left
+        ((M, S, M), -1, -5, 4),  # dim 2, 4 elems from the right
+    )
+
+    for shape, dim, start, length in shapes_and_args:
+        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
+                             requires_grad=requires_grad)
+        yield SampleInput(tensor, dim, start, length)
+        # narrow also accepts the start argument being a Tensor
+        if is_narrow:
+            yield SampleInput(tensor, dim, torch.tensor(start), length)
+
+def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # 0-dim
+    yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1),
+                     error_type=RuntimeError,
+                     error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.")
+
+    # out of bounds dim
+    if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
+        # narrow_copy_dense_cpu_out
+        yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0),
+                         error_type=RuntimeError,
+                         error_regex=r"Expected dim < static_cast\(self_sizes.size\(\)\) to be true, but got false\.")
+    else:
+        yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0),
+                         error_type=IndexError,
+                         error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)")
+    # out of bounds dim (negative)
+    yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0),
+                     error_type=IndexError,
+                     error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)")
+
+    # out of bounds start
+    yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
+                     error_type=IndexError,
+                     error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)")
+    # out of bounds start (negative)
+    yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0),
+                     error_type=IndexError,
+                     error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)")
+
+    # out of bounds length
+    yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1),
+                     error_type=RuntimeError,
+                     error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.")
+    # out of bounds length (negative)
+    if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
+        # narrow_copy_dense_cpu_out
+        yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1),
+                         error_type=RuntimeError,
+                         error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.")
+    else:
+        yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1),
+                         error_type=RuntimeError,
+                         error_regex=r"narrow\(\): length must be non-negative\.")
+
+    # Test Tensor overload that was added for XLA. Start must be an 0-dim
+    # integral Tensor. narrow_copy doesn't have this overload.
+    # https://github.com/pytorch/pytorch/issues/31558
+    if is_narrow:
+        # *1-dim* integral Tensor
+        yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2),
+                         error_type=RuntimeError,
+                         error_regex=r"start must be an 0-dim integral Tensor\.")
+
+        # 0-dim *bool* Tensor (bools are not allowed)
+        yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3),
+                         error_type=RuntimeError,
+                         error_regex=r"start must be an 0-dim integral Tensor\.")
+
+
+def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs):
+    y_shape_x_shape_and_kwargs = [
+        ((2, 3), (2, 3), {}),
+        ((2, 3), (2, 3), {'dim': 1}),
+        ((6,), (6,), {}),
+        ((6,), None, {}),
+        # When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad
+        # See Issue #{61619}
+        # ((6,0), (6,0), {}),
+        ((2, 3), (1, 3), {}),
+        ((3, 3), (3, 3), {}),
+        ((3, 3), (3, 3), {'dim': -2}),
+        ((5,), None, {'dx': 2.0}),
+        ((2, 2), None, {'dx': 3.0})
+    ]
+    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None,
+                       requires_grad=requires_grad)
+    for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs:
+        y_tensor = make_arg(y_shape)
+        if x_shape is not None:
+            x_tensor = make_arg(x_shape)
+            yield SampleInput(y_tensor, x_tensor, **kwarg)
+        else:
+            yield SampleInput(y_tensor, **kwarg)
+
+def sample_cumulative_trapezoid(op_info, device, dtype, requires_grad, **kwargs):
+
+    y_shape_x_shape_and_kwargs = [
+        ((2, 3), (2, 3), {}),
+        ((2, 3), (2, 3), {'dim': 1}),
+        ((6,), (6,), {}),
+        ((6,), None, {}),
+        # When 'cumulative_trapezoid' is called with an empty input, it does not produce an output with requires_grad
+        # See Issue #{61619}
+        # ((6,0), (6,0), {}),
+        ((2, 3), (1, 3), {}),
+        ((3, 3), (3, 3), {}),
+        ((3, 3), (3, 3), {'dim': -2}),
+        ((5,), None, {'dx': 2.0}),
+        ((2, 2), None, {'dx': 3.0})
+    ]
+    make_arg = partial(make_tensor, device=device, dtype=dtype,
+                       requires_grad=requires_grad, low=None, high=None)
+    for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs:
+        y_tensor = make_arg(y_shape)
+        if x_shape is not None:
+            x_tensor = make_arg(x_shape)
+            yield SampleInput(y_tensor, x_tensor, **kwarg)
+        else:
+            yield SampleInput(y_tensor, **kwarg)
+
+def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
+    shapes_and_axes = [
+        ((3, 4, 5), 0),
+        ((3, 4, 5), 1),
+        ((3, 4, 5), 3),
+        ((3, 4, 5), -1),
+        ((3, 4, 5), -3),
+        ((), 0),
+        ((), -1),
+        ((1,), 0),
+        ((1,), -1),
+    ]
+
+    for shape, axis in shapes_and_axes:
+        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
+                             requires_grad=requires_grad)
+        yield SampleInput(tensor, axis)
+
+
+def sample_inputs_nn_unfold(op_info, device, dtype, requires_grad, **kwargs):
+    shapes = ((0, 1, 5, 5), (2, 3, 5, 5))
+    kernel_sizes = (2, (2, 2), (2, 3))
+    dilations = (1, 2, (1, 2))
+    paddings = (0, 1, (1, 2))
+    strides = (1, 2, (1, 2))
+
+    cases = product(shapes, kernel_sizes, dilations, paddings, strides)
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    for shape, kernel_size, dilation, padding, stride in cases:
+        tensor = make_arg(shape)
+        yield SampleInput(tensor, kernel_size, dilation, padding, stride)
+
+    # With default args
+    yield SampleInput(make_arg((1, 1, 5, 5)), (3, 3))
+
+
+def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
+    shapes_and_args = (
+        ((S, 1, S, 1), ()),
+        ((1, 1, 1, 1), ()),
+        ((1, 1, 1, 1), (0,)),
+        ((S, 1, S, 1), (1,)),
+        ((S, 1, S, 1), (-1,)),
+        ((S, 1, S, 1), (2,)),
+        ((S, 1, S, 1), (-2,)),
+        ((), (0, )),
+    )
+
+    for shape, args in shapes_and_args:
+        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
+                             requires_grad=requires_grad)
+
+        yield SampleInput(tensor, args=args)
+
+
+def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs):
+    shapes_and_args = (
+        ((1, 1, 1, 1), ()),
+        ((S, 1, S, 1), (1,)),
+        ((S, 1, S, 1), (-1,)),
+        ((S, 1, S, 1), (1, 3)),
+        ((S, 1, S, 1), (1, 2,)),
+        ((), (0,)),
+    )
+
+    for shape, dims in shapes_and_args:
+        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
+                             requires_grad=requires_grad)
+
+        yield SampleInput(tensor, dims)
+
+
+def _squeeze_ref(x, axis=None):
+    # NumPy doesn't allow squeezing scalars
+    if x.ndim == 0:
+        return x
+
+    if isinstance(axis, Sequence):
+        # Numpy doesn't allow specifying non-singular dimensions
+        axis = tuple(a for a in axis if x.shape[a] == 1)
+
+    if isinstance(axis, int) and x.shape[axis] != 1:
+        return x
+
+    return np.squeeze(x, axis)
+
+def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
+    assert mode in ('constant', 'reflect', 'replicate', 'circular')
+    if mode in ['reflect', 'replicate']:
+        cases: tuple = (  # ignore
+            ((1, 3), (1, 2)),
+            ((1, 3), (0, 1)),
+            ((0, 3, 3), (1, 2)),
+            ((0, 3, 3), (0, 1)),
+            ((1, 3, 3), (1, 2)),
+            ((1, 3, 3), (0, 1)),
+            ((1, 3, 3), (0, 2, 0, 1)),
+            ((0, 3, 3, 3), (0, 2, 0, 1)),
+            ((3, 3, 5, 5), (0, 2, 0, 1)),
+            ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 4, 4), (-1, 1, -2, 1)),
+        )
+    elif mode == 'constant':
+        cases = (
+            ((1, 3), (1, 2)),
+            ((1, 3), (0, 1)),
+            ((1, 3), (0, 2, 0, 1)),
+            ((0, 3, 3), (1, 2)),
+            ((0, 3, 3), (0, 1)),
+            ((0, 3, 3), (0, 2, 0, 1)),
+            ((0, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 3), (1, 2)),
+            ((1, 3, 3), (0, 1)),
+            ((1, 3, 3), (0, 2, 0, 1)),
+            ((1, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((0, 3, 3, 3), (1, 2)),
+            ((0, 3, 3, 3), (0, 1)),
+            ((0, 3, 3, 3), (0, 2, 0, 1)),
+            ((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((3, 3, 5, 5), (1, 2)),
+            ((3, 3, 5, 5), (0, 1)),
+            ((3, 3, 5, 5), (0, 2, 0, 1)),
+            ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 3, 3, 3), (1, 2)),
+            ((1, 3, 3, 3, 3), (0, 1)),
+            ((1, 3, 3, 3, 3), (0, 2, 0, 1)),
+            ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
+            ((1, 3, 4, 4), (-1, 1, -2, 1)),
+        )
+    else:  # mode == 'circular'
+        if dtype == torch.bool:
+            # test_dtypes fails on ASAN with for the case ab
+            # runtime error: load of value 190, which is not a valid value for type 'bool'
+            # Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562
+            # Reference Issue: https://github.com/pytorch/pytorch/issues/63034
+            cases = (
+                ((2, 3, 3), (1, 2)),
+                ((1, 3, 3), (1, 2)),
+            )
+        else:
+            cases = (
+                ((0, 3, 3), (1, 2)),
+                ((0, 3, 3), (0, 1)),
+                ((1, 3, 3), (1, 2)),
+                ((1, 3, 3), (0, 1)),
+                ((0, 3, 3, 3), (0, 2, 0, 1)),
+                ((3, 3, 5, 5), (0, 2, 0, 1)),
+                ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
+                ((1, 3, 4, 4), (-1, 1, -2, 1)),
+            )
+
+    make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    if mode == 'constant':
+        # Default args
+        yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),))
+
+    if mode in ['reflect', 'replicate', 'circular']:
+        for shape, pad in cases:
+            yield SampleInput(make_inp(shape), args=(pad, mode))
+    else:  # mode == 'constant'
+        for pad_value in (1., 2.):
+            for shape, pad in cases:
+                yield SampleInput(make_inp(shape), args=(pad, mode, pad_value))
+
+def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs):
+    cases: tuple = (
+        ((5, 3, 4, 4), (-4, 5, 0, 0)),
+        ((6, 2, 4, 4), (0, 0, 2, -4)),
+        ((5, 6, 4, 4), (5, -4, -4, 3)),
+        ((4, 2, 5, 5), (-2, -1, 4, 6)),
+        ((2, 6, 5, 5), (8, -1, -1, -3)),
+        ((8, 1, 5, 5), (-2, -1, -1, -3)),
+    )
+    make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    for shape, pad in cases:
+        yield SampleInput(make_inp(shape), args=(pad, 'replicate'))
+
+def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs):
+    # Inherit sample inputs from nn.pad, but transform them to fit
+    # constant_pad_nd's interface
+    nn_samples = sample_inputs_nn_pad(op_info, device, dtype, *args,
+                                      mode='constant', **kwargs)
+
+    # NOTE: primTorch is more strict about the type of the fill value argument
+    # So we must cast it to the correct dtype
+    from torch._prims_common import dtype_to_type
+    scalar_type = dtype_to_type(dtype)
+
+    def drop_mode_argument(input, pad, mode=None, value=None):
+        if value is None:
+            return SampleInput(input, args=(pad,))
+        else:
+            return SampleInput(input, args=(pad, scalar_type(value)))
+
+    for sample in nn_samples:
+        yield drop_mode_argument(sample.input, *sample.args, **sample.kwargs)
+
+def sample_inputs_repeat_interleave(op_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(make_input(()), repeats=2)
+    yield SampleInput(make_input((2, 3, 4)), repeats=2)
+    yield SampleInput(make_input((2, 3, 4)), repeats=2, dim=1)
+    yield SampleInput(make_input((2, 3, 4)), repeats=torch.arange(3, device=device), dim=1)
+
+
+def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
+    def mt(shape, **kwargs):
+        return make_tensor(shape, device=device, dtype=dtype,
+                           requires_grad=requires_grad, **kwargs)
+
+    yield SampleInput(mt(100), n_fft=10, return_complex=True)
+    yield SampleInput(mt(100), n_fft=10, return_complex=False)
+    if dtype.is_complex:
+        yield SampleInput(mt(100), n_fft=10)
+
+    for center in [False, True]:
+        yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True)
+        yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4,
+                          center=center, return_complex=True)
+
+    window = mt(16, low=.5, high=2.0)
+    yield SampleInput(
+        mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center))
+    yield SampleInput(
+        mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center))
+    if not dtype.is_complex:
+        yield SampleInput(
+            mt((10, 100)), n_fft=16, window=window, onesided=False,
+            return_complex=True)
+
+
+def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def mt(shape, **kwargs):
+        real_shape = shape if dtype.is_complex else shape + (2,)
+        return make_arg(real_shape, **kwargs)
+
+    yield SampleInput(mt((10, 2)), kwargs=dict(n_fft=10))
+    yield SampleInput(mt((6, 3)), kwargs=dict(n_fft=6, onesided=False))
+    yield SampleInput(mt((6, 4)), kwargs=dict(n_fft=10, onesided=True))
+
+    for center in [False, True]:
+        yield SampleInput(mt((10, 10, 6)), kwargs=dict(n_fft=10, center=center))
+        yield SampleInput(mt((1, 9, 10)), kwargs=dict(n_fft=16, hop_length=4, center=center))
+
+    window = make_arg(10, low=.5, high=2.0)
+    yield SampleInput(mt((10, 10, 6)), kwargs=dict(
+        n_fft=10, window=window, center=center, return_complex=dtype.is_complex))
+    yield SampleInput(mt((10, 10, 10)), kwargs=dict(
+        n_fft=10, window=window[:8], win_length=8, center=center, return_complex=True))
+
+    real_window = window if not dtype.is_complex else window.real
+    yield SampleInput(mt((10, 5, 6)), kwargs=dict(n_fft=8, window=real_window[:8], center=center))
+
+def sample_inputs_ormqr(op_info, device, dtype, requires_grad, **kwargs):
+    # create a helper function wrapping `make_tensor`
+    make_input = partial(make_tensor, dtype=dtype, device=device, low=-1, high=1)
+
+    batches = [(), (0, ), (2, ), (2, 1)]
+    ns = [5, 2, 0]
+    tf = [True, False]
+    for batch, (m, n), left, transpose in product(batches, product(ns, ns), tf, tf):
+        input = make_input((*batch, m, n))
+        reflectors, tau = torch.geqrf(input)
+        reflectors.requires_grad_(requires_grad)
+        tau.requires_grad_(requires_grad)
+        other_matrix_shape = (m, n) if left else (n, m)
+        other = make_input((*batch, *other_matrix_shape), requires_grad=requires_grad)
+        yield SampleInput(reflectors, tau, other, left=left, transpose=transpose)
+
+
+def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, **kwargs):
+    cholesky_inverse_samples = sample_inputs_linalg_cholesky_inverse(
+        op_info, device, dtype, requires_grad=False
+    )
+
+    for sample in cholesky_inverse_samples:
+        psd_matrix = sample.input
+        sample.input = make_tensor(psd_matrix.shape, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None)
+        sample.args = (psd_matrix.requires_grad_(requires_grad),)
+        yield sample
+
+
+def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(make_fullrank_matrices_with_distinct_singular_values,
+                       dtype=dtype, device=device, requires_grad=requires_grad)
+
+    # not needed once OpInfo tests support Iterables
+    batch_shapes = ((), (3,), (3, 3))
+    for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)):
+        shape = batch_shape + (S + size_delta, S)
+        input = make_arg(*shape)
+        yield SampleInput(input, args=(True, get_infos))
+
+
+def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs):
+    def out_fn(output):
+        return output[1], output[2]
+
+    for lu_sample in sample_inputs_linalg_lu(op_info, device, dtype, requires_grad, **kwargs):
+        lu_data, pivots = torch.linalg.lu_factor(lu_sample.input)
+        lu_data.requires_grad_(requires_grad)
+        yield SampleInput(lu_data, pivots).with_metadata(output_process_fn_grad=out_fn)
+
+
+def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    args = ((0, 0), (1, 2), (0, 2), (2, 0), (-1, 0), (10000, 1), (2,), ((1, 2, -1), (0, 1, 2)))
+
+    for arg in args:
+        yield SampleInput(make_arg((0, 0, 0)), args=arg)
+        yield SampleInput(make_arg((S, S, S)), args=arg)
+
+    # Scalar tensor
+    yield SampleInput(make_arg(()), args=(10, ))
+
+def error_inputs_roll(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+    err_msg1 = "`shifts` required"
+    s1 = SampleInput(make_arg((S,)), ())
+    yield ErrorInput(s1, error_regex=err_msg1)
+
+    err_msg2 = ("shifts and dimensions must align")
+    s2 = SampleInput(make_arg((S, S)), (2, 1), 0)
+    yield ErrorInput(s2, error_regex=err_msg2)
+
+    err_msg3 = ("out of range")
+    s3 = SampleInput(make_arg((S, )), 0, 2)
+    yield ErrorInput(s3, error_regex=err_msg3, error_type=IndexError)
+
+    err_msg4 = ("Dimension specified as 0")
+    s4 = SampleInput(make_arg(()), 0, 0)
+    yield ErrorInput(s4, error_regex=err_msg4, error_type=IndexError)
+
+def sample_inputs_rot90(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    args = itertools.product(range(-5, 6), [(0, 1), (1, 2), (1, -1)])
+
+    yield SampleInput(make_arg((S, S, S)))
+    for arg in args:
+        yield SampleInput(make_arg((S, S, S)), args=arg)
+
+
+def error_inputs_rot90(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+    err_msg1 = "expected total rotation dims"
+    s1 = SampleInput(make_arg((S, S)), dims=(0,))
+    yield ErrorInput(s1, error_regex=err_msg1)
+
+    err_msg2 = "expected total dims >= 2"
+    s2 = SampleInput(make_arg((S,)))
+    yield ErrorInput(s2, error_regex=err_msg2)
+
+    err_msg3 = "expected rotation dims to be different"
+    s3 = SampleInput(make_arg((S, S)), dims=(1, 1))
+    yield ErrorInput(s3, error_regex=err_msg3)
+
+
+def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs):
+    tensor_nd = partial(make_tensor, (S, S, S), device=device, dtype=dtype,
+                        requires_grad=requires_grad)
+    tensor_1d = partial(make_tensor, (S,), device=device, dtype=dtype,
+                        requires_grad=requires_grad)
+
+    yield SampleInput(tensor_nd())
+    yield SampleInput(tensor_nd(), dim=1)
+    yield SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True)
+    yield SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True)
+    yield SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False)
+
+    yield SampleInput(tensor_nd(), dim=(1,), correction=1.3)
+    yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2)
+    yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True)
+    yield SampleInput(tensor_nd(), dim=None, correction=None)
+    yield SampleInput(tensor_nd(), correction=0, keepdim=True)
+    yield SampleInput(make_tensor(3, 4, 5, device=device, dtype=dtype, requires_grad=requires_grad), dim=-3)
+
+
+def sample_inputs_std_var_unbiased(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype,
+                       requires_grad=requires_grad)
+
+    # Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
+    yield SampleInput(make_arg((S, S)), True)
+    yield SampleInput(make_arg((S,)), False)
+
+
+def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs):
+    shapes = [(2,), (1, 2), (3, 2), (2, 3)]
+    for shape in shapes:
+        yield make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+
+
+def sample_inputs_corrcoef(op_info, device, dtype, requires_grad, **kwargs):
+    return (SampleInput(t) for t in _generate_correlation_inputs(device, dtype, requires_grad))
+
+def sample_inputs_copysign(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_elementwise_binary(op_info, device, dtype, requires_grad, **kwargs)
+    if dtype.is_floating_point:
+        yield SampleInput(make_tensor(5, dtype=dtype, device=device, requires_grad=requires_grad), -3.14)
+
+
+def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
+    for t in _generate_correlation_inputs(device, dtype, requires_grad):
+        yield SampleInput(t)
+        num_observations = t.numel() if t.ndimension() < 2 else t.size(1)
+        fweights = make_tensor((num_observations,), dtype=torch.int, device=device, low=1, high=10)
+        aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=0, high=1, requires_grad=requires_grad)
+        for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]):
+            yield SampleInput(t.clone().requires_grad_(requires_grad),
+                              correction=correction, fweights=fw, aweights=aw)
+
+
+def error_inputs_cov(op_info, device, **kwargs):
+    a = torch.rand(S, device=device)
+    yield ErrorInput(
+        SampleInput(torch.rand(S, S, S, device=device)),
+        error_regex="expected input to have two or fewer dimensions")
+    yield ErrorInput(
+        SampleInput(a, fweights=torch.rand(S, S, device=device)),
+        error_regex="expected fweights to have one or fewer dimensions")
+    yield ErrorInput(
+        SampleInput(a, aweights=torch.rand(S, S, device=device)),
+        error_regex="expected aweights to have one or fewer dimensions")
+    yield ErrorInput(
+        SampleInput(a, fweights=torch.rand(S, device=device)),
+        error_regex="expected fweights to have integral dtype")
+    yield ErrorInput(
+        SampleInput(a, aweights=torch.tensor([1, 1], device=device)),
+        error_regex="expected aweights to have floating point dtype")
+    yield ErrorInput(
+        SampleInput(a, fweights=torch.tensor([1], device=device)),
+        error_regex="expected fweights to have the same numel")
+    yield ErrorInput(
+        SampleInput(a, aweights=torch.rand(1, device=device)),
+        error_regex="expected aweights to have the same numel")
+    yield ErrorInput(
+        SampleInput(a, fweights=torch.tensor([-1, -2, -3, -4 , -5], device=device)),
+        error_regex="fweights cannot be negative")
+    yield ErrorInput(
+        SampleInput(a, aweights=torch.tensor([-1., -2., -3., -4., -5.], device=device)),
+        error_regex="aweights cannot be negative")
+
+
+def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = [((1, 2, 3, 4), (0, 2, 3, 1)),
+             ((1, 2, 3, 4), (0, -2, -1, 1)),
+             ((), ()),
+             ((1, 2, 3, 4), (2, 1, 3, 0))]
+
+    for shape, args in cases:
+        yield SampleInput(make_arg(shape), args=(args,))
+
+def reference_inputs_permute(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_permute(op, device, dtype, requires_grad, **kwargs)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = (
+        ((), ()),
+        ((1,), (0,)),
+        ((2, 2), (1, 0)),
+        ((2, 2), (0, 1)),
+        ((2, 0, 1), (0, 2, 1)),
+        ((3, 4, 2), (2, 1, 0)),
+        ((3, 4, 2), (1, 0, 2)),
+        ((3, 4, 2), (0, 1, 2)),
+    )
+
+    # Adds tricky permutations and permutations with noncontiguity
+    for shape, permutation in cases:
+        for p in itertools.permutations(permutation):
+            a = make_arg(shape).permute(p)
+            yield SampleInput(a, args=(permutation,))
+
+            a = make_arg(shape, noncontiguous=True).permute(p)
+            yield SampleInput(a, args=(permutation,))
+
+def error_inputs_softshrink(op, device, **kwargs):
+    yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}),
+                     error_regex="lambda must be greater or equal to 0, but found to be -0.5")
+
+def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # The additional sample is to check additional values of lambd beyond the default
+    # value (what is already checked by sample_inputs_elementwise_unary)
+    for lbda in (0., 0.5):
+        yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda})
+
+    yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
+
+def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # The additional sample is to check additional values of lambd beyond the default
+    # value (what is already checked by sample_inputs_elementwise_unary)
+    # Note that unlike softshrink, lambd is allowed to be negative for hardshrink
+    for lbda in (-0.5, 0., 0.5):
+        yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda})
+
+    yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
+
+
+def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # The additional sample is to check additional values of min_val and max_val beyond the default
+    # value (what is already checked by sample_inputs_elementwise_unary)
+    for max_val, min_val in ((0.5, -0.5), (0., 0.)):
+        yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val})
+
+    yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
+
+def error_inputs_hardtanh(op_info, device, **kwargs):
+    # Tests that hardtanh errors out when passed min_val > max_val.
+    yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"min_val": 0.5, "max_val": -0.5}),
+                     error_type=ValueError, error_regex="min_val cannot be greater than max_val")
+
+def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs):
+    def c(t):
+        return t.clone().requires_grad_(requires_grad)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    x = make_arg((3,))
+    y = make_arg((4,))
+    A = make_arg((2, 3,))
+    B = make_arg((1, 3,))
+    C = make_arg((1, 2, 3,))
+    D = make_arg((1, 3, 4,))
+    E = make_arg((4, 4,))
+    H = make_arg((3, 3,))
+    I = make_arg((1, 3, 1,))
+
+    # Vector operations
+    yield SampleInput([c(x)], 'i->')                      # sum
+    yield SampleInput([c(x), c(y)], 'i,j->ij')            # outer
+
+    # Matrix operations
+    yield SampleInput([c(A)], "ij->i")                    # col sum
+    yield SampleInput([c(A), c(B)], "ij,kj->ik")          # matmul
+    yield SampleInput([c(A), c(E)], "ij,Ab->ijAb")        # matrix outer product
+
+    # Tensor operations
+    yield SampleInput([c(C), c(D)], "aij,ajk->aik")       # batch matmul
+    yield SampleInput([c(D), c(E)], "aij,jk->aik")        # tensor matrix contraction
+    yield SampleInput([c(C), c(B)], "ijk,ik->j")          # non contiguous
+
+    # Test diagonals
+    yield SampleInput([c(I)], 'iji->j')                   # non-contiguous trace
+
+    # Test ellipsis
+    yield SampleInput([c(H)], "i...->...")
+    yield SampleInput([c(C), c(x)], '...ik, ...j -> ij')
+
+
+def sample_inputs_flip(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    sizes = ((S, M, S), (S, 0, M))
+    all_dims = ((0, 1, 2), (0,), (0, 2), (-1,), ())
+
+    for size, dims in product(sizes, all_dims):
+        yield SampleInput(make_arg(size), kwargs={"dims": dims})
+
+def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad, **kwargs):
+    shapes = [
+        (S, M, S),
+        (S, 0, M),
+    ]
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    return (SampleInput(make_arg(shape, low=None, high=None)) for shape in shapes)
+
+def error_inputs_fliplr(op, device, **kwargs):
+    yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device)),
+                     error_regex="Input must be >= 2-d.")
+
+def error_inputs_flipud(op, device, **kwargs):
+    yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device)),
+                     error_regex="Input must be >= 1-d.")
+
+def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
+    make_integral_arg = partial(make_tensor, dtype=torch.int32, device=device, low=None, high=None, requires_grad=False)
+    shape = (S, M, S)
+
+    yield SampleInput(make_arg(shape), args=(make_arg(shape), make_arg(shape)))
+    yield SampleInput(make_arg(shape), args=(make_arg(shape[1:]), make_arg(shape[1:])))
+    yield SampleInput(make_arg(shape), args=(make_arg((S, 1, S)),))
+    yield SampleInput(make_arg(shape), args=(None, make_arg(shape)))
+    yield SampleInput(make_arg(shape), args=(make_arg(shape), None))
+    # test type promotion
+    yield SampleInput(make_arg(shape), args=(make_integral_arg(shape), None))
+    yield SampleInput(make_arg(shape), args=(make_arg(shape), make_integral_arg(shape)))
+
+def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sample_inputs_func, supports_scalars=False, **kwargs):
+    yield from sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
+
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_scalar_tensor = partial(make_tensor, (), device='cpu', dtype=dtype, requires_grad=requires_grad)
+    supported_dtypes = op.supported_dtypes(device)
+
+    # broadcasting and oncontiguous cases
+    cases = (
+        ((4, 4), (4, 4), (4, 4)),
+        ((4, 4), (1, 4, 4), (4, 4)),
+        ((4, 4), (1, 4, 4), (4, 1, 4)),
+        ((4, 4, 1), (1, 4, 4), (4, 4)),
+        ((4, 1), (1, 4, 4), (1, 4)),
+        ((4, 4), (), (4, 4)),
+        ((4, 4), (), ()),
+        ((), (4, 4), (1, 4, 4)),
+    )
+
+    for a, b, c in cases:
+        yield SampleInput(make_arg(a), args=(make_arg(b), make_arg(c)))
+        yield SampleInput(make_arg(a, noncontiguous=True),
+                          args=(make_arg(b).transpose(0, -1), make_arg(c, noncontiguous=True).transpose(0, -1)))
+
+    # scalar cases
+    if supports_scalars:
+        cases = [
+            ((), 1, 2,),
+            ((), 1., 2),
+            ((4, 4), 1., 2,),
+            ((3, 4), make_scalar_tensor(), make_scalar_tensor()),
+        ]
+
+        if torch.complex64 in supported_dtypes:
+            cases.extend([
+                ((3, 1, 4), complex(1, 2), 3.),
+            ])
+
+        for a, b, c in cases:
+            yield SampleInput(make_arg(a), args=(b, c))
+
+    # type promotion cases
+    # int x float
+    if torch.float in supported_dtypes and torch.long in supported_dtypes:
+        a = make_arg((), dtype=torch.long)
+        b = make_arg((1, 4), dtype=torch.float)
+        c = make_arg((3, 4))
+
+        cases = (
+            (a, b, c),
+            (c, a, b),
+        )
+
+        for a, b, c in cases:
+            yield SampleInput(a, args=(b, c))
+
+    # NaN propagation
+    if dtype.is_floating_point or dtype.is_complex:
+        nan = float('nan') if dtype.is_floating_point else complex(float('nan'), float('nan'))
+
+        a = make_arg((12,))
+        a[4] = nan
+        a[7] = nan
+        b = make_arg((12,))
+        b[1] = nan
+        b[7] = nan
+        c = make_arg((12,))
+        c[9] = nan
+
+        yield SampleInput(a, args=(b, c))
+
+
+def _clamp_min_numpy(a, min=None):
+    return np.maximum(a, min)
+
+
+def _clamp_max_numpy(a, max=None):
+    return np.minimum(a, max)
+
+
+def _clamp_numpy(a, min=None, max=None):
+    if min is None:
+        return np.minimum(a, max)
+    if max is None:
+        return np.maximum(a, min)
+
+    return np.minimum(max, np.maximum(a, min))
+
+
+def sample_inputs_cumprod(op_info, device, dtype, requires_grad, **kwargs):
+    def make_arg(shape):
+        # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
+        return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad)
+
+    def prod_zeros(dim_select):
+        assert len(dim_select) == 2
+        result = make_arg(3 * (S,))
+        result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
+        result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
+        result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
+        return result
+
+    for dim in range(3):
+        yield SampleInput(make_arg((S, S, S)), args=(dim,))
+    # Scalar tensors and empty tensor
+    for size in [(), (1,), (0,)]:
+        yield SampleInput(make_arg(size), args=(0,))
+
+    yield SampleInput(prod_zeros([0, 1]), args=(1,))
+    yield SampleInput(prod_zeros([0, 2]), args=(1,))
+    yield SampleInput(prod_zeros([1, 2]), args=(1,))
+
+    # test dtype kwarg
+    yield SampleInput(prod_zeros([1, 2]), args=(1,), kwargs={'dtype': dtype})
+
+def sample_inputs_view_as_complex(op_info, device, dtype, requires_grad, **kwargs):
+    yield SampleInput(make_tensor((S, 2), dtype=dtype, device=device, requires_grad=requires_grad))
+
+def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    sizes = ((S, S), ())
+    return (SampleInput(make_arg(size)) for size in sizes)
+
+def error_inputs_complex(op_info, device, is_ref=False, **kwargs):
+    make_arg = partial(make_tensor, dtype=torch.float32, device=device)
+
+    if is_ref:
+        error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32"
+        error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument"
+        error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead"
+    else:
+        error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int"
+        error_dtype = "Expected object of scalar type Float but got scalar type Double for second argument"
+        error_out = "Expected object of scalar type ComplexDouble but got scalar type ComplexFloat for argument 'out'"
+
+    yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)),
+                     error_type=RuntimeError, error_regex=error_float)
+
+    yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.float64)),
+                     error_type=RuntimeError, error_regex=error_dtype)
+
+    yield ErrorInput(SampleInput(make_arg(M, S, dtype=torch.float64), make_arg(M, S, dtype=torch.float64),
+                                 out=make_arg(M, S, dtype=torch.complex64)),
+                     error_type=RuntimeError, error_regex=error_out)
+
+def sample_inputs_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    shape = (S, S)
+    yield SampleInput(make_arg(shape), make_arg(shape))
+
+def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs):
+    def make_arg(shape):
+        # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
+        return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad)
+
+    def prod_single_zero():
+        result = make_arg(2 * (S,))
+        result[0, 1] = 0
+        return result
+
+    for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
+        # only Tensor, ignore other inputs
+        yield SampleInput(sample.input.clone().requires_grad_(requires_grad))
+        yield sample
+
+    # Generates samples with keepdim = True
+    for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
+        sample.kwargs['keepdim'] = True
+        yield sample
+
+    yield SampleInput(prod_single_zero())
+    yield SampleInput(make_arg((3, 3, 3)), args=(1,))
+    yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True})
+
+    yield SampleInput(make_arg((3, 0)), args=(1,))
+    yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True})
+    yield SampleInput(torch.tensor([2., 3, 0, 0], dtype=dtype, device=device, requires_grad=requires_grad))
+
+    # test zero scalar tensor
+    zero = make_arg(())
+    zero.zero_()
+    yield SampleInput(zero.clone().requires_grad_(requires_grad))
+    yield SampleInput(zero.clone().requires_grad_(requires_grad), args=(0,))
+    yield SampleInput(zero.clone().requires_grad_(requires_grad),
+                      args=(0,),
+                      kwargs={'keepdim': True})
+
+def error_inputs_neg(op_info, device, **kwargs):
+    si = SampleInput(torch.tensor((False, True), device=device))
+    msg = ("Negation, the `\\-` operator, on a bool tensor is not supported."
+           " If you are trying to invert a mask, use the `\\~` or"
+           " `logical_not\\(\\)` operator instead.")
+    yield ErrorInput(si, error_regex=msg)
+
+def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
+    yield SampleInput(make_arg(M))
+
+    tensors = (
+        make_arg((M, M)),
+        make_arg((3, 5)),
+        make_arg((5, 3)),
+    )
+
+    args = ((), (2,), (-2,), (1,), (2,))
+
+    for tensor, arg in product(tensors, args):
+        yield SampleInput(tensor.clone().requires_grad_(requires_grad), *arg)
+
+def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_diagonal_diag_embed(
+        op_info, device, dtype, requires_grad, **kwargs)
+
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    shapes1d = ((0,), (1,))
+    shapes2d = ((L, M),)
+    shapes3d = ((L, M, S),)
+
+    kwargs1d = {}
+
+    kwargs2d = (
+        # dim1 > dim2 is allowed
+        dict(dim1=1, dim2=0),
+        # negative dims are allowed
+        dict(dim1=-2, dim2=-1),
+        # one dim negative and the other nonnegative is allowed
+        dict(dim1=-1, dim2=0),
+        # out of bounds offset should return an empty tensor in diagonal and
+        # offset the diagonal in diag_embed
+        dict(offset=100),
+    )
+
+    kwargs3d = kwargs2d + (
+        # make sure we can use non-sequential dims
+        dict(offset=-1, dim1=0, dim2=2),
+    )
+
+    samples1d = product(shapes1d, kwargs1d)
+    samples2d = product(shapes2d, kwargs2d)
+    samples3d = product(shapes3d, kwargs3d)
+
+    for shape, kwargs in chain(samples1d, samples2d, samples3d):
+        if 'diagonal' in op_info.name:
+            # these are error inputs for diagonal
+            if shape in ((0,), (1,)):
+                continue
+        yield SampleInput(input=make_arg(shape), kwargs=kwargs)
+
+
+def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    # Shapes for 2D Tensors
+    shapes_2d = ((M, M), (3, 5), (5, 3))
+
+    # Shapes for 3D Tensors
+    shapes_3d = ((M, M, M),)
+
+    args_2d = ((), (2,), (-2,), (1,))
+    args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1))
+
+    for input_shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)):
+        input_ = make_arg(input_shape)
+        # We can programmatically figure out the right shape for src:
+        # It should be the same size as input.diagonal(other_args...)
+        if not isinstance(arg, tuple):
+            arg_tuple = (arg,)
+        else:
+            arg_tuple = arg
+        src_shape = input_.diagonal(*arg_tuple).size()
+        src = make_arg(src_shape)
+        yield SampleInput(input_, args=(src, *arg_tuple))
+
+
+def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(make_arg((S, S))).with_metadata(output_process_fn_grad=lambda x: x.to_dense())
+    yield SampleInput(make_arg((S, S)), 1).with_metadata(output_process_fn_grad=lambda x: x.to_dense())
+
+def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs):
+    batch_size, num_classes = shape = (2, 3)
+    reductions = ("mean", "sum", "none")
+
+    input_shape_and_kwargs: list[tuple[tuple[int, ...], dict[str, Any]]] = [
+        (shape, {}),
+        ((*shape, 1), {}),
+        ((*shape, 1, 2), {}),
+        ((*shape, 1, 2, 3), {}),
+        *[(shape, dict(reduction=reduction)) for reduction in reductions],
+        *[
+            (
+                shape,
+                dict(
+                    weight=make_tensor((num_classes,), device=device, dtype=dtype),
+                    reduction=reduction,
+                ),
+            )
+            for reduction in reductions
+        ],
+        (shape, dict(ignore_index=1)),
+    ]
+
+    for (input_shape, kwargs), probabilities_target in itertools.product(input_shape_and_kwargs, (False, True)):
+        input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad)
+
+        if probabilities_target:
+            # ignore_index is not supported for probabilities target
+            if "ignore_index" in kwargs:
+                continue
+
+            target = make_tensor(
+                input_shape,
+                low=0,
+                high=1,
+                device=device,
+                dtype=dtype,
+                requires_grad=requires_grad,
+            )
+        else:
+            target = make_tensor(
+                (batch_size, *input_shape[2:]),
+                low=0,
+                high=num_classes,
+                device=device,
+                dtype=torch.long,
+            )
+
+            if "ignore_index" in kwargs and torch.all(target == kwargs["ignore_index"]):
+                # make sure at least one item in target is not ignored
+                target[0] = random.sample(sorted(set(range(num_classes)) - {kwargs["ignore_index"]}), 1)[0]
+
+        yield SampleInput(input, target, **kwargs)
+
+
+def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs):
+    low, high = op_info.domain
+
+    # Note: Operator is very sensitive at points near the
+    # start and end of domain and leads to NaN for float16
+    # if domain_eps is 1e-5.
+    if dtype.is_floating_point or dtype.is_complex:
+        domain_eps = op_info._domain_eps if dtype != torch.float16 else 3e-2
+
+        low = low + domain_eps
+        high = high - domain_eps
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
+
+    yield SampleInput(make_arg((S, S, S)))
+    yield SampleInput(make_arg((S, S, S)), 0.2)
+    yield SampleInput(make_arg(()))
+    yield SampleInput(make_arg(()), 0.2)
+
+def sample_inputs_isin(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    # isin has two paths based on the size of elements and test_elements.
+    # if elements.numel() < 10 * pow(test_elements.numel(), 0.145):
+    yield SampleInput(make_arg((L,)), args=(make_arg((S,)),))
+    # else:
+    yield SampleInput(make_arg((S,)), args=(make_arg((L,)),))
+
+def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))))
+    yield SampleInput(make_arg((S, S)), args=(torch.randn((S,), device=device) > 0, make_arg((S, S))))
+    yield SampleInput(make_arg((S, S)), args=(bernoulli_scalar().to(device), make_arg((S, S))))
+    yield SampleInput(make_arg((S,)),
+                      args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))),
+                      broadcasts_input=True)
+
+def error_inputs_masked_scatter(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float)
+    for mask_dtype in [torch.float, torch.uint8]:
+        yield ErrorInput(SampleInput(make_arg(1, 3), args=(torch.ones(1, 3, device=device, dtype=mask_dtype),
+                                                           make_arg(3, 4))),
+                         error_regex=r"masked_scatter_ only supports boolean masks")
+
+def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10))
+    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg(())))
+    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10))
+    yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10))
+    yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, make_arg(())))
+    yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10))
+
+    yield SampleInput(make_arg((S,)),
+                      args=(torch.randn(S, S, device=device) > 0, make_arg(())),
+                      broadcasts_input=True)
+    yield SampleInput(make_arg((S,)),
+                      args=(torch.randn(S, S, device=device) > 0, 10),
+                      broadcasts_input=True)
+
+    if torch.device(device).type == 'cuda':
+        # `self` and `mask` on CUDA but `value` is a CPU scalar tensor.
+        yield SampleInput(make_arg((S, S)),
+                          args=(torch.randn(S, S, device=device) > 0,
+                                make_tensor((), device="cpu", dtype=dtype)))
+
+def error_inputs_masked_fill(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
+    # `value` is not a 0-D tensor.
+    yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, make_arg((1,)))),
+                     error_regex="only supports a 0-dimensional value tensor, but got tensor with 1 dimension")
+    # downcasting complex value (scalar overload)
+    yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, 1j)),
+                     error_regex=r"value cannot be converted to type .* without overflow")
+    # downcasting complex value (tensor overload)
+    yield ErrorInput(SampleInput(torch.ones(2, dtype=torch.long, device=device),
+                                 args=(make_arg(()) > 0, torch.tensor(1j, device=device))),
+                     error_regex=r"value cannot be converted to type .* without overflow")
+
+    if torch.device(device).type == 'cuda':
+        # `self` and `mask` on CPU but `value` is a CUDA scalar tensor.
+        yield ErrorInput(SampleInput(torch.randn((S, S), device='cpu'),
+                                     args=(torch.randn(S, S, device='cpu') > 0,
+                                           torch.randn((), device='cuda'))),
+                         error_regex=r"to be on same device")
+
+
+def sample_inputs_masked_select(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
+
+    yield SampleInput(make_arg((M, M)), torch.randn(M, M, device=device) > 0)
+
+    yield SampleInput(make_arg((M, M)), torch.randn((M,), device=device) > 0)
+    yield SampleInput(make_arg((M,)), torch.randn((M, M), device=device) > 0)
+
+    yield SampleInput(make_arg((M, 1, M)), torch.randn((M, M), device=device) > 0)
+
+    yield SampleInput(make_arg(()), torch.tensor(1, device=device, dtype=torch.bool))
+
+    yield SampleInput(make_arg((M, M)), torch.tensor(1, device=device, dtype=torch.bool))
+
+    yield SampleInput(make_arg(()), torch.randn((M, M), device=device) > 0)
+
+def sample_inputs_matrix_exp(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(make_arg((S, S)))
+    yield SampleInput(make_arg((S, S, S)))
+
+def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None,
+                       high=None, requires_grad=requires_grad)
+    test_cases = (((L,), (L,)),
+                  ((S, M), (M,)),
+                  ((M,), (M, S)),
+                  ((S, M), (M, S)),
+                  ((S, 0), (0, M)),
+                  ((S, S, M), (M,)),
+                  ((S, S, M), (M, S)),
+                  ((S, S, 0), (0, S)),
+                  ((M,), (S, M, S)),
+                  ((S, M), (S, M, S)),
+                  ((0, 0), (S, 0, 0)),
+                  ((S, S, M, M), (S, S, M, S)),
+                  ((S, S, M, M), (M,)),
+                  ((M,), (S, S, M, S)),
+                  ((S, S, S), (1, S, S))
+                  )
+    for lhs_shape, rhs_shape in test_cases:
+        lhs = make_arg(lhs_shape)
+        rhs = make_arg(rhs_shape)
+        if not is_rmatmul:
+            yield SampleInput(lhs, rhs)
+        else:
+            yield SampleInput(rhs, lhs)
+
+
+def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype,
+                           requires_grad: bool,
+                           *, variant: str, **kwargs) -> list[SampleInput]:
+    if variant == 'variadic':
+        def make_inputs(
+                tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor,
+                                                            list[torch.Tensor]],
+                                                      tuple[torch.Tensor, ...]]:
+            return tensors
+    elif variant == 'list':
+        def make_inputs(
+                tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor,
+                                                            list[torch.Tensor]],
+                                                      tuple[torch.Tensor, ...]]:
+            return [tensors]
+    else:
+        raise ValueError(
+            'Unsupported variant, must be one of {"variadic", "list"}. '
+            f'Got "{variant}".')
+
+    SCALAR = torch.Size([])
+    VECTOR = torch.Size([3])
+    test_cases: list[list[torch.Size]] = [
+        [SCALAR],
+        [VECTOR],
+        [VECTOR, SCALAR],
+        [VECTOR, SCALAR, VECTOR],
+        [VECTOR, SCALAR, VECTOR, SCALAR],
+    ]
+
+    for shapes, indexing in itertools.product(test_cases, {'xy', 'ij'}):
+        args = make_inputs(
+            [make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+             for shape in shapes])
+        yield SampleInput(*args, indexing=indexing)
+
+
+def sample_inputs_mvlgamma(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    tensor_shapes = ((S, S), ())
+    ns = (1, 2, 3, 4, 5)
+
+    # Since the accepted lower bound for input
+    # to mvlgamma depends on `p` argument,
+    # the following function computes the lower bound
+    # which we pass to `make_tensor`.
+    def compute_min_val(p):
+        return (p - 1.) / 2
+
+    for shape, n in product(tensor_shapes, ns):
+        min_val = compute_min_val(n)
+        if not dtype.is_floating_point:
+            # Round-up minimum value for integral dtypes
+            min_val += 1
+        else:
+            min_val += 2 * torch.finfo(dtype).eps
+        yield SampleInput(make_arg(shape, low=min_val), args=(n,))
+
+
+# Since `mvlgamma` has multiple entries,
+# there are multiple common skips for the additional
+# entries. Following function is a helper to that end.
+def skips_mvlgamma(skip_redundant=False):
+    skips = (
+        # outside domain values are hard error for mvlgamma op.
+        DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'),
+        DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
+                     'test_reference_numerics_extremal'),
+        DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                     'test_reference_numerics_large',
+                     dtypes=(torch.float16, torch.int8)),
+        DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                     'test_reference_numerics_small',
+                     dtypes=(torch.int8,)),
+    )
+    if skip_redundant:
+        # Redundant tests
+        skips = skips + (  # type: ignore[assignment]
+            DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
+        )
+    return skips
+
+
+# To test reference numerics against multiple values of argument `p`,
+# we make multiple OpInfo entries with each entry corresponding to different value of p.
+# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing.
+def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs):
+    return UnaryUfuncInfo('mvlgamma',
+                          ref=reference_mvlgamma if TEST_SCIPY else None,
+                          aliases=('special.multigammaln',),
+                          variant_test_name=variant_test_name,
+                          domain=domain,
+                          decorators=(precisionOverride({torch.float16: 5e-2}),),
+                          dtypes=all_types_and(torch.half, torch.bfloat16),
+                          dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+                          sample_inputs_func=sample_inputs_mvlgamma,
+                          supports_forward_ad=True,
+                          supports_fwgrad_bwgrad=True,
+                          promotes_int_to_float=True,
+                          skips=skips,
+                          sample_kwargs=sample_kwargs)
+
+
+def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs):
+    def _make_tensor_helper(shape, low=None, high=None):
+        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
+
+    yield SampleInput(_make_tensor_helper((S, S, S)), 0)
+    yield SampleInput(_make_tensor_helper((S, S, S)), 1)
+    yield SampleInput(_make_tensor_helper(()), 0)
+
+    if supports_dtype_kwargs:
+        # NOTE: if `dtype` is not same as input, then inplace variants fail with
+        # `provided dtype must match the dtype of self tensor in cumsum`
+        yield SampleInput(_make_tensor_helper((S, S, S)), 1, dtype=dtype)
+
+
+def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs):
+    test_cases = (
+        ((), (0, 1, 1)),
+        ((S, S, S, S), (0, 3, 1)),
+        ((S, S, S, S), (1, 3, 1)),
+        ((S, S, S, S), (2, 3, 1)),
+        ((S, S, S, S), (3, 3, 1)),
+        ((S, S, S, S), (0, 3, 2)),
+        ((S, S, S, S), (1, 3, 2)),
+        ((S, S, S, S), (2, 3, 2)),
+        ((S, S, S, S), (3, 3, 2)),
+        ((S, S, S, S), (0, 4, 1)),
+        ((S, S, S, S), (1, 4, 1)),
+        ((S, S, S, S), (2, 4, 1)),
+        ((S, S, S, S), (3, 4, 1)),
+        ((M,), (0, 3, 1)),
+        ((M,), (0, 3, 2)),
+        ((M,), (0, 3, 3)),
+        ((1000,), (0, 3, 11)),
+        ((1000,), (0, 2, 27)),
+        ((10, 10), (0, 1, 2)),
+        ((10, 10), (1, 2, 3)),
+        ((10, 10), (1, 2, 2)),
+        ((S, S, S), (2, 3, 2)),
+    )
+
+    for shape, arguments in test_cases:
+        yield SampleInput(make_tensor(shape, dtype=dtype, device=device,
+                                      low=None, high=None,
+                                      requires_grad=requires_grad),
+                          *arguments)
+
+def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=False, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    if list_args:
+        cases = (
+            ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)),
+            ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),),
+            ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),)
+        )
+    else:
+        cases = (  # type: ignore[assignment]
+            ((S, S, S), (2,)),
+            ((S, S, S), (S, 1)),
+        )
+
+    for shape, args in cases:
+        yield SampleInput(make_arg(shape), args=args)
+
+
+def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)),
+             ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)),
+             ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)),
+             ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)),
+             )
+
+    for shape, args in cases:
+        yield SampleInput(make_arg(shape), args=args)
+
+
+def sample_inputs_msort(op_info, device, dtype, requires_grad, **kwargs):
+    def apply_grad(t):
+        if dtype in floating_types_and(torch.float16, torch.bfloat16):
+            t.requires_grad_(requires_grad)
+
+    def large_1d_unique(dtype, device):
+        res = torch.randperm(L * L * L, dtype=torch.int64, device=device)
+        res = res.to(dtype)
+        apply_grad(res)
+        return res
+
+    # Test case for large tensor.
+    yield SampleInput(large_1d_unique(dtype, device))
+
+    yield SampleInput(make_tensor((S, M, S), dtype=dtype, device=device,
+                                  low=None, high=None,
+                                  requires_grad=requires_grad))
+
+def sample_inputs_lerp(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    # no broadcast
+    yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4)
+    # broadcast rhs
+    yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4)
+    # scalar tensor
+    yield SampleInput(make_arg(()), make_arg(()), 0.4)
+    # broadcast rhs scalar-tensor
+    yield SampleInput(make_arg((S, S)), make_arg(()), 0.4)
+    # broadcast rhs with weight tensor
+    yield SampleInput(make_arg((S, S)), make_arg((S,)), make_arg((S, S)))
+    # broadcast rhs and weight tensor
+    yield SampleInput(make_arg((S, S)), make_arg((S, 1)), make_arg((S,)))
+    # broadcast lhs
+    yield SampleInput(make_arg((S,)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
+    # scalar broadcast_lhs
+    yield SampleInput(make_arg(()), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
+    # broadcast all
+    yield SampleInput(make_arg((S, 1)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
+    # tensor broadcast all
+    yield SampleInput(make_arg((S, 1)), make_arg((S, S)), make_arg((S, 1))).with_metadata(
+        broadcasts_input=True)
+    # no broadcast with weight tensor
+    yield SampleInput(make_arg((S, S)), make_arg((S, S)), make_arg((S, S)))
+    # broadcast lhs with weight tensor
+    yield SampleInput(make_arg((S,)), make_arg((S, S)), make_arg((S, S))).with_metadata(
+        broadcasts_input=True)
+    # broadcast lhs and weight tensor
+    yield SampleInput(make_arg((S,)), make_arg((S, S, S)), make_arg((S, S))).with_metadata(
+        broadcasts_input=True)
+    # broadcast lhs and weight tensor variant
+    yield SampleInput(make_arg((S, S)), make_arg((S, S, S)), make_arg((S,))).with_metadata(
+        broadcasts_input=True)
+
+    if dtype.is_complex:
+        # no broadcast
+        yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4j)
+        yield SampleInput(make_arg((S, S)), make_arg((S, S)), 1.2 + 0.1j)
+        # broadcast rhs
+        yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4j)
+        yield SampleInput(make_arg((S, S)), make_arg((S, S)), 5.4 + 9j)
+        # scalar tensor
+        yield SampleInput(make_arg(()), make_arg(()), 0.4j)
+        yield SampleInput(make_arg(()), make_arg(()), 6.1 + 0.004j)
+        # broadcast rhs scalar-tensor
+        yield SampleInput(make_arg((S, S)), make_arg(()), 0.4j)
+        yield SampleInput(make_arg((S, S)), make_arg(()), 1 + 2j)
+
+def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs):
+    cases = (
+        ((2, 2, 2), (2, 2, 2), (2)),
+        ((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])),
+        ((1, 1, 1), (2, 1, 2), ([0, 1], [2, 0])),
+    )
+    for first_shape, second_shape, dims in cases:
+        yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device,
+                                      requires_grad=requires_grad, low=-1, high=+2),
+                          make_tensor(second_shape, dtype=dtype, device=device,
+                                      requires_grad=requires_grad, low=-1, high=+2),
+                          dims=dims)
+
+def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None)
+    test_cases = (
+        ((S, S), (M, L)),
+    )
+
+    for input_shape, other_shape in test_cases:
+        input = make_arg(input_shape)
+        other = make_arg(other_shape)
+        yield SampleInput(input, other)
+
+def sample_inputs_inner(self, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    yield SampleInput(make_arg(S), make_arg(S))
+    yield SampleInput(make_arg(), make_arg(S, S))
+
+def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs):
+    def _tensor(shape, dtype=dtype, low=None, high=None):
+        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
+
+    def _gather(shape, index_dim, max_indices):
+        return gather_variable(shape, index_dim, max_indices, device=device)
+
+    zero = torch.tensor(0, dtype=torch.long, device=device)
+    test_cases = (
+        (_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))),
+        (_tensor((M, S)), (0, _gather((S, S), 1, M).to(torch.int32), _tensor((S, S)))),
+        (_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))),
+        (_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))),
+        (_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))),
+        (_tensor((M, S)), (1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))),
+        (_tensor((M, S)), (-1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))),
+        (_tensor(()), (0, zero.detach().clone(), _tensor(()))),
+        (_tensor(()), (0, zero.detach().clone(), 2.5)),
+    )
+
+    for tensor, args in test_cases:
+        yield SampleInput(tensor, *args)
+
+        if not requires_grad:
+            yield SampleInput(tensor.detach().clone(), *args, reduce='add')
+
+            if dtype.is_floating_point:
+                yield SampleInput(tensor.detach().clone(), *args, reduce='multiply')
+
+def sample_inputs_scatter_add(op_info, device, dtype, requires_grad, **kwargs):
+    def _tensor(shape, dtype=dtype, low=None, high=None):
+        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
+
+    def _gather(shape, index_dim, max_indices):
+        return gather_variable(shape, index_dim, max_indices, device=device)
+
+    zero = torch.tensor(0, dtype=torch.long, device=device)
+    yield SampleInput(_tensor((M, S)), 0, _gather((S, S), 1, M), _tensor((S, S)))
+    yield SampleInput(_tensor((M, S)), 1, _gather((S, S), 0, S), _tensor((S, S)))
+    yield SampleInput(_tensor((M, S)), -1, _gather((S, S), 0, S), _tensor((S, S)))
+    yield SampleInput(_tensor((M, S)), 0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))
+    yield SampleInput(_tensor((M, S)), 1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))
+    yield SampleInput(_tensor((M, S)), -1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))
+    yield SampleInput(_tensor(()), 0, zero.detach().clone(), _tensor(()))
+
+def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    gather = partial(gather_variable, device=device)
+
+    zero = torch.tensor(0, dtype=torch.long, device=device)
+    test_cases = (
+        ((M, S), 0, gather((S, S), 1, M), (S, S)),
+        ((M, S), 1, gather((S, S), 0, S), (S, S)),
+        ((M, S), -1, gather((S, S), 0, S), (S, S)),
+        ((M, S), 0, gather((M, S // 2), 1, M), (M, S // 2)),
+        ((M, S), 1, gather((M, S // 2), 0, S), (M, S // 2)),
+        ((M, S), -1, gather((M, S // 2), 0, S), (M, S // 2)),
+        ((), 0, zero.detach().clone(), ()),
+    )
+
+    reduce = op_info.variant_test_name
+    for (inp_shape, dim, index, src_shape), include_self in product(test_cases, [False, True, False]):
+        yield SampleInput(make_arg(inp_shape),
+                          args=(dim, index, make_arg(src_shape), reduce),
+                          kwargs={'include_self': include_self})
+
+
+    # Sample inputs to test edge cases for backward
+    # Check that gradients are propagated correctly for prod when zeros in self/src are reduced
+    if requires_grad and reduce == 'prod':
+        # This sample tests gradients for the following cases
+        # (a) 1 zero reduced (from src (self[0, 1], self[1, 1]), from self (self[0, 0], self[2, 0]))
+        # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0])
+        # (c) no zeros reduced (self([2, 1]))
+        # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py
+        #     test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad
+        input = torch.tensor([[0, 13], [0, 17], [0, 19]], dtype=dtype, device=device, requires_grad=requires_grad)
+        src = torch.tensor([[0, 1, 2, 3], [0, 4, 0, 1], [2, 3, 5, 6]], dtype=dtype, device=device, requires_grad=requires_grad)
+        idx = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.long, device=device)
+
+        yield SampleInput(input,
+                          args=(1, idx, src, reduce),
+                          kwargs={'include_self': True})
+
+def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs):
+    def _tensor(shape, dtype=dtype, low=None, high=None):
+        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
+
+    test_cases = (
+        # inp_shape, dim, lengths, unsafe
+        ((S,), 0, [0, 1, 2, 2], False),
+        ((S,), 0, [0, 1, 2, 2], True),
+        ((S,), 0, [2, 0, 3, 0], False),
+        ((S, S), 0, [0, 1, 2, 2], False),
+        # test when lengths do not sum to dim size
+        ((M, S, S), 0, [1, 2, 0, 6, 0], True),
+        # test for higher dimensions
+        ((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
+        ((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
+        ((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
+        ((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
+    )
+
+    reductions = ["max", "mean", "min", "sum", "prod"]
+    for args, reduce, initial in product(test_cases, reductions, [1, 2]):
+        inp_shape, dim, lengths, unsafe = args
+        lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)
+        sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial}
+        if mode == 'lengths':
+            sample_input_kwargs['lengths'] = lengths_t
+        elif mode == 'offsets':
+            zeros_shape = list(lengths_t.shape)
+            zeros_shape[dim] = 1
+            offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim)
+            sample_input_kwargs['offsets'] = offsets_t
+        else:
+            raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.")
+        yield SampleInput(_tensor(inp_shape),
+                          args=(reduce,),
+                          kwargs=sample_input_kwargs)
+
+
+def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device,
+                       low=None, high=None, requires_grad=requires_grad)
+    yield SampleInput(make_arg((S, S, S)))
+    yield SampleInput(make_arg(()))
+    yield SampleInput(make_arg((S, S, S), noncontiguous=True))
+
+def sample_inputs_unravel_index(op_info, device, dtype, requires_grad, **kwargs):
+    yield SampleInput(
+        torch.tensor(
+            [[3, 8, 13], [0, 5, 10]],
+            device=device,
+            dtype=dtype),
+        (4, 5))
+    yield SampleInput(
+        torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype),
+        (4, 2**30))
+    yield SampleInput(
+        torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype),
+        (2**30, 4))
+    yield SampleInput(
+        torch.tensor(2, device=device, dtype=dtype),
+        (2, 2))
+    max_val = 2**(8 * dtype.itemsize - (1 if dtype.is_signed else 0)) - 1
+    yield SampleInput(
+        torch.tensor(max_val - 1, device=device, dtype=dtype),
+        (1, max_val))
+    yield SampleInput(
+        torch.tensor([22, 41, 37], device=device, dtype=dtype),
+        (7, 6))
+    yield SampleInput(
+        torch.tensor(min(1621, max_val), device=device, dtype=dtype),
+        (6, 7, 8, 9))
+    yield SampleInput(
+        torch.tensor([], device=device, dtype=dtype),
+        (10, 3, 5))
+    yield SampleInput(
+        torch.tensor(
+            [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]],
+            device=device,
+            dtype=dtype),
+        (5, 8))
+    yield SampleInput(
+        torch.tensor(
+            [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]],
+            device=device,
+            dtype=dtype),
+        (5, 8, 10))
+    yield SampleInput(
+        torch.tensor(0, device=device, dtype=dtype),
+        ())
+
+    a = np.array([[2, 4, 5, 6], [7, 8, 1, 15]])
+    b = np.array([[3, 2, 7, 6], [10, 12, 8, 9]])
+    _, i1, i2 = np.intersect1d(a, b, assume_unique=True, return_indices=True)
+    yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape)
+    yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape)
+
+    a = np.array([[2, 4, 5, 6, 6], [4, 7, 8, 7, 2]])
+    b = np.array([[3, 2, 7, 7], [10, 12, 8, 7]])
+    _, i1, i2 = np.intersect1d(a, b, return_indices=True)
+    yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape)
+    yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape)
+
+
+def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    cases = (((M, M), ()),
+             ((M, M), (2,),),
+             ((M, S), ()),
+             ((M, S), (-1,)),
+             ((M, M), (2,),),
+             ((S, M, S), ()),
+             ((S, M, S), (2,)),
+             ((3, 3, S, S), ()),)
+
+    for shape, args in cases:
+        yield SampleInput(make_arg(shape), args=args)
+
+def error_inputs_tril_triu(opinfo, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # error inputs for input.ndim <= 2
+    yield ErrorInput(SampleInput(make_arg((4,))), error_regex="input tensor must have at least 2 dimensions")
+
+def sample_inputs_trilu_indices(op_info, device, dtype, requires_grad, **kwargs):
+    # (row, col, offset)
+    args_list = ((0, 0),
+                 (20, 0),
+                 (0, 20),
+                 (20, 21, 0),
+                 (20, 21, 7),
+                 (20, 21, -7),
+                 # Large test cases below are deliberately commented out to speed up CI
+                 # tests and to avoid OOM error. When modifying implementations of
+                 # tril_indices and triu_indices, please enable these tests and make sure
+                 # they pass.
+                 # (2, 68435455, 3),
+                 # (5000, 5000),
+                 # (5000, 5000, 1234),
+                 # (5000, 5000, -1233),
+                 )
+    for args in args_list:
+        yield SampleInput(args[0], args=args[1:], kwargs={"dtype": dtype, "device": device})
+
+def sample_inputs_clone_contiguous(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    yield SampleInput(make_arg((S, M, S)))
+    yield SampleInput(make_arg(()))
+
+def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs):
+    # NOTE: the default memory format for clone is torch.preserve_format, for contiguous it's torch.contiguous_format
+    # This exploits that default to test torch.preserve_format for clone, without causing an error when testing contiguous
+    yield from sample_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs)
+
+    shapes = (
+        (3, 5, 6),
+        (1, 1, 3, 5, 6),
+        (1, 1, 3, 5, 6, 1, 1),
+        (1, 0, 3, 5, 0, 2),
+        (1, 0, 3, 5, 0, 0, 1, 1, 2),
+        (),
+    )
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for shape in shapes:
+        yield SampleInput(make_arg(shape))
+        yield SampleInput(make_arg(shape).transpose(0, -1))
+        yield SampleInput(make_arg(shape, noncontiguous=True))
+        yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1))
+
+        yield SampleInput(make_arg(shape), kwargs={'memory_format': torch.contiguous_format})
+        yield SampleInput(make_arg(shape).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format})
+        yield SampleInput(make_arg(shape, noncontiguous=True), kwargs={'memory_format': torch.contiguous_format})
+        yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format})
+
+    # shape, strides, offset
+    strided_cases = (
+        ((5, 6, 2), (1, 1, 7), 2),
+        ((5, 5, 4), (1, 1, 7), 2),
+        ((5, 5, 2), (4, 5, 7), 3),
+        ((5, 5, 2), (5, 5, 7), 3),
+        ((5, 5, 2), (5, 5, 5), 3),
+        ((9, 5, 2), (0, 1, 7), 3),
+    )
+
+    for shape, strides, offset in strided_cases:
+        yield SampleInput(make_arg(500,).as_strided(shape, strides, offset))
+        yield SampleInput(make_arg(500,).as_strided(shape, strides, offset), kwargs={'memory_format': torch.contiguous_format})
+
+    # channels last 2D
+    yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last})
+    a = make_arg((2, 2, 2, 2)).permute(0, 3, 1, 2)
+    yield SampleInput(a, kwargs={'memory_format': torch.channels_last})
+
+    # channels last 3D
+    yield SampleInput(make_arg((2, 2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last_3d})
+    a = make_arg((2, 2, 2, 2, 2)).permute(0, 4, 1, 2, 3)
+    yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d})
+
+
+def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    # list of tuples (shape, shape) defining the shapes of the input and output tensors
+    sample_shapes = [
+        ((), ()),
+        ((S,), (1,)),
+        ((S, S), (1, 1)),
+        ((S, S), (1, S)),
+        ((S, S), (S, S)),
+        ((S, S, S), (S, 1, S)),
+    ]
+
+    for input_shape, output_shape in sample_shapes:
+        yield SampleInput(make_arg(input_shape), args=(output_shape,))
+        if output_shape == ():
+            continue
+        yield SampleInput(make_arg(input_shape), args=(list(output_shape),))
+        yield SampleInput(make_arg(input_shape), args=(*output_shape,))
+
+
+def error_inputs_sum_to_size(op_info, device, **kwargs):
+    shape = (M, S, M)
+    err_msg = "is not expandable to size"
+    si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M, M))
+    yield ErrorInput(si, error_regex=err_msg)
+
+    shape = (M + 1, S, S, M)
+    err_msg = "is not expandable to size"
+    si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M + 1, 1))
+    yield ErrorInput(si, error_regex=err_msg)
+
+
+def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device)
+    cases = (((S, S, S), (S * S, S)),
+             ((), ()),
+             ((), (1, 1, 1)),
+             )
+
+    for shape, args_or_shape in cases:
+        # Update `args` based on operator
+        if op_info.name == 'resize_':
+            # resize_ takes shape/tuple of ints,
+            args = (args_or_shape, )
+        elif op_info.name == 'resize_as_':
+            # resize_as_ takes another tensor
+            args = (make_arg(shape, requires_grad=False), )  # type:ignore[assignment]
+        else:
+            raise ValueError("sample_inputs_resize_ops is being used with incorrect operator")
+
+        yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args)
+
+def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (
+        # a, b, is_tensor_supported
+        ((S, S, S), (S * S, S), True),
+        ((S * S, S), (S, S, S), True),
+        ((S * S, S), (S, -1, S), False),  # neg index
+        ((S * S * 2, S), (S, -1), False),  # neg index
+        ((S,), (S,), True),
+        ((), (), False),  # empty
+        ((), (1,), True),
+    )
+
+    for a, b, is_tensor_supported in cases:
+        # skip unsupported cases
+        if kwargs.get("tensor_arg") and not is_tensor_supported:
+            continue
+
+        # convert to tensor
+        if kwargs.get("tensor_arg"):
+            b = make_arg(b, requires_grad=False)
+
+        yield SampleInput(make_arg(a), args=(b,))
+
+def reference_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        # a, b, is_tensor_supported
+        ((125,), (25, 5), True),
+        ((25, 25), (1, 5, 5, 1, 5, 1, 5, 1), True),
+        ((16, 32), (2, 4, 1, 4, 4, 1, 4), True),
+        ((16, 12), (12, 16), True),
+        ((1, 16, 12), (12, 16), True),
+        ((1, 5, 1, 5), (25, 1), True),
+        ((2, 4, 2), (4, 4), True),
+        ((1, 4), (1, 1, 2, 1, 2), True),
+        ((3, 5, 7), (7, 5, 3), True),
+        ((1,), (), False),  # empty
+        ((5, 0, 2, 3), (5, 0, 2, 3), True),
+        ((2, 1, 0, 3, 1), (5, 0), True),
+        ((1,), (), False),  # empty
+        ((4, 5, 6), (4, 5, 6, 1, 1, 1), True),
+        ((), (1, 1, 1, 1), False),  # empty
+    )
+
+    irreversible_cases = (
+        ((), (-1,), False),  # neg index, empty
+        ((4, 7, 9, 1, 1), (1, 4, 3, -1, 1), False),  # neg index
+    )
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for a, b, is_tensor_supported in cases:
+        # skip unsupported cases
+        if kwargs.get("tensor_arg") and not is_tensor_supported:
+            continue
+
+        if kwargs.get("tensor_arg"):
+            # convert to tensor
+            yield SampleInput(make_arg(a), args=(make_arg(b, requires_grad=False),))
+            yield SampleInput(make_arg(b), args=(make_arg(a, requires_grad=False),))
+        else:
+            yield SampleInput(make_arg(a), args=(b,))
+            yield SampleInput(make_arg(b), args=(a,))
+
+    for a, b, is_tensor_supported in irreversible_cases:
+        # skip unsupported cases
+        if kwargs.get("tensor_arg") and not is_tensor_supported:
+            continue
+
+        # convert to tensor
+        if kwargs.get("tensor_arg"):
+            b = make_arg(b, requires_grad=False)
+
+        yield SampleInput(make_arg(a), args=(b,))
+
+def error_inputs_view_reshape(op, device, **kwargs):
+
+    cases = (
+        # a, b, is_tensor_supported
+        # Reshape to different numel
+        ((2,), (), False),  # empty
+        ((1, 3, 0), (), False),  # empty
+        ((4, 3), (4, 2), True),
+        ((1, 3, 5), (5, 2, 2), True),
+        # No valid inference
+        ((1, 3, 5), (5, -1, 2), False),  # neg index
+        # Two inferred shapes
+        ((1, 3, 5), (5, -1, -1), False),  # neg index
+        ((1), (0, -1), False),  # neg index
+        ((0, 5), (0, -1), False),  # neg index
+    )
+
+    make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False)
+    for a, b, is_tensor_supported in cases:
+        # skip unsupported cases
+        if kwargs.get("tensor_arg") and not is_tensor_supported:
+            continue
+
+        if b == (5, -1, -1):
+            error_regex = "only one dimension can be inferred"
+        elif a == (0, 5):
+            error_regex = (r"cannot reshape tensor of 0 elements into shape "
+                           r"\[0, -1\] because the unspecified dimension size "
+                           r"-1 can be any value and is ambiguous")
+        else:
+            # to avoid having issues with a regex
+            shape = ', '.join(map(str, b))
+            size = a if type(a) is int else functools.reduce(operator.mul, a, 1)
+            error_regex = rf"shape '\[{shape}\]' is invalid for input of size {size}"
+
+        # convert to tensor
+        if kwargs.get("tensor_arg"):
+            b = make_arg(b, requires_grad=False)
+
+        yield ErrorInput(SampleInput(make_arg(a), args=(b,)), error_type=Exception,
+                         error_regex=error_regex)
+
+
+def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs):
+    shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),)
+    make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for shape in shapes:
+        yield SampleInput(make_tensor_partial(shape))
+    yield SampleInput([make_tensor_partial(shape) for shape in shapes])
+
+def sample_inputs_column_stack(op_info, device, dtype, requires_grad, **kwargs):
+    cases: tuple[tuple, tuple] = (  # type: ignore[assignment]
+        ((S, 2, 1), (S, 3, 1)),
+        ((S), (S, 5)), ((), (1, S))
+    )
+    make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for shape1, shape2 in cases:
+        yield SampleInput([make_tensor_partial(shape1), make_tensor_partial(shape2)])
+
+def sample_inputs_flatten(op_info, device, dtype, requires_grad, **kwargs):
+    shapes = ((S, S, S), (S, S), (S, ), (),)
+    make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for shape in shapes:
+        yield SampleInput(make_tensor_partial(shape))
+        if len(shape) > 1:
+            yield SampleInput(make_tensor_partial(shape), start_dim=1, end_dim=-1)
+
+def reference_inputs_flatten(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_flatten(op, device, dtype, requires_grad, **kwargs)
+
+    # shape x start_dim x end_dim
+    cases = (
+        ((5, 4, 0, 1, 3, 7), 1, 3),
+        ((5, 4, 0, 1, 3, 7), 4, 5),
+        ((5, 4, 1, 1, 3, 7), 2, 3),
+        ((), 0, -1),
+        ((1,), 0, -1),
+        ((3, 7, 5), 1, 2),
+        ((4, 5), 1, 1),
+        ((1, 5, 5, 1, 5, 1, 5, 1), 0, 2),
+        ((1, 5, 5, 1, 5, 1, 5, 1), 3, -1),
+        ((1, 5, 5, 1, 5, 7, 5, 1), -2, -1),
+        ((2, 4, 2), 0, 1),
+        ((4, 2, 2), 1, 2),
+        ((0, 3, 4, 5), 1, 3),
+    )
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for shape, start, end in cases:
+        yield SampleInput(make_arg(shape), args=(start, end,))
+        yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), args=(start, end,))
+        yield SampleInput(make_arg(shape).transpose(0, -1), args=(start, end,))
+
+def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs):
+    # in_shape, dim, sizes
+    args = (((8,), 0, (8,)),
+            ((8,), 0, (4, 2)),
+            ((8,), -1, (2, 2, 2)),
+            ((8,), -1, (-1, 2)),
+            ((3, 6, 2), 1, (2, 3)),
+            ((3, 6, 2), -2, (2, 3)),
+            ((3, 6, 2), -2, (-1, 3)),
+            ((3, 2, 12), 2, (3, 2, 2)),
+            ((4, 0), 0, (2, 2)),
+            ((4, 0), 1, (2, 0, 0, 0)),
+            )
+    make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    for in_shape, dim, sizes in args:
+        yield SampleInput(make_tensor_partial(in_shape), args=(dim, sizes))
+
+
+def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (((S, S, S), (1, 2)),
+             ((S, S, S), (-1, 2)),
+             ((S, S, S), (-1, -1)),
+             ((S, S, S), (1, -1)),
+             ((S, S), (-1, 2)),
+             ((S,), (0, 2))
+             )
+
+    for shape, args in cases:
+        yield SampleInput(make_arg(shape), args=args)
+
+
+def sample_inputs_select_scatter(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (((S, S, S), (S, S), (1, 2)),
+             ((S, S, S), (S, S), (-1, 2)),
+             ((S, S, S), (S, S), (-1, -1)),
+             ((S, S, S), (S, S), (1, -1)),
+             ((S,), (), (0, 2))
+             )
+
+    for input_shape, src_shape, args in cases:
+        input_ = make_arg(input_shape)
+        src = make_arg(src_shape)
+        yield SampleInput(input_, args=(src, *args))
+
+
+def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (((L, L, L), (L, L, L,), (0, 0, L, 1)),
+             ((L, L, L), (L // 2, L, L,), (0, L // 2, L, 1)),
+             ((L, L, L), (L // 4, L, L,), (0, L // 2, L, 2)),
+             ((L, L, L), (L, L, L,), (1, 0, L, 1)),
+             ((L, L, L), (L, L // 2, L,), (1, L // 2, L, 1)),
+             ((L, L, L), (L, L // 4, L,), (1, L // 2, L, 2)),
+             ((L, L, L), (L, L, L,), (2, 0, L, 1)),
+             ((L, L, L), (L, L, L // 2,), (2, L // 2, L, 1)),
+             ((L, L, L), (L, L, L // 4,), (2, L // 2, L, 2)),
+             )
+
+    for input_shape, src_shape, args in cases:
+        input_ = make_arg(input_shape)
+        src = make_arg(src_shape)
+        yield SampleInput(input_, args=(src, *args))
+
+def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (((S, 1, 1), (S, S, S)),
+             ((S, 1, S), (S, S, S)),
+             ((S, 1, S), (-1, S, -1)),
+             ((S, 1, S), (-1, S, S)),
+             ((S, 1), (S, S, S)),
+             ((1,), (S, S, S)),
+             ((1, S), (1, 1, S)),
+             ((), ()),
+             ((), (1, 3, 2)),
+             )
+
+    for case in cases:
+        shape, args = case
+        yield SampleInput(make_arg(shape), args=(args,))
+
+def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    shapes = ((),
+              (2, 3))
+    memory_format_options = [None, torch.contiguous_format]
+
+    for shape, memory_format in itertools.product(shapes, memory_format_options):
+        yield SampleInput(make_arg(shape),
+                          kwargs={'memory_format': memory_format} if memory_format else {})
+    yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last})
+
+def sample_inputs_byte(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, low=0, high=255, requires_grad=requires_grad)
+
+    shapes = ((),
+              (2, 3))
+    memory_format_options = [None, torch.contiguous_format]
+
+    for shape, memory_format in itertools.product(shapes, memory_format_options):
+        yield SampleInput(make_arg(shape),
+                          kwargs={'memory_format': memory_format} if memory_format else {})
+    yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last})
+
+def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device)
+
+    cases = (((S, 1, 1), (S, S, S)),
+             ((), ()),
+             ((), (1, 1)),
+             )
+
+    for shape, shape_other in cases:
+        yield SampleInput(make_arg(shape, requires_grad=requires_grad),
+                          args=(make_arg(shape_other, requires_grad=False),))
+
+
+def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    def make_bool_mask(shape):
+        # Make sure at least one element is nonzero,
+        # except for empty tensor
+        mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
+
+        if mask_t.numel() == 0:
+            return mask_t
+        elif mask_t.numel() == 1:
+            mask_t.fill_(True)
+            return mask_t
+
+        if mask_t.sum() == 0:
+            def random_index(shape):
+                return tuple(random.randrange(0, max_idx) for max_idx in shape)
+
+            mask_t[random_index(mask_t.shape)] = True
+            return mask_t
+
+        return mask_t
+
+    cases = (((M, M), (M, M), (M, M), False),
+             ((M, 1, M), (M, M), (M, M, 1), True),
+             ((), (), (), False),
+             ((M, 1, M), (), (M, M, 1), True),
+             ((), (M, M), (), True),
+             ((), (2), (1, 1), True),
+             )
+
+    for shape, mask_shape, other_shape, broadcasts_input in cases:
+        yield SampleInput(make_arg(shape),
+                          args=(make_bool_mask(mask_shape), make_arg(other_shape)),
+                          broadcasts_input=broadcasts_input)
+
+# TODO: add reference inputs for where(condition) signature
+def reference_inputs_where(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_where(op, device, dtype, requires_grad, **kwargs)
+
+    make_cond = partial(make_tensor, dtype=torch.bool, device=device, requires_grad=requires_grad)
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    # noncontiguous
+    c = make_cond((10, 3), noncontiguous=True)
+    a = make_arg((10, 1), noncontiguous=True)
+    b = make_arg((3, 10, 3)).transpose(0, -1)
+
+    # NOTE that the OpInfo for where takes samples of the form a, cond, b
+    yield SampleInput(a, args=(c, b))
+
+    # MPS does not support float64, which causes issues in the following tests
+    if torch.device(device).type == "mps":
+        return
+
+    # type promoting
+    # FIXME(rec): shouldn't other_dtype be used two lines below?
+    other_dtype = torch.double if dtype is not torch.double else torch.long  # noqa: F841
+    c = make_cond((10, 3), noncontiguous=True)
+    a = make_arg((10, 1), dtype=torch.long)
+    b = make_arg((10, 1))
+
+    yield SampleInput(a, args=(c, b))
+
+    # two python scalars
+    c = make_cond((10, 3), noncontiguous=True)
+    a = make_arg((1,)).item()
+    b = make_arg((1,)).item()
+
+    yield SampleInput(a, args=(c, b))
+
+    # NaN propagation
+    if dtype.is_floating_point or dtype.is_complex:
+        if dtype.is_floating_point:
+            nan = float('nan')
+        else:
+            # dtype.is_complex
+            nan = complex(float('nan'), float('nan'))
+        c = make_cond((1, 10, 3))
+        a = make_arg((10, 3), noncontiguous=True)
+        a[2, 1] = nan
+        b = make_arg((1, 3))
+        b[0, 2] = nan
+
+        yield SampleInput(a, args=(c, b))
+
+    # Python scalars type promotion
+    for scalar in (0, 0.0, 2j, False):
+        yield SampleInput(scalar, args=(c, b))
+        yield SampleInput(a, args=(c, scalar))
+
+
+def error_inputs_where(op_info, device, **kwargs):
+    shape = (S,)
+    err_msg = "Expected all tensors to be on the same device"
+    for devices in product(('cpu', device), repeat=3):
+        if len(set(devices)) == 2:
+            si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32),
+                             args=(make_tensor(shape, dtype=torch.bool, device=devices[1]),
+                             make_tensor(shape, device=devices[2], dtype=torch.float32)))
+            yield ErrorInput(si, error_regex=err_msg)
+
+def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
+
+    inputs = []
+    for shape in sizes:
+        # construct input without any non-zero elements
+        zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+        inputs.append(zeros)
+
+        # construct input with mixed zero and non-zero elements
+        mixed = make_arg(shape).requires_grad_(False)
+        mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
+        mixed[mask_t] = 0
+        inputs.append(mixed)
+
+    for input_t, as_tuple in product(inputs, [False, True]):
+        yield SampleInput(input_t.clone().requires_grad_(requires_grad),
+                          kwargs=dict(as_tuple=as_tuple))
+
+def sample_inputs_nonzero_static(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
+
+    inputs = []
+    for shape in sizes:
+        # construct input without any non-zero elements
+        zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
+        inputs.append(zeros)
+
+        # construct input with mixed zero and non-zero elements
+        mixed = make_arg(shape).requires_grad_(False)
+        mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
+        mixed[mask_t] = 0
+        inputs.append(mixed)
+
+    nonzero_sizes = [0, 1, XS, S, M]
+
+    for input_t, nonzero_size in product(inputs, nonzero_sizes):
+        yield SampleInput(input_t.clone().requires_grad_(requires_grad),
+                          kwargs=dict(size=nonzero_size))
+
+def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    cases = (((S, S, S), (2,)),
+             ((S, S, S), (S, 1)),
+             ((S, S, S), (S, -1)))
+
+    for case in cases:
+        shape, args = case
+        yield SampleInput(make_arg(shape), args=args)
+
+def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs)
+
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    # shape x chunks x dim
+    cases = (
+        ((13, 9, 11), 17, -1),
+        ((13, 9, 11), 11, -1),
+        ((13,), 12, -1),
+        ((15,), 12, -1),
+        ((15,), 7, 0),
+        ((15,), 9, 0),
+        ((3, 7), 9, 1),
+        ((3, 7), 9, 0),
+        ((3, 7), 2, 0),
+        ((3, 7), 3, 0),
+        ((3, 7), 1, 0),
+        ((3, 7), 1, 1),
+        ((4, 4), 2, 0),
+    )
+
+    for shape, chunks, dim in cases:
+        yield SampleInput(make_arg(shape), args=(chunks, dim))
+
+def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs):
+    def _tensor(shape, dtype=dtype, low=None, high=None):
+        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
+
+    test_cases = [
+        ((S, S, S), (2,)),
+        ((S, S, S), (2, 1,)),
+        ((S, S, S), (2, -1,)),
+        ((S, S, S), (2, 1, True,)),
+        ((S, S, S), (2, -1, True,)),
+        ((S,), (2, 0,)),
+        ((S,), (2, 0, True,)),
+        ((), (1,)),
+        ((), (1, 0,)),
+        ((), (1, 0, True)),
+    ]
+
+    yield from (SampleInput(_tensor(tensor), *args) for tensor, args in test_cases)
+
+def error_inputs_kthvalue(op_info, device, **kwargs):
+    # tests overlapping output fails
+    t = make_tensor(10, dtype=torch.float32, device=device)
+    indices = torch.empty((), device=device, dtype=torch.long)
+    yield ErrorInput(SampleInput(t, 5, out=(t, indices)),
+                     error_regex="unsupported operation")
+
+    k_out_of_range_err = "selected number k out of range for dimension"
+    yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3, 0),
+                     error_regex=k_out_of_range_err)
+    yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3),
+                     error_regex=k_out_of_range_err)
+    yield ErrorInput(SampleInput(torch.tensor(2, device=device), 3),
+                     error_regex=k_out_of_range_err)
+
+def sample_inputs_dropout(op_info, device, dtype, requires_grad, *,
+                          train=None, valid_input_dim=None, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    if valid_input_dim:
+        cases = ((S,) * i for i in valid_input_dim)
+    else:
+        cases = ((S, S), (S,), ())
+    p_vals = [0.0, 0.5, 1.0]
+    # This is to handle special case for feature_alpha_dropout which has different
+    # supported dtypes depending on `train` parameter
+    training_vals = [train] if train is not None else [True, False]
+
+    for case, p, training in product(cases, p_vals, training_vals):
+        yield SampleInput(make_arg(case), p=p, training=training)
+    yield SampleInput(make_arg(case))
+
+def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False)
+
+    cases = ((S, S, S, S), (S,), ())
+    scale_vals = [0.0, 1.0, 2.0]
+
+    for case, scale in product(cases, scale_vals):
+        yield SampleInput(make_arg(case), make_mask(case), scale)
+
+def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs):
+    def make_input(shape):
+        return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_long_input(shape, *, low, high, noncontiguous=False):
+        return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high,
+                           noncontiguous=noncontiguous)
+
+    def make_per_sample_weight(flag, idx):
+        # a tensor of float / double weights, or None
+        # to indicate all weights should be taken to be 1
+        if flag:
+            return make_input(idx.shape)
+        return None
+
+    offsets = torch.tensor([0, 3], device=device, dtype=torch.long)
+    for generate_per_sample_weight in (True, False):
+        for mode in ('sum', 'mean', 'max'):
+            # per_sample_weights is only supported for mode='sum' (got mode='****')
+            if generate_per_sample_weight and mode in ('mean', 'max'):
+                continue
+
+            # 1-D index tensor
+            idx = make_long_input((S,), low=0, high=M)
+            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+            yield SampleInput(make_input((M, S)), args=(idx,),
+                              kwargs={'offsets': offsets, 'mode': mode,
+                                      'per_sample_weights': per_sample_weights})
+
+            idx = make_long_input((S,), low=0, high=M, noncontiguous=True)
+            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+            yield SampleInput(make_input((M, S)), args=(idx,),
+                              kwargs={'offsets': offsets, 'mode': mode,
+                                      'per_sample_weights': per_sample_weights})
+
+            # bag with zero length
+            idx = make_long_input((S,), low=0, high=M, noncontiguous=True)
+            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+            yield SampleInput(make_input((M, S)), args=(idx,),
+                              kwargs={'offsets': torch.tensor([0, 0, 3], device=device, dtype=torch.long),
+                                      'mode': mode,
+                                      'per_sample_weights': per_sample_weights})
+
+            # 2-D index tensor
+            idx = make_long_input((S, S), low=0, high=M)
+            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+            yield SampleInput(make_input((M, S)), args=(idx,),
+                              kwargs={'mode': mode, 'per_sample_weights': per_sample_weights})
+
+            idx = make_long_input((S, S), low=0, high=M, noncontiguous=True)
+            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+            yield SampleInput(make_input((M, S)), args=(idx,),
+                              kwargs={'mode': mode, 'per_sample_weights': per_sample_weights})
+
+            # The gradient vector at `padding_idx` is not updated.
+            # Negative padding_idx
+            idx = make_long_input((6,), low=0, high=S)
+            idx[0] = 4
+            idx[4] = 4
+            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+            yield SampleInput(make_input((S, S)), args=(idx,),
+                              kwargs={'padding_idx': -1, 'offsets': offsets,
+                                      'mode': mode, 'per_sample_weights': per_sample_weights},)
+
+            idx = make_long_input((3, 3), low=0, high=S)
+            # Positive padding_idx
+            idx[0, 0] = 2
+            idx[1, 1] = 2
+            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+            yield SampleInput(make_input((S, S)), args=(idx,),
+                              kwargs={'padding_idx': 2, 'mode': mode,
+                                      'per_sample_weights': per_sample_weights},)
+
+            idx = make_long_input((6, ), low=0, high=S)
+            weights = make_input((S, S))
+            offsets_ = torch.tensor([0, 3, 6], device=device, dtype=torch.long)
+            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+            yield SampleInput(weights, args=(idx,),
+                              kwargs={'mode': mode, 'offsets': offsets_, 'include_last_offset': True},)
+
+            if not requires_grad:
+                # Following inputs return different gradient from the numerical gradient.
+                # This is expected and relevant tests are present in `test_nn.py`.
+
+                # Due to inplace renorming of weight, the numerical gradient doesn't match the
+                # analytical gradient.
+                idx = make_long_input((2, 2), low=0, high=S)
+                weights = make_input((S, S)) * 2
+                per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+                yield SampleInput(weights, args=(idx,),
+                                  kwargs={'max_norm': 1., 'mode': mode,
+                                          'per_sample_weights': per_sample_weights},)
+
+                idx = make_long_input((6, ), low=0, high=S)
+                weights = make_input((S, S)) * 2
+                per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+                yield SampleInput(weights, args=(idx,),
+                                  kwargs={'max_norm': 1., 'norm_type': 1.0,
+                                          'mode': mode, 'offsets': offsets,
+                                          'per_sample_weights': per_sample_weights},)
+
+                if mode != 'max':
+                    # Scale the gradient based on the inverse frequency of a particular index.
+                    # Note : smax mode does not support sparse weights
+                    idx = make_long_input((2, 2), low=0, high=S)
+                    idx[0, 0] = 1
+                    idx[0, 1] = 1
+                    weights = make_input((S, S))
+                    per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+                    yield SampleInput(weights, args=(idx,),
+                                      kwargs={'scale_grad_by_freq': True, 'mode': mode,
+                                              'per_sample_weights': per_sample_weights},)
+
+                    # gradcheck not implemented for sparse tensors.
+                    # Note : max mode does not support sparse weights
+                    idx = make_long_input((6, ), low=0, high=S)
+                    weights = make_input((S, S))
+                    per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+                    yield SampleInput(weights, args=(idx,),
+                                      kwargs={'sparse': True, 'offsets': offsets,
+                                              'mode': mode, 'per_sample_weights': per_sample_weights})
+
+                    idx = make_long_input((6, ), low=0, high=S)
+                    idx[0] = 1  # freq more than 1
+                    idx[1] = 1  # freq more than 1
+                    idx[3] = 0  # padding_idx
+                    weights = make_input((S, S)) * 2
+                    per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
+                    yield SampleInput(weights, args=(idx,),
+                                      kwargs={'sparse': True, 'scale_grad_by_freq': True, 'padding_idx': 0,
+                                              'max_norm': 1., 'offsets': offsets,
+                                              'mode': mode, 'per_sample_weights': per_sample_weights})
+
+
+def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs):
+    def make_input(shape):
+        return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_long_input(shape, *, low, high):
+        return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high)
+
+    # 0-D index tensor
+    idx = make_long_input((), low=0, high=M)
+    yield SampleInput(make_input((M, S)), args=(idx,),)
+
+    # 1-D index tensor
+    idx = make_long_input((S,), low=0, high=M)
+    yield SampleInput(make_input((M, S)), args=(idx,),)
+
+    # 2-D index tensor
+    idx = make_long_input((S, S), low=0, high=M)
+    yield SampleInput(make_input((M, S)), args=(idx,),)
+
+    if not requires_grad:
+        # Following inputs return different gradient from the numerical gradient.
+        # This is expected and relevant tests are present in `test_nn.py`.
+
+        # The gradient vector at `padding_idx` is not updated.
+        idx = make_long_input((2, 2), low=0, high=S)
+        idx[0, 0] = 2
+        idx[1, 1] = 2
+        yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},)
+
+        idx = make_long_input((2, 2), low=0, high=S)
+        idx[0, 0] = 4
+        idx[1, 1] = 4
+        yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},)
+
+        # Due to inplace renorming of weight, the numerical gradient doesn't match the
+        # analytical gradient.
+        idx = make_long_input((2, 2), low=0, high=S)
+        weights = make_input((S, S)) * 2
+        yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1.},)
+
+        idx = make_long_input((2, 2), low=0, high=S)
+        weights = make_input((S, S)) * 2
+        yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1., 'norm_type': 1.0},)
+
+        # Scale the gradient based on the inverse frequency of a particular index.
+        idx = make_long_input((2, 2), low=0, high=S)
+        idx[0, 0] = 1
+        idx[0, 1] = 1
+        weights = make_input((S, S))
+        yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},)
+
+        # gradcheck not implemented for sparse tensors.
+        idx = make_long_input((2, 2), low=0, high=S)
+        weights = make_input((S, S))
+        yield SampleInput(weights, args=(idx,), kwargs={'sparse': True})
+
+        idx = make_long_input((3, 3), low=0, high=S)
+        idx[0, 0] = 1  # freq more than 1
+        idx[0, 1] = 1  # freq more than 1
+        idx[1, 0] = 0  # padding_idx
+        weights = make_input((S, S)) * 2
+        yield SampleInput(weights, args=(idx,),
+                          kwargs={'sparse': True, 'scale_grad_by_freq': True,
+                                  'padding_idx': 0, 'max_norm': 1.})
+
+
+def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs):
+    def make_input(shape, *, low, high):
+        return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad)
+
+    shapes = ((), (S,), (L, M, S))
+    num_classess = (-1, 10)
+
+    return (
+        SampleInput(
+            make_input(
+                shape,
+                low=0,
+                high=10 if num_classes == -1 else num_classes // 2,
+            ),
+            kwargs=dict(num_classes=num_classes),
+        )
+        for shape, num_classes in itertools.product(shapes, num_classess)
+    )
+
+
+def sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs):
+    rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad)
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # Although most losses also support the reduce and size_average combination instead of reduce, the former is
+    # deprecated since 0.4.1 and thus is not tested
+    shapes_and_kwargs = (
+        ((), None),
+        ((S,), dict(reduction="mean")),
+        ((S,), dict(reduction="sum")),
+        ((S,), dict(reduction="none")),
+        ((S, S), None),
+        ((S, S, S), None),
+    )
+
+    for shape, kwargs in shapes_and_kwargs:
+        yield SampleInput(_make_tensor(shape),
+                          args=(_make_tensor(shape, requires_grad=rhs_requires_grad),),
+                          kwargs=kwargs)
+
+def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs):
+    # We get better tests if we change the range of the values to something like [-2,2]
+    # because for grid (second tensor argument) the "useful" range is [-1,1] and this way
+    # you get a better combination of out-of-range and in-range test cases
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad,
+                           low=-2, high=2)
+
+    batch_size = 2
+    num_channels = 3
+    modes = ("bilinear", "nearest")
+    align_cornerss = (False, True)
+    padding_modes = ("zeros", "border", "reflection")
+
+    for dim in (2, 3):
+
+        modes_ = (*modes, "bicubic") if dim == 2 else modes
+
+        for mode, padding_mode, align_corners in itertools.product(modes_, padding_modes, align_cornerss):
+            yield SampleInput(
+                _make_tensor((batch_size, num_channels, *[S] * dim)),
+                _make_tensor((batch_size, *[S] * dim, dim)),
+                mode=mode,
+                padding_mode=padding_mode,
+                align_corners=align_corners,
+            )
+
+def reference_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs):
+
+    batch_size = 2
+    num_channels = 3
+    height = 345
+    width = 456
+    modes = ("bilinear", "nearest", "bicubic")
+    align_cornerss = (False, True)
+    padding_modes = ('zeros', 'border', 'reflection')
+
+    # Create an affine transformation matrix
+    a = torch.deg2rad(torch.tensor(45.0))
+    ca, sa = torch.cos(a), torch.sin(a)  # rotation angles
+    s1, s2 = 1.23, 1.34  # scales
+
+    theta = torch.tensor([[
+        [ca / s1, sa, 0.0],
+        [-sa, ca / s2, 0.0],
+    ]], dtype=dtype, device=device)
+    theta = theta.expand(batch_size, 2, 3).contiguous()
+
+    x = torch.arange(batch_size * num_channels * height * width, device=device)
+    x = x.reshape(batch_size, num_channels, height, width).to(torch.uint8)
+    x = x.to(dtype=dtype)
+    x.requires_grad_(requires_grad)
+
+    for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss):
+        grid = torch.nn.functional.affine_grid(
+            theta, size=(batch_size, num_channels, height, width), align_corners=align_corners
+        )
+        yield SampleInput(
+            x,
+            grid,
+            mode,
+            padding_mode,
+            align_corners,
+        )
+
+def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwargs):
+    # We get better tests if we change the range of the values to something like [-2,2]
+    # because for grid (second tensor argument) the "useful" range is [-1,1] and this way
+    # you get a better combination of out-of-range and in-range test cases
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad,
+                           low=-2, high=2)
+
+    batch_size = 2
+    num_channels = 3
+    modes = (0, 1, 2)
+    align_cornerss = (False, True)
+    padding_modes = (0, 1, 2)
+
+    for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss):
+        yield SampleInput(
+            _make_tensor((batch_size, num_channels, S, L)),
+            _make_tensor((batch_size, M + 3, M, 2)),
+            mode,
+            padding_mode,
+            align_corners,
+        )
+
+def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_target(shape):
+        shape = () if len(shape) == 1 else (shape[0], )
+        t = torch.randint(0, 2, shape, device=device, dtype=torch.long)
+        # Label with -1 or 1
+        t = t * 2 - 1
+        target = t.to(dtype=dtype).detach_().requires_grad_(requires_grad)
+        return target
+
+    shapes = ((S, S), (S,))
+    reductions = ('none', 'mean', 'sum')
+    for s, r in product(shapes, reductions):
+        yield SampleInput(
+            make_input(s),
+            args=(make_input(s), make_target(s)),
+            kwargs=dict(reduction=r, margin=random.uniform(-1, 1))
+        )
+
+def sample_inputs_ctc_loss(op_info, device, dtype, requires_grad, **kwargs):
+    input_length = 50
+    batch = 16
+    num_char = 20
+    target_length = 30
+
+    def make_log_probs(s):
+        t = make_tensor(s, device=device, dtype=dtype)
+        log_probs = t.log_softmax(2).to(device=device, dtype=dtype).detach().requires_grad_(requires_grad=requires_grad)
+        return log_probs
+
+    reductions = ('none', 'mean', 'sum')
+    zero_inf = (True, False)
+    lengths_type = (list, torch.Tensor)
+    for r, z, lt in product(reductions, zero_inf, lengths_type):
+        log_probs = make_log_probs((input_length, batch, num_char))
+        targets = torch.randint(1, num_char, (batch, target_length), dtype=torch.long, device=device)
+        input_lengths = torch.full((batch, ), input_length, dtype=torch.long, device=device)
+        target_lengths = torch.randint(10, target_length, (batch, ), dtype=torch.long, device=device)
+
+        # Dont generate int[] types if reduction = "Mean" since this results in non composite compliant calls
+        # to ctc_loss.IntList since a tensor needs to be created from the target lengths.
+        # Creating such a tensor requires the use of pointers to copy data from int[] -> torch.Tensor
+        # e.g. via std::copy. Similarly symbolic/real tracing with fx will also not work
+        if lt is list and r in ["none", "sum"]:
+            input_lengths = input_lengths.tolist()
+            target_lengths = target_lengths.tolist()
+
+        yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,),
+                          kwargs=dict(reduction=r, zero_infinity=z))
+
+
+def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
+    shape = (2, 3)
+    num_classes = shape[1]
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    # FIXME: Derivative wrt. weight not implemented
+    make_weight = partial(make_tensor, num_classes, device=device, dtype=dtype, requires_grad=False)
+
+    def make_target(shape, zeros=False):
+        s = (shape[0], *shape[2:]) if len(shape) > 1 else ()
+        if zeros:
+            return torch.zeros(s, device=device, dtype=torch.long)
+        else:
+            return make_tensor(s,
+                               low=0,
+                               high=shape[1] if len(shape) > 1 else shape[0],
+                               device=device,
+                               dtype=torch.long)
+
+
+    def gen_shape_kwargs():
+        # Batched, non-batched and 2d
+        shapes = (shape, (num_classes,), shape + (2, 2))
+        reductions = ('none', 'mean', 'sum')
+        for reduction, s in product(reductions, shapes):
+            yield make_input(s), make_target(s), dict(reduction=reduction)
+            yield make_input(s), make_target(s), dict(weight=make_weight(), reduction=reduction)
+            yield make_input(s), make_target(s), dict(weight=make_weight(low=0), reduction=reduction)
+            yield make_input(s), make_target(s), dict(weight=make_weight(high=0), reduction=reduction)
+            t = make_target(s)
+            ignore = num_classes // 2
+            # If "mean", nll returns NaN, so it's not differentiable at those points
+            if t.eq(ignore).all() and reduction == "mean":
+                t.fill_(0)
+            yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction)
+            yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction, weight=make_weight())
+            # Test ignoring all the targets
+            # If "mean", nll returns NaN, so it's not differentiable at those points
+            if reduction != "mean":
+                yield make_input(s), make_target(s, zeros=True), dict(ignore_index=0, reduction=reduction)
+
+    for input, target, kwargs in gen_shape_kwargs():
+        yield SampleInput(input, args=(target,), kwargs=kwargs)
+
+    target = torch.tensor([-1, 2], device=device, dtype=torch.long)
+    yield SampleInput(make_input(shape), args=(target,), kwargs={'ignore_index': -1})
+
+
+def sample_inputs_binary_cross_entropy_with_logits(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    make = partial(make_tensor, device=device, dtype=dtype)
+    make_prob = partial(make, low=0, high=1)
+    reductions = ("mean", "sum", "none")
+
+    def make_weight_shape_kwargs():
+        kwargs = []
+        for shape in ((1,), (1, S), (S), (S, S)):
+            kwargs.extend([((S, S), dict(reduction=reduction, weight=make(shape))) for reduction in reductions])
+        return kwargs
+
+    shapes_and_kwargs = [
+        *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
+        *[((S, S), dict(reduction=reduction)) for reduction in reductions],
+        *make_weight_shape_kwargs(),
+        *[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions],
+        *[((S, S), dict(reduction=reduction, weight=make((S, S)), pos_weight=make((S,), low=0))) for reduction in reductions],
+    ]
+
+    for shape, kwargs in shapes_and_kwargs:
+        yield SampleInput(
+            make(shape, requires_grad=requires_grad),
+            args=(make_prob(shape, requires_grad=requires_grad),),
+            kwargs=kwargs,
+        )
+
+def sample_inputs_argwhere(op_info, device, dtype, requires_grad, **kwargs):
+    yield SampleInput(torch.tensor([1, 0, 2, 0], dtype=dtype, device=device, requires_grad=requires_grad))
+    mask = torch.tensor([[0, 1, 0, 1, 0],
+                         [1, 1, 1, 1, 0],
+                         [0, 0, 0, 1, 0],
+                         [1, 0, 1, 1, 0],
+                         [1, 0, 0, 1, 0]], dtype=torch.bool, device=device)
+    t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad)
+    t[mask] = 0
+    yield SampleInput(t)
+
+    t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True)
+    t[mask] = 0
+    yield SampleInput(t)
+
+    t = make_tensor((S, 0), dtype=dtype, device=device, requires_grad=requires_grad)
+    yield SampleInput(t)
+
+    yield SampleInput(torch.zeros((S,), dtype=dtype, device=device, requires_grad=requires_grad))
+    yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad))
+
+def _generate_sample_shape_reduction():
+    shapes = ((S,), (S, S), (S, S, S))
+    reductions = ('none', 'mean', 'sum')
+    yield from product(shapes, reductions)
+
+def sample_inputs_gaussian_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    # Set low slightly above 0 so gradcheck doesn't accidentally dip below 0
+    make_var = partial(make_tensor, low=0.1, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def gen_shape(shape):
+        yield shape
+        # Broadcast
+        yield (*shape[:-1], 1)
+        yield shape[:-1]
+
+    def gen_shape_kwargs():
+        for s, r in _generate_sample_shape_reduction():
+            for t_s, v_s in product(gen_shape(s), gen_shape(s)):
+                yield _make_tensor(s), _make_tensor(t_s), make_var(v_s), dict(reduction=r)
+                yield (
+                    _make_tensor(s), _make_tensor(t_s), make_var(v_s),
+                    dict(full=True, reduction=r)
+                )
+                yield (
+                    _make_tensor(s), _make_tensor(t_s), make_var(v_s),
+                    dict(eps=random.uniform(1e-6, 1e-3), reduction=r)
+                )
+                yield (
+                    _make_tensor(s), _make_tensor(t_s), make_var(v_s),
+                    dict(full=True, eps=random.uniform(1e-6, 1e-3), reduction=r)
+                )
+
+    for input, target, var, kwargs in gen_shape_kwargs():
+        yield SampleInput(input, args=(target, var, ), kwargs=kwargs)
+
+def error_inputs_gaussian_nll_loss(op_info, device, **kwargs):
+    _make = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # invalid reduction value
+    yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"),
+                     error_type=ValueError, error_regex="abc is not valid")
+
+    # var is of incorrect shape
+    yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)),
+                     error_type=ValueError, error_regex="var is of incorrect size")
+
+    # target is of incorrect shape
+    yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)),
+                     error_type=RuntimeError,
+                     error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) "
+                                  r"at non-singleton dimension 2"))
+
+def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    for s, r in _generate_sample_shape_reduction():
+        yield _make_tensor(s), _make_tensor(s), dict(reduction=r)
+
+def sample_inputs_hinge_embedding_loss(op_info, device, dtype, requires_grad, **kwargs):
+    for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
+        # target should contain either 1 or -1 as per docs
+        mask = torch.rand_like(target) > 0.5
+        target[mask] = 1
+        target[~mask] = -1
+        d['margin'] = random.uniform(-9, 9)
+        yield SampleInput(input, args=(target, ), kwargs=d)
+
+    # scalar input and target.
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield SampleInput(_make_tensor(()), args=(_make_tensor(()), ))
+
+def error_inputs_hinge_embedding_loss(op, device, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=torch.float32)
+    # invalid reduction value
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
+                     error_type=ValueError, error_regex='is not a valid value')
+
+def reference_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs)
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    for reduction in ('sum', 'mean', 'none'):
+        if dtype.is_floating_point:  # only supports ints and floats
+            # NaN propagation
+            inp = make_input((10, ))
+            inp[2] = float('nan')
+            target = make_input((10, ))
+            # target should contain either 1 or -1 as per docs
+            mask = torch.rand_like(target) > 0.5
+            target[mask] = -1
+            target[~mask] = 1
+            yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
+
+            # Inf Handling
+            inp = make_input((10, ))
+            inp[4] = float('inf')
+            target = make_input((10, ))
+            mask = torch.rand_like(target) > 0.5
+            target[mask] = -1
+            target[~mask] = 1
+            yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
+
+        # Broadcasting
+        inp = make_input((5, 5))
+        target = make_input((1, 5))
+        mask = torch.rand_like(target) > 0.5
+        target[mask] = -1
+        target[~mask] = 1
+        yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
+
+def sample_inputs_huber_loss(op_info, device, dtype, requires_grad, **kwargs):
+    for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
+        d['delta'] = random.uniform(1e-3, 9)
+        yield SampleInput(input, args=(target, ), kwargs=d)
+
+def error_inputs_huber_loss(op, device, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=torch.float32)
+    # invalid reduction value
+    err = 'is not a valid value for reduction'
+    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
+                     error_type=ValueError, error_regex=err)
+    # delta <= 0
+    for delta in (0, -1):
+        err = 'huber_loss does not support non-positive values for delta.'
+        yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'delta': delta}),
+                         error_type=RuntimeError, error_regex=err)
+
+def sample_inputs_poisson_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
+    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def gen_shape_kwargs():
+        for s, r in _generate_sample_shape_reduction():
+            for li in (True, False):
+                for f in (True, False):
+                    i1 = _make_tensor(s)
+                    i2 = _make_tensor(s)
+                    # For Poisson NLL Loss,
+                    # target is assumed to be from
+                    # Poisson Distribution which
+                    # always has positive samples
+                    t1 = _make_tensor(s, low=0)
+                    t2 = _make_tensor(s, low=0)
+
+                    if not li:
+                        i1.abs_()
+                        i2.abs_()
+                    t1.abs_()
+                    t2.abs_()
+
+                    yield (
+                        i1, t1,
+                        dict(log_input=li, full=f, reduction=r)
+                    )
+                    yield (
+                        i2, t2,
+                        dict(log_input=li, full=f,
+                             eps=random.uniform(1e-8, 1e-3),
+                             reduction=r)
+                    )
+
+    for input, target, kwargs in gen_shape_kwargs():
+        yield SampleInput(input, args=(target, ), kwargs=kwargs)
+
+    # test INT_TO_FLOAT promotion
+    if dtype.is_complex:
+        for d in (torch.bool, torch.int64):
+            yield SampleInput(_make_tensor(dtype=dtype), args=(_make_tensor(dtype=d),))
+            yield SampleInput(_make_tensor(dtype=d), args=(_make_tensor(dtype=dtype),))
+
+def error_inputs_poisson_nll_loss(op_info, device, **kwargs):
+    make = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # invalid reduction value
+    yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
+                     kwargs={'reduction': 'abc'}),
+                     error_type=ValueError,
+                     error_regex='abc is not a valid value for reduction')
+    # invalid input shapes
+    yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
+                     error_regex=(r'(Attempting to broadcast a dimension of length|'
+                                  r'The size of tensor a \(5\) must match the '
+                                  r'size of tensor b \(4\) at non-singleton '
+                                  r'dimension 1)'))
+
+def error_inputs_soft_margin_loss(op_info, device, **kwargs):
+    make = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # invalid reduction value
+    yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
+                     kwargs={'reduction': 'abc'}),
+                     error_type=ValueError,
+                     error_regex='abc is not a valid value for reduction')
+    # invalid input shapes
+    yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
+                     error_regex=(r'(Attempting to broadcast a dimension of length|'
+                                  r'The size of tensor a \(4\) must match the '
+                                  r'size of tensor b \(5\) at non-singleton '
+                                  r'dimension 1)'))
+
+def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs):
+    make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad)
+
+    kwargss = (
+        *[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)],
+        dict(swap=True),
+        *[dict(reduction=reduction) for reduction in ("mean", "sum", "none")],
+    )
+
+    for kwargs in kwargss:
+        input = make()
+        args = (make(), make())
+        if with_distance:
+            kwargs["distance_function"] = torch.nn.PairwiseDistance()
+        yield SampleInput(input, args=args, kwargs=kwargs)
+
+def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=torch.float32)
+
+    samples = (
+        # input, args, kwargs, error_type, error_regex
+        # invalid reduction
+        (make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
+         dict(reduction="abc"),
+         ValueError, "abc is not a valid value for reduction"),
+
+        # invalid margin
+        (make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
+         dict(margin=-1.0),
+         ValueError, "margin must be greater than 0, got -1.0"),
+
+        # shape mismatch
+        (make_input(3, 5), (make_input(3, 4), make_input(3, 4)),
+         {},
+         RuntimeError,
+         (r'(Attempting to broadcast a dimension of length|'
+          r"The size of tensor a \(5\) must match the size of tensor b \(4\) "
+          r"at non-singleton dimension 1)")),
+        (make_input(3, 4), (make_input(3, 5), make_input(3, 4)),
+         {},
+         RuntimeError,
+         (r'(Attempting to broadcast a dimension of length|'
+          r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
+          r"at non-singleton dimension 1)")),
+        (make_input(3, 4), (make_input(3, 4), make_input(3, 5)),
+         {},
+         RuntimeError,
+         (r'(Attempting to broadcast a dimension of length|'
+          r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
+          r"at non-singleton dimension 1)")),
+
+        # different dimensions
+        (make_input(3,), (make_input(3, 4), make_input(3, 4)),
+         {},
+         RuntimeError,
+         (r"The anchor, positive, and negative tensors are expected to have "
+          r"the same number of dimensions, but got: anchor 1D, positive 2D, "
+          r"and negative 2D inputs")),
+        (make_input(3, 4), (make_input(3,), make_input(3, 4)),
+         {},
+         RuntimeError,
+         (r"The anchor, positive, and negative tensors are expected to have "
+          r"the same number of dimensions, but got: anchor 2D, positive 1D, "
+          r"and negative 2D inputs")),
+        (make_input(3, 4), (make_input(3, 4), make_input(3,)),
+         {},
+         RuntimeError,
+         (r"The anchor, positive, and negative tensors are expected to have "
+          r"the same number of dimensions, but got: anchor 2D, positive 2D, "
+          r"and negative 1D inputs")),
+    )
+
+    for input, args, kwargs, error_type, error_regex in samples:
+        yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
+                         error_type=error_type, error_regex=error_regex)
+
+def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
+    make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad)
+    make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad)
+    make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
+    M, N, K = 15, 32, 16
+    samples = []
+    # two e4m3
+    mat1 = make_mat_e4m3((M, K))
+    mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
+    scale1 = make_scale((1,))
+    scale2 = make_scale((1,))
+    samples.append(SampleInput(mat1, mat2, scale1, scale2))
+    # mat1 e4m3 mat2 e5m2
+    mat1 = make_mat_e4m3((M, K))
+    mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
+    scale1 = make_scale((1,))
+    scale2 = make_scale((1,))
+    samples.append(SampleInput(mat1, mat2, scale1, scale2))
+    # mat1 e5m2 mat2 e4m3
+    mat1 = make_mat_e5m2((M, K))
+    mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
+    scale1 = make_scale((1,))
+    scale2 = make_scale((1,))
+    samples.append(SampleInput(mat1, mat2, scale1, scale2))
+
+    yield from samples
+
+def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
+    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
+
+    dim_3_q_shape = (batch, seq_q, head_dim)
+    dim_3_kv_shape = (batch, seq_kv, head_dim)
+    dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
+    dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
+
+    broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim))
+
+    qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
+    samples = []
+    gqa_options = [True, False]
+    causal_options = [True, False]
+    for qkv_shape, is_causal, dropout_p, _enable_gqa in product(
+            qkv_shapes, causal_options, [0.0, 0.5], gqa_options):
+        shape_q, shape_kv = qkv_shape
+        samples.append(SampleInput(
+            make(shape_q),
+            make(shape_kv),
+            make(shape_kv),
+            is_causal=is_causal,
+            dropout_p=dropout_p
+        ))
+
+    # Add non standard shapes
+    # FIXME(rec): should diff_v_head_dim be appended to samples?
+    diff_v_head_dim = SampleInput(  # noqa: F841
+        make((batch, num_heads, seq_q, head_dim)),
+        make((batch, num_heads, seq_kv, head_dim)),
+        make((batch, num_heads, seq_kv, head_dim + 8)),
+        is_causal=is_causal,
+        dropout_p=dropout_p
+    )
+
+    # Add an attn_mask
+    samples.append(
+        SampleInput(
+            make((batch, num_heads, seq_q, head_dim)),
+            make((batch, num_heads, seq_kv, head_dim)),
+            make((batch, num_heads, seq_kv, head_dim)),
+            attn_mask=make((seq_q, seq_kv)),
+            is_causal=False,
+            dropout_p=0.0)
+    )
+
+    yield from samples
+
+
+def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
+    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    batch, num_heads, head_dim = 4, 4, 8
+    seq_q = 11
+    seq_kv = 32
+
+    dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
+    dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
+
+    qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
+    samples = []
+    mask_types = [1, 2]  # UpperLeft, LowerRight
+    scales = [None, 1.0]
+
+    for qkv_shape, _is_causal, dropout_p, mask_type, scale in product(
+            qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales):
+        shape_q, shape_kv = qkv_shape
+        samples.append(SampleInput(
+            make(shape_q).transpose(1, 2),
+            make(shape_kv).transpose(1, 2),
+            make(shape_kv).transpose(1, 2),
+            bias=None,
+            cu_seqlens_q=None,
+            cu_seqlens_k=None,
+            max_seqlen_q=None,
+            max_seqlen_k=None,
+            dropout_p=dropout_p,
+            custom_mask_type=mask_type,
+            compute_log_sumexp=requires_grad,
+            scale=scale,
+            seqlen_k=None
+        ))
+
+    # Add non standard shapes
+    # FIXME(rec): should diff_v_head_dim be appended to samples?
+    diff_v_head_dim = SampleInput(  # noqa: F841
+        make((batch, seq_q, num_heads, head_dim)),
+        make((batch, seq_kv, num_heads, head_dim)),
+        make((batch, seq_kv, num_heads, head_dim + 8)),
+        bias=None,
+        cu_seqlens_q=None,
+        cu_seqlens_k=None,
+        max_seqlen_q=None,
+        max_seqlen_k=None,
+        dropout_p=dropout_p,
+        custom_mask_type=0,  # No Mask
+        compute_log_sumexp=requires_grad,
+        scale=None,
+        seqlen_k=None
+    )
+
+    # Add an attn_mask
+    samples.append(
+        SampleInput(
+            make((batch, seq_q, num_heads, head_dim)),
+            make((batch, seq_kv, num_heads, head_dim)),
+            make((batch, seq_kv, num_heads, head_dim)),
+            bias=make(batch, num_heads, seq_q, seq_kv),
+            cu_seqlens_q=None,
+            cu_seqlens_k=None,
+            max_seqlen_q=None,
+            max_seqlen_k=None,
+            dropout_p=dropout_p,
+            custom_mask_type=0,  # No Mask
+            compute_log_sumexp=requires_grad,
+            scale=None,
+            seqlen_k=None
+        )
+    )
+
+    # jagged (with query/keys offsets)
+    cu_seqlens_k = torch.arange(-1, 32 * 2 + 1, 2, dtype=torch.int32, device=device)
+    cu_seqlens_k[-1] = 62
+    cu_seqlens_k[0] = 0
+    samples.append(
+        SampleInput(
+            make((32, 2, 64)).view(-1, 8, 8).unsqueeze(0),
+            make((64, 64)).view(-1, 8, 8).unsqueeze(0),
+            make((64, 64)).view(-1, 8, 8).unsqueeze(0),
+            bias=None,
+            cu_seqlens_q=torch.arange(0, 32 * 2 + 2, 2, dtype=torch.int32, device=device),
+            cu_seqlens_k=cu_seqlens_k,
+            max_seqlen_q=2,
+            max_seqlen_k=2,
+            dropout_p=0.0,
+            custom_mask_type=0,  # No Mask
+            compute_log_sumexp=requires_grad,
+            scale=None,
+            seqlen_k=None,
+        )
+    )
+
+    yield from samples
+
+def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
+    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    batch, num_heads, head_dim = 4, 4, 8
+    seq_q = 11
+    seq_kv = 32
+
+    dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
+    dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
+
+    qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
+    samples = []
+    scales = [None, 1.0]
+
+    for qkv_shape, is_causal, dropout_p, scale in product(
+            qkv_shapes, [True, False], [0.0, 0.5], scales):
+        shape_q, shape_kv = qkv_shape
+        samples.append(SampleInput(
+            make(shape_q).transpose(1, 2),
+            make(shape_kv).transpose(1, 2),
+            make(shape_kv).transpose(1, 2),
+            cum_seq_q=None,
+            cum_seq_k=None,
+            max_q=seq_q,
+            max_k=seq_kv,
+            dropout_p=dropout_p,
+            is_causal=is_causal,
+            return_debug_mask=False,
+            scale=scale,
+        ))
+
+    yield from samples
+
+def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
+    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    shape = (3,)
+    batched_shape = (2, *shape)
+    shapes_and_kwargs = [
+        (shape, None),
+        (batched_shape, None),
+        (shape, dict(keepdim=True)),
+        (batched_shape, dict(keepdim=True)),
+        (shape, dict(p=5.0)),
+        (shape, dict(p=-1.0)),
+        (shape, dict(eps=1.0)),
+    ]
+
+    return (
+        SampleInput(make(shape), args=(make(shape),), kwargs=kwargs) for shape, kwargs in shapes_and_kwargs
+    )
+
+def sample_inputs_pixel_shuffle(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield from (
+        SampleInput(make_arg((1, 9, 2, 2)), upscale_factor=upscale_factor)
+        for upscale_factor in (1, 3)
+    )
+    yield from (
+        SampleInput(make_arg(shape), upscale_factor=1)
+        for shape in [
+            (1, 0, 1, 1),
+            (1, 1, 0, 1),
+            (1, 1, 1, 0),
+        ]
+    )
+
+def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    yield from (
+        SampleInput(make_arg((1, 1, 6, 6)), downscale_factor=downscale_factor)
+        for downscale_factor in (1, 3)
+    )
+    yield from (
+        SampleInput(make_arg(shape), downscale_factor=1)
+        for shape in [
+            (1, 0, 1, 1),
+            (1, 1, 0, 1),
+            (1, 1, 1, 0),
+        ]
+    )
+
+def sample_inputs_channel_shuffle(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    shapes_groups = [
+        ((1, 4, 10, 10), 2),
+        ((2, 6, 8, 8), 3),
+        ((2, 8, 5, 5), 4),
+    ]
+
+    yield from (
+        SampleInput(make_arg(shape), args=(groups,))
+        for shape, groups in shapes_groups
+    )
+
+def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs):
+    make = partial(make_tensor, device=device, dtype=dtype)
+    # Lower bounds must be greater than 'eps' defined in gradcheck.py::gradgradcheck() -> eps
+    # otherwise perturbation calculation causes Tensor value to become negative triggering
+    # a device-side hardware assertion
+    make_prob = partial(make, low=1e-6, high=1)
+
+    reductions = ("mean", "sum", "none")
+
+    shapes_and_kwargs = [
+        *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
+        *[((S, S), dict(reduction=reduction)) for reduction in reductions],
+        *[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions],
+    ]
+
+    if logits:
+        shapes_and_kwargs.extend(
+            [((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions]
+        )
+
+    for shape, kwargs in shapes_and_kwargs:
+        yield SampleInput(
+            (make if logits else make_prob)(shape, requires_grad=requires_grad),
+            args=(make_prob(shape, requires_grad=requires_grad),),
+            kwargs=kwargs,
+        )
+
+def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs):
+    sample_shapes = [(), (S), (S, S, S)]
+    atols = [1e-2, 1e-16]
+    rtols = [1e-1, 0.5]
+    for s, rtol, atol in product(sample_shapes, rtols, atols):
+        # close sample
+        t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
+        close = (t + atol).detach().requires_grad_(requires_grad)
+        yield SampleInput(t, close, rtol=rtol, atol=atol)
+
+        # random sample
+        a = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
+        b = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
+        yield SampleInput(a, b, rtol=rtol, atol=atol)
+
+
+def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs)
+
+    # test COMPLEX_TO_FLOAT promotion
+    if dtype.is_complex:
+        make = partial(make_tensor, (), device=device, requires_grad=requires_grad)
+        yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),))
+        yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),))
+
+def error_inputs_l1_loss(op_info, device, **kwargs):
+    make = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # invalid reduction value
+    yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
+                     kwargs={'reduction': 'abc'}),
+                     error_type=ValueError,
+                     error_regex='abc is not a valid value for reduction')
+    # invalid input shapes
+    yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
+                     error_regex=(r'(Attempting to broadcast a dimension of length|'
+                                  r'The size of tensor a \(4\) must match the '
+                                  r'size of tensor b \(5\) at non-singleton '
+                                  r'dimension 1)')
+                     )
+
+def sample_inputs_smooth_l1_loss(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs)
+
+    make = partial(make_tensor, (S, S), device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # This test case always triggers the smooth condition, since absolute difference of input and target
+    # is smaller than beta
+    yield SampleInput(make(low=0, high=2), args=(make(low=-2, high=0),), kwargs=dict(beta=5))
+    yield SampleInput(make(), args=(make(),), kwargs=dict(beta=0))
+
+def sample_inputs_kl_div(op_info, device, dtype, requires_grad, **kwargs):
+    # kl_div works with inputs in [0, 1] (aka the pdf of a probability measure)
+    # Then log [0, 1] = (-inf, 0], so this is the log space
+    make_arg = partial(make_tensor, low=0., device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_log(shape):
+        out = torch.nn.functional.log_softmax(make_arg(shape), -1)
+        out.requires_grad_(requires_grad)
+        return out
+
+    def make_prob(shape):
+        out = torch.nn.functional.softmax(make_arg(shape), -1)
+        out.requires_grad_(requires_grad)
+        return out
+
+    shapes = ((2,), (2, 3))
+    reductions = ("none", "mean", "batchmean", "sum")
+    for shape, reduction, log_target in product(shapes, reductions, (True, False)):
+        input = make_log(shape)
+        target = make_log(shape) if log_target else make_prob(shape)
+        yield SampleInput(input, args=(target,), kwargs=dict(reduction=reduction, log_target=log_target))
+
+def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield from (SampleInput(make_input((n, m))) for n, m in itertools.product((1, S), repeat=2))
+    yield from (SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf")))
+
+def reference_pdist(input, p=2):
+    pdist = scipy.spatial.distance.pdist
+    if p == 0:
+        output = pdist(input, "hamming") * input.shape[1]
+    elif p == float("inf"):
+        output = pdist(input, lambda x, y: np.abs(x - y).max())
+    else:
+        output = pdist(input, "minkowski", p=p)
+    return output.astype(input.dtype)
+
+def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(make_input(()))
+    yield SampleInput(make_input((2,)))
+    yield SampleInput(make_input((2, 2)))
+    yield SampleInput(make_input((2,)), offset=1)
+    yield SampleInput(make_input((2,)), offset=-1)
+
+def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
+    unpool_name_to_pool_method_dict = {
+        'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d,
+        'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d,
+        'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d
+    }
+
+    unpool_name_to_dim = {
+        'nn.functional.max_unpool1d': 1,
+        'nn.functional.max_unpool2d': 2,
+        'nn.functional.max_unpool3d': 3
+    }
+
+    unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()}
+
+    pool_dim = unpool_name_to_dim[op_info.name]
+    pool_method = unpool_name_to_pool_method_dict[op_info.name]
+
+    pool_op_info = copy.copy(op_info)
+    pool_op_info.name = unpool_to_pool_name_dict[op_info.name]
+
+    for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs):
+        # shapes (C, ...) do not work as of now,
+        # see https://github.com/pytorch/pytorch/issues/68337
+        # TODO: remove once the issue is resolved
+        if sample.input.dim() != pool_dim + 2:
+            continue
+
+        # No dilation > 1 for max_unpool,
+        # see https://github.com/pytorch/pytorch/issues/68420
+        if sample.kwargs['dilation'] != 1:
+            continue
+
+        # Can't unpool without indices
+        if sample.kwargs['return_indices']:
+            pool, indices = pool_method(sample.input, **sample.kwargs)
+            # arg has to be a leaf
+            arg = pool.detach().requires_grad_(requires_grad)
+            sample_kwargs = {
+                'kernel_size': sample.kwargs['kernel_size'],
+                'stride': sample.kwargs['stride'],
+                'padding': sample.kwargs['padding'],
+                # output_size could be None but we specify it explicitly
+                # to compensate for the information lose in pool due
+                # to the floor/ceil operation used to compute the shapes
+                'output_size': sample.input.size()
+            }
+
+            yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs)
+
+def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs):
+    for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
+        indices = sample.args[0]
+        # The samples for max_unpool are generated with max_pool.
+        # It could be that a single element from the max_pool's
+        # input is mapped to several locations in its output.
+        # This situation leads to failed gradchecks because
+        # the finite difference algorithm perturbs the elements
+        # of the output one by one, and not in classes of
+        # equivalences determined by whether two elements
+        # in the output are coming from the same location in the
+        # input (simply put, they have the same corresponding index).
+        # So, there are two ways to resolve this issue:
+        # 1. Extract a perturbation for one element and apply it all
+        #    the elements from the same equivalence class, or
+        # 2. Make sure that the equivalence classes are all singletons,
+        # i.e. the index tensor has to be comprised of only unique
+        # indices.
+        # Here we go with the solution 2, the easiest of all.
+        if indices.unique().numel() == indices.numel():
+            yield sample
+
+def sample_inputs_multi_head_attention_forward(opinfo, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    if requires_grad:
+        # backward tests would take too long to complete, causing the job timeout.
+        bsz = 2
+        is_batcheds = (True,)
+        use_separate_proj_weights = (False,)
+        emb_sizes = (2,)
+        src_lens = (XS,)
+        tgt_lens = (XS,)
+        heads = (2,)
+        dropouts = (0.5,)
+        mask_types = ("2d",)
+    else:
+        bsz = 2
+        is_batcheds = (False, True)
+        use_separate_proj_weights = (False, True)
+        emb_sizes = (2, 4)
+        src_lens = (XS,)
+        tgt_lens = (XS, S)
+        heads = (1, 2)
+        dropouts = (0.0, 0.5)
+        mask_types = (None, "2d", "3d")
+
+    for is_batched, use_separate_proj_weight, mask_type, emb_size, src_len, tgt_len, num_heads, dropout_p in itertools.product(
+        is_batcheds, use_separate_proj_weights, mask_types, emb_sizes, src_lens, tgt_lens, heads, dropouts
+    ):
+        attn_mask = None
+        if mask_type == "2d":
+            attn_mask = make_input(src_len, tgt_len)
+        elif mask_type == "3d":
+            attn_mask = make_input((bsz if is_batched else 1) * num_heads, src_len, tgt_len)
+
+        if is_batched:
+            q = make_input(src_len, bsz, emb_size)
+            k = make_input(tgt_len, bsz, emb_size)
+            v = make_input(tgt_len, bsz, emb_size)
+        else:
+            q = make_input(src_len, emb_size)
+            k = make_input(tgt_len, emb_size)
+            v = make_input(tgt_len, emb_size)
+        if use_separate_proj_weight:
+            in_proj_weight = None
+            q_proj_weight = make_input(emb_size, emb_size)
+            k_proj_weight = make_input(emb_size, emb_size)
+            v_proj_weight = make_input(emb_size, emb_size)
+        else:
+            in_proj_weight = make_input(emb_size * 3, emb_size)
+            q_proj_weight = None
+            k_proj_weight = None
+            v_proj_weight = None
+
+        bias_k = make_input(emb_size)
+        bias_v = make_input(emb_size)
+        in_proj_bias = make_input(emb_size * 3)
+        out_proj_weight = make_input(emb_size, emb_size)
+        out_proj_bias = make_input(emb_size)
+        sample_args = (
+            k, v, emb_size, num_heads, in_proj_weight,
+            in_proj_bias, bias_k, bias_v, False,
+            dropout_p, out_proj_weight, out_proj_bias
+        )
+        sample_kwargs = {
+            "q_proj_weight" : q_proj_weight,
+            "k_proj_weight" : k_proj_weight,
+            "v_proj_weight" : v_proj_weight,
+            "attn_mask" : attn_mask,
+            "training" : True if dropout_p > 0.0 else False,
+            "use_separate_proj_weight" : use_separate_proj_weight
+        }
+
+        yield SampleInput(q, args=sample_args, kwargs=sample_kwargs)
+
+
+# Includes some values such that N * N won't be a multiple of 4,
+# which should ensure we test the vectorized and non-vectorized
+# kernel code paths.
+NUM_SIZE0_TENSORS = 10000
+foreach_num_tensors = [20, 23] if not TEST_WITH_SLOW else [23, 30, 300]
+_foreach_inputs_default_kwargs = {"noncontiguous": False, "same_size": False, "low": None, "high": None}
+
+
+class ForeachRightmostArgType(enum.Enum):
+    TensorList = enum.auto()
+    ScalarList = enum.auto()
+    Scalar = enum.auto()
+    Tensor = enum.auto()
+
+
+class ForeachSampleInput(SampleInput):
+    # For TensorList  Scalar/Tensor, we compute the reference
+    # by converting it into TensorList  ScalarList/TensorList and
+    # then converting into multiple Tensor  Scalar/Tensor.
+    # ref_args contains the args converted to TensorList  ScalarList/TensorList
+    ref_args: Any
+    disable_fastpath: bool
+
+    def __init__(self, *args, disable_fastpath=False, ref_args=None, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.ref_args = ref_args or self.args
+        self.disable_fastpath = disable_fastpath
+
+
+class foreach_inputs_sample_func:
+    def __init__(
+        self,
+        arity: int,
+        rightmost_supports_scalar: bool,
+        rightmost_supports_scalarlist: bool,
+        rightmost_supports_tensor: bool = False,
+    ) -> None:
+        self.arity = arity
+        self._set_rightmost_arg_types(
+            rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor,
+        )
+        self._intersperse_empty = (True, False)
+
+    def _set_rightmost_arg_types(
+        self,
+        rightmost_supports_scalar: bool,
+        rightmost_supports_scalarlist: bool,
+        rightmost_supports_tensor: bool,
+    ) -> None:
+        self._rightmost_arg_types = [ForeachRightmostArgType.TensorList]
+        if self.arity > 1:
+            if rightmost_supports_scalar:
+                self._rightmost_arg_types.append(ForeachRightmostArgType.Scalar)
+            if rightmost_supports_scalarlist:
+                self._rightmost_arg_types.append(ForeachRightmostArgType.ScalarList)
+            if rightmost_supports_tensor:
+                self._rightmost_arg_types.append(ForeachRightmostArgType.Tensor)
+
+    def _sample_rightmost_arg(
+        self,
+        opinfo,
+        rightmost_arg_type,
+        device,
+        dtype,
+        num_tensors,
+        allow_higher_dtype_scalars,
+        **_foreach_inputs_kwargs,
+    ):
+        if rightmost_arg_type == ForeachRightmostArgType.TensorList:
+            return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)]
+        if rightmost_arg_type == ForeachRightmostArgType.Tensor:
+            return [make_tensor(
+                (), device=device, dtype=dtype,
+                noncontiguous=_foreach_inputs_kwargs["noncontiguous"],
+                requires_grad=_foreach_inputs_kwargs.get("requires_grad", False),
+            )]
+        should_use_simpler_scalars = opinfo.name == "_foreach_pow" and dtype in (torch.float16, torch.bfloat16)
+
+        def sample_float():
+            s = random.random()
+            if should_use_simpler_scalars:
+                return 1.0 if s > 0.5 else 2.0
+            else:
+                return 1.0 - s
+
+        high = 2 if should_use_simpler_scalars else 9
+        if rightmost_arg_type == ForeachRightmostArgType.ScalarList:
+            scalarlist_list = []
+            scalarlist_list.append([random.randint(0, high) + 1 for _ in range(num_tensors)])
+
+            if allow_higher_dtype_scalars or dtype.is_floating_point:
+                scalarlist_list.append([sample_float() for _ in range(num_tensors)])
+            if allow_higher_dtype_scalars or dtype.is_complex:
+                scalarlist_list.append([complex(sample_float(), sample_float()) for _ in range(num_tensors)])
+                scalarlist_list.append([1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)])
+                scalarlist_list.append([True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)])
+            return scalarlist_list
+        if rightmost_arg_type == ForeachRightmostArgType.Scalar:
+            scalars = []
+            scalars.append(random.randint(1, high + 1))
+            if allow_higher_dtype_scalars or dtype.is_floating_point:
+                scalars.append(sample_float())
+            if allow_higher_dtype_scalars or dtype.is_complex:
+                scalars.append(complex(sample_float(), sample_float()))
+            scalars.append(True)
+            return scalars
+        raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
+
+    def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
+        if self.arity == 1:
+            if "foreach_abs" in opinfo.name and dtype in complex_types():
+                return True
+            # unary
+            if opinfo.ref in (torch.abs, torch.neg):
+                return False
+            if opinfo.ref_inplace in (torch.Tensor.zero_,):
+                return False
+            return dtype in integral_types_and(torch.bool)
+        if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor:
+            return None
+        if "foreach_pow" in opinfo.name and dtype in integral_types_and(torch.bool):
+            return True
+        if any(
+                foreach_name in opinfo.name
+                for foreach_name in ("foreach_clamp_max", "foreach_clamp_min", "foreach_maximum", "foreach_minimum")
+        ) and dtype in integral_types_and(torch.bool):
+            return True
+        if rightmost_arg_type == ForeachRightmostArgType.TensorList:
+            disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool)
+            if "foreach_add" in opinfo.name and dtype == torch.bool:
+                disable_fastpath = True
+            return disable_fastpath
+        elif rightmost_arg_type == ForeachRightmostArgType.Scalar:
+            disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool)
+            if isinstance(rightmost_arg, bool):
+                disable_fastpath |= dtype == torch.bool
+                if opinfo.ref in (torch.add, torch.mul):
+                    disable_fastpath = False
+            elif isinstance(rightmost_arg, int):
+                disable_fastpath |= dtype == torch.bool
+            elif isinstance(rightmost_arg, float):
+                disable_fastpath |= dtype in integral_types_and(torch.bool)
+            elif isinstance(rightmost_arg, complex):
+                disable_fastpath |= dtype not in complex_types()
+            else:
+                raise AssertionError(f"Invalid scalar of type {rightmost_arg_type} - {rightmost_arg}")
+            return disable_fastpath
+        elif rightmost_arg_type == ForeachRightmostArgType.ScalarList:
+            disable_fastpath = opinfo.ref == torch.div and dtype in integral_types_and(torch.bool)
+            elmt_t = type(rightmost_arg[0])
+            has_same_type = all(isinstance(v, elmt_t) for v in rightmost_arg)
+            if not has_same_type:
+                return dtype not in complex_types()
+            if isinstance(rightmost_arg[0], bool):
+                if ("foreach_add" in opinfo.name or "foreach_mul" in opinfo.name) and dtype == torch.bool:
+                    disable_fastpath = False
+            elif isinstance(rightmost_arg[0], int):
+                disable_fastpath |= dtype == torch.bool
+            elif isinstance(rightmost_arg[0], float):
+                disable_fastpath |= dtype in integral_types_and(torch.bool)
+            elif isinstance(rightmost_arg[0], complex):
+                disable_fastpath |= dtype not in complex_types()
+            else:
+                raise AssertionError(f"Invalid scalarlist of {rightmost_arg}")
+            return disable_fastpath
+        else:
+            raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
+
+    def _sample_kwargs(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
+        kwargs = {}
+        if rightmost_arg_type == ForeachRightmostArgType.TensorList and opinfo.supports_alpha_param:
+            if dtype in integral_types_and(torch.bool):
+                kwargs["alpha"] = 3
+            elif dtype.is_complex:
+                kwargs["alpha"] = complex(3, 3)
+            else:
+                kwargs["alpha"] = 3.14
+        if self.arity > 1:
+            kwargs["disable_fastpath"] = self._should_disable_fastpath(opinfo, rightmost_arg, rightmost_arg_type, dtype)
+        return kwargs
+
+    def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
+        assert "num_input_tensors" not in kwargs
+        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
+        _foreach_inputs_kwargs["requires_grad"] = requires_grad
+        allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False)
+        for _rightmost_arg_type in self._rightmost_arg_types:
+            zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs)
+            zero_size_foreach_inputs_kwargs["zero_size"] = True
+            input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs)
+            if self.arity > 1:
+                args = [
+                    sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs)
+                    for _ in range(self.arity - 2)
+                ]
+                args.append(
+                    self._sample_rightmost_arg(
+                        opinfo,
+                        ForeachRightmostArgType.TensorList,
+                        device,
+                        dtype,
+                        NUM_SIZE0_TENSORS,
+                        allow_higher_dtype_scalars=allow_higher_dtype_scalars,
+                        **zero_size_foreach_inputs_kwargs,
+                    )[0])
+                kwargs = self._sample_kwargs(
+                    opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype)
+            else:
+                args = []
+                kwargs = {}
+                if opinfo.ref in (torch.abs, torch.neg):
+                    kwargs["disable_fastpath"] = False
+                else:
+                    kwargs["disable_fastpath"] = dtype in integral_types_and(torch.bool)
+            yield ForeachSampleInput(input, *args, **kwargs)
+
+    def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
+        num_input_tensors_specified = "num_input_tensors" in kwargs
+        num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors
+        assert isinstance(num_input_tensors, list)
+        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
+        _foreach_inputs_kwargs["requires_grad"] = requires_grad
+        _foreach_inputs_kwargs["zero_size"] = False
+        allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False)
+
+        # add empty tensor interspersion to test fully fixing #100701
+        for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product(
+                num_input_tensors, self._rightmost_arg_types, self._intersperse_empty):
+            if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'):
+                # generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy
+                continue
+            _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors
+            input = sample_inputs_foreach(
+                None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
+            args = []
+            if self.arity > 1:
+                args = [
+                    sample_inputs_foreach(
+                        None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
+                    for _ in range(self.arity - 2)
+                ]
+                rightmost_arg_list = self._sample_rightmost_arg(
+                    opinfo, rightmost_arg_type, device, dtype, num_tensors, allow_higher_dtype_scalars,
+                    **_foreach_inputs_kwargs)
+                for rightmost_arg in rightmost_arg_list:
+                    args.append(rightmost_arg)
+                    kwargs = self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)
+                    ref_args = args
+                    if rightmost_arg_type in (ForeachRightmostArgType.Scalar, ForeachRightmostArgType.Tensor):
+                        ref_args = args[:-1] + [[args[-1] for _ in range(num_tensors)]]
+                    sample = ForeachSampleInput(input, *args, ref_args=ref_args, **kwargs)
+                    yield sample
+                    args.pop()
+            else:
+                yield ForeachSampleInput(
+                    input,
+                    *args,
+                    disable_fastpath=self._should_disable_fastpath(opinfo, None, None, dtype),
+                )
+
+
+class foreach_max_sample_func(foreach_inputs_sample_func):
+    def __init__(
+        self,
+        arity: int,
+        rightmost_supports_scalar: bool,
+        rightmost_supports_scalarlist: bool,
+        rightmost_supports_tensor: bool = False,
+    ) -> None:
+        super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor)
+        self._intersperse_empty = (False,)
+
+    def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
+        return []
+
+    def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
+        return False
+
+
+class foreach_norm_sample_func(foreach_inputs_sample_func):
+    def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
+        assert "num_input_tensors" not in kwargs
+        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
+        _foreach_inputs_kwargs["requires_grad"] = requires_grad
+        for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')):
+            input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
+            disable_fastpath = True
+            if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
+                disable_fastpath = False
+            yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
+
+    def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
+        num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors)
+        assert isinstance(num_input_tensors, list)
+        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
+        _foreach_inputs_kwargs["requires_grad"] = requires_grad
+        _allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False)
+
+        for num_tensors, ord, out_dtype, intersperse_empty_tensors in product(
+            num_input_tensors,
+            (0, 1, 2, -1, -2, float('inf'), float('-inf')),
+            (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,),
+            (True, False),
+        ):
+            # inf norm and negative norms on empty tensors is not supported by our reference func vector norm:
+            # linalg.vector_norm cannot compute the inf norm on an empty tensor because the operation does not have an identity
+            if (ord in [float('inf'), float('-inf')] or ord < 0) and intersperse_empty_tensors:
+                continue
+
+            _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors
+            input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
+            disable_fastpath = True
+            if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
+                disable_fastpath = False
+            yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype)
+
+        # Also test nan propagation with a single tensor, but skip autograd testing
+        if not requires_grad:
+            nan_inputs = [
+                [float('nan')],
+                [float('nan'), 1.0],
+                [1.0, float('nan')],
+                [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0],
+                [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0],
+                [3.0, float('nan'), float('nan'), -1.5, 6.0],
+            ]
+            for input in nan_inputs:
+                x = torch.tensor(input, device=device)
+                disable_fastpath = True
+                if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
+                    disable_fastpath = False
+                yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath)
+
+
+class foreach_pointwise_sample_func(foreach_inputs_sample_func):
+
+    def __init__(
+        self,
+        arity: int = 3,
+        rightmost_supports_scalar: bool = False,
+        rightmost_supports_scalarlist: bool = False,
+    ):
+        super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist)
+
+    def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
+        return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,)
+
+    def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
+        assert "num_input_tensors" not in kwargs
+        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
+        _foreach_inputs_kwargs["requires_grad"] = requires_grad
+        # zero_size tensor
+        input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
+        args = [
+            sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
+            for _ in range(2)
+        ]
+        if "scalars" in kwargs:
+            del kwargs["scalars"]
+        kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype))
+        yield ForeachSampleInput(input, *args, **kwargs)
+
+    def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
+        num_input_tensors_specified = "num_input_tensors" in kwargs
+        num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors
+        assert isinstance(num_input_tensors, list)
+        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
+        _foreach_inputs_kwargs["requires_grad"] = requires_grad
+        allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False)
+
+        for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product(
+                num_input_tensors, self._rightmost_arg_types, (True, False)):
+            _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors
+            input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
+            args = [
+                sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
+                for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList))
+            ]
+            rightmost_arg_list = self._sample_rightmost_arg(
+                opinfo,
+                rightmost_arg_type,
+                device,
+                dtype,
+                num_tensors,
+                zero_size=False,
+                allow_higher_dtype_scalars=False if intersperse_empty_tensors else allow_higher_dtype_scalars,
+                **_foreach_inputs_kwargs,
+            )
+            for rightmost_arg in rightmost_arg_list:
+                kwargs = {}
+                if rightmost_arg_type == ForeachRightmostArgType.TensorList:
+                    args.append(rightmost_arg)
+                elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]:
+                    kwargs["scalars"] = rightmost_arg
+                else:
+                    kwargs["value"] = rightmost_arg
+                kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype))
+                assert len(args) == 2, f"{len(args)=}"
+                sample = ForeachSampleInput(input, *args, **kwargs)
+                yield sample
+                if rightmost_arg_type == ForeachRightmostArgType.TensorList:
+                    args.pop()
+
+
+foreach_unary_op_db: list[OpInfo] = [
+    ForeachFuncInfo(
+        'exp',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32),
+        backward_requires_result=True,
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'acos',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'asin',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'atan',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'cos',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'cosh',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'log',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'log10',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'log2',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'tan',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        backward_requires_result=True,
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            # due to https://github.com/pytorch/pytorch/pull/102427 enabling jiterator for complex
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.complex64: tol(atol=3e-04, rtol=2e-05)
+                    }
+                ),
+                'TestForeach',
+                'test_parity',
+                device_type='cuda'
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'tanh',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        backward_requires_result=True,
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {torch.complex64: tol(atol=5e-03, rtol=1e-04)}
+                ),
+                'TestForeach',
+                'test_parity',
+                device_type='cuda'
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'sin',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'sinh',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'neg',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_unary_op_tensors_on_different_devices",
+                device_type="cuda",
+                dtypes=(torch.bool,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'sqrt',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        backward_requires_result=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'rsqrt',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        backward_requires_result=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'ceil',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'erf',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'erfc',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'expm1',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        backward_requires_result=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'floor',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'log1p',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'round',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'frac',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=integral_types_and(torch.bool) + complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'reciprocal',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        backward_requires_result=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'sigmoid',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        backward_requires_result=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'trunc',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=complex_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'abs',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                device_type="cpu",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                device_type="cpu",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                device_type="cpu",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                device_type="cpu",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                device_type="cpu",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                device_type="cpu",
+                dtypes=(torch.bool,),
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=complex_types()),
+        ),
+    ),
+    ForeachFuncInfo(
+        'zero',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        supports_out=False,
+    ),
+    ForeachFuncInfo(
+        'sign',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        'lgamma',
+        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+                         "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool)),
+            # DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+            #              "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
+                         "test_meta_inplace", dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=complex_types() + integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=complex_types() + integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=complex_types() + integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=complex_types(),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+        ),
+    ),
+]
+
+foreach_binary_op_db: list[OpInfo] = [
+    ForeachFuncInfo(
+        "add",
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32),
+        supports_alpha_param=True,
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            # These tests fail with aten._local_scalar_dense not being implemented.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=integral_types() + complex_types_and(torch.bool, torch.bfloat16, torch.float16, torch.float64)),
+        ),
+    ),
+    ForeachFuncInfo(
+        "sub",
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_alpha_param=True,
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+            DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"),
+                         "TestForeach", "test_parity", dtypes=(torch.complex128,),
+                         active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]),
+        ),
+    ),
+    ForeachFuncInfo(
+        "mul",
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"),
+                         "TestForeach", "test_parity", dtypes=(torch.complex128,),
+                         active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]),
+        ),
+    ),
+    ForeachFuncInfo(
+        "div",
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32, torch.int8),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=integral_types_and(torch.bool)),
+        ),
+    ),
+    ForeachFuncInfo(
+        "clamp_min",
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int64, torch.int32, torch.int8, torch.bool),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_binary_op_scalar_with_overlapping_tensors",
+                dtypes=complex_types(),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        "clamp_max",
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int64, torch.int32, torch.int8, torch.bool),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_binary_op_scalar_with_overlapping_tensors",
+                dtypes=complex_types(),
+            ),
+        ),
+    ),
+    # note(crcrpar): forward ad not implemented.
+    ForeachFuncInfo(
+        "minimum",
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+        supports_autograd=True,
+        supports_inplace_autograd=False,
+        supports_forward_ad=False,
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_binary_op_scalar_with_overlapping_tensors",
+                dtypes=complex_types(),
+            ),
+        ),
+    ),
+    # note(crcrpar): forward ad not implemented.
+    ForeachFuncInfo(
+        "maximum",
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+        supports_autograd=True,
+        supports_forward_ad=False,
+        supports_inplace_autograd=False,
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
+                         dtypes=complex_types_and(torch.bool)),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                device_type="cuda",
+                dtypes=(torch.complex128,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_binary_op_scalar_with_overlapping_tensors",
+                dtypes=complex_types(),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        "pow",
+        supports_alpha_param=False,
+        supports_scalar_self_arg=True,
+        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16, torch.int32, torch.int8, torch.bool),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=(torch.bool,),),
+            DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)),
+            DecorateInfo(
+                unittest.skip("failed starting on ROCm 6.2"),
+                "TestForeach",
+                "test_parity",
+                device_type="cuda",
+                dtypes=(torch.complex64,),
+                active_if=TEST_WITH_ROCM),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_binary_op_with_scalar_self_support",
+                device_type="cuda",
+                dtypes=(torch.bool,),
+                active_if=lambda kwargs: kwargs["is_fastpath"],
+            ),
+        ),
+        backward_requires_result=True,
+    ),
+    ForeachFuncInfo(
+        "copy",
+        sample_inputs_func=foreach_inputs_sample_func(2, False, False),
+        supports_out=False,
+        supports_forward_ad=False,
+        supports_autograd=False,
+        supports_inplace_autograd=False,
+    )
+]
+
+foreach_pointwise_op_db: list[ForeachFuncInfo] = [
+    ForeachFuncInfo(
+        "addcmul",
+        sample_inputs_func=foreach_pointwise_sample_func(4, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
+                         dtypes=(torch.bool,)),
+            # # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=integral_types() + complex_types_and(torch.bool)),
+        ),
+    ),
+    ForeachFuncInfo(
+        "addcdiv",
+        sample_inputs_func=foreach_pointwise_sample_func(4, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            # Samples have complex types and inplace only works if the dtype is complex.
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
+                         dtypes=integral_types() + complex_types_and(torch.bool)),
+            # fails with div_cpu is not implemented with ComplexHalf
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
+                         dtypes=integral_types_and(torch.bool)),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
+                         dtypes=integral_types() + complex_types_and(torch.bool)),
+        ),
+    ),
+]
+
+foreach_reduce_op_db: list[ForeachFuncInfo] = [
+    ForeachFuncInfo(
+        "max",
+        sample_inputs_func=foreach_max_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            # no complex support for ordering ops like max
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_autodiff",
+                dtypes=(torch.complex128, torch.complex64),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_foreach_reduce_large_input",
+                dtypes=(torch.complex128, torch.complex64),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=(torch.complex128, torch.complex64),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=(torch.complex128, torch.complex64),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=(torch.complex128, torch.complex64),
+            ),
+        ),
+    ),
+    ForeachFuncInfo(
+        "norm",
+        sample_inputs_func=foreach_norm_sample_func(1, False, False),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestForeach",
+                "test_foreach_reduce_large_input",
+                device_type="cuda",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+]
+
+foreach_other_op_db: list[ForeachFuncInfo] = [
+    ForeachFuncInfo(
+        "lerp",
+        sample_inputs_func=foreach_inputs_sample_func(3, True, True),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_autograd=True,
+        supports_inplace_autograd=True,
+        supports_forward_ad=True,
+        decorators=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_meta_outplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_inplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_meta_outplace",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_inplace_all_strides",
+                dtypes=integral_types_and(torch.bool),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestMeta",
+                "test_dispatch_symbolic_meta_outplace_all_strides",
+                dtypes=integral_types_and(torch.bool),
+            ),
+        ),
+    ),
+]
+
+def reference_sign(x):
+    if x.dtype == np.bool_:
+        # `np.sign` doesn't support `bool`.
+        # >>> np.sign(True)
+        # ufunc 'sign' did not contain a loop
+        # with signature matching types dtype('bool') -> dtype('bool')
+        return np.sign(x, dtype=np.uint8).astype(np.bool_)
+    return np.sign(x)
+
+
+def reference_sgn(x):
+    # NumPy doesn't have an equivalent to `torch.sgn` when the dtype is complex.
+    # For complex inputs, `np.sign` returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j.
+    # while `torch.sgn` returns, 0 if abs(input) == 0 else input/abs(input)
+    if x.dtype not in [np.complex64, np.complex128]:
+        return reference_sign(x)
+
+    out = (x / np.abs(x))
+    if out.ndim == 0:
+        # Handle x == 0 case
+        if (x == 0):
+            # Can't assign to np.complex object
+            # So make a new one.
+            return np.array(complex(0, 0), dtype=x.dtype)
+        return out
+
+    # Handle x == 0 case
+    mask = (x == 0)
+    out[mask] = complex(0, 0)
+    return out
+
+
+def reference_sigmoid(x):
+    # 'scipy.special.expit' not supported for the input types
+    if x.dtype in [np.complex64, np.complex128]:
+        return (1 / (1 + np.exp(-x)))
+    return scipy.special.expit(x)
+
+
+def reference_logsigmoid(x):
+    return np.where(
+        x < 0,
+        x - np.log1p(np.exp(x)),
+        -np.log1p(np.exp(-x)))
+
+
+def reference_hardsigmoid(x):
+    intermediate = x / 6 + 0.5
+    y = np.clip(intermediate, 0, None)
+    return np.where(y > 1, 1, y).astype(x.dtype)
+
+
+def reference_lgamma(x):
+    # scipy.special.gammaln returns `-inf` when input is `-inf`.
+    # While Pytorch, C and C++, all return `inf` when input is `-inf`.
+    # Reference:
+    # https://en.cppreference.com/w/cpp/numeric/math/lgamma
+    # https://en.cppreference.com/w/c/numeric/math/lgamma
+
+    # To handle the above discrepancy,
+    # we replace -inf with inf so values
+    # that were originally -inf map to inf as expected
+    if x.dtype.kind == 'f':
+        x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x)
+
+    out = scipy.special.gammaln(x)
+
+    if x.dtype == np.float16:
+        # `scipy.special.gammaln` returns output of float32 when input is float16,
+        # while `torch.lgamma` preserves `float16`. But due to smaller range of float16,
+        # Pytorch version outputs `inf` while SciPy returns finite values.
+        out = out.astype(np.float16)
+
+    return out
+
+
+def reference_mvlgamma(x, d):
+    if x.dtype == np.float16:
+        return scipy.special.multigammaln(x, d).astype(np.float16)
+
+    return scipy.special.multigammaln(x, d)
+
+def reference_softplus(input, beta=1, threshold=20):
+    non_linear = input * beta <= threshold
+    output = input.copy()
+    output[non_linear] = np.log(1 + np.exp(beta * input[non_linear])) / beta
+    return output
+
+def reference_gelu(X, *, approximate='none'):
+    def _gelu_ref(X):
+        return X * stats.norm.cdf(X)
+
+    def _tanh_gelu_ref(X):
+        M_SQRT_2_PI = math.sqrt(2 / math.pi)
+        Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0))
+        return 0.5 * X * (1.0 + np.tanh(Z))
+
+    if approximate == 'tanh':
+        return _tanh_gelu_ref(X)
+    else:
+        return _gelu_ref(X)
+
+
+def reference_one_hot(a: npt.NDArray, num_classes: int = -1) -> npt.NDArray:
+    if num_classes == -1:
+        num_classes = int(np.amax(a) + 1)
+
+    idcs = a.reshape(-1) + np.arange(0, a.size, dtype=np.int64) * num_classes
+    one_hot = np.zeros((a.size, num_classes), dtype=a.dtype)
+    np.put(one_hot, idcs, 1)
+    return one_hot.reshape(*a.shape, -1)
+
+
+def reference_mse_loss(input, target, reduction="mean"):
+    se = (input - target) ** 2
+    if reduction == "mean":
+        return np.mean(se)
+    elif reduction == "sum":
+        return np.sum(se)
+    else:  # reduction == "none"
+        return se
+
+
+def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, bias=None, eps=1e-5):
+    return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0]
+
+
+def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight, bias, eps):
+    feature_size = np.prod(normalized_shape)
+    inp_view = inp.reshape(-1, feature_size)  # type: ignore[call-overload]
+    mean = inp_view.mean(axis=-1, keepdims=True)
+    var = inp_view.var(axis=-1, ddof=0, keepdims=True)
+    Y = (inp_view - mean) / np.sqrt(var + eps)
+    if weight is None and bias is not None:
+        Y = Y + bias.reshape(-1)
+    elif weight is not None and bias is None:
+        Y = Y * weight.reshape(-1)
+    elif weight is not None and bias is not None:
+        Y = Y * weight.reshape(-1) + bias.reshape(-1)
+    axis = inp.ndim - len(normalized_shape)
+    stat_shape = inp.shape[:axis] + (1,) * len(normalized_shape)
+    return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape)
+
+
+def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, eps=None):
+    if eps is None:
+        eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps
+    feature_size = np.prod(normalized_shape)
+    inp_view = inp.reshape(-1, feature_size)  # type: ignore[call-overload]
+    rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps)
+    Y = inp_view / rms
+    if weight is not None:
+        Y = Y * weight.reshape(-1)
+    return Y.reshape(*inp.shape)
+
+
+def reference_group_norm(inp: npt.NDArray, num_groups: int, weight=None, bias=None, eps=1e-5):
+    inp_view = inp
+    if np.prod(inp.shape) != 0:
+        inp_view = inp.reshape((inp.shape[0], num_groups, -1))
+    mean = inp_view.mean(axis=-1, keepdims=True)
+    var = inp_view.var(axis=-1, ddof=0, keepdims=True)
+    Y = (inp_view - mean) / np.sqrt(var + eps)
+    Y = Y.reshape(inp.shape)
+    if weight is not None:
+        # weight is a vector of length equal to the channel
+        if len(Y.shape) > 2:
+            weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)])
+        Y = Y * weight
+    if bias is not None:
+        # bias is a vector of length equal to the channel
+        if len(Y.shape) > 2:
+            bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)])
+        Y = Y + bias
+    return Y
+
+
+# using a custom reference function since numpy only has a string side arg (instead of right and side) and doesn't
+# have an out_int32 arg. Additionally, numpy doesn't support searchsorted with ND arrays, so this splits those into
+# stacked 1D cases
+def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=False, side='left', sorter=None):
+    side = 'right' if (right or side == 'right') else 'left'
+    if len(sorted_sequence.shape) == 1 :
+        ret = np.searchsorted(sorted_sequence, boundary, side=side, sorter=sorter)
+        return ret.astype(np.int32) if out_int32 else ret
+    elif sorted_sequence.shape[0] == 0:
+        if sorter is not None:
+            sorter = sorter.flatten()
+        ret = np.searchsorted(sorted_sequence.flatten(), boundary.flatten(), side=side, sorter=sorter)
+        ret = ret.astype(np.int32) if out_int32 else ret
+        return ret.reshape(boundary.shape)
+    else:
+        # numpy searchsorted only supports 1D inputs so we split up ND inputs
+        orig_shape = boundary.shape
+        num_splits = np.prod(sorted_sequence.shape[:-1])
+        splits = range(0, num_splits)
+        sorted_sequence, boundary = sorted_sequence.reshape(num_splits, -1), boundary.reshape(num_splits, -1)
+        if sorter is not None:
+            sorter = sorter.reshape(num_splits, -1)
+
+        split_sequence = [sorted_sequence[i] for i in splits]
+        split_boundary = [boundary[i] for i in splits]
+        split_sorter = [sorter[i] if (sorter is not None) else None for i in splits]
+
+        split_ret = [np.searchsorted(s_seq, b, side=side, sorter=s_sort)
+                     for (s_seq, b, s_sort) in zip(split_sequence, split_boundary, split_sorter)]
+        split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret
+        return np.stack(split_ret).reshape(orig_shape)
+
+def loss_reference_reduction_wrapper(fn):
+    def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs):
+        if size_average is not None or reduce is not None:
+            raise RuntimeError(
+                "The keyword arguments 'size_average' and 'reduce' are deprecated and not supported by this wrapper"
+            )
+        output = fn(input, target, **other_kwargs)
+        if reduction == "mean":
+            return np.mean(output)
+        elif reduction == "sum":
+            return np.sum(output)
+        else:  # reduction == "none"
+            return output
+
+    return wrapper
+
+@loss_reference_reduction_wrapper
+def reference_smooth_l1_loss(input, target, beta=1.0):
+    diff = input - target
+    abs_diff = np.abs(diff)
+    above_threshold = abs_diff >= beta
+
+    loss = np.empty_like(input)
+    loss[above_threshold] = abs_diff[above_threshold] - 0.5 * beta
+    loss[~above_threshold] = diff[~above_threshold] ** 2 / (2 * beta)
+
+    return loss
+
+def reference_std_var(f):
+    """Forwards unbiased/correction kwargs as NumPy's equivalent ddof"""
+    g = reference_reduction_numpy(f)
+
+    @wraps(g)
+    def wrapper(x: npt.NDArray, *args, **kwargs):
+        assert not ('unbiased' in kwargs and 'correction' in kwargs)
+
+        if 'unbiased' in kwargs:
+            kwargs['ddof'] = int(kwargs.pop('unbiased'))
+        elif 'correction' in kwargs:
+            kwargs['ddof'] = kwargs.pop('correction')
+
+        return g(x, *args, **kwargs)
+
+    return wrapper
+
+def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
+    """Generates unbiased/correction kwargs for std/var operators"""
+    yield ((), {'unbiased': True})
+    yield ((), {'unbiased': False})
+
+    # Currently, calling std with correction is only enabled when
+    # both dim and keepdim are provided.
+    if 'dim' in kwargs and 'keepdim' in kwargs:
+        yield ((), {'correction': 0})
+        yield ((), {'correction': 1})
+
+        numel = torch.tensor(t.shape)[kwargs.get('dim')].prod()
+        yield ((), {'correction': numel // 2})
+
+def error_inputs_mean(op_info, device, is_ref=False, **kwargs):
+    if is_ref:
+        err_msg1 = (r"mean\(\): could not infer output dtype. "
+                    r"Input dtype must be either a floating point or complex dtype. "
+                    r"Got: torch.int64")
+    else:
+        err_msg1 = (r"mean\(\): could not infer output dtype. "
+                    r"Input dtype must be either a floating point or complex dtype. "
+                    r"Got: Long")
+    yield ErrorInput(
+        SampleInput(make_tensor((3, 4, 5), dtype=torch.int64, device=device), []),
+        error_regex=err_msg1,
+    )
+
+    if is_ref:
+        err_msg2 = (r"mean\(\): could not infer output dtype. "
+                    r"Optional dtype must be either a floating point or complex dtype. "
+                    r"Got: torch.int64")
+    else:
+        err_msg2 = (r"mean\(\): could not infer output dtype. "
+                    r"Optional dtype must be either a floating point or complex dtype. "
+                    r"Got: Long")
+    yield ErrorInput(
+        SampleInput(
+            make_tensor((3, 4, 5), dtype=torch.float32, device=device),
+            [],
+            dtype=torch.int64),
+        error_regex=err_msg2
+    )
+
+# numpy implementation of torch.flatten
+# unfortunately there's no np.flatten. we figure out the desired shape and call np.reshape
+def reference_flatten(input, start_dim=0, end_dim=-1):
+    in_shape = input.shape
+    in_rank = len(in_shape)
+    for d in start_dim, end_dim:
+        if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank):
+            raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank - 1}], but got {d}")
+    end_dim = end_dim if end_dim >= 0 else in_rank + end_dim
+    start_dim = start_dim if start_dim >= 0 else in_rank + start_dim
+    if in_rank == 0:
+        end_dim = start_dim
+    if end_dim < start_dim:
+        raise RuntimeError("flatten() has invalid args: start_dim cannot come after end_dim")
+    flatten_bit_dim = functools.reduce(operator.mul, in_shape[start_dim:end_dim + 1], 1)
+    out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:]
+    return np.reshape(input, out_shape)
+
+
+def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
+    yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad))
+    yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad))
+
+
+# Operator database (sorted alphabetically)
+op_db: list[OpInfo] = [
+    UnaryUfuncInfo('abs',
+                   aliases=('absolute', ),
+                   ref=np.abs,
+                   dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   skips=(
+                       DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients',
+                                    'test_inplace_grad', dtypes=(torch.cdouble,)),
+                       DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients',
+                                    'test_inplace_gradgrad', dtypes=(torch.cdouble,)),
+                       DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestFwdGradients',
+                                    'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)),
+                       DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestSparseUnaryUfuncs",
+                                    "test_inplace", dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
+                       # Reference: https://github.com/pytorch/pytorch/issues/49224
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    dtypes=[torch.int8], active_if=TEST_WITH_ASAN),
+                       # TODO: Fix test_out_arg_all_dtypes as torch.empty_like(expected_output) where expected_output=op(input)
+                       # We can break the logic of the loop over all possible types but it is OK.
+                       # https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes',
+                                    dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace',
+                                    dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace',
+                                    dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace',
+                                    dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides',
+                                    dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
+                   ),
+                   supports_fwgrad_bwgrad=True,
+                   assert_autodiffed=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_forward_ad=True),
+    # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952)
+    UnaryUfuncInfo('acos',
+                   aliases=('arccos', ),
+                   ref=np.arccos,
+                   domain=(-1, 1),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   decorators=(precisionOverride({torch.float16: 1e-2,
+                                                  torch.bfloat16: 1e-1,
+                                                  torch.complex64: 1e-2}),),
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                                    device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       # Failing with wrong imaginary sign on at least some Windows jobs
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       # Failing with wrong imaginary sign on at least some Windows jobs
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad',
+                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_method_grad',
+                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_inplace_grad',
+                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
+                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_inplace_forward_mode_AD',
+                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),)),
+    # NOTE: the derivative for inplace acosh is not implemented
+    UnaryUfuncInfo('acosh',
+                   aliases=('arccosh', ),
+                   ref=np.arccosh,
+                   domain=(1, None),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
+                   supports_inplace_autograd=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                                    device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       # Failing with wrong imaginary sign on at least some Windows jobs
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                   ),
+                   # acosh is not defined at x < 1 (real)
+                   reference_numerics_filter=NumericsFilter(
+                       condition=lambda x: (x < 1 if not x.is_complex() else torch.zeros_like(x, dtype=torch.bool)),
+                       safe_val=2)),
+    BinaryUfuncInfo('add',
+                    # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
+                    ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
+                    else np.add(input, np.multiply(alpha, other)),
+                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
+                                                     torch.float16, torch.chalf),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+                    assert_autodiffed=True,
+                    sample_inputs_func=sample_inputs_add_sub,
+                    supports_fwgrad_bwgrad=True,
+                    supports_forward_ad=True,
+                    supports_two_python_scalars=True,
+                    decorators=(
+                        DecorateInfo(
+                            toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
+                            'TestBinaryUfuncs', 'test_reference_numerics'),
+                    ),
+                    skips=(
+                        # boolean alpha not handled properly
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestNNCOpInfo',
+                                     'test_nnc_correctness',
+                                     dtypes=(torch.bool,)),
+                        DecorateInfo(unittest.skip("Skipped!"),
+                                     'TestCommon',
+                                     'test_numpy_refs',
+                                     dtypes=(torch.complex128,)),
+                        DecorateInfo(unittest.skip("Skipped!"),
+                                     'TestBinaryUfuncs',
+                                     'test_reference_numerics_extremal_values',
+                                     dtypes=(torch.complex64, torch.complex128)),
+                    )),
+    OpInfo('item',
+           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs),
+           ref=np.ndarray.item,
+           method_variant=None,
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool),
+           dtypesIfHpu=custom_types(torch.float32),
+           supports_out=False,
+           supports_autograd=False,
+           error_inputs_func=error_inputs_item,
+           sample_inputs_func=sample_inputs_item,
+           skips=(
+               # Error testing item function variant
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.float32, torch.complex64)),
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # RuntimeError: Composite compliance check failed with the above error.
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+               # Booleans mismatch: AssertionError: False is not true
+               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
+               # Booleans mismatch: AssertionError: False is not true
+               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
+           )),
+    OpInfo('arange',
+           dtypes=all_types_and(torch.bfloat16, torch.float16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+           supports_out=True,
+           supports_autograd=False,
+           is_factory_function=True,
+           error_inputs_func=error_inputs_arange,
+           sample_inputs_func=sample_inputs_arange,
+           skips=(
+               # https://github.com/pytorch/pytorch/issues/81774
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+
+               # Lazy tensor failures
+               DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
+
+               # Exception raised from analyzeImpl at ../torch/csrc/jit/ir/alias_analysis.cpp:608
+               # We don't have an op for aten::arange but it isn't a special case.
+               # Argument types: bool, bool, bool, int, int, Device, boo
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
+
+               # Captured graph does not contain aten::arange (succeeds on complex!)
+               # g: graph():
+               #   %25 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]()
+               #   return (%25)
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+           )),
+    OpInfo('cauchy',
+           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.cauchy_, inp, *args, **kwargs),
+           inplace_variant=torch.Tensor.cauchy_,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_autograd=False,
+           allow_cow_input_materialize_forward=[0],
+           sample_inputs_func=sample_inputs_cauchy,
+           error_inputs_func=error_inputs_cauchy,
+           skips=(
+               # Tests that assume input tensor has a meaningful effect on output tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+
+               # AssertionError: Tensor-likes are not close!
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+
+               # vmap: calling random operator not supported
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+
+               DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
+
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
+           )),
+    OpInfo('exponential',
+           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.exponential_, inp, *args, **kwargs),
+           inplace_variant=torch.Tensor.exponential_,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_out=False,
+           supports_autograd=False,
+           allow_cow_input_materialize_forward=[0],
+           sample_inputs_func=sample_inputs_exponential,
+           error_inputs_func=error_inputs_exponential,
+           skips=(
+               # Tests that assume input tensor has a meaningful effect on output tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+
+               # AssertionError: Tensor-likes are not close!
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+
+               # vmap: calling random operator not supported
+               DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('geometric',
+           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.geometric_, inp, *args, **kwargs),
+           inplace_variant=torch.Tensor.geometric_,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_out=False,
+           supports_autograd=False,
+           allow_cow_input_materialize_forward=[0],
+           sample_inputs_func=sample_inputs_geometric,
+           error_inputs_func=error_inputs_geometric,
+           skips=(
+               # Tests that assume input tensor has a meaningful effect on output tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+
+               # AssertionError: Tensor-likes are not close!
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+
+               # vmap: calling random operator not supported
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
+           )),
+    OpInfo('log_normal',
+           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.log_normal_, inp, *args, **kwargs),
+           inplace_variant=torch.Tensor.log_normal_,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_out=False,
+           supports_autograd=False,
+           allow_cow_input_materialize_forward=[0],
+           sample_inputs_func=sample_inputs_log_normal,
+           error_inputs_func=error_inputs_log_normal,
+           skips=(
+               # Tests that assume input tensor has a meaningful effect on output tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+
+               # AssertionError: Tensor-likes are not close!
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+
+               # vmap: calling random operator not supported
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
+           )),
+    OpInfo('normal',
+           variant_test_name='in_place',
+           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.normal_, inp, *args, **kwargs),
+           inplace_variant=torch.Tensor.normal_,
+           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_out=False,
+           supports_autograd=False,
+           allow_cow_input_materialize_forward=[0],
+           sample_inputs_func=sample_inputs_normal,
+           error_inputs_func=error_inputs_normal,
+           skips=(
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
+
+               # Tests that assume input tensor has a meaningful effect on output tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # AssertionError: Tensor-likes are not close!
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # vmap: calling random operator not supported
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+           )),
+    OpInfo('uniform',
+           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.uniform_, inp, *args, **kwargs),
+           method_variant=None,
+           inplace_variant=torch.Tensor.uniform_,
+           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_out=False,
+           supports_autograd=False,
+           is_factory_function=False,
+           allow_cow_input_materialize_forward=[0],
+           sample_inputs_func=sample_inputs_uniform,
+           error_inputs_func=error_inputs_uniform,
+           skips=(
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # Tests that assume input tensor has a meaningful effect on output tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # aten.uniform was not decomposed
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    BinaryUfuncInfo('clamp_max',
+                    ref=_clamp_max_numpy,
+                    dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+                    supports_forward_ad=True,
+                    supports_rhs_python_scalar=False,
+                    supports_fwgrad_bwgrad=True,
+                    rhs_make_tensor_kwargs=dict(exclude_zero=False),
+                    skips=(
+                        # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestBinaryUfuncs',
+                                     'test_type_promotion',
+                                     device_type='cuda'),
+                        # dispatch to lazy test failed
+                        DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
+                        # test error disabled since rhs non-tensor python scalar is supported
+                        DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'),
+                    )),
+    BinaryUfuncInfo('clamp_min',
+                    ref=_clamp_min_numpy,
+                    dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+                    supports_forward_ad=True,
+                    supports_rhs_python_scalar=False,
+                    supports_fwgrad_bwgrad=True,
+                    rhs_make_tensor_kwargs=dict(exclude_zero=False),
+                    skips=(
+                        # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestBinaryUfuncs',
+                                     'test_type_promotion',
+                                     device_type='cuda'),
+                        # dispatch to lazy test failed
+                        DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
+                        # test error disabled since rhs non-tensor python scalar is supported
+                        DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'),
+                    )),
+    BinaryUfuncInfo('mul',
+                    aliases=('multiply',),
+                    dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                    assert_autodiffed=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_two_python_scalars=True,
+                    error_inputs_sparse_func=error_inputs_sparse_mul,
+                    sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_coo),
+                    sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csr),
+                    sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csc),
+                    sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsr),
+                    sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsc)),
+    BinaryUfuncInfo('sub',
+                    # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
+                    ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)),
+                    aliases=('subtract',),
+                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                    assert_autodiffed=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    sample_inputs_func=sample_inputs_add_sub,
+                    supports_two_python_scalars=True,
+                    decorators=(
+                        DecorateInfo(
+                            toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0),
+                                               torch.bfloat16: tol(atol=1e-5, rtol=5e-3),
+                                               torch.complex32: tol(atol=1e-5, rtol=1e-3)}),
+                            'TestBinaryUfuncs', 'test_reference_numerics'),
+                        DecorateInfo(
+                            toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
+                            'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'),
+                        DecorateInfo(
+                            toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
+                            'TestDecomp', 'test_comprehensive', device_type='cpu'),
+                        DecorateInfo(
+                            toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
+                            'TestDecomp', 'test_quick', device_type='cpu'),
+                    ),
+                    skips=(
+                        DecorateInfo(unittest.skip("Skipped!"),
+                                     'TestBinaryUfuncs',
+                                     'test_reference_numerics',
+                                     dtypes=(torch.uint8,)),
+                        DecorateInfo(unittest.skip("Skipped!"),
+                                     'TestBinaryUfuncs',
+                                     'test_reference_numerics_small_values',
+                                     dtypes=(torch.uint8,)),
+                    )),
+    OpInfo('addmm',
+           # This addmm OpInfo is for when alpha and beta are not both equal to 1.
+           # alpha=beta=1 is tested in the following opinfo, because that special case will
+           # trigger addmm being decomposed by a jit pass.
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=sample_inputs_addmm,
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+               DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
+                            "TestConsistency", "test_output_grad_match", device_type="mps"),
+           )),
+    OpInfo('addmm',
+           # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
+           variant_test_name='decomposed',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
+           sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1),
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+               # https://github.com/pytorch/pytorch/issues/71784
+               DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
+                            device_type='cpu', dtypes=(torch.float16,)),
+           )),
+    OpInfo('addmv',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
+                                           torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.half: tol(atol=1e-5, rtol=3e-3)}),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
+                            "TestConsistency", "test_output_match", device_type="mps"),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
+                            "TestConsistency", "test_output_grad_match", device_type="mps"),
+           ],
+           sample_inputs_func=sample_inputs_addmv),
+    OpInfo('addbmm',
+           ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M),
+                                                                 np.multiply(np.asarray(alpha, dtype=batch1.dtype),
+                                                                             np.sum(np.matmul(batch1, batch2), axis=0))),
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
+                                                       *[torch.bfloat16]
+                                                       if SM53OrLater or TEST_WITH_ROCM else []),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05),
+                                      torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
+                   'TestCommon', 'test_numpy_refs'),
+               # MPS has slightly worse precision. Is this acceptable?
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-04),
+                                      torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
+                   'TestCommon', 'test_numpy_ref_mps'),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5),
+                                      torch.bfloat16: tol(atol=2e-1, rtol=6e-1)}),
+                   'TestConsistency',
+                   'test_output_match',
+               ),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}),
+                   'TestCommon', 'test_out'),
+               DecorateInfo(
+                   toleranceOverride({torch.half: tol(atol=6e-3, rtol=1e-2)}),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
+           ],
+           skips=(
+               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
+               # addbmm does not correctly warn when resizing out= inputs
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # https://github.com/pytorch/pytorch/issues/55907
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+           ),
+           sample_inputs_func=sample_inputs_addbmm),
+    OpInfo('baddbmm',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
+                                           torch.bfloat16),
+           backward_dtypesIfCUDA=floating_types_and(torch.float16,
+                                                    *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else [],
+                                                    torch.complex64, torch.complex128),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
+                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
+                   'TestMathBits', 'test_conj_view', device_type='cuda'),
+           ],
+           sample_inputs_func=sample_inputs_baddbmm,
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+           )),
+    OpInfo('dot',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           assert_autodiffed=True,
+           sample_inputs_func=sample_inputs_dot_vdot,
+           error_inputs_func=error_inputs_dot_vdot,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+           )),
+    OpInfo('vdot',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_dot_vdot,
+           error_inputs_func=error_inputs_dot_vdot,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+           )),
+    OpInfo('bmm',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
+                                                       *[torch.bfloat16]
+                                                       if SM53OrLater or TEST_WITH_ROCM else []),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           assert_autodiffed=True,
+           assert_jit_shape_analysis=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
+                            "TestCommon", "test_out"),
+               # Fast math on MacOS-13?
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}),
+                   'TestConsistency',
+                   'test_output_match',
+                   active_if=lambda _: MACOS_VERSION < 14.0,
+                   device_type='mps',
+                   dtypes=(torch.float32,)),
+           ),
+           sample_inputs_func=sample_inputs_bmm),
+    OpInfo('mv',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_mv),
+    OpInfo('addr',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+           # Reference: https://github.com/pytorch/pytorch/issues/50747
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # Reference: https://github.com/pytorch/pytorch/issues/50747
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
+                            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
+           ),
+           sample_inputs_func=sample_inputs_addr,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
+    OpInfo('addcmul',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # TODO: update sample inputs with for_inplace_variant kwarg to support this test
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+           ),
+           sample_inputs_func=sample_inputs_addcmul_addcdiv,
+           reference_inputs_func=partial(
+               reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)),
+    OpInfo('addcdiv',
+           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # TODO: update sample inputs with for_inplace_variant kwarg to support this test
+               DecorateInfo(unittest.expectedFailure,
+                            'TestCommon',
+                            'test_variant_consistency_eager'),
+           ),
+           sample_inputs_func=sample_inputs_addcmul_addcdiv,
+           reference_inputs_func=partial(
+               reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)),
+    UnaryUfuncInfo('asin',
+                   aliases=('arcsin', ),
+                   ref=np.arcsin,
+                   domain=(-1, 1),
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   decorators=[
+                       DecorateInfo(
+                           toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}),
+                           'TestUnaryUfuncs', device_type='cuda'
+                       ),
+                       DecorateInfo(
+                           toleranceOverride({torch.float32: tol(atol=8e-5, rtol=4e-5)}),
+                           'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
+                       ),
+                       DecorateInfo(
+                           toleranceOverride({torch.complex64: tol(atol=5e-05, rtol=2e-05)}),
+                           'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu'
+                       ),
+                       precisionOverride({torch.bfloat16: 1e-2}),
+                   ],
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                   )),
+    # NOTE: derivative for inplace asinh is not implemented
+    UnaryUfuncInfo('asinh',
+                   aliases=('arcsinh', ),
+                   ref=np.arcsinh,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
+                   supports_inplace_autograd=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cuda', dtypes=[torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                   )),
+    UnaryUfuncInfo('atan',
+                   aliases=('arctan', ),
+                   ref=np.arctan,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True,
+                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                   )),
+    BinaryUfuncInfo('atan2',
+                    aliases=('arctan2',),
+                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    promotes_int_to_float=True,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        # Incorrectly attempts to use a scalar for the second argument
+                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
+                    )),
+    UnaryUfuncInfo('atanh',
+                   aliases=('arctanh', ),
+                   ref=np.arctanh,
+                   domain=(-1, 1),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   decorators=[
+                       precisionOverride({torch.bfloat16: 1e-2}),
+                       DecorateInfo(
+                           toleranceOverride({torch.float32: tol(atol=9e-3, rtol=8e-5)}),
+                           "TestInductorOpInfo",
+                           "test_comprehensive",
+                           device_type="cuda"
+                       ),
+                       DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
+                                    "TestConsistency", "test_output_grad_match", device_type="mps"),
+                   ],
+                   supports_inplace_autograd=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cuda', dtypes=[torch.cfloat],
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                   )),
+    OpInfo('allclose',
+           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           ref=np.allclose,
+           supports_autograd=False,
+           supports_forward_ad=False,
+           sample_inputs_func=sample_inputs_allclose,
+           skips=(
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
+           ),
+           supports_out=False),
+    OpInfo('broadcast_to',
+           ref=np.broadcast_to,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_broadcast_to),
+    OpInfo('broadcast_shapes',
+           op=torch.broadcast_shapes,
+           ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None,
+           dtypes=_dispatch_dtypes((torch.float32,)),
+           supports_out=False,
+           supports_gradgrad=False,
+           assert_autodiffed=False,
+           supports_autograd=False,
+           supports_scripting=False,
+           sample_inputs_func=sample_inputs_broadcast_shapes,
+           skips=(
+               # https://github.com/pytorch/pytorch/issues/64997
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # skip dtype tests since broadcast_shape is not device dependent.
+               # having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
+               # skip these tests since we have non tensor input
+               DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
+           )),
+    OpInfo('broadcast_tensors',
+           ref=np.broadcast_arrays,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_broadcast_tensors,
+           reference_inputs_func=reference_inputs_broadcast_tensors,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           skips=(
+               # https://github.com/pytorch/pytorch/issues/64997
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # JIT does not support variadic tensors.
+               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
+           )),
+    OpInfo('block_diag',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # Default batching rule in core doesn't work for ops with TensorList args
+           check_batched_forward_grad=False,
+           skips=(
+               # https://github.com/pytorch/pytorch/issues/64997
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # JIT does not support variadic tensors.
+               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
+           ),
+           sample_inputs_func=sample_inputs_block_diag),
+    UnaryUfuncInfo('bitwise_not',
+                   ref=np.bitwise_not,
+                   dtypes=integral_types_and(torch.bool),
+                   dtypesIfHpu=custom_types(torch.bool),
+                   operator_variant=operator.invert,
+                   supports_autograd=False),
+    BinaryUfuncInfo('bitwise_left_shift',
+                    op=torch.bitwise_left_shift,
+                    dtypes=integral_types(),
+                    dtypesIfCUDA=integral_types(),
+                    dtypesIfHpu=custom_types(torch.int32, torch.int8, torch.bool),
+                    operator_variant=operator.lshift,
+                    inplace_operator_variant=operator.ilshift,
+                    supports_autograd=False,
+                    supports_one_python_scalar=True,
+                    rhs_make_tensor_kwargs=dict(low=0),
+                    skips=(
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
+                        # https://github.com/pytorch/pytorch/issues/70904
+                        DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
+                    )),
+    BinaryUfuncInfo('bitwise_right_shift',
+                    op=torch.bitwise_right_shift,
+                    dtypes=integral_types(),
+                    dtypesIfCUDA=integral_types(),
+                    dtypesIfHpu=custom_types(torch.int32, torch.int8, torch.bool),
+                    operator_variant=operator.rshift,
+                    inplace_operator_variant=operator.irshift,
+                    supports_autograd=False,
+                    supports_one_python_scalar=True,
+                    rhs_make_tensor_kwargs=dict(low=0),
+                    skips=(
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
+                        # https://github.com/pytorch/pytorch/issues/70904
+                        DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
+                    )),
+    OpInfo('combinations',
+           op=torch.combinations,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           supports_out=False,
+           sample_inputs_func=sample_inputs_combinations),
+    OpInfo('cartesian_prod',
+           op=torch.cartesian_prod,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_cartesian_prod,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
+               DecorateInfo(unittest.expectedFailure,
+                            'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           )),
+    OpInfo('cdist',
+           dtypes=floating_types(),
+           supports_out=False,
+           supports_gradgrad=False,
+           assert_autodiffed=False,
+           sample_inputs_func=sample_inputs_cdist),
+    UnaryUfuncInfo('ceil',
+                   ref=np.ceil,
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   skips=(
+                       DecorateInfo(unittest.expectedFailure,
+                                    'TestNNCOpInfo',
+                                    'test_nnc_correctness',
+                                    dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
+                   ),
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   assert_autodiffed=True),
+    OpInfo('cholesky',
+           dtypes=floating_and_complex_types(),
+           sample_inputs_func=sample_inputs_linalg_cholesky,
+           gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+           decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],),
+    OpInfo('cholesky_inverse',
+           dtypes=floating_and_complex_types(),
+           backward_dtypes=floating_and_complex_types(),
+           # https://github.com/pytorch/pytorch/issues/80411
+           gradcheck_fast_mode=True,
+           supports_fwgrad_bwgrad=True,
+           supports_forward_ad=True,
+           check_batched_gradgrad=True,
+           sample_inputs_func=sample_inputs_linalg_cholesky_inverse,
+           gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal,
+           decorators=[
+               skipCUDAIfNoMagma,
+               skipCPUIfNoLapack,
+               DecorateInfo(
+                   toleranceOverride({
+                       torch.float32: tol(atol=5e-03, rtol=1e-04)
+                   }),
+                   'TestCommon', device_type='cpu',
+               ),
+               DecorateInfo(
+                   toleranceOverride({
+                       torch.float32: tol(atol=5e-03, rtol=1e-04)
+                   }),
+                   'TestEagerFusionOpInfo', device_type='cpu',
+               ),
+           ],
+           skips=(
+               # Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),)
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),),
+           ),
+    OpInfo('cholesky_solve',
+           op=torch.cholesky_solve,
+           dtypes=floating_and_complex_types(),
+           sample_inputs_func=sample_inputs_cholesky_solve,
+           check_batched_gradgrad=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
+           decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
+    OpInfo('chunk',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           sample_inputs_func=sample_inputs_chunk,
+           reference_inputs_func=reference_inputs_chunk,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('unsafe_chunk',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           sample_inputs_func=sample_inputs_chunk,
+           check_batched_forward_grad=False,
+           reference_inputs_func=reference_inputs_chunk,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('clone',
+           ref=np.copy,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+           sample_inputs_func=sample_inputs_clone_contiguous,
+           reference_inputs_func=reference_inputs_clone_contiguous,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           skips=(
+               # TypeError: _copy_dispatcher() got an unexpected keyword argument 'memory_format'
+               # (NumPy reference needs to be extended with memory_format)
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'),
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'),
+           ),),
+    OpInfo('contiguous',
+           op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs),
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           sample_inputs_func=sample_inputs_clone_contiguous,
+           reference_inputs_func=reference_inputs_clone_contiguous,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           autodiff_fusible_nodes=['aten::contiguous'],
+           assert_jit_shape_analysis=True,
+           supports_out=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+           )),
+    OpInfo('sum_to_size',
+           op=lambda x, *args, **kwargs: x.sum_to_size(*args, **kwargs),
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_sum_to_size,
+           error_inputs_func=error_inputs_sum_to_size,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           skips=(
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float,)),
+           )),
+    OpInfo('clamp',
+           aliases=('clip',),
+           ref=_clamp_numpy,
+           dtypes=all_types_and(torch.bfloat16, torch.half),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+           sample_inputs_func=sample_inputs_clamp,
+           reference_inputs_func=partial(reference_inputs_elementwise_ternary, sample_inputs_func=sample_inputs_clamp),
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # NNC appear to not handle boolean clamp
+               DecorateInfo(unittest.expectedFailure,
+                            'TestNNCOpInfo',
+                            'test_nnc_correctness',
+                            dtypes=(torch.bool,)),
+               # MPS does not support float64, while numpy does internal computations in float64.
+               # See https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264
+               DecorateInfo(unittest.expectedFailure,
+                            'TestCommon',
+                            'test_numpy_ref_mps'),
+           )),
+    UnaryUfuncInfo('positive',
+                   ref=np.positive,
+                   dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
+                   supports_out=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   ),
+    UnaryUfuncInfo('conj',
+                   ref=np.conj,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
+                                                    torch.half, torch.chalf),
+                   dtypesIfHpu=custom_types(torch.float32, torch.int32),
+                   supports_sparse=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   # See https://github.com/pytorch/pytorch/pull/78358
+                   check_batched_forward_grad=False,
+                   supports_out=False),
+    UnaryUfuncInfo('conj_physical',
+                   decomp_aten_name='_conj_physical',
+                   ref=np.conj,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
+                                                    torch.half, torch.chalf),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   skips=(
+                       # RuntimeError: inputSet && outputSet
+                       # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118,
+                       # please report a bug to PyTorch.
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
+                       DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"),
+                                    'TestSparseUnaryUfuncs', 'test_inplace'),
+                   )),
+    OpInfo('resolve_conj',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_view_as_real,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           ),
+    OpInfo('resolve_neg',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_view_as_real,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           ),
+    OpInfo('view_as_real',
+           dtypes=complex_types(),
+           supports_forward_ad=True,
+           supports_out=False,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_view_as_real,
+           test_conjugated_samples=False,
+           ),
+    OpInfo('view_as_complex',
+           dtypes=floating_types_and(torch.half),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           test_neg_view=False,
+           sample_inputs_func=sample_inputs_view_as_complex,
+           skips=(
+               # RuntimeError: Tensor must have a last dimension with stride 1
+               DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"),
+               # RuntimeError: "eq_cpu" not implemented for 'ComplexHalf'
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)),
+               # RuntimeError: view size is not compatible with input tensor's size and stride
+               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+           )),
+    BinaryUfuncInfo('complex',
+                    dtypes=floating_types_and(torch.half),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_rhs_python_scalar=False,
+                    error_inputs_func=error_inputs_complex,
+                    skips=(
+                        # Tests don't account for complex's type promotion semantics
+                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps'),
+                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),)),
+    BinaryUfuncInfo('copysign',
+                    sample_inputs_func=sample_inputs_copysign,
+                    dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                    promotes_int_to_float=True,
+                    # https://github.com/pytorch/pytorch/issues/80411
+                    gradcheck_fast_mode=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True),
+    OpInfo('corrcoef',
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_corrcoef,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+           ),
+           supports_out=False),
+    UnaryUfuncInfo('cos',
+                   ref=np.cos,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   handles_large_floats=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
+                       # This fails on CUDA but passes on ROCm
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=(torch.cdouble,), device_type='cuda'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
+                       # AssertionError: Tensor-likes are not close!
+                       # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed)
+                       # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
+                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cuda',
+                                    dtypes=(torch.chalf,), active_if=IS_WINDOWS),
+                   )),
+    UnaryUfuncInfo('cosh',
+                   ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/issues/48641
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=[torch.int8]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=[torch.cdouble]),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu',
+                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
+                       # AssertionError: Tensor-likes are not close!
+                       # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed)
+                       # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed)
+                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cuda',
+                                    dtypes=(torch.chalf,), active_if=IS_WINDOWS),
+                   )),
+    OpInfo('cov',
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_cov,
+           error_inputs_func=error_inputs_cov,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+               # Float did not match double
+               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
+               # Jacobian mismatch
+               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
+               DecorateInfo(unittest.skip("Barely fails"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
+               # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507)
+               # RuntimeError:
+               # undefined value tensor:
+               #   File "", line 3
+               # def the_method(i0):
+               #     return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950
+               #                                                                ~~~~~~ <--- HERE
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}),
+                            "TestInductorOpInfo", "test_comprehensive", device_type="cpu"),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=3e-4, rtol=1e-4)}),
+                            "TestConsistency", "test_output_grad_match", device_type="mps"),
+           )),
+    OpInfo('cross',
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           sample_inputs_func=sample_inputs_cross,
+           supports_fwgrad_bwgrad=True,
+           supports_out=True,
+           supports_forward_ad=True),
+    OpInfo('cumsum',
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # cumsum does not handle correctly out= dtypes
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+           ),
+           sample_inputs_func=sample_inputs_cumulative_ops),
+    OpInfo('cumprod',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # cumprod does not handle correctly out= dtypes
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+           ),
+           # gradgradcheck fails in fast_mode=True: #56275
+           sample_inputs_func=sample_inputs_cumprod,
+           gradcheck_fast_mode=False),
+    OpInfo('cummax',
+           dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+           ),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
+    OpInfo('cummin',
+           dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+           ),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
+    UnaryUfuncInfo('deg2rad',
+                   ref=np.radians,
+                   decorators=(precisionOverride({torch.bfloat16: 7e-1,
+                                                  torch.float16: 7e-1}),),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True),
+    OpInfo('diff',
+           op=torch.diff,
+           # np.diff has np._NoValue as default values for prepend and append, compare_with_reference breaks if prepend/append
+           # are set as None when converting to numpy
+           ref=lambda input, n=1, dim=-1, prepend=np._NoValue, append=np._NoValue: (
+               np.diff(input, n, dim, np._NoValue if prepend is None else prepend, np._NoValue if append is None else append)
+           ),
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_diff,
+           error_inputs_func=error_inputs_diff,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           skips=(
+           )),
+    BinaryUfuncInfo('div',
+                    aliases=('divide',),
+                    variant_test_name='no_rounding_mode',
+                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+                    # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+                    gradcheck_fast_mode=True,
+                    supports_forward_ad=True,
+                    promotes_int_to_float=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_two_python_scalars=True,
+                    assert_autodiffed=True,
+                    rhs_make_tensor_kwargs=dict(exclude_zero=True),),
+    BinaryUfuncInfo('div',
+                    aliases=('divide',),
+                    variant_test_name='trunc_rounding',
+                    dtypes=all_types_and(torch.half, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+                    sample_kwargs=lambda device, dtype, input:
+                        ({"rounding_mode": "trunc"}, {"rounding_mode": "trunc"}),
+                    # https://github.com/pytorch/pytorch/issues/80411
+                    gradcheck_fast_mode=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_two_python_scalars=True,
+                    assert_autodiffed=True,
+                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
+                    decorators=(
+                        # See https://github.com/pytorch/pytorch/issues/111126
+                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+                    ),
+                    skips=(
+                        # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div
+                        DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'),
+                        # FIXME:
+                        # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
+                        # output 0 with respect to input 1,
+                        # numerical:tensor(-17746.9307, dtype=torch.float64)
+                        # analytical:tensor(0., dtype=torch.float64)
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
+                                     'test_fn_grad', device_type='cpu',
+                                     dtypes=(torch.float64,)),
+                    )),
+    BinaryUfuncInfo('div',
+                    aliases=('divide',),
+                    variant_test_name='floor_rounding',
+                    dtypes=all_types_and(torch.half, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+                    sample_kwargs=lambda device, dtype, input:
+                        ({"rounding_mode": "floor"}, {"rounding_mode": "floor"}),
+                    # https://github.com/pytorch/pytorch/issues/80411
+                    gradcheck_fast_mode=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_two_python_scalars=True,
+                    assert_autodiffed=True,
+                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
+                    decorators=(
+                        # See https://github.com/pytorch/pytorch/issues/111126
+                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+                    ),
+                    skips=(
+                        # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div
+                        DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'),
+                        # FIXME:
+                        # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
+                        # output 0 with respect to input 1,
+                        # numerical:tensor(-17746.9307, dtype=torch.float64)
+                        # analytical:tensor(0., dtype=torch.float64)
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
+                                     'test_fn_grad',
+                                     dtypes=(torch.float64,),
+                                     device_type='cpu'),
+                        DecorateInfo(unittest.skip("Broken on MacOS13"),
+                                     'TestConsistency',
+                                     'test_output_match',
+                                     device_type='mps',
+                                     dtypes=(torch.float16,),
+                                     active_if=lambda _: MACOS_VERSION < 14.0),
+                    )),
+    BinaryUfuncInfo('true_divide',
+                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+                    supports_forward_ad=True,
+                    promotes_int_to_float=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_two_python_scalars=True,
+                    rhs_make_tensor_kwargs=dict(exclude_zero=True)),
+    OpInfo('equal',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+           ref=lambda input, other: (input == other).all(),
+           sample_inputs_func=sample_inputs_equal,
+           supports_autograd=False,
+           supports_tracing=False,
+           skips=(
+           )),
+    UnaryUfuncInfo('exp',
+                   ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/issues/48010
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+                   ),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True),
+    OpInfo('expand',
+           op=lambda self, shape: self.expand(shape),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           sample_inputs_func=sample_inputs_expand,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           supports_out=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+           )),
+    OpInfo('expand_as',
+           op=lambda self, other: self.expand_as(other),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_expand_as,
+           supports_out=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),),
+           ),
+    OpInfo('expand_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_expand,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           supports_out=True,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           )),
+    OpInfo('diag',
+           ref=np.diag,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+           dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_diag,
+           error_inputs_func=error_inputs_diag),
+    OpInfo('diag_embed',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           supports_out=False,
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_diagonal_diag_embed,
+           reference_inputs_func=reference_inputs_diagonal_diag_embed,
+           error_inputs_func=error_inputs_diagonal_diag_embed),
+    OpInfo('diagonal',
+           aten_backward_name='diagonal_backward',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_diagonal_diag_embed,
+           reference_inputs_func=reference_inputs_diagonal_diag_embed,
+           error_inputs_func=error_inputs_diagonal_diag_embed),
+    OpInfo('diagonal_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_diagonal_diag_embed,
+           reference_inputs_func=reference_inputs_diagonal_diag_embed,
+           error_inputs_func=error_inputs_diagonal_diag_embed),
+    OpInfo('diagonal_scatter',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_diagonal_scatter),
+    OpInfo('alias_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           sample_inputs_func=sample_inputs_alias_copy,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=True),
+    BinaryUfuncInfo('eq',
+                    ref=np.equal,
+                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+                    always_returns_bool=True,
+                    supports_autograd=False,
+                    sample_inputs_func=sample_inputs_comparison_ops,
+                    skips=(
+                    )),
+    BinaryUfuncInfo('fmax',
+                    op=torch.fmax,
+                    dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
+                    )),
+    BinaryUfuncInfo('fmin',
+                    op=torch.fmin,
+                    dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
+                    )),
+    BinaryUfuncInfo('fmod',
+                    ref=np.fmod,
+                    dtypes=all_types_and(torch.float16, torch.bfloat16),
+                    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+                    # https://github.com/pytorch/pytorch/issues/80411
+                    gradcheck_fast_mode=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    assert_autodiffed=None,
+                    rhs_make_tensor_kwargs={'exclude_zero': True},
+                    decorators=(
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                                     'test_contig_vs_every_other',
+                                     dtypes=(torch.bfloat16,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                                     'test_non_contig',
+                                     dtypes=(torch.bfloat16,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                                     'test_reference_numerics',
+                                     dtypes=(torch.bfloat16,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                                     'test_reference_numerics_small_values',
+                                     dtypes=(torch.uint8,)),
+                        # FIXME:
+                        # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
+                        # output 0 with respect to input 1,
+                        # numerical:tensor(101.6283, dtype=torch.float64)
+                        # analytical:tensor(-18.3575, dtype=torch.float64)
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
+                                     'test_fn_grad',
+                                     dtypes=(torch.float64,),
+                                     device_type='cpu'),
+                    )),
+    BinaryUfuncInfo('remainder',
+                    ref=np.remainder,
+                    dtypes=all_types_and(torch.float16, torch.bfloat16),
+                    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+                    # https://github.com/pytorch/pytorch/issues/80411
+                    gradcheck_fast_mode=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    assert_autodiffed=None,
+                    operator_variant=operator.mod,
+                    inplace_operator_variant=operator.imod,
+                    supports_one_python_scalar=True,
+                    rhs_make_tensor_kwargs={'exclude_zero': True},
+                    decorators=(
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                                     'test_contig_vs_every_other',
+                                     dtypes=(torch.bfloat16,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                                     'test_non_contig',
+                                     dtypes=(torch.bfloat16,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                                     'test_reference_numerics',
+                                     dtypes=(torch.bfloat16,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                                     'test_reference_numerics_small_values',
+                                     dtypes=(torch.uint8,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo',
+                                     'test_nnc_correctness',
+                                     dtypes=(torch.bfloat16,)),
+                        # Fails on XLA
+                        # False is not true : Tensors failed to compare as equal!
+                        # Attempted to compare equality of tensors with different dtypes
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
+                        # FIXME:
+                        # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
+                        # output 0 with respect to input 1,
+                        # numerical:tensor(102.4676, dtype=torch.float64)
+                        # analytical:tensor(-17.5182, dtype=torch.float64)
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
+                                     'test_fn_grad', device_type='cpu',
+                                     dtypes=(torch.float64,)),
+                        DecorateInfo(
+                            toleranceOverride({
+                                torch.float16: tol(atol=5e-4, rtol=3e-3),
+                            }),
+                            "TestInductorOpInfo",
+                            "test_comprehensive",
+                            device_type="cuda"
+                        ),
+                        DecorateInfo(unittest.skip("Broken on MacOS13"),
+                                     'TestConsistency',
+                                     'test_output_match',
+                                     device_type='mps',
+                                     dtypes=(torch.float16,),
+                                     active_if=lambda _: MACOS_VERSION < 14.0),
+                    )),
+    UnaryUfuncInfo('frac',
+                   ref=lambda x: np.modf(x)[0],
+                   dtypes=floating_types_and(torch.bfloat16, torch.float16),
+                   dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)),
+                       # 76047
+                       DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
+                                    dtypes=(torch.bfloat16, torch.float32, torch.float64)),
+                   )),
+    OpInfo('stft',
+           decorators=[
+               skipCPUIfNoFFT,
+               DecorateInfo(unittest.skip("Skipped! stft does not match the native function"),
+                            'TestJit', 'test_variant_consistency_jit'),
+           ],
+           dtypes=floating_and_complex_types(),
+           sample_inputs_func=sample_inputs_stft,
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           supports_out=False,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           ),
+    OpInfo('istft',
+           dtypes=complex_types(),
+           sample_inputs_func=sample_inputs_istft,
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           supports_out=False,
+           decorators=(
+               DecorateInfo(unittest.skip("Skipped! istft does not match the native function"),
+                            'TestJit', 'test_variant_consistency_jit'),
+           ),
+           skips=(
+               skipCPUIfNoFFT,
+               # gradcheck fails on ROCm (gh-68429)
+               # grad is computed improperly (probably for weights tensor)
+               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
+               # Pre-existing condition (calls .item); needs to be fixed
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
+           )),
+    UnaryUfuncInfo('floor',
+                   ref=np.floor,
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   skips=(
+                       DecorateInfo(unittest.expectedFailure,
+                                    'TestNNCOpInfo',
+                                    'test_nnc_correctness',
+                                    dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
+                   ),
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   assert_autodiffed=True),
+    OpInfo('flip',
+           op=torch.flip,
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+           sample_inputs_func=sample_inputs_flip,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('fliplr',
+           op=torch.fliplr,
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_fliplr_flipud,
+           error_inputs_func=error_inputs_fliplr,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('flipud',
+           op=torch.flipud,
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_fliplr_flipud,
+           error_inputs_func=error_inputs_flipud,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('sparse.sampled_addmm',
+           dtypes=floating_and_complex_types(),
+           supports_autograd=True,
+           sample_inputs_func=sample_inputs_sparse_sampled_addmm,
+           decorators=[
+               skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3))
+                               or (_get_torch_rocm_version() >= (5, 2))),
+                          "cusparseSDDMM was added in 11.2.1"),
+               skipCPUIfNoMklSparse, ],
+           skips=(
+               # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               # RuntimeError: Sparse CSR tensors do not have strides.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
+               # RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # RuntimeError: unsupported memory format option Preserve
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
+               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
+               # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype.
+               # RuntimeError: Sparse CSR tensors do not have is_contiguous
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
+               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+               # NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
+           )),
+    OpInfo('sparse.mm',
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
+           variant_test_name='reduce',
+           supports_autograd=True,
+           supports_out=False,
+           supports_gradgrad=False,
+           supports_forward_ad=False,
+           sample_inputs_func=sample_inputs_sparse_mm_reduce,
+           decorators=[onlyCPU],
+           skips=(
+               # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               # RuntimeError: Sparse CSR tensors do not have strides.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # RuntimeError: unsupported memory format option Preserve
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
+               # RuntimeError: Sparse CSR tensors do not have is_contiguou
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
+               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
+               # RuntimeError: Sparse CSR tensors do not have strides
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
+               # NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
+           )),
+    UnaryUfuncInfo('i0',
+                   ref=np_unary_ufunc_integer_promotion_wrapper(
+                       scipy.special.i0) if TEST_SCIPY else None,
+                   aliases=('special.i0',),
+                   decorators=(precisionOverride({torch.bfloat16: 3e-1,
+                                                  torch.float16: 5e-1}),),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   sample_inputs_func=sample_inputs_i0_i1,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=(torch.int8,)),
+                   )),
+    BinaryUfuncInfo('floor_divide',
+                    ref=_floor_divide_np,
+                    dtypes=all_types_and(torch.half, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+
+                    supports_autograd=False,
+                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
+                    supports_two_python_scalars=True,
+                    skips=(
+                        # AssertionError: Results of original model and exported/imported version of model differed
+                        DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
+                        # bfloat16 floor_divide compared with a float32 reference works inconsistently
+                        DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
+                                     dtypes=(torch.bfloat16,)),
+                        # int8 floor divide has different results for -128 // -1 vs. NumPy
+                        DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
+                                     dtypes=(torch.int8,)),
+                        # The following tests fails on some jobs
+                        DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
+                                     dtypes=(torch.float16,)),
+                        DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}),
+                                     'TestBinaryUfuncs', 'test_reference_numerics'),
+                    )),
+    UnaryUfuncInfo('frexp',
+                   op=torch.frexp,
+                   ref=np.frexp,
+                   dtypes=floating_types_and(torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   # skip testing torch.frexp as it is not supported by ROCm platform yet
+                   decorators=[],
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   skips=(
+                       # skips below tests as torch.frexp returns tuple-like (mantissa, exponent) as outputs,
+                       # while these tests currently requires output to a single tensor.
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_non_contig_expand'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
+
+                       # skips test_reference_numerics due to error in Windows CI.
+                       # The np.frexp returns exponent as np.intc dtype on Windows platform,
+                       # and np.intc does not have the correspond torch dtype
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    active_if=IS_WINDOWS),
+                   )),
+    UnaryUfuncInfo('log1p',
+                   ref=np.log1p,
+                   aliases=('special.log1p',),
+                   domain=(-1, None),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   assert_autodiffed=True,
+                   promotes_int_to_float=True),
+    BinaryUfuncInfo('ge',
+                    ref=np.greater_equal,
+                    aliases=('greater_equal',),
+                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+                    always_returns_bool=True,
+                    supports_autograd=False,
+                    skips=(
+                    )),
+    OpInfo('geqrf',
+           dtypes=floating_and_complex_types(),
+           sample_inputs_func=sample_inputs_linalg_qr_geqrf,
+           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+           supports_autograd=False,
+           skips=(
+               # FIXME: geqrf can't forward with complex inputs that require grad
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
+               # Strides are not the same!
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+           )),
+    BinaryUfuncInfo('gt',
+                    ref=np.greater,
+                    aliases=('greater',),
+                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+                    always_returns_bool=True,
+                    supports_autograd=False,
+                    skips=(
+                    )),
+    UnaryUfuncInfo('imag',
+                   ref=np.imag,
+                   dtypes=complex_types_and(torch.chalf),
+                   supports_out=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   # See https://github.com/pytorch/pytorch/issues/66357
+                   # RuntimeError: view_as_real doesn't work on unresolved conjugated tensors.
+                   check_batched_forward_grad=False,
+                   skips=(
+                       # Skip since real and imag don't have out variants.
+                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
+                   )),
+    OpInfo('gradient',
+           dtypes=floating_and_complex_types_and(torch.int8, torch.int16,
+                                                 torch.int32, torch.int64,
+                                                 torch.bfloat16, torch.half),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # following tests give a runtime error with undefined value tensor
+               # see discussion : https://github.com/pytorch/pytorch/issues/56660
+               # RuntimeError:
+               # Arguments for call are not valid.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)),  # noqa: B950
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
+           ),
+           supports_inplace_autograd=False,
+           sample_inputs_func=sample_inputs_gradient,
+           error_inputs_func=error_inputs_gradient),
+    OpInfo('isin',
+           dtypes=all_types_and(torch.bfloat16, torch.half),
+           supports_autograd=False,
+           sample_inputs_func=sample_inputs_isin),
+    OpInfo('kthvalue',
+           dtypes=all_types_and(torch.bfloat16, torch.float16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_kthvalue,
+           error_inputs_func=error_inputs_kthvalue),
+    BinaryUfuncInfo('le',
+                    ref=np.less_equal,
+                    aliases=('less_equal',),
+                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
+                    always_returns_bool=True,
+                    supports_autograd=False,
+                    skips=(
+                    )),
+    OpInfo('linspace',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           is_factory_function=True,
+           supports_out=True,
+           supports_autograd=False,
+           error_inputs_func=error_inputs_linspace,
+           sample_inputs_func=sample_inputs_linspace,
+           skips=(
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
+               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
+               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
+               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.cfloat,), device_type="cuda"),
+           )),
+    OpInfo('linspace',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           is_factory_function=True,
+           supports_out=True,
+           supports_autograd=False,
+           error_inputs_func=error_inputs_linspace,
+           sample_inputs_func=sample_inputs_linspace_tensor_overload,
+           variant_test_name="tensor_overload",
+           skips=(
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # TypeError: 'int' object is not subscriptable
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
+               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
+               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
+               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.cfloat,), device_type="cuda"),
+           )),
+    OpInfo('logspace',
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           is_factory_function=True,
+           supports_out=True,
+           supports_autograd=False,
+           error_inputs_func=error_inputs_linspace,
+           sample_inputs_func=sample_inputs_logspace,
+           skips=(
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+
+               # Off-by-one issue when casting floats to ints
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick',
+                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive',
+                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
+               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
+               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
+               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
+               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.cfloat,), device_type="cuda"),
+           )),
+    OpInfo('logspace',
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           is_factory_function=True,
+           supports_out=True,
+           supports_autograd=False,
+           error_inputs_func=error_inputs_linspace,
+           sample_inputs_func=sample_inputs_logspace_tensor_overload,
+           variant_test_name="tensor_overload",
+           skips=(
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # TypeError: 'int' object is not subscriptable
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+
+               # Off-by-one issue when casting floats to ints
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick',
+                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive',
+                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
+               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
+               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
+               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
+               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.cfloat,), device_type="cuda"),
+           )),
+    UnaryUfuncInfo('log',
+                   ref=np.log,
+                   domain=(0, None),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+                   backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                   ),
+                   # log(z)->-inf for |z|->0
+                   reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
+    UnaryUfuncInfo('log10',
+                   ref=np.log10,
+                   domain=(0, None),
+                   decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=IS_WINDOWS),
+                   ),
+                   # log10(z)->-inf for |z|->0
+                   reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
+    UnaryUfuncInfo('log2',
+                   ref=np.log2,
+                   domain=(0, None),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.cfloat, torch.cdouble]),
+                   ),
+                   # log2(z)->-inf for |z|->0
+                   reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
+    BinaryUfuncInfo('ldexp',
+                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                    # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+                    gradcheck_fast_mode=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_inplace_autograd=False,
+                    promotes_int_to_float=True,
+                    supports_out=True,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        # RuntimeError: mul(): functions with out=... arguments don't support
+                        # automatic differentiation, but one of the arguments requires grad
+                        # https://github.com/pytorch/pytorch/issues/68966
+                        DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+                        DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+                        DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+                        DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+                    ),
+                    decorators=[
+                        DecorateInfo(
+                            toleranceOverride({
+                                torch.complex64: tol(atol=1e-05, rtol=1e-05)
+                            }),
+                            'TestCommon', device_type='cpu',
+                        ),
+                    ], ),
+    BinaryUfuncInfo('logaddexp',
+                    dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
+                    dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        # TODO: FIXME: RuntimeError: not implemented for 'ComplexFloat'
+                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
+                    )),
+    OpInfo('logaddexp2',
+           dtypes=floating_types_and(torch.bfloat16, torch.half),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_logaddexp),
+    UnaryUfuncInfo('logical_not',
+                   ref=np.logical_not,
+                   decorators=(precisionOverride({torch.bfloat16: 7e-1,
+                                                  torch.float16: 5e-1}),),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool),
+                   supports_autograd=False,
+                   skips=(
+                       # The function variant always returns BoolTensor
+                       # while the inplace variant preserves the input dtype.
+                       # >>> t = torch.randn(3)
+                       # >>> torch.logical_not(t)
+                       # tensor([False, False, False])
+                       # >>> torch.logical_not(t).dtype
+                       # torch.bool
+                       # >>> t.logical_not_().dtype
+                       # torch.float32
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency',
+                                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
+                                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)),
+                   )),
+    BinaryUfuncInfo('lt',
+                    ref=np.less,
+                    aliases=('less',),
+                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.int32),
+                    always_returns_bool=True,
+                    supports_autograd=False,
+                    skips=(
+                    )),
+    OpInfo('lu_unpack',
+           op=torch.lu_unpack,
+           dtypes=floating_and_complex_types(),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(skipCPUIfNoLapack,),
+           sample_inputs_func=sample_inputs_lu_unpack),
+    OpInfo('lu',
+           op=torch.lu,
+           dtypes=floating_and_complex_types(),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_lu,
+           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+           skips=(
+               # we skip jit tests because `lu` is a torch function
+               # RuntimeError:
+               # 'Tensor (inferred)' object has no attribute or method 'lu'.:
+               # File "", line 3
+               # def the_method(i0):
+               #     return i0.lu(True, True)
+               #            ~~~~~ <--- HERE
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # RuntimeError not raised: Expected RuntimeError when calling with input.device=cpu and out.device=cuda
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+           )),
+    OpInfo('lu_solve',
+           op=torch.lu_solve,
+           dtypes=floating_and_complex_types(),
+           supports_forward_ad=True,
+           # See https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_lu_solve,
+           skips=(
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
+                            device_type='mps', dtypes=[torch.float32]),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
+                            device_type='mps', dtypes=[torch.float32]),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            device_type='mps', dtypes=[torch.float32]),
+               DecorateInfo(unittest.skip("Tests different backward paths"),
+                            "TestCommon", "test_floating_inputs_are_differentiable"),),
+           decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver]),
+    OpInfo('masked_fill',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool, torch.int32),
+           sample_inputs_func=sample_inputs_masked_fill,
+           error_inputs_func=error_inputs_masked_fill,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           supports_out=False),
+    OpInfo('masked_scatter',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool, torch.int32),
+           sample_inputs_func=sample_inputs_masked_scatter,
+           error_inputs_func=error_inputs_masked_scatter,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           supports_out=False,
+           skips=(
+               # Compiler issue on ROCm. Regression started in ROCm 6.4.
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    OpInfo('masked_select',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_masked_select,
+           error_inputs_func=error_inputs_masked_select,
+           skips=(
+               # Compiler issue on ROCm. Might need to skip until ROCm5.5
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    OpInfo('matrix_exp',
+           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           aliases=('linalg.matrix_exp',),
+           sample_inputs_func=sample_inputs_matrix_exp,
+           # Needs to construct a 2nx2n matrix by copy_ ing into it
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           skips=(
+               # mexp does not support bf16 and fp16
+               DecorateInfo(unittest.skip('Skipped!'), 'TestInductorOpInfo', 'test_comprehensive',
+                            dtypes=[torch.half], device_type="cpu"),
+           ),
+           supports_out=False,
+           ),
+    OpInfo('matmul',
+           aliases=('linalg.matmul',),
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
+                                                       *[torch.bfloat16]
+                                                       if SM53OrLater or TEST_WITH_ROCM else []),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           assert_autodiffed=True,
+           assert_jit_shape_analysis=True,
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False),
+           decorators=[
+               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
+               # ROCm intermittently fails the test with standard atol/rtol
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}),
+                            'TestCommon', 'test_noncontiguous_samples', device_type='cuda',
+                            active_if=TEST_WITH_ROCM),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}),
+                            'TestCommon', 'test_out', device_type='cuda',
+                            active_if=TEST_WITH_ROCM),
+               # mv for the sample with shapes (S, S, M, M), (M,) has some variance in the
+               # backward on CPU
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=0, rtol=1e-5)}),
+                            'TestCommon', 'test_noncontiguous_samples',
+                            device_type='cpu'),
+               DecorateInfo(
+                   toleranceOverride({
+                       torch.float32: tol(atol=1e-5, rtol=1e-5),
+                       torch.complex64: tol(atol=1e-5, rtol=1e-5),
+                   }),
+                   "TestDecomp", "test_comprehensive", device_type="cuda",
+               ),
+           ],
+           skips=(
+               # Strides are not the same!
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               # https://github.com/pytorch/pytorch/issues/67470
+               DecorateInfo(unittest.skip("67470!"),
+                            'TestCommon', 'test_noncontiguous_samples',
+                            device_type='cpu', dtypes=(torch.long,)),
+               # AssertionError: False is not true : Tensors failed to compare as equal!
+               DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo',
+                            device_type='xla', dtypes=(torch.long,)),
+               # https://github.com/pytorch/pytorch/issues/71774
+               DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
+                            device_type='cpu', dtypes=(torch.long,)),
+           )),
+    OpInfo('max',
+           variant_test_name='reduction_with_dim',
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+           ),
+           supports_forward_ad=True),
+    OpInfo('max',
+           variant_test_name='reduction_no_dim',
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
+           skips=(
+           )),
+    OpInfo('median',
+           dtypes=all_types_and(torch.bfloat16, torch.float16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           # TODO: some signatures of median do support out
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           error_inputs_func=error_inputs_median,
+           sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
+    OpInfo('nanmedian',
+           dtypes=all_types_and(torch.bfloat16, torch.float16),
+           # TODO: some signatures of nanmedian do support out
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
+    OpInfo('var_mean',
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_std_var,
+           # TODO: some signatures of var_mean do support out
+           supports_out=False,
+           supports_forward_ad=True,
+           check_batched_forward_grad=False,
+           supports_fwgrad_bwgrad=True,
+           decorators=(
+               DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
+                            "TestDecomp", "test_comprehensive", device_type="cuda"),
+               DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
+                            "TestInductorOpInfo", "test_comprehensive", device_type="cuda"),
+           )),
+    OpInfo('var_mean',
+           variant_test_name='unbiased',
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_std_var_unbiased,
+           # TODO: some signatures of var_mean do support out
+           supports_out=False,
+           supports_forward_ad=True,
+           check_batched_forward_grad=False,
+           supports_fwgrad_bwgrad=True,
+           decorators=(
+               DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
+                            "TestDecomp", "test_comprehensive", device_type="cuda"),
+               DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
+                            "TestInductorOpInfo", "test_comprehensive", device_type="cuda"),
+           )),
+    OpInfo('std_mean',
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_std_var,
+           # TODO: some signatures of std_mean do support out
+           supports_out=False,
+           supports_forward_ad=True,
+           check_batched_forward_grad=False,
+           supports_fwgrad_bwgrad=True,
+           decorators=(
+               DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
+                            "TestDecomp", "test_comprehensive", device_type="cuda"),
+           )),
+    OpInfo('std_mean',
+           variant_test_name='unbiased',
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_std_var_unbiased,
+           # TODO: some signatures of var_mean do support out
+           supports_out=False,
+           supports_forward_ad=True,
+           check_batched_forward_grad=False,
+           supports_fwgrad_bwgrad=True,
+           decorators=(
+               DecorateInfo(
+                   toleranceOverride({
+                       torch.float16: tol(atol=4e-5, rtol=9e-3),
+                       torch.float64: tol(atol=2e-7, rtol=2e-7),
+                   }),
+                   "TestDecomp",
+                   "test_comprehensive",
+                   device_type="cuda"
+               ),
+               DecorateInfo(
+                   toleranceOverride({
+                       torch.float16: tol(atol=4e-5, rtol=9e-3),
+                       torch.float64: tol(atol=2e-7, rtol=2e-7),
+                   }),
+                   "TestInductorOpInfo",
+                   "test_comprehensive",
+                   device_type="cuda"
+               ),
+           )),
+    OpInfo('meshgrid',
+           variant_test_name='variadic_tensors',
+           ref=np.meshgrid,
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'),
+           skips=[
+               # JIT does not support variadic tensors.
+               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # meshgrid is defined in torch.functional to take a
+               # variadic list of tensors. Variadic parameters are not
+               # compatible with the normalize operator tests.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # Skip operator schema test because this is a functional and not an operator
+               DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+           ],
+           supports_out=False,
+           supports_fwgrad_bwgrad=True,
+           supports_forward_ad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,),
+    OpInfo('meshgrid',
+           variant_test_name='list_of_tensors',
+           # Unlike the variant above, we do not use np.meshgrid as a
+           # ref since it does not officially support list of numpy
+           # arrays.
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_meshgrid, variant='list'),
+           skips=[
+               # meshgrid is defined in torch.functional to take a
+               # variadic list of tensors. Variadic parameters are not
+               # compatible with the normalize operator tests.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+           ],
+           assert_autodiffed=True,
+           supports_out=False,
+           autodiff_nonfusible_nodes=[],
+           supports_fwgrad_bwgrad=True,
+           supports_forward_ad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,),
+    OpInfo('min',
+           variant_test_name='reduction_with_dim',
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+           sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
+           supports_fwgrad_bwgrad=True,
+           supports_forward_ad=True,
+           skips=(
+           )),
+    OpInfo('min',
+           variant_test_name='reduction_no_dim',
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
+           skips=(
+           )),
+    OpInfo('quantile',
+           dtypes=floating_types(),
+           sample_inputs_func=sample_inputs_reduction_quantile,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/issues/66357
+           # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which
+           # does not have a batching rule in core
+           check_batched_forward_grad=False),
+    OpInfo('nanquantile',
+           dtypes=floating_types(),
+           sample_inputs_func=sample_inputs_reduction_quantile,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/issues/66357
+           # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which
+           # does not have a batching rule in core
+           check_batched_forward_grad=False),
+    BinaryUfuncInfo(
+        'max',
+        aliases=('maximum',),
+        variant_test_name='binary',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=True,
+        ref=np.maximum,
+        supports_rhs_python_scalar=False,
+        skips=(
+            # Incorrectly attempts to use a scalar for the second argument
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
+            # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
+        )),
+    BinaryUfuncInfo(
+        'maximum',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        ref=np.maximum,
+        supports_rhs_python_scalar=False,
+        skips=(
+            # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
+        )),
+    BinaryUfuncInfo(
+        'min',
+        aliases=('minimum',),
+        variant_test_name='binary',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=True,
+        ref=np.minimum,
+        supports_rhs_python_scalar=False,
+        skips=(
+            # Incorrectly attempts to use a scalar for the second argument
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
+            # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
+            DecorateInfo(unittest.expectedFailure,
+                         'TestBinaryUfuncs',
+                         'test_type_promotion',
+                         device_type='cuda'),
+        )),
+    BinaryUfuncInfo(
+        'minimum',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        ref=np.minimum,
+        supports_rhs_python_scalar=False,
+        skips=(
+            # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
+            DecorateInfo(unittest.expectedFailure,
+                         'TestBinaryUfuncs',
+                         'test_type_promotion',
+                         device_type='cuda'),
+        ),
+    ),
+    BinaryUfuncInfo('logical_and',
+                    ref=np.logical_and,
+                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+                    supports_autograd=False,
+                    always_returns_bool=True,
+                    supports_rhs_python_scalar=False),
+    BinaryUfuncInfo('logical_or',
+                    ref=np.logical_or,
+                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool),
+                    supports_autograd=False,
+                    always_returns_bool=True,
+                    supports_rhs_python_scalar=False),
+    BinaryUfuncInfo('logical_xor',
+                    ref=np.logical_xor,
+                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int8, torch.bool),
+                    supports_autograd=False,
+                    always_returns_bool=True,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                    )),
+    BinaryUfuncInfo('bitwise_and',
+                    ref=np.bitwise_and,
+                    dtypes=integral_types_and(torch.bool),
+                    dtypesIfHpu=custom_types(torch.bool),
+                    operator_variant=operator.and_,
+                    inplace_operator_variant=operator.iand,
+                    supports_autograd=False,
+                    supports_one_python_scalar=True,
+                    skips=(
+                        # RuntimeError: "bitwise_and_cuda" not implemented for 'Half'
+                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs',
+                                     'test_type_promotion', device_type='cuda'),
+                    )),
+    BinaryUfuncInfo('bitwise_or',
+                    ref=np.bitwise_or,
+                    dtypes=integral_types_and(torch.bool),
+                    dtypesIfHpu=custom_types(torch.bool),
+                    operator_variant=operator.or_,
+                    inplace_operator_variant=operator.ior,
+                    supports_autograd=False,
+                    supports_one_python_scalar=True,
+                    skips=(
+                        # TODO: FIXME: RuntimeError: "bitwise_or_cuda" not implemented for 'Half'
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestBinaryUfuncs',
+                                     'test_type_promotion',
+                                     device_type='cuda'),
+                    )),
+    BinaryUfuncInfo('bitwise_xor',
+                    ref=np.bitwise_xor,
+                    dtypes=integral_types_and(torch.bool),
+                    dtypesIfHpu=custom_types(torch.bool),
+                    operator_variant=operator.xor,
+                    inplace_operator_variant=operator.ixor,
+                    supports_autograd=False,
+                    supports_one_python_scalar=True,
+                    skips=(
+                        # TODO: FIXME: RuntimeError: "bitwise_xor_cuda" not implemented for 'Half'
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestBinaryUfuncs',
+                                     'test_type_promotion',
+                                     device_type='cuda'),
+                    )),
+    BinaryUfuncInfo('heaviside',
+                    ref=lambda a, b: (
+                        # necessary because np.heaviside incorrectly returns float64 when passed args of dtype int64
+                        np.int64(np.heaviside(a, b)) if a.dtype == np.int64 and b.dtype == np.int64 else np.heaviside(a, b)
+                    ),
+                    dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+                    dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
+                    supports_autograd=False,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        # RuntimeError: heaviside is not yet implemented for tensors with different dtypes.
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestBinaryUfuncs',
+                                     'test_type_promotion'),
+                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
+                        # PyTorch's heaviside does not appear to propagate NaNs
+                        DecorateInfo(unittest.skip("Skipped!"),
+                                     'TestBinaryUfuncs',
+                                     'test_reference_numerics_extremal_values'),
+                    )),
+    BinaryUfuncInfo('lcm',
+                    ref=np.lcm,
+                    dtypes=integral_types_and(),
+                    supports_autograd=False,
+                    supports_rhs_python_scalar=False),
+    BinaryUfuncInfo('gcd',
+                    ref=np.gcd,
+                    dtypes=integral_types_and(),
+                    supports_autograd=False,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestBinaryUfuncs',
+                                     'test_reference_numerics_small_values',
+                                     dtypes=(torch.int8,)),)),
+    BinaryUfuncInfo('isclose',
+                    ref=np.isclose,
+                    dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+                    sample_inputs_func=sample_inputs_isclose,
+                    error_inputs_func=error_inputs_isclose,
+                    supports_autograd=False,
+                    supports_out=False,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestCommon',
+                                     'test_numpy_refs', dtypes=(torch.complex128,)),
+                        # RuntimeError: Short did not match Int
+                        DecorateInfo(unittest.expectedFailure,
+                                     'TestBinaryUfuncs',
+                                     'test_type_promotion'),
+                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
+                        DecorateInfo(unittest.skip("Skipped!"),
+                                     'TestBinaryUfuncs',
+                                     'test_reference_numerics_extremal_values'),
+                    )),
+    # `softmax` supports different dtypes based on whether `dtype` argument,
+    # is passed or not. Hence two OpInfo entries, one with dtype and other without.
+    # https://github.com/pytorch/pytorch/issues/68752
+    OpInfo('softmax',
+           aliases=('special.softmax', 'nn.functional.softmax',),
+           aten_name='softmax',
+           aten_backward_name='_softmax_backward_data',
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_softmax_variant,
+           assert_jit_shape_analysis=True,
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=True),
+    OpInfo('softmax',
+           aliases=('special.softmax', 'nn.functional.softmax',),
+           variant_test_name="with_dtype",
+           aten_name='softmax',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=True),
+    OpInfo(
+        '_softmax_backward_data',
+        op=torch.ops.aten._softmax_backward_data,
+        aten_name='_softmax_backward_data',
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        sample_inputs_func=sample_inputs_softmax_backward_data,
+        assert_autodiffed=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+        ),
+    ),
+    # `softmin` supports different dtypes based on whether `dtype` argument,
+    # is passed or not. Hence two OpInfo entries, one with dtype and other without.
+    # https://github.com/pytorch/pytorch/issues/68752
+    OpInfo('nn.functional.softmin',
+           aten_name='softmin',
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_softmax_variant,
+           assert_jit_shape_analysis=False,
+           assert_autodiffed=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('nn.functional.softmin',
+           variant_test_name="with_dtype",
+           aten_name='softmin',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
+           assert_autodiffed=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo(
+        "nn.functional.cross_entropy",
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        sample_inputs_func=sample_inputs_cross_entropy,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=3e-3, rtol=1e-3)}),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cpu",
+            ),
+        ),
+        skips=(
+            # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 1536
+            # test_ops.TestJitCUDA.test_variant_consistency_jit_nn_functional_cross_entropy_cuda_float32 leaked
+            # 1536 bytes CUDA memory on device 0
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cuda",
+            ),
+            DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"),
+                         dtypes=(torch.half,), device_type="mps"),
+
+        )
+    ),
+    OpInfo('nn.functional.normalize',
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_normalize,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True),
+    OpInfo('aminmax',
+           ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)),
+           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8),
+           decorators=(onlyNativeDeviceTypes,),
+           supports_autograd=False,
+           sample_inputs_func=sample_inputs_aminmax,
+           error_inputs_func=error_inputs_aminmax_amax_amin),
+    OpInfo('as_strided',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_as_strided,
+           skips=(
+               # Note: This xfail is fine -- it's inherent to how as_strided works
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
+               # AssertionError: False is not true : Scalars failed to compare as equal!
+               DecorateInfo(unittest.skip("Errors when storage_offset is included"),
+                            'TestCommon', 'test_variant_consistency_eager'),
+               # Not close
+               DecorateInfo(unittest.skip("Errors when storage_offset is included"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               # Not close
+               DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
+               DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
+           )),
+    OpInfo('as_strided',
+           variant_test_name='partial_views',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8, torch.bool),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_as_strided_partial_views,
+           skips=(
+               # Note: This xfail is fine -- it's inherent to how as_strided works
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
+               # These fail because the test changes the input's in-memory layout
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
+                            dtypes=(torch.complex64, torch.complex128)),
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'),
+               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'),
+               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'),
+               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo',
+                            'test_make_fx_symbolic_exhaustive_inplace'),
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
+               # Fail but are also flaky
+               DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'),
+               DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon',
+                            'test_non_standard_bool_values'),
+               # RuntimeError: setStorage: sizes [2, 2], strides [1, 2], storage offset 10, and itemsize 2 requiring a
+               # storage size of 28 are out of bounds for storage of size 20
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides'),
+           )),
+    OpInfo('as_strided_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_as_strided,
+           skips=(
+               # Note: This xfail is fine -- it's inherent to how as_strided works
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
+               # AssertionError: False is not true : Scalars failed to compare as equal!
+               DecorateInfo(unittest.skip("Errors when storage_offset is included"),
+                            'TestCommon', 'test_variant_consistency_eager'),
+               # Not close
+               DecorateInfo(unittest.skip("Errors when storage_offset is included"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               # Not close
+               DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
+               DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
+               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
+           )),
+    OpInfo('as_strided_scatter',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_as_strided_scatter,
+           error_inputs_func=error_inputs_as_strided_scatter,
+           skips=(
+               DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'),  # noqa: B950
+               DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'),  # noqa: B950
+               DecorateInfo(unittest.skip('Fails on cuda + rocm'), 'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
+               DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
+               # AssertionError: Tensor-likes are not close! (new_empty_strided.default)
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)),
+    OpInfo('native_layer_norm',
+           aten_name='native_layer_norm',
+           ref=reference_native_layer_norm,
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_out=False,
+           assert_jit_shape_analysis=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_native_layer_norm,
+           error_inputs_func=error_inputs_native_layer_norm,
+           skips=(
+               # IndexError: tuple index out of range
+               DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients', 'test_forward_mode_AD'),
+               # Tests fail when weight=None and bias is defined
+               # https://github.com/pytorch/pytorch/issues/79705
+               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
+               # JIT test also tries to compute double backward, which fails
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}),
+                            "TestDecomp", "test_comprehensive", device_type="cpu"),
+           )),
+    OpInfo('native_batch_norm',
+           aten_name='native_batch_norm',
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           allow_cow_input_materialize_forward=[3, 4],
+           allow_cow_input_materialize_backward=[3, 4],
+           sample_inputs_func=sample_inputs_native_batch_norm,
+           skips=(
+               # NotImplementedError: Could not run
+               # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
+               # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
+               # Problem with _get_numerical_jacobian
+               # IndexError: tuple index out of range
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+               # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # https://github.com/pytorch/pytorch/issues/85960
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
+               # AssertionError: Booleans mismatch: True is not False
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
+                            "TestCompositeCompliance", "test_forward_ad"),
+           )
+           ),
+    OpInfo('_native_batch_norm_legit',
+           aten_name='_native_batch_norm_legit',
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           allow_cow_input_materialize_forward=[3, 4],
+           allow_cow_input_materialize_backward=[3, 4],
+           sample_inputs_func=sample_inputs__native_batch_norm_legit,
+           skips=(
+               # NotImplementedError: Could not run
+               # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
+               # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
+               # Problem with _get_numerical_jacobian
+               # IndexError: tuple index out of range
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+               # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # https://github.com/pytorch/pytorch/issues/85960
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
+                            "TestCompositeCompliance", "test_forward_ad"),
+           )
+           ),
+    OpInfo('_batch_norm_with_update',
+           op=torch.ops.aten._batch_norm_with_update,
+           aten_name='_batch_norm_with_update',
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           allow_cow_input_materialize_forward=[3, 4],
+           allow_cow_input_materialize_backward=[3, 4],
+           sample_inputs_func=sample_inputs__batch_norm_with_update,
+           skips=(
+               # NotImplementedError: Could not run
+               # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
+               # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
+               # Problem with _get_numerical_jacobian
+               # IndexError: tuple index out of range
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+               # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
+                            "TestCompositeCompliance", "test_forward_ad"),
+               # _batch_norm_with_update expects contiguous inputs for cudnn and miopen
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"),
+               DecorateInfo(unittest.expectedFailure,
+                            'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"),
+               # _batch_norm_with_update does not have python bindings
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # aten out variants do not accept out= kwarg, only python out variants
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+           )
+           ),
+    OpInfo('nn.functional.cosine_similarity',
+           aten_name="cosine_similarity",
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.float16: tol(atol=1.3e-5, rtol=2e-2)}),
+                   "TestInductorOpInfo",
+                   "test_comprehensive",
+                   device_type="cuda"
+               ),
+           ],
+           sample_inputs_func=sample_inputs_cosine_similarity),
+    OpInfo('nn.functional.adaptive_avg_pool1d',
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_adaptive_avg_pool1d,
+           sample_inputs_func=sample_inputs_adaptive_avg_pool1d),
+    OpInfo('nn.functional.adaptive_avg_pool2d',
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16),
+           decorators=(
+               # RuntimeError:
+               # adaptive_avg_pool2d(Tensor input, int[2] output_size) -> (Tensor):
+               # Expected a value of type 'List[int]' for argument 'output_size' but
+               # instead found type 'Tuple[NoneType, int]'. :
+               #   File "", line 3
+               # def the_method(i0):
+               #     return torch.nn.functional.adaptive_avg_pool2d(i0, (None, 7))
+               #            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_adaptive_avg_pool2d,
+           sample_inputs_func=sample_inputs_adaptive_avg_pool2d),
+    OpInfo('nn.functional.adaptive_avg_pool3d',
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16),
+           decorators=(
+               # RuntimeError:
+               # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor):
+               # Expected a value of type 'List[int]' for argument 'output_size' but
+               # instead found type 'Tuple[NoneType, NoneType, NoneType]'. :
+               #   File "", line 3
+               #
+               # def the_method(i0):
+               #     return torch.nn.functional.adaptive_avg_pool3d(i0, (None, None, None))
+               #            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
+               #
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_adaptive_avg_pool3d,
+           sample_inputs_func=sample_inputs_adaptive_avg_pool3d),
+    OpInfo('nn.functional.adaptive_max_pool1d',
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # got: Batching rule not implemented for aten::flatten.using_ints
+           check_batched_forward_grad=False,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_adaptive_max_pool1d,
+           sample_inputs_func=sample_inputs_adaptive_max_pool1d),
+    OpInfo('nn.functional.adaptive_max_pool2d',
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           decorators=(
+               # RuntimeError:
+               # adaptive_max_pool2d(Tensor input, int[2] output_size) -> (Tensor):
+               # Expected a value of type 'List[int]' for argument 'output_size' but
+               # instead found type 'Tuple[NoneType, int]'. :
+               #   File "", line 3
+               # def the_method(i0):
+               #     return torch.nn.functional.adaptive_max_pool2d(i0, (None, 7))
+               #            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # got: Batching rule not implemented for aten::flatten.using_ints
+           check_batched_forward_grad=False,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_adaptive_max_pool2d,
+           sample_inputs_func=sample_inputs_adaptive_max_pool2d),
+    OpInfo('nn.functional.adaptive_max_pool3d',
+           dtypes=floating_types_and(torch.bfloat16, torch.half),
+           decorators=(
+               # RuntimeError:
+               # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor):
+               # Expected a value of type 'List[int]' for argument 'output_size' but
+               # instead found type 'Tuple[NoneType, NoneType, NoneType]'. :
+               #   File "", line 3
+               #
+               # def the_method(i0):
+               #     return torch.nn.functional.adaptive_max_pool3d(i0, (None, None, None))
+               #            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
+               #
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # got: Batching rule not implemented for aten::flatten.using_ints
+           check_batched_forward_grad=False,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_adaptive_max_pool3d,
+           sample_inputs_func=sample_inputs_adaptive_max_pool3d),
+    OpInfo('nn.functional.avg_pool1d',
+           aten_name='avg_pool1d',
+           supports_autograd=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_avg_pool1d,
+           sample_inputs_func=sample_inputs_avgpool1d),
+    OpInfo('nn.functional.avg_pool3d',
+           aten_name='avg_pool3d',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.int64),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_avg_pool3d,
+           sample_inputs_func=sample_inputs_avgpool3d,
+           skips=(
+               # AssertionError: Tensor-likes are not close!
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
+           )),
+    OpInfo(
+        "nn.functional.binary_cross_entropy_with_logits",
+        aten_name="binary_cross_entropy_with_logits",
+        supports_autograd=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16),
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                'TestJit',
+                'test_variant_consistency_jit',
+                dtypes=(torch.float32,)
+            ),
+            DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=3e-6)}),
+                         "TestConsistency", "test_output_match", device_type="mps"),
+        ),
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.relu',
+        aten_name="relu",
+        ref=lambda a: np.where(a <= 0, 0, a),
+        supports_autograd=True,
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        supports_sparse_csc=True,
+        supports_sparse_bsr=True,
+        supports_sparse_bsc=True,
+        dtypes=all_types_and(torch.half, torch.bfloat16),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.float16),
+        sample_inputs_func=sample_inputs_nn_activation_relu,
+        supports_out=False,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True),
+    OpInfo('nn.functional.conv_transpose1d',
+           # `ref` for this function is backward of
+           # corresponding `conv*d`
+           ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose1d),
+           aten_name='conv_transpose1d',
+           aliases=('conv_transpose1d',),
+           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
+                                                       torch.bfloat16),
+           sample_inputs_func=sample_inputs_conv_transpose1d,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           decorators=(
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
+                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }),
+                   'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(
+                   toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }),
+                   'TestCommon', 'test_numpy_ref_mps'),
+               DecorateInfo(
+                   toleranceOverride({torch.half: tol(atol=1e-3, rtol=5e-3), }),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
+           ),
+           skips=(
+               # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.complex64,)),
+               # RuntimeError: UNSUPPORTED DTYPE: complex
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
+                            dtypes=(torch.complex64, torch.complex128)),
+               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.float,)),
+               # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long'
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
+                            dtypes=(torch.int64,)),
+           ),
+           supports_out=False,),
+    OpInfo('nn.functional.conv_transpose2d',
+           aten_name='conv_transpose2d',
+           aliases=('conv_transpose2d',),
+           # `ref` for this function is backward of
+           # corresponding `conv*d`
+           ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d),
+           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
+                                                       torch.bfloat16),
+           sample_inputs_func=sample_inputs_conv_transpose2d,
+           # Runs very slowly on slow-gradcheck for complex.
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
+                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }),
+                   'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.chalf: tol(atol=8e-2, rtol=8e-2), }),
+                   'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(
+                   toleranceOverride({torch.half: tol(atol=1e-3, rtol=4e-3), }),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')],
+           skips=(
+               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # RuntimeError: UNSUPPORTED DTYPE: complex
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
+                            dtypes=(torch.complex64, torch.complex128)),
+               # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long'
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
+                            dtypes=(torch.int64,)),
+               # Reference: https://github.com/pytorch/pytorch/issues/86356
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
+                            dtypes=(torch.double, torch.cdouble)),
+               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
+               # AssertionError: None mismatch: torch.complex64 is not None
+               DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules',
+                            dtypes=(torch.complex64, torch.complex128)),
+           ),
+           supports_out=False,),
+    OpInfo('nn.functional.conv_transpose3d',
+           aten_name='conv_transpose3d',
+           aliases=('conv_transpose3d',),
+           # `ref` for this function is backward of
+           # corresponding `conv*d`
+           ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d),
+           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(
+               torch.float16, torch.chalf, torch.bfloat16),
+           sample_inputs_func=sample_inputs_conv_transpose3d,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.float16: tol(atol=5e-2, rtol=5e-2), }),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06),
+                                     torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
+                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }),
+                   'TestCompositeCompliance', 'test_operator', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06),
+                                     torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
+                   'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }),
+                   'TestCompositeCompliance', 'test_forward_ad', device_type='cuda',
+                   active_if=TEST_CUDNN),
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}),
+                   "TestMathBits", "test_conj_view", device_type='cuda'),
+               DecorateInfo(
+                   toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }),
+                   'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(
+                   toleranceOverride({torch.half: tol(atol=9e-3, rtol=2e-1), }),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')],
+           skips=(
+               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long'
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
+                            dtypes=(torch.int64,)),
+               # Reference: https://github.com/pytorch/pytorch/issues/86356
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
+                            dtypes=(torch.double, torch.cdouble)),
+               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
+               # RuntimeError: UNSUPPORTED DTYPE: complex
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
+                            dtypes=(torch.complex64, torch.complex128)),
+               DecorateInfo(unittest.skip('Skipped for ROCm!'), 'TestCommon', 'test_complex_half_reference_testing',
+                            dtypes=[torch.complex32], active_if=TEST_WITH_ROCM),
+           ),
+           supports_out=False,),
+    OpInfo('nn.functional.conv1d',
+           aliases=('conv1d',),
+           aten_name='conv1d',
+           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
+                                                       torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_conv1d,
+           error_inputs_func=error_inputs_conv1d,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           decorators=(
+               DecorateInfo(
+                   toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=5e-2)}),
+                   'TestCommon', 'test_complex_half_reference_testing'
+               ),
+               DecorateInfo(
+                   toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cuda',
+               ),
+           ),
+           skips=(
+               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # Ref: https://github.com/pytorch/pytorch/issues/75309
+               # AssertionError: None mismatch: torch.complex128 is not None
+               DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules',
+                            'test_custom_rules', dtypes=(torch.complex64, torch.complex128)),
+               # Ref: https://github.com/pytorch/pytorch/issues/75309
+               # RuntimeError: UNSUPPORTED DTYPE: complex
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo',
+                            'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)),
+           ),
+           supports_expanded_weight=True,
+           supports_out=False,),
+    OpInfo('nn.functional.conv2d',
+           aliases=('conv2d',),
+           aten_name='conv2d',
+           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
+                                                       torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_conv2d),
+           error_inputs_func=error_inputs_conv2d,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           decorators=(
+               DecorateInfo(
+                   toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}),
+                   'TestCommon', 'test_complex_half_reference_testing',
+               ),
+               DecorateInfo(
+                   toleranceOverride({torch.float16: tol(atol=5e-3, rtol=1e-3)}),
+                   'TestInductorOpInfo', 'test_comprehensive',
+               ),
+           ),
+           skips=(
+               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Works on some configs!"), 'TestJit', 'test_variant_consistency_jit'),
+               # Ref: https://github.com/pytorch/pytorch/issues/75309
+               # AssertionError: None mismatch: torch.complex128 is not None
+               DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules',
+                            'test_custom_rules', dtypes=(torch.complex64, torch.complex128)),
+               # RuntimeError: UNSUPPORTED DTYPE: complex
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo',
+                            'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)),
+           ),
+           supports_expanded_weight=True,
+           supports_out=False,),
+    OpInfo('nn.functional.conv3d',
+           aliases=('conv3d',),
+           aten_name='conv3d',
+           dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           sample_inputs_func=sample_inputs_conv3d,
+           error_inputs_func=error_inputs_conv3d,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           decorators=(
+               DecorateInfo(
+                   toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}),
+                   'TestCommon', 'test_complex_half_reference_testing',
+               ),
+               # TF32
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3),
+                                     torch.complex64: tol(atol=5e-3, rtol=1e-3)}),
+                   'TestCommon', 'test_noncontiguous_samples',
+               ),
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}),
+                   'TestCommon', 'test_variant_consistency_eager',
+               ),
+               DecorateInfo(
+                   toleranceOverride({torch.complex64: tol(atol=5e-5, rtol=5e-6)}),
+                   'TestMathBits', 'test_conj_view',
+               ),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-6)}),
+                   'TestOperators', 'test_vjpvmap',
+               ),
+               DecorateInfo(
+                   toleranceOverride({torch.float16: tol(atol=5e-3, rtol=1e-3)}),
+                   'TestInductorOpInfo', 'test_comprehensive',
+               ),
+           ),
+           skips=(
+               # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # RuntimeError: UNSUPPORTED DTYPE: complex
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo',
+                            'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)),
+               # AssertionError: Tensor-likes are not close!
+               # break slow tests
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'),
+           ),
+           supports_expanded_weight=True,
+           supports_out=False,),
+    OpInfo('nn.functional.group_norm',
+           aten_name='group_norm',
+           aliases=('group_norm',),
+           ref=reference_group_norm,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           error_inputs_func=error_inputs_group_norm,
+           decorators=[
+               # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
+               # Consider making it a parameter or input, or detaching the gradient
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=5e-05, rtol=3e-03)}),
+                   "TestDecomp",
+                   "test_comprehensive",
+                   device_type="cpu"
+               ),
+           ],
+           sample_inputs_func=sample_inputs_group_norm,
+           reference_inputs_func=reference_inputs_group_norm,
+           supports_expanded_weight=True,),
+    OpInfo('nn.functional.instance_norm',
+           # no ref because instance_norm will often have numerical instability (large numbers or nan)
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           allow_cow_input_materialize_forward=['running_mean', 'running_var'],
+           decorators=[
+               # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
+               # Consider making it a parameter or input, or detaching the gradient
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           ],
+           sample_inputs_func=sample_inputs_instance_norm,
+           supports_expanded_weight=True,),
+    OpInfo('nn.functional.layer_norm',
+           aten_name='layer_norm',
+           aten_backward_name='layer_norm_backward',
+           aliases=('layer_norm',),
+           ref=reference_layer_norm,
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
+                   'TestCommon', 'test_numpy_refs'
+               ),
+               DecorateInfo(unittest.skip("Bug in MPS backend!"), 'TestCommon', 'test_numpy_ref_mps'),
+           ],
+           sample_inputs_func=sample_inputs_layer_norm,
+           supports_expanded_weight=True,),
+    OpInfo('nn.functional.rms_norm',
+           aten_name='rms_norm',
+           aliases=('rms_norm',),
+           ref=reference_rms_norm,
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_rms_norm,
+           error_inputs_func=error_inputs_rms_norm,),
+    OpInfo('nn.functional.local_response_norm',
+           dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           decorators=[
+               # RuntimeError: falseINTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           ],
+           sample_inputs_func=sample_inputs_local_response_norm,),
+    OpInfo('constant_pad_nd',
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
+           sample_inputs_func=sample_inputs_constant_pad_nd,
+           supports_out=False,
+           skips=(
+               # bool can't be passed to Scalar arguments in JIT tracer because
+               # BoolType is not a subtype of ScalarType.
+               DecorateInfo(
+                   unittest.expectedFailure, 'TestNNCOpInfo',
+                   'test_nnc_correctness', dtypes=(torch.bool,)),
+           )),
+    OpInfo('nn.functional.pad',
+           variant_test_name='constant',
+           aten_name='constant_pad_nd',
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
+           sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'),
+           supports_out=False),
+    OpInfo('nn.functional.pad',
+           variant_test_name='reflect',
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
+           sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'),
+           skips=(
+               # Doesn't have a corresponding aten operator.
+               # RuntimeError: falseINTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           ),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           supports_out=False),
+    OpInfo('nn.functional.pad',
+           variant_test_name='replicate',
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'),
+           skips=(
+               # Doesn't have a corresponding aten operator.
+               # RuntimeError: falseINTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           ),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           supports_out=False),
+    OpInfo('nn.functional.pad',
+           variant_test_name='replicate_negative',
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_nn_pad_replicate_negative,
+           skips=(
+               # Doesn't have a corresponding aten operator.
+               # RuntimeError: falseINTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+               # Some negative padding cases cause a segfault on MPS
+               DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'),
+           ),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           supports_out=False),
+    OpInfo('nn.functional.pad',
+           variant_test_name='circular',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
+           sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_grad=False,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           skips=(
+               # Doesn't have a corresponding aten operator.
+               # RuntimeError: falseINTERNAL ASSERT FAILED at
+               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+               # Difference from  is larger with decomposition new_empty_strided.default than original on output 0
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.hardswish',
+           aten_name="hardswish",
+           aten_backward_name='hardswish_backward',
+           supports_autograd=True,
+           assert_autodiffed=True,
+           sample_inputs_func=sample_inputs_hardswish,
+           dtypes=floating_types_and(torch.bfloat16, torch.half),
+           supports_gradgrad=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           autodiff_nonfusible_nodes=["aten::hardswish"]),
+    OpInfo('nn.functional.unfold',
+           aten_name='im2col',
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool),
+           sample_inputs_func=sample_inputs_nn_unfold,
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           skips=(
+               # NOTE: this failure may not reproduce consistently on different systems
+               # false INTERNAL ASSERT FAILED at "...torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185
+               DecorateInfo(unittest.skip("Internal assert failed!"), 'TestJit', 'test_variant_consistency_jit'),
+               # Compiler issue on ROCm. Regression started in ROCm 6.4.
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='nearest',
+           supports_autograd=True,
+           supports_fwgrad_bwgrad=True,
+           supports_forward_ad=True,
+           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'),
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='nearest-exact',
+           supports_autograd=True,
+           supports_fwgrad_bwgrad=True,
+           supports_forward_ad=True,
+           dtypes=floating_types_and(torch.half, torch.bfloat16, torch.uint8),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'nearest-exact'),
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # RuntimeError: aten::_upsample_nearest_exact*d hit the vmap fallback which is currently disabled
+               DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
+               DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'),
+               DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='linear',
+           supports_autograd=True,
+           supports_fwgrad_bwgrad=True,
+           supports_forward_ad=True,
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'linear'),
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='bilinear',
+           supports_fwgrad_bwgrad=True,
+           supports_autograd=True,
+           supports_forward_ad=True,
+           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'),
+           reference_inputs_func=partial(reference_inputs_interpolate, 'bilinear'),
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='bicubic',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'),
+           reference_inputs_func=partial(reference_inputs_interpolate, 'bicubic'),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='trilinear',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'),
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.interpolate',
+           aten_name="interpolate",
+           variant_test_name='area',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+           sample_inputs_func=partial(sample_inputs_interpolate, 'area'),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('nn.functional.upsample_bilinear',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=partial(sample_inputs_upsample, 'bilinear'),
+           reference_inputs_func=partial(reference_inputs_upsample, 'bilinear'),
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo('_upsample_bilinear2d_aa',
+           op=torch.ops.aten._upsample_bilinear2d_aa,
+           aten_name='_upsample_bilinear2d_aa',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.uint8),
+           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=partial(sample_inputs_upsample_aa, 'bilinear'),
+           supports_out=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
+               DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+           )),
+    OpInfo(
+        "nn.functional.soft_margin_loss",
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        # doesn't support grad on target
+        sample_inputs_func=partial(sample_inputs_loss, rhs_requires_grad=False),
+        error_inputs_func=error_inputs_soft_margin_loss,
+    ),
+    OpInfo('nn.functional.upsample_nearest',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=partial(sample_inputs_upsample, 'nearest'),
+           skips=(
+               # RuntimeError: false
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           supports_out=False),
+    OpInfo(
+        "nn.functional.margin_ranking_loss",
+        dtypes=all_types_and(torch.half, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_margin_ranking_loss,
+        error_inputs_func=error_inputs_margin_ranking_loss,
+        reference_inputs_func=reference_inputs_margin_ranking_loss,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True),
+    OpInfo(
+        "nn.functional.multi_margin_loss",
+        dtypes=floating_types(),
+        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
+        supports_out=False,
+        supports_gradgrad=False,
+        sample_inputs_func=sample_inputs_multi_margin_loss,
+        reference_inputs_func=reference_inputs_multi_margin_loss,
+        error_inputs_func=error_inputs_multi_margin_loss,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                "TestJit",
+                "test_variant_consistency_jit",
+            ),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.multilabel_margin_loss",
+        dtypes=floating_types(),
+        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
+        supports_out=False,
+        supports_gradgrad=False,
+        sample_inputs_func=sample_inputs_multilabel_margin_loss,
+        reference_inputs_func=reference_inputs_multilabel_margin_loss,
+        error_inputs_func=error_inputs_multilabel_margin_loss,
+    ),
+    OpInfo('nn.functional.leaky_relu',
+           aliases=None,
+           aten_name="leaky_relu",
+           aten_backward_name='leaky_relu_backward',
+           sample_inputs_func=sample_inputs_leaky_relu,
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
+           inplace_variant=lambda x, negative_slope=0.01:
+               torch.nn.functional.leaky_relu(x, negative_slope, inplace=True),
+           supports_autograd=True,
+           assert_autodiffed=True,
+           supports_gradgrad=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           autodiff_nonfusible_nodes=["aten::leaky_relu"]),
+    OpInfo(
+        "nn.functional.multilabel_soft_margin_loss",
+        supports_out=False,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                "TestJit",
+                "test_variant_consistency_jit",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=4e-3, rtol=1.3e-3)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda"
+            ),
+        ),
+        skips=(
+            # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096
+            # __main__.TestJitCUDA.test_variant_consistency_jit_nn_functional_multilabel_soft_margin_loss_cuda_float32
+            # leaked 4096 bytes CUDA memory on device 0
+            DecorateInfo(
+                # Skip instead of expectedFailure because this fails
+                # locally for me but passes in CI.
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo('nn.functional.avg_pool2d',
+           aten_name='avg_pool2d',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           error_inputs_func=error_inputs_avg_pool2d,
+           sample_inputs_func=sample_inputs_avgpool2d,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'),
+           )),
+    OpInfo('nn.functional.fractional_max_pool2d',
+           supports_autograd=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           op=lambda input, *args, **kwargs:
+               wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs),
+           # vmap does not support random operations
+           check_batched_forward_grad=False,
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
+           test_neg_view=False,
+           sample_inputs_func=sample_inputs_fractional_max_pool2d,
+           decorators=(
+               # FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')),
+           skips=(
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)),
+    OpInfo('nn.functional.fractional_max_pool3d',
+           supports_autograd=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           op=lambda input, *args, **kwargs:
+               wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs),
+           # vmap does not support random operations
+           check_batched_forward_grad=False,
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
+           test_neg_view=False,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           sample_inputs_func=sample_inputs_fractional_max_pool3d,
+           decorators=(
+               # FIXME: both derivatives are implemented incorrectly
+               # https://github.com/pytorch/pytorch/issues/69322
+               # FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')),
+           skips=(
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)),
+    OpInfo('nn.functional.max_pool1d',
+           aten_name='max_pool1d',
+           supports_autograd=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # got: Batching rule not implemented for aten::flatten.using_ints
+           check_batched_forward_grad=False,
+           # TODO: add shape checks
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           skips=(
+               # Pre-existing condition; Needs to be fixed
+               DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
+                            'test_nnc_correctness', dtypes=(torch.bfloat16,)),
+               # RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet.
+               # Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data()
+               # to actually allocate memory
+               DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
+           ),
+           error_inputs_func=error_inputs_max_pool1d,
+           sample_inputs_func=sample_inputs_max_pool),
+    OpInfo('nn.functional.max_pool2d',
+           aten_name='max_pool2d',
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           # Vmap is not happy with non-contiguous (channels_last) inputs
+           check_batched_gradgrad=False,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # got: Batching rule not implemented for aten::flatten.using_ints
+           check_batched_forward_grad=False,
+           assert_jit_shape_analysis=True,
+           dtypes=all_types_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           error_inputs_func=error_inputs_max_pool2d,
+           sample_inputs_func=sample_inputs_max_pool),
+    OpInfo('max_pool2d_with_indices_backward',
+           op=max_pool2d_backward,
+           # We've defined a custom op, so there's no corresponding aten op
+           aten_name=None,
+           method_variant=None,
+           inplace_variant=None,
+           operator_variant=None,
+           inplace_operator_variant=None,
+           check_batched_gradgrad=False,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
+           sample_inputs_func=sample_inputs_max_pool,
+           skips=(
+               # We've defined a custom op here, and we don't handle the case where we receive an out kwarg
+               DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"),
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected)
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')
+           )),
+    OpInfo('nn.functional.max_pool3d',
+           aten_name='max_pool3d',
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # got: Batching rule not implemented for aten::flatten.using_ints
+           check_batched_forward_grad=False,
+           # TODO: add shape checks
+           assert_jit_shape_analysis=False,
+           dtypes=all_types_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           # TODO: investigate nondeterminism
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           error_inputs_func=error_inputs_max_pool3d,
+           sample_inputs_func=sample_inputs_max_pool),
+    OpInfo('nn.functional.max_unpool1d',
+           aten_name='max_unpool1d',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_max_unpool,
+           skips=(
+               # Gradients are tested in `variant_test_name=grad` below.
+               # We skip tests here because there is non-determinism in backward
+               # with gather, when there are writes into the same memory location,
+               # and if there are several indices pointing to the same memory,
+               # gradcheck is oblivious about that and cannot perturb them all at once
+               # (see sample_inputs_max_unpool_grad to find out more).
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
+                            active_if=(not IS_MACOS)),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad',
+                            device_type='cpu'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'),
+           )),
+    OpInfo('nn.functional.max_unpool1d',
+           variant_test_name='grad',
+           aten_name='max_unpool1d',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_max_unpool_grad),
+    OpInfo('nn.functional.max_unpool2d',
+           aten_name='max_unpool2d',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_max_unpool,
+           skips=(
+               # Gradients are tested in `variant_test_name=grad` below.
+               # We skip tests here because there is non-determinism in backward
+               # with gather, when there are writes into the same memory location,
+               # and if there are several indices pointing to the same memory,
+               # gradcheck is oblivious about that and cannot perturb them all at once
+               # (see sample_inputs_max_unpool_grad to find out more).
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
+                            active_if=(not IS_MACOS)),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'),
+           )),
+    OpInfo('nn.functional.max_unpool2d',
+           variant_test_name='grad',
+           aten_name='max_unpool2d',
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # Vmap is not happy with non-contiguous (channels_last) inputs
+           check_batched_grad=False,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_max_unpool_grad),
+    OpInfo('nn.functional.max_unpool3d',
+           aten_name='max_unpool3d',
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_max_unpool,
+           skips=(
+               # Gradients are tested in `variant_test_name=grad` below.
+               # We skip tests here because there is non-determinism in backward
+               # with gather, when there are writes into the same memory location,
+               # and if there are several indices pointing to the same memory,
+               # gradcheck is oblivious about that and cannot perturb them all at once
+               # (see sample_inputs_max_unpool_grad to find out more).
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
+                            active_if=(not IS_MACOS)),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick_core_backward'),
+           )),
+    OpInfo('nn.functional.max_unpool3d',
+           variant_test_name='grad',
+           aten_name='max_unpool3d',
+           supports_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           assert_jit_shape_analysis=False,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_max_unpool_grad),
+    OpInfo('nn.functional.linear',
+           aten_name='linear',
+           supports_autograd=True,
+           supports_gradgrad=True,
+           sample_inputs_func=sample_inputs_linear,
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           # linear calls mm under the hood which is nondeterministic on CUDA
+           # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           supports_expanded_weight=True,
+           decorators=(
+               # Strides are not the same!
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+           )),
+    OpInfo('nn.functional.bilinear',
+           aten_name='bilinear',
+           supports_autograd=True,
+           sample_inputs_func=sample_inputs_bilinear,
+           dtypes=all_types_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.float16,
+                                           *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else []),
+           decorators=(
+               DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-03, rtol=1.3e-03)}),
+                            'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
+           ),
+           skips=(
+               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)),
+           ),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('nn.functional.glu',
+           aten_name='glu',
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           sample_inputs_func=sample_inputs_glu,
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    UnaryUfuncInfo(
+        'nn.functional.elu',
+        aten_backward_name='elu_backward',
+        ref=lambda x, alpha=1.0, inplace=False:
+            np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_gradgrad=True,
+        supports_out=False,
+        sample_kwargs=lambda device, dtype, input:
+            ({'alpha': 0.8}, {'alpha': 0.8}),
+        inplace_variant=lambda x, alpha=1.0:
+            torch.nn.functional.elu(x, alpha, inplace=True),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({
+                    torch.float16: tol(atol=1e-03, rtol=1.2e-03),
+                    torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
+                }),
+                'TestUnaryUfuncs', device_type='cuda',
+            ), ],
+    ),
+    # Marked as a Unary function because it has some rather odd broadcasting semantics in its
+    # second argument
+    UnaryUfuncInfo(
+        'nn.functional.prelu',
+        aten_backward_name='_prelu_kernel_backward',
+        ref=lambda x, weight:
+            np.maximum(0., x) + np.minimum(0., x) *
+            (weight if x.ndim == 1 else weight.reshape([weight.size if i == 1 else 1 for i in range(0, x.ndim)])),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_gradgrad=True,
+        supports_out=False,
+        # test_reference_numerics only tests the case when the weight tensor is a scalar
+        sample_kwargs=sample_kwargs_prelu_scalar_weight,
+        error_inputs_func=error_inputs_prelu,
+        sample_inputs_func=sample_inputs_prelu,
+        reference_inputs_func=reference_inputs_prelu,
+        decorators=[
+            # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
+            # Consider making it a parameter or input, or detaching the gradient
+            # https://github.com/pytorch/pytorch/issues/68752
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), ],
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.celu',
+        ref=lambda x, alpha=1.0, inplace=False:
+            np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_gradgrad=True,
+        supports_out=False,
+        sample_kwargs=lambda device, dtype, input:
+            ({'alpha': 0.8}, {'alpha': 0.8}),
+        inplace_variant=lambda x, alpha=1.0:
+            torch.nn.functional.celu(x, alpha, inplace=True),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({
+                    torch.float16: tol(atol=1e-03, rtol=1.2e-03),
+                    torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
+                }),
+                'TestUnaryUfuncs', device_type='cuda',
+            ), ],
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.rrelu',
+        aten_backward_name='rrelu_with_noise_backward',
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.rrelu, input, *args, **kwargs),
+        inplace_variant=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.rrelu, input, *args, inplace=True, **kwargs),
+        dtypes=floating_types_and(torch.bfloat16),
+        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        gradcheck_wrapper=wrapper_set_seed,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        sample_kwargs=lambda device, dtype, input:
+            (dict(lower=0., upper=1., training=True), dict(lower=0., upper=1., training=True)),
+        sample_inputs_func=sample_inputs_rrelu,
+        error_inputs_func=error_inputs_rrelu,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride({
+                    torch.float16: tol(atol=1e-03, rtol=1.2e-03),
+                    torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
+                }),
+                'TestUnaryUfuncs', device_type='cuda',
+            ),),
+        skips=(
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # In-place operations do not play well with forward AD
+            # https://github.com/pytorch/pytorch/issues/77447
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients',
+                         'test_inplace_forward_mode_AD'),
+            # The noise vector that's generated in these tests is not the same elementwise
+            DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'),
+            DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'),
+            DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'),
+            DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
+        skip_correctness_check_compile_vs_eager=True,
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.selu',
+        ref=lambda x, inplace=False:
+            1.0507009873554804934193349852946 * (
+                np.maximum(0., x) + np.minimum(0., 1.6732632423543772848170429916717 * (np.exp(x) - 1))
+            ),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        supports_forward_ad=True,  # depends on 'elu'
+        supports_fwgrad_bwgrad=True,
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_gradgrad=True,
+        supports_out=False,
+        inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({
+                    torch.float16: tol(atol=1e-2, rtol=1.8e-2),
+                    torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2)
+                }),
+                'TestUnaryUfuncs', device_type='cuda',
+            ), ],
+    ),
+    OpInfo(
+        'torch._scaled_mm',
+        sample_inputs_func=sample_inputs_scaled_mm,
+        dtypes=float8_types(),
+        dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
+        supports_out=True,
+        supports_forward_ad=False,
+        supports_autograd=False,
+        decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
+        skips=(
+            # Sample inputs isn't really parametrized on dtype
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
+            # "add_stub" not implemented for 'Float8_e4m3fn'
+            # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
+            # https://github.com/pytorch/pytorch/issues/107256
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
+            # "mul_cuda" not implemented for float8_e4m3fn
+            # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
+            # https://github.com/pytorch/pytorch/issues/107256
+            DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'),
+            # aten::_scaled_mm hit the vmap fallback which is currently disabled
+            DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+            DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
+                         dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
+        )
+    ),
+    OpInfo(
+        'torch.ops.aten._safe_softmax.default',
+        dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool),
+        sample_inputs_func=sample_inputs_safe_softmax,
+        assert_jit_shape_analysis=True,
+        assert_autodiffed=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        supports_cow_input_no_materialize_backward=False,
+        decorators=[],
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+        ),
+    ),
+    OpInfo(
+        'nn.functional.scaled_dot_product_attention',
+        op=lambda *args, **kwargs:
+               wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
+        sample_inputs_func=sample_inputs_scaled_dot_product_attention,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=False,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        decorators=[DecorateInfo(toleranceOverride(
+            {torch.float32: tol(atol=5e-05, rtol=5e-6)}), 'TestCommon',), ],
+        skips=(
+            # When attn mask is a composite tensor this fails backward by returning a none
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
+            # This is only failing on Linux Bionic 3.10 Cuda 11.6
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
+                         device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples',
+                         dtypes=(torch.float32,)),
+            # AssertionError: JIT Test does not execute any logic
+            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+            # Forward works for dtype=float64 which is the math path
+            DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+            # Not implemented for Forward AD
+            DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
+                         device_type='cpu'),
+            # Not implemented for backward derivative
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad',
+                         device_type='cpu'),
+            # CPU and CUDA have inconsistencies for intermediate outputs
+            DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace',
+                         device_type='cpu'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
+                         device_type='cpu'),
+            # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward',
+                         device_type='cpu'),
+            # OpInfo was implemented with a lambda
+            DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # TODO Need to understand what this is testing and why it doesn't work
+            DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'),
+            DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),
+            # TODO skip this for now since we can't skip on runtime arch support
+            DecorateInfo(unittest.skip('This is '), 'TestInductorOpInfo', 'test_comprehensive'),
+            # skip for sm < 80
+            DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
+                         device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),
+            # FIXME
+            DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'),
+                         'TestCompositeCompliance', 'test_cow_input',
+                         device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32),
+                         active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),),
+    ),
+    OpInfo(
+        'torch.ops.aten._flash_attention_forward',
+        sample_inputs_func=sample_inputs_flash_attention_forward,
+        dtypes=empty_types(),
+        dtypesIfCUDA=custom_types(torch.float16)
+        if not SM80OrLater
+        else custom_types(torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_autograd=True,
+        supports_fwgrad_bwgrad=False,
+        supports_forward_ad=False,
+        check_batched_forward_grad=False,
+        decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")],
+        skips=(
+            # Checking the scalar value of the philox seed and offset
+            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
+            # None Mismatch Tensor
+            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
+        )
+    ),
+    OpInfo(
+        'torch.ops.aten._efficient_attention_forward',
+        sample_inputs_func=sample_inputs_efficient_attention_forward,
+        dtypes=empty_types(),
+        dtypesIfCUDA=custom_types(torch.float16, torch.float32)
+        if not SM80OrLater
+        else custom_types(torch.float16, torch.float32, torch.bfloat16),
+        supports_out=False,
+        supports_autograd=True,
+        supports_fwgrad_bwgrad=False,
+        supports_forward_ad=False,
+        check_batched_forward_grad=False,
+        # TODO: Skip because it produces a CUDA illegal memory access for some reason
+        skip_cow_input_backward=True,
+        # FIXME: mask_type == 2 (LowerRight)
+        decorators=[
+            skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"),
+            skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")],
+        skips=(
+            # Checking the scaler value of the philox seed and offset
+            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
+            # None Mismatch Tensor
+            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
+        )
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.silu',
+        aten_backward_name='silu_backward',
+        ref=lambda x, inplace=False: x / (1 + np.exp(-x)),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        supports_forward_ad=True,
+        supports_autograd=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=True,
+        supports_out=False,
+        inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({
+                    torch.float16: tol(atol=1e-3, rtol=1e-3),
+                    torch.bfloat16: tol(atol=1e-4, rtol=1e-4)
+                }),
+                'TestUnaryUfuncs', device_type='cuda',
+            ), ],
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                         dtypes=(torch.cfloat,), device_type='cpu'),
+        ),
+        autodiff_nonfusible_nodes=["aten::silu"],
+    ),
+    # TODO: combine this with the nn.functional.silu OpInfo when
+    # complex autodiff for silu is supported or when
+    # the forward bug is fixed
+    # Note: silu errors when given inputs that require grad
+    #   but it doesn't support grad in their dtype
+    #   This is why the dtypes list above passes test_dtypes,
+    #   because it's getting lucky and failing in forward
+    #   because test_dtypes sets requires_grad to True
+    #   THIS IS A BUG
+    UnaryUfuncInfo(
+        'nn.functional.silu',
+        variant_test_name='complex',
+        ref=lambda x, inplace=False:
+            x / (1 + np.exp(-x)),
+        dtypes=complex_types(),
+        dtypesIfCUDA=complex_types(),
+        supports_forward_ad=False,
+        supports_autograd=False,
+        assert_autodiffed=False,
+        supports_out=False,
+        inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({
+                    torch.float16: tol(atol=1e-3, rtol=1e-3),
+                    torch.bfloat16: tol(atol=1e-4, rtol=1e-4)
+                }),
+                'TestUnaryUfuncs', device_type='cuda',
+            ), ],
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                         dtypes=(torch.cfloat,)),
+            # FIXME: intentionally misreports dtypes
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
+            # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j)
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestUnaryUfuncs', 'test_reference_numerics_large',
+                         dtypes=(torch.complex64, torch.cdouble)),
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestUnaryUfuncs', 'test_reference_numerics_small',
+                         dtypes=(torch.complex64,)),
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                         dtypes=(torch.complex64,)))),
+    UnaryUfuncInfo(
+        'nn.functional.hardsigmoid',
+        aten_backward_name='hardsigmoid_backward',
+        ref=reference_hardsigmoid,
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_gradgrad=False,
+        supports_forward_ad=True,
+        supports_out=False,
+        inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ],
+        skips=[
+            # still want to test that first derivative works though second derivative isn't supported
+            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad"),
+            # produces 0 instead of nan on ROCM
+            DecorateInfo(unittest.expectedFailure,
+                         'TestUnaryUfuncs', "test_reference_numerics_extremal",
+                         device_type='cuda',
+                         active_if=(TEST_WITH_ROCM)), ]
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.logsigmoid',
+        aten_name="log_sigmoid",
+        aten_backward_name='log_sigmoid_backward',
+        ref=reference_logsigmoid,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_gradgrad=True,
+        # autodiff_nonfusible_nodes=["aten::log_sigmoid"],
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
+                'TestUnaryUfuncs', 'test_reference_numerics_small'),
+            DecorateInfo(
+                precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
+                'TestUnaryUfuncs', 'test_reference_numerics_large'),
+            DecorateInfo(
+                precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
+                'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
+        ],
+        skips=(
+            # Resized a non-empty tensor but did not warn about it.
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cpu'),
+        ),
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.mish',
+        aten_backward_name='mish_backward',
+        ref=lambda x: x * np.tanh(reference_softplus(x)),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_gradgrad=True,
+        supports_out=False,
+        inplace_variant=partial(torch.nn.functional.mish, inplace=True),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 'TestUnaryUfuncs',), ],
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.softsign',
+        ref=lambda x: x / (np.abs(x) + 1),
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_gradgrad=True,
+        supports_out=False,
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ],
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                         dtypes=(torch.int, torch.int8)),),
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.tanhshrink',
+        ref=lambda x: x - np.tanh(x),
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_autograd=True,
+        assert_autodiffed=False,
+        supports_gradgrad=True,
+        supports_out=False,
+        decorators=[
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(
+                toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 'TestUnaryUfuncs',),
+            DecorateInfo(toleranceOverride({torch.complex64: tol(atol=6e-04, rtol=1e-05),
+                                            torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}),
+                         'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'),
+        ],
+        skips=(
+            # in each case, pytorch will produce a nan while numpy will not
+            DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
+                         'TestUnaryUfuncs', "test_reference_numerics_large",
+                         dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)),
+            DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
+                         'TestUnaryUfuncs', "test_reference_numerics_extremal",
+                         dtypes=(torch.complex64, torch.complex128), device_type='cpu',
+                         active_if=(IS_MACOS or IS_WINDOWS)),
+        ),
+        # tan(j * pi/2 * odd_number) is nan which also make tanhshrink nan.
+        reference_numerics_filter=NumericsFilter(
+            condition=lambda x: (close_to_int(x / (math.pi * 0.5j))
+                                 if x.is_complex() else x.new_tensor(False, dtype=torch.bool)),
+            safe_val=0)
+    ),
+    UnaryUfuncInfo(
+        'nn.functional.threshold',
+        ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype),
+        dtypes=all_types_and(torch.half, torch.bfloat16),
+        inplace_variant=lambda x, threshold, value:
+            torch.nn.functional.threshold(x, threshold, value, inplace=True),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=False,
+        supports_gradgrad=True,
+        supports_out=False,
+        sample_kwargs=lambda device, dtype, input: ({'threshold': float.fromhex('0x1.3ap-3'),
+                                                    'value': -9},
+                                                    {'threshold': float.fromhex('0x1.3ap-3'),
+                                                    'value': -9}),
+        # TODO(whc) should not need sample_inputs_func, but without it
+        # kwargs aren't being hooked up properly
+        sample_inputs_func=sample_inputs_threshold,
+    ),
+    OpInfo(
+        "nn.functional.triplet_margin_loss",
+        sample_inputs_func=sample_inputs_triplet_margin_loss,
+        error_inputs_func=error_inputs_triplet_margin_loss,
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    OpInfo(
+        "nn.functional.triplet_margin_with_distance_loss",
+        sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
+        error_inputs_func=error_inputs_triplet_margin_loss,
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # This test cannot handle a callable passed to `distance_function`. If we would use
+            # `distance_function=None`, the test would pass fine.
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestJit",
+                "test_variant_consistency_jit",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+        ),
+    ),
+    BinaryUfuncInfo('nextafter',
+                    dtypes=floating_types_and(torch.bfloat16, torch.half),
+                    supports_autograd=False,
+                    supports_rhs_python_scalar=False),
+    OpInfo(
+        "to",
+        op=lambda x, *args, **kwargs: x.to(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        sample_inputs_func=sample_inputs_to,
+        skips=(
+            # RuntimeError: undefined value cpu
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cpu",
+            ),
+            # NotImplementedError: Cannot copy out of meta tensor; no data!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMeta",
+                "test_meta_outplace",
+            ),
+            # https://github.com/pytorch/pytorch/issues/84335
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestProxyTensorOpInfo",
+                "test_make_fx_symbolic_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+        ),
+    ),
+    OpInfo('topk',
+           dtypes=all_types_and(torch.bfloat16, torch.float16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           sample_inputs_func=sample_inputs_topk),
+    # Multiple variants for batch_norm to test with and without cuDNN disabled
+    # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details
+    OpInfo('nn.functional.batch_norm',
+           aten_name='batch_norm',
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           allow_cow_input_materialize_forward=[1, 2],
+           allow_cow_input_materialize_backward=[1, 2],
+           sample_inputs_func=sample_inputs_batch_norm,
+           skips=(
+               # see https://github.com/pytorch/pytorch/issues/71286
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
+                            device_type='cpu', dtypes=(torch.bfloat16, torch.float16)),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-05, rtol=1e-05)}),
+                            'TestCompositeCompliance', 'test_forward_ad', device_type="cpu"),
+           )),
+    # This variant tests batch_norm with cuDNN disabled only on CUDA devices
+    OpInfo('nn.functional.batch_norm',
+           variant_test_name='without_cudnn',
+           aten_name='batch_norm',
+           dtypes=empty_types(),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           allow_cow_input_materialize_forward=[1, 2],
+           allow_cow_input_materialize_backward=[1, 2],
+           decorators=[onlyCUDA, disablecuDNN],
+           skips=(
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-04)}),
+                            'TestJit', 'test_variant_consistency_jit'),
+           ),
+           sample_inputs_func=sample_inputs_batch_norm),
+    OpInfo(
+        "nn.functional.binary_cross_entropy",
+        aten_backward_name='binary_cross_entropy_backward',
+        sample_inputs_func=sample_inputs_binary_cross_entropy,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        gradcheck_fast_mode=False,
+        supports_autograd=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=(
+            # RuntimeError: expected int at position 0, but got: Tensor
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCudaFuserOpInfo",
+            ),
+            # RuntimeError: expected int at position 0, but got: Tensor
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestNNCOpInfo",
+                "test_nnc_correctness",
+            ),
+            # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120783
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCompositeCompliance",
+                "test_cow_input",
+                device_type='cuda',
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
+                "TestJit",
+                "test_variant_consistency_jit",
+            ),
+            # RuntimeError: output with shape [] doesn't match the broadcast shape [5, 5]
+            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'),
+            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
+            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
+        ),
+        skips=(
+            # RuntimeError: expected int at position 0, but got: Tensor
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestJit",
+                "test_variant_consistency_jit",
+            ),
+        ),
+    ),
+    # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
+    # standard entry, second is to run gradcheck tests on the second argument.
+    BinaryUfuncInfo('igamma',
+                    dtypes=floating_types_and(torch.bfloat16, torch.float16),
+                    aliases=('torch.special.gammainc',),
+                    dtypesIfCUDA=floating_types(),
+                    # TODO: FIXME
+                    supports_rhs_python_scalar=False,
+                    supports_autograd=False,
+                    skips=(
+                        # FIXME: incorrectly tries to pass a rhs scalar
+                        DecorateInfo(unittest.expectedFailure, 'TestJit',
+                                     'test_jit_alias_remapping'),
+                    )),
+    # TODO: FIXME, ideally by implemented grad for both inputs
+    # BinaryUfuncInfo('igamma',
+    #                 variant_test_name='grad_other',
+    #                 # Since autograd formula is implemented only for other and
+    #                 # gradcheck test verifies the formula for input in SampleInput,
+    #                 # we permute the arguments.
+    #                 op=lambda self, other, **kwargs: torch.igamma(other, self, **kwargs),
+    #                 inplace_variant=None,
+    #                 method_variant=None,
+    #                 supports_rhs_python_scalar=False,
+    #                 rhs_make_tensor_kwargs=dict(requires_grad=False),
+    #                 dtypes=floating_types_and(torch.bfloat16, torch.float16),
+    #                 backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
+    #                 dtypesIfCUDA=floating_types(),
+    #                 backward_dtypesIfCUDA=floating_types(),
+    #                 supports_inplace_autograd=False,
+    #                 skips=(
+    #                     # Derivative wrt first tensor not implemented
+    #                     DecorateInfo(unittest.expectedFailure, "TestCommon",
+    #                                  "test_floating_inputs_are_differentiable"),"),
+    #                     # test does not work with passing lambda for op
+    #                     # AssertionError: False is not true : Tensors failed to compare as equal!
+    #                     DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+    #                     # test fails are we permute the arguments function variant
+    #                     # but not for inplace or method.
+    #                     DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+    #                     # TypeError: igamma(): argument 'input' (position 1) must be Tensor, not float
+    #                     DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
+    #                 )),
+    BinaryUfuncInfo('igammac',
+                    dtypes=floating_types_and(torch.bfloat16, torch.float16),
+                    aliases=('torch.special.gammaincc',),
+                    dtypesIfCUDA=floating_types(),
+                    supports_autograd=False,
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        # FIXME: incorrectly tries to pass a rhs scalar
+                        DecorateInfo(unittest.expectedFailure, 'TestJit',
+                                     'test_jit_alias_remapping'),
+                    )),
+    # TODO: FIXME, ideally by implementing grad for both inputs
+    # BinaryUfuncInfo('igammac',
+    #                 variant_test_name='grad_other',
+    #                 # Since autograd formula is implemented only for other and
+    #                 # gradcheck test verifies the formula for input in SampleInput,
+    #                 # we permute the arguments
+    #                 op=lambda self, other, **kwargs: torch.igammac(other, self, **kwargs),
+    #                 inplace_variant=None,
+    #                 method_variant=None,
+    #                 supports_rhs_python_scalar=False,
+    #                 rhs_make_tensor_kwargs=dict(requires_grad=False),
+    #                 dtypes=floating_types_and(torch.bfloat16, torch.float16),
+    #                 backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
+    #                 dtypesIfCUDA=floating_types(),
+    #                 backward_dtypesIfCUDA=floating_types(),
+    #                 supports_inplace_autograd=False,
+    #                 decorators=[
+    #                     # Derivative wrt first tensor not implemented
+    #                     DecorateInfo(unittest.expectedFailure, "TestCommon",
+    #                                  "test_floating_inputs_are_differentiable"),
+    #                 ],
+    #                 skips=(
+    #                     # test does not work with passing lambda for op
+    #                     # AssertionError: False is not true : Tensors failed to compare as equal!
+    #                     DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+    #                     # test fails are we permute the arguments function variant
+    #                     # but not for inplace or method.
+    #                     DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+    #                     # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float
+    #                     DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
+    #                 )),
+    UnaryUfuncInfo('nn.functional.softshrink',
+                   aten_name="softshrink",
+                   aten_backward_name='softshrink_backward',
+                   dtypes=floating_types_and(torch.bfloat16, torch.float16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   assert_autodiffed=False,
+                   sample_inputs_func=sample_inputs_softshrink,
+                   error_inputs_func=error_inputs_softshrink),
+    UnaryUfuncInfo('nn.functional.hardshrink',
+                   aten_name="hardshrink",
+                   aten_backward_name='hardshrink_backward',
+                   dtypes=floating_types_and(torch.bfloat16, torch.float16),
+                   assert_autodiffed=True,
+                   sample_inputs_func=sample_inputs_hardshrink,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   autodiff_nonfusible_nodes=["aten::hardshrink"]),
+    UnaryUfuncInfo('nn.functional.hardtanh',
+                   aten_name="hardtanh",
+                   aten_backward_name='hardtanh_backward',
+                   dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.half, torch.bfloat16),
+                   backward_dtypes=all_types_and(torch.half, torch.bfloat16),
+                   backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+                   assert_autodiffed=True,
+                   sample_inputs_func=sample_inputs_hardtanh,
+                   error_inputs_func=error_inputs_hardtanh,
+                   supports_out=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   autodiff_nonfusible_nodes=["aten::hardtanh"]),
+    OpInfo('nn.functional.gelu',
+           aten_name="gelu",
+           aten_backward_name='gelu_backward',
+           ref=reference_gelu if TEST_SCIPY else None,
+           error_inputs_func=error_inputs_gelu,
+           supports_autograd=True,
+           assert_autodiffed=True,
+           sample_inputs_func=sample_inputs_gelu,
+           dtypes=floating_types_and(torch.bfloat16, torch.half),
+           supports_gradgrad=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           autodiff_nonfusible_nodes=["aten::gelu"],
+           skips=(
+               # AssertionError: Tensor-likes are not close!
+               # May not replicate in CI
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
+               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
+           )),
+    UnaryUfuncInfo('nn.functional.relu6',
+                   aten_name="relu6",
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
+                   backward_dtypes=floating_types_and(torch.half, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_out=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   autodiff_nonfusible_nodes=["aten::relu6"]),
+    OpInfo('mm',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_mm,
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+               # Fast math on MacOS-13?
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=2e-5, rtol=5e-6)}),
+                   'TestConsistency',
+                   'test_output_match',
+                   active_if=lambda _: MACOS_VERSION < 14.0,
+                   device_type='mps',
+                   dtypes=(torch.float32,)),
+           )),
+    OpInfo('mode',
+           op=torch.mode,
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # Resized a non-empty tensor but did not warn about it
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # FIXME:
+               # Expected 2114 but got 1123.
+               # Absolute difference: 991 (up to 0.001 allowed)
+               # Relative difference: 0.46877956480605487 (up to 0.001 allowed)
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   "TestCommon",
+                   "test_compare_cpu",
+                   dtypes=(torch.float32,),
+                   device_type="cuda",
+               ),
+           ),
+           sample_inputs_func=sample_inputs_mode,),
+    make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1',
+                         domain=(1, None),
+                         skips=skips_mvlgamma(),
+                         sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})),
+    make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3',
+                         domain=(2, None),
+                         skips=skips_mvlgamma(),
+                         sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})),
+    make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5',
+                         domain=(3, None),
+                         skips=skips_mvlgamma(),
+                         sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})),
+    BinaryUfuncInfo('ne',
+                    ref=np.not_equal,
+                    aliases=('not_equal',),
+                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+                    always_returns_bool=True,
+                    supports_autograd=False,
+                    skips=(
+                    )),
+    OpInfo('narrow',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True),
+           reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True),
+           error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False),
+           skips=(
+               # Use of .item()
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+           )),
+    OpInfo('narrow_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           supports_out=True,
+           supports_forward_ad=False,
+           supports_fwgrad_bwgrad=False,
+           supports_autograd=False,
+           # https://github.com/pytorch/pytorch/issues/86931
+           sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False),
+           reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False),
+           error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False),
+           skips=(
+               # https://github.com/pytorch/pytorch/issues/84577
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # Could not run 'aten::narrow_copy.out' with arguments from the 'CUDA' backend
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace',
+                            device_type='cuda'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace',
+                            device_type='cuda'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
+                            device_type='cuda'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
+           )),
+    OpInfo('view_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+           ref=lambda x, newshape: np.reshape(x, newshape).copy(),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_autograd=True,
+           sample_inputs_func=sample_inputs_view_reshape,
+           error_inputs_func=error_inputs_view_reshape,
+           skips=(
+               # RuntimeError: view size is not compatible with input tensor's size and stride
+               # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
+               DecorateInfo(
+                   unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"
+               ),
+           )),
+    UnaryUfuncInfo('neg',
+                   aliases=('negative', ),
+                   ref=np.negative,
+                   dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
+                   error_inputs_func=error_inputs_neg,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   assert_autodiffed=True),
+    OpInfo('dist',
+           op=torch.dist,
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
+           # Could not allocate memory to change Tensor SizesAndStrides!
+           check_batched_forward_grad=False,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_dist),
+    OpInfo('outer',
+           op=torch.outer,
+           aliases=('ger', ),
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_outer,),
+    OpInfo('ormqr',
+           op=torch.ormqr,
+           dtypes=floating_and_complex_types(),
+           # https://github.com/pytorch/pytorch/issues/80411
+           gradcheck_fast_mode=True,
+           supports_forward_ad=False,
+           supports_fwgrad_bwgrad=False,
+           sample_inputs_func=sample_inputs_ormqr,
+           error_inputs_func=error_inputs_ormqr,
+           decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack],
+           skips=(
+               # Strides are not the same!
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+           )),
+    OpInfo('permute',
+           ref=np.transpose,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           assert_autodiffed=True,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           assert_jit_shape_analysis=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_varargs=True,
+           sample_inputs_func=sample_inputs_permute,
+           reference_inputs_func=reference_inputs_permute),
+    OpInfo('permute_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=True,
+           assert_autodiffed=True,
+           assert_jit_shape_analysis=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_varargs=False,  # torch.permute is also not varargs
+           sample_inputs_func=sample_inputs_permute,
+           reference_inputs_func=reference_inputs_permute,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
+           )),
+    BinaryUfuncInfo('pow',
+                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+                    dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
+                    ref=np.power,
+                    # Due to AVX2 currently not being fully supported for Float16, log_vml_cpu can't be enabled
+                    # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently
+                    # unsupported on CPU.
+                    backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+                    backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
+                    # https://github.com/pytorch/pytorch/issues/80411
+                    gradcheck_fast_mode=True,
+                    supports_inplace_autograd=False,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    assert_autodiffed=True,
+                    supports_one_python_scalar=True,
+                    # Integer types do not support negative exponentes
+                    rhs_make_tensor_kwargs=dict(low=0),
+                    # Raising negative real numbers to fractional powers is not supported
+                    lhs_make_tensor_kwargs=dict(low=0),
+                    decorators=(
+                        DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}),
+                                     'TestBinaryUfuncs', 'test_reference_numerics'),
+                        DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05),
+                                                        torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}),
+                                     'TestBinaryUfuncs', 'test_scalar_support'),
+                    ),
+                    skips=(
+                        # Skipping integers because they are being raised to negative powers causing an error
+                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
+                                     dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]),
+                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
+                                     dtypes=[torch.int16, torch.int32, torch.int64]),
+                        # FIXME Complex values error with: Greatest absolute difference: nan at index
+                        # Ref: https://github.com/pytorch/pytorch/issues/76853
+                        # For `chalf`, reference computation in `numpy` is computed in `cfloat`.
+                        # Output of `chalf` saturates to `inf` quicker than reference due to its small range
+                        # which leads to failure of this test.
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick',
+                                     dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
+                        # FIXME:
+                        # Mismatched elements: 1 / 500 (0.2%)
+                        # Greatest absolute difference: nan at index (7, 9, 0) (up to 1e-05 allowed)
+                        # Greatest relative difference: nan at index (7, 9, 0) (up to 0.001 allowed)
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive',
+                                     dtypes=(torch.complex32,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing',
+                                     dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_batch_vs_slicing',
+                                     dtypes=(torch.complex32,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_non_contig',
+                                     dtypes=(torch.complex32,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics',
+                                     dtypes=(torch.complex32,)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
+                                     dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
+                                     dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
+                                     dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+                    )),
+    BinaryUfuncInfo('float_power',
+                    ref=np.float_power,
+                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
+                    promotes_int_to_float=True,
+                    # https://github.com/pytorch/pytorch/issues/80411
+                    gradcheck_fast_mode=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_one_python_scalar=True,
+                    # Integer types do not support negative exponentes
+                    rhs_make_tensor_kwargs=dict(low=0),
+                    # Raising negative real numbers to fractional powers is not supported
+                    lhs_make_tensor_kwargs=dict(low=0),
+                    decorators=(
+                        DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05),
+                                                        torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}),
+                                     'TestBinaryUfuncs', 'test_scalar_support'),
+                    ),
+                    skips=(
+                        # FIXME
+                        # AssertionError: Object comparison failed: torch.float64 != torch.float32
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
+                        # -3.43399e+38 is outside the range of representable values of type 'float'
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+                        # Complex values error with: Greatest absolute difference: nan at index
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
+                                     dtypes=[torch.complex64, torch.complex128]),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
+                                     dtypes=[torch.complex64, torch.complex128]),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
+                                     dtypes=[torch.complex64, torch.complex128]),
+                        # Inplace always promotes to double and thus other floating dtypes are not supported
+                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace',
+                                     dtypes=[torch.bfloat16, torch.float16, torch.float32]),
+                    )),
+    OpInfo('qr',
+           op=torch.qr,
+           dtypes=floating_and_complex_types(),
+           sample_inputs_func=sample_inputs_linalg_qr_geqrf,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # In-place ops
+           check_batched_gradgrad=False,
+           decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack]),
+    UnaryUfuncInfo('rad2deg',
+                   ref=np.degrees,
+                   decorators=(precisionOverride({torch.bfloat16: 7e-1,
+                                                  torch.float16: 7e-1}),),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True),
+    UnaryUfuncInfo('real',
+                   ref=np.real,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+                   supports_out=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   # See https://github.com/pytorch/pytorch/issues/66357
+                   check_batched_forward_grad=False,
+                   skips=(
+                       # Skip since real and imag don't have out variants.
+                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
+                   )),
+    OpInfo(
+        "roll",
+        ref=np.roll,
+        dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+        error_inputs_func=error_inputs_roll,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_roll,
+        decorators=(onlyNativeDeviceTypes,),
+    ),
+    OpInfo(
+        "rot90",
+        dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
+        error_inputs_func=error_inputs_rot90,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_rot90,
+    ),
+    # To test reference numerics against multiple values of argument `decimals`,
+    # we make multiple OpInfo entries with each entry corresponding to different value of decimals.
+    UnaryUfuncInfo('round',
+                   ref=np.round,
+                   aliases=('special.round',),
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   skips=(
+                       DecorateInfo(unittest.expectedFailure,
+                                    'TestNNCOpInfo',
+                                    'test_nnc_correctness',
+                                    dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
+                       DecorateInfo(unittest.skip("Skipped!"),
+                                    'TestNNCOpInfo',
+                                    'test_nnc_correctness',
+                                    dtypes=(torch.bfloat16,)),
+                   ),
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   assert_autodiffed=True,
+                   ),
+    UnaryUfuncInfo('round',
+                   ref=np.round,
+                   variant_test_name='decimals_0',
+                   aliases=('special.round',),
+                   dtypes=floating_types_and(torch.half, torch.bfloat16),
+                   sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}),
+                   sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   assert_autodiffed=False,
+                   supports_sparse_csr=False),
+    UnaryUfuncInfo('round',
+                   ref=np.round,
+                   variant_test_name='decimals_3',
+                   aliases=('special.round',),
+                   dtypes=floating_types_and(torch.bfloat16),
+                   dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+                   sample_kwargs=lambda device, dtype, input: ({'decimals': 3}, {'decimals': 3}),
+                   sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 3}),
+                   skips=(
+                       # test_ops already tested for this overload with `decimals_0` opinfo entry
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'),
+                       DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
+                                    "TestUnaryUfuncs", "test_reference_numerics_extremal",
+                                    device_type="cuda"),
+                       DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
+                                    "TestUnaryUfuncs", "test_reference_numerics_normal",
+                                    device_type="cuda"),
+                   ),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   assert_autodiffed=False,
+                   supports_sparse_csr=False),
+    UnaryUfuncInfo('round',
+                   ref=np.round,
+                   variant_test_name='decimals_neg_3',
+                   aliases=('special.round',),
+                   dtypes=floating_types_and(torch.bfloat16),
+                   dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+                   sample_kwargs=lambda device, dtype, input: ({'decimals': -3}, {'decimals': -3}),
+                   sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': -3}),
+                   skips=(
+                       # test_ops already tested for this overload with `decimals_0` opinfo entry
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'),
+                   ),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   assert_autodiffed=False,
+                   supports_sparse_csr=False),
+    UnaryUfuncInfo('sin',
+                   ref=np.sin,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   assert_autodiffed=True,
+                   handles_large_floats=False,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       # Fails on CUDA but passes on ROCm
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=(torch.cdouble,), device_type='cuda'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                       DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
+                                    "TestConsistency", "test_output_grad_match", device_type="mps"),
+                   ),
+                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),)),
+    UnaryUfuncInfo('sinc',
+                   ref=np_sinc_with_fp16_as_fp32,
+                   aliases=('special.sinc',),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   handles_large_floats=False,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True),
+    UnaryUfuncInfo('sinh',
+                   ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True,
+                   decorators=(precisionOverride({torch.float16: 1e-2}),),
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=(torch.cdouble,)),
+                       # Reference: https://github.com/pytorch/pytorch/issues/48641
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=[torch.int8]),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                   )),
+    UnaryUfuncInfo('sign',
+                   ref=reference_sign,
+                   dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
+                   dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/issues/41245
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]),
+                   )),
+    UnaryUfuncInfo('sgn',
+                   ref=reference_sgn,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+                   backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
+                   backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/issues/41245
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                   )),
+    OpInfo('split',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
+           sample_inputs_func=partial(sample_inputs_split, list_args=False),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           assert_autodiffed=True),
+    OpInfo('split',
+           # Cannot declare this aten_name because of
+           # test_variant_consistency_jit_split_list_args_cpu_float32
+           decomp_aten_name='split_with_sizes',
+           variant_test_name='list_args',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
+           sample_inputs_func=partial(sample_inputs_split, list_args=True),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    # `unsafe_split` supports only `int` for split_size argument
+    OpInfo('unsafe_split',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
+           sample_inputs_func=partial(sample_inputs_split, list_args=False),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           assert_autodiffed=True,
+           check_batched_forward_grad=False),
+    OpInfo('split_with_sizes',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
+           sample_inputs_func=sample_inputs_split_with_sizes,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_autodiffed=True),
+    OpInfo('split_with_sizes_copy',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
+           sample_inputs_func=sample_inputs_split_with_sizes,
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # No error raised
+               DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_requires_grad_error"),
+           )),
+    BinaryUfuncInfo('__radd__',
+                    op=torch.Tensor.__radd__,
+                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
+                    supports_out=False,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
+
+                    ),
+                    assert_autodiffed=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    autodiff_nonfusible_nodes=['aten::add'],),
+    BinaryUfuncInfo('__rdiv__',
+                    op=torch.Tensor.__rdiv__,
+                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
+                    promotes_int_to_float=True,
+                    lhs_make_tensor_kwargs={'exclude_zero': True},
+                    # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+                    gradcheck_fast_mode=True,
+                    supports_out=False,
+                    skips=(
+                        # https://github.com/pytorch/pytorch/issues/76806
+                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
+                    ),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    assert_autodiffed=True,
+                    autodiff_nonfusible_nodes=['aten::mul', 'aten::reciprocal'],),
+    BinaryUfuncInfo('__rmul__',
+                    op=torch.Tensor.__rmul__,
+                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
+                    supports_out=False,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
+                    ),
+                    assert_autodiffed=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    autodiff_nonfusible_nodes=['aten::mul'],),
+    BinaryUfuncInfo('__rand__',
+                    op=torch.Tensor.__rand__,
+                    dtypes=integral_types_and(torch.bool),
+                    supports_out=False,
+                    supports_autograd=False,
+                    supports_forward_ad=True,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+                    )),
+    BinaryUfuncInfo('__ror__',
+                    op=torch.Tensor.__ror__,
+                    dtypes=integral_types_and(torch.bool),
+                    supports_out=False,
+                    supports_autograd=False,
+                    supports_forward_ad=True,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+                    )),
+    BinaryUfuncInfo('__rxor__',
+                    op=torch.Tensor.__rxor__,
+                    dtypes=integral_types_and(torch.bool),
+                    supports_out=False,
+                    supports_autograd=False,
+                    supports_forward_ad=True,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+                    )),
+    OpInfo('__rmatmul__',
+           op=torch.Tensor.__rmatmul__,
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
+                                                       *[torch.bfloat16]
+                                                       if SM53OrLater or TEST_WITH_ROCM else []),
+           assert_autodiffed=True,
+           sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           decorators=(
+               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
+               DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
+                            'TestMathBits', 'test_conj_view'),
+               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}),
+                            'TestCommon', 'test_noncontiguous_samples'),
+               DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1e-05)}),
+                            "TestDecomp", "test_comprehensive", device_type="cuda",
+                            active_if=TEST_WITH_ROCM),
+           ),
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
+               # https://github.com/pytorch/pytorch/issues/67470
+               DecorateInfo(unittest.skip("67470!"),
+                            'TestCommon', 'test_noncontiguous_samples',
+                            device_type='cpu', dtypes=(torch.long,)),
+               # Fails on XLA.
+               # AssertionError: False is not true : Tensors failed to compare as equal
+               DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
+               # https://github.com/pytorch/pytorch/issues/71774
+               DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
+                            device_type='cpu', dtypes=(torch.long,)),
+           )),
+    BinaryUfuncInfo('__rmod__',
+                    op=torch.Tensor.__rmod__,
+                    dtypes=floating_types_and(torch.bfloat16, torch.half,),
+                    dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half),
+                    # https://github.com/pytorch/pytorch/issues/80411
+                    gradcheck_fast_mode=True,
+                    supports_out=False,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_one_python_scalar=True,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
+                    ),
+                    # Support autograd after torch.remainder(Tensor, Tensor) supports
+                    # autograd of the second argument.
+                    # https://github.com/pytorch/pytorch/pull/58476/files#r637167630
+                    # supports_autograd=False,
+                    assert_autodiffed=True,
+                    autodiff_nonfusible_nodes=['aten::remainder'],),
+    BinaryUfuncInfo('__rpow__',
+                    op=torch.Tensor.__rpow__,
+                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
+                    # Reference: https://github.com/pytorch/pytorch/issues/54774
+                    # "log2" "_vml_cpu" not implemented for Half
+                    backward_dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
+                    supports_out=False,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_one_python_scalar=True,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
+                        # TODO: FIXME tolerance is too high
+                        DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients'),
+                        DecorateInfo(unittest.skip('Skipped!'), 'TestBwdGradients'),
+                    ),
+                    assert_autodiffed=True,
+                    autodiff_nonfusible_nodes=['aten::pow'],),
+    BinaryUfuncInfo('__rsub__',
+                    op=torch.Tensor.__rsub__,
+                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_out=False,
+                    supports_one_python_scalar=True,
+                    skips=(
+                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
+                    ),
+                    assert_autodiffed=True,
+                    autodiff_nonfusible_nodes=['aten::rsub'],),
+    BinaryUfuncInfo('rsub',
+                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_out=False,
+                    supports_inplace_autograd=False,
+                    assert_autodiffed=None,
+                    sample_inputs_func=sample_inputs_add_sub),
+    OpInfo('select',
+           aten_backward_name='select_backward',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
+           sample_inputs_func=sample_inputs_select,
+           assert_jit_shape_analysis=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('select_scatter',
+           dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool),
+           sample_inputs_func=sample_inputs_select_scatter,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=False),
+    OpInfo('slice',
+           op=torch.ops.aten.slice.Tensor,
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
+           sample_inputs_func=sample_inputs_slice,
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_scripting=False,
+           supports_inplace_autograd=False,
+           supports_out=False),
+    OpInfo('slice_scatter',
+           dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool),
+           sample_inputs_func=sample_inputs_slice_scatter,
+           # https://github.com/pytorch/pytorch/issues/80411
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_out=True),
+    UnaryUfuncInfo('signbit',
+                   ref=np.signbit,
+                   dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_autograd=False,),
+    UnaryUfuncInfo('tan',
+                   ref=np.tan,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   decorators=(DecorateInfo(
+                               toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}),
+                               'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                               device_type='cuda'),),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                       # FIXME:
+                       # Mismatched elements: 2 / 400 (0.5%)
+                       # Greatest absolute difference: inf at index (7, 16) (up to 1e-05 allowed)
+                       # Greatest relative difference: nan at index (7, 16) (up to 0.001 allowed)
+                       DecorateInfo(
+                           unittest.skip("Skipped!"),
+                           "TestInductorOpInfo",
+                           "test_comprehensive",
+                           dtypes=(torch.float16,),
+                           device_type="cuda",
+                       ),
+                       DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}),
+                                    "TestConsistency", "test_output_match", device_type="mps"),
+                       DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
+                                    "TestConsistency", "test_output_grad_match", device_type="mps"),
+                   ),
+                   # tan(pi/2 * odd_number) is nan
+                   reference_numerics_filter=NumericsFilter(
+                       condition=lambda x: close_to_int(x / (math.pi * 0.5)), safe_val=math.pi)),
+    UnaryUfuncInfo('tanh',
+                   ref=np.tanh,
+                   aten_backward_name='tanh_backward',
+                   aliases=('nn.functional.tanh',),
+                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),
+                               DecorateInfo(
+                                   toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}),
+                                   'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                   device_type='cuda'),),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   assert_autodiffed=True,
+                   assert_jit_shape_analysis=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                                    active_if=(IS_MACOS or IS_WINDOWS)),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                       DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}),
+                                    "TestConsistency", "test_output_match", device_type="mps"),
+                   ),
+                   # tan(j * pi/2 * odd_number) is nan
+                   reference_numerics_filter=NumericsFilter(
+                       condition=lambda x: (close_to_int(x / (math.pi * 0.5j))
+                                            if x.is_complex() else x.new_tensor(False, dtype=torch.bool)),
+                       safe_val=0)),
+    OpInfo('tensor_split',
+           ref=np.array_split,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+           dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # Pre-existing condition; Needs to be fixed
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
+               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
+           ),
+           sample_inputs_func=sample_inputs_tensor_split,),
+    OpInfo('hsplit',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_hsplit,
+           error_inputs_func=error_inputs_hsplit,),
+    OpInfo('vsplit',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_vsplit,
+           error_inputs_func=error_inputs_vsplit,),
+    OpInfo('dsplit',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_dsplit,
+           error_inputs_func=error_inputs_dsplit,),
+    OpInfo('triangular_solve',
+           op=torch.triangular_solve,
+           dtypes=floating_and_complex_types(),
+           sample_inputs_func=sample_inputs_legacy_solve,
+           check_batched_gradgrad=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
+           decorators=[
+               skipCUDAIfNoMagma,
+               skipCPUIfNoLapack,
+               DecorateInfo(
+                   toleranceOverride({torch.float32: tol(atol=3e-5, rtol=3e-6)}),
+                   'TestConsistency', 'test_output_match', device_type='cpu',
+               ),
+           ],
+           skips=(
+               # AssertionError: Scalars are not equal!
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               # Gradcheck fails
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
+                            dtypes=floating_and_complex_types()),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
+                            device_type='mps', dtypes=[torch.float32]),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
+                            device_type='mps', dtypes=[torch.float32]),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            device_type='mps', dtypes=[torch.float32]),
+           )),
+    UnaryUfuncInfo('trunc',
+                   aliases=('fix', ),
+                   ref=np.trunc,
+                   dtypes=all_types_and(torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   skips=(
+                       DecorateInfo(unittest.expectedFailure,
+                                    'TestNNCOpInfo',
+                                    'test_nnc_correctness',
+                                    dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
+                   ),
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   assert_autodiffed=True),
+    UnaryUfuncInfo('exp2',
+                   aliases=('special.exp2', ),
+                   ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=[torch.cdouble]),
+                       # Reference: https://github.com/pytorch/pytorch/issues/48010
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+                   )),
+    UnaryUfuncInfo('expm1',
+                   aliases=('special.expm1', ),
+                   ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   promotes_int_to_float=True,
+                   assert_autodiffed=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cuda', dtypes=[torch.complex128]),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                   )),
+    UnaryUfuncInfo('nan_to_num',
+                   ref=np.nan_to_num,
+                   dtypes=all_types_and(torch.half, torch.bool, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                   ),
+                   # Passing numpy_kwargs via sample_kwargs, as numpy does comparison
+                   # with BFloat16 in float, since it currently doesn't support BFloat16.
+                   # Ref: https://github.com/pytorch/pytorch/issues/57982#issuecomment-839150556
+                   sample_kwargs=lambda device, dtype, input: ({},
+                                                               {'posinf': torch.finfo(torch.bfloat16).max,
+                                                                'neginf': torch.finfo(torch.bfloat16).min})
+                   if dtype is torch.bfloat16 else ({}, {})),
+    UnaryUfuncInfo('reciprocal',
+                   ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/issues/45690
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.cfloat, torch.cdouble]),
+                   )),
+    UnaryUfuncInfo('rsqrt',
+                   ref=lambda x: np.reciprocal(np.sqrt(x)),
+                   domain=(0, None),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   decorators=(precisionOverride({torch.half: 5e-2}),),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=(torch.cfloat, torch.cdouble)),
+                       # AssertionError: Tensor-likes are not close!
+                       # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed)
+                       # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
+                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=(torch.chalf,)),
+                   )),
+    UnaryUfuncInfo('sqrt',
+                   ref=np.sqrt,
+                   supports_sparse=True,
+                   domain=(0, None),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   decorators=(
+                       precisionOverride({torch.bfloat16: 7e-2}),
+                       DecorateInfo(
+                           toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
+                           'TestUnaryUfuncs', 'test_reference_numerics_large'),
+                   ),
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/issues/47358
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    device_type='cpu', dtypes=(torch.cfloat, torch.cdouble),
+                                    active_if=IS_MACOS),
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+                       DecorateInfo(toleranceOverride({torch.complex64: tol(atol=2e-5, rtol=3e-6)}),
+                                    "TestConsistency", "test_output_match", device_type="mps"),
+                   )),
+    UnaryUfuncInfo('square',
+                   ref=np.square,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+                   decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/issues/52549
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=[torch.cfloat, torch.cdouble]),
+                       # >>> t = torch.tensor(complex(-0.01, float("inf")))
+                       # >>> np.square(t.numpy())
+                       # (-inf-infj)
+                       # >>> t.square()
+                       # tensor(-inf-infj)
+                       # >>> t.cuda().square()
+                       # tensor(inf+nanj, device='cuda:0')
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace',
+                                    dtypes=[torch.bool]),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace',
+                                    dtypes=[torch.bool]),
+                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace',
+                                    dtypes=[torch.bool]),
+                   ),),
+    OpInfo('lerp',
+           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_lerp,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_autodiffed=True),
+    UnaryUfuncInfo('angle',
+                   ref=np.angle,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool),
+                   decorators=(precisionOverride({torch.float16: 1e-2,
+                                                  torch.bfloat16: 1e-2}),),
+                   backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
+                   backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_complex_to_float=True,
+                   skips=(
+                       # Ref: https://github.com/pytorch/pytorch/issues/78413
+                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64),),
+                   )),
+    UnaryUfuncInfo('isfinite',
+                   ref=np.isfinite,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+                   supports_out=False,
+                   supports_autograd=False),
+    UnaryUfuncInfo('isinf',
+                   ref=np.isinf,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+                   supports_out=False,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_autograd=False),
+    UnaryUfuncInfo('isposinf',
+                   ref=np.isposinf,
+                   dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_autograd=False),
+    UnaryUfuncInfo('isneginf',
+                   ref=np.isneginf,
+                   dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_autograd=False),
+    UnaryUfuncInfo('isreal',
+                   ref=np.isreal,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+                   supports_out=False,
+                   supports_autograd=False),
+    UnaryUfuncInfo('isnan',
+                   ref=np.isnan,
+                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+                   supports_out=False,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_autograd=False),
+    OpInfo('einsum',
+           # we need this lambda because SampleInput expects tensor input as the first argument
+           # TODO(@heitorschueroff) update SampleInput to handle such cases
+           op=lambda tensors, equation: torch.einsum(equation, tensors),
+           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           # See https://github.com/pytorch/pytorch/issues/66357
+           sample_inputs_func=sample_inputs_einsum,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # test does not work with passing lambda for op
+               # there's a test `test_einsum` in `test_jit.py` to handle this case
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+           )),
+    OpInfo('svd',
+           op=torch.svd,
+           dtypes=floating_and_complex_types(),
+           sample_inputs_func=sample_inputs_svd,
+           # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           # We're using at::allclose, which does not have a batching rule
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+           skips=(
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
+                            device_type='mps', dtypes=[torch.float32]),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
+                            device_type='mps', dtypes=[torch.float32]),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            device_type='mps', dtypes=[torch.float32]),
+           )),
+    OpInfo('svd_lowrank',
+           op=lambda *args, **kwargs: wrapper_set_seed(
+               lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs),
+               *args, **kwargs
+           ),
+           dtypes=floating_and_complex_types(),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           # Due to the use of randomness
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           check_batched_forward_grad=False,
+           supports_fwgrad_bwgrad=True,
+           supports_forward_ad=True,
+           sample_inputs_func=sample_inputs_svd_lowrank,
+           decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off,
+                       DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03),
+                                                       torch.complex64: tol(atol=1e-02, rtol=1e-02)}),
+                                    'TestCommon', 'test_noncontiguous_samples'),
+                       # FIXME This should be the following, but the toleranceOverride does not seem to do anything!
+                       # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}),
+                       #              'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
+                       DecorateInfo(unittest.skip("See comment above"),
+                                    'TestFwdGradients',
+                                    'test_fn_fwgrad_bwgrad',
+                                    dtypes=[torch.complex128]),
+                       ],
+           skips=(
+               # test does not work with passing lambda for op
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
+                            dtypes=(torch.complex64, torch.complex128)),
+               DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'),
+           )),
+    OpInfo('pca_lowrank',
+           op=lambda *args, **kwargs: wrapper_set_seed(
+               lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs),
+               *args, **kwargs
+           ),
+           dtypes=floating_and_complex_types(),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           check_batched_forward_grad=False,
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_pca_lowrank,
+           decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off,
+                       DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03),
+                                                       torch.complex64: tol(atol=4e-02, rtol=4e-02)}),
+                                    'TestCommon', 'test_noncontiguous_samples'),
+                       DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=5e-05)}),
+                                    'TestOperators', 'test_grad'),
+                       # FIXME This should be the following, but the toleranceOverride does not seem to do anything!
+                       # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}),
+                       #              'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
+                       DecorateInfo(unittest.skip("See comment above"),
+                                    'TestFwdGradients',
+                                    'test_fn_fwgrad_bwgrad',
+                                    dtypes=[torch.complex128]),
+                       DecorateInfo(
+                           toleranceOverride({torch.float32: tol(atol=3e-5, rtol=1e-3)}),
+                           'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'),
+                       ],
+           skips=(
+               # test does not work with passing lambda for op
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
+                            dtypes=(torch.complex64, torch.complex128)),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    BinaryUfuncInfo('polar',
+                    dtypes=floating_types(),
+                    # this function is undefined if 'abs' values are <0
+                    supports_forward_ad=True,
+                    lhs_make_tensor_kwargs=dict(low=0),
+                    supports_rhs_python_scalar=False,
+                    skips=(
+                        # RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument
+                        DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_type_promotion'),
+                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
+                        # GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0
+                        # Numerical:
+                        #  tensor([[0.]], dtype=torch.float64)
+                        # Analytical:
+                        # tensor([[-0.0047]], dtype=torch.float64, grad_fn=)
+                        DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
+                    )),
+    # TODO(@kshitij12345): Refactor similar to `mvlgamma` entries.
+    # To test reference numerics against multiple values of argument `n`,
+    # we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4).
+    # We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing.
+    UnaryUfuncInfo('polygamma',
+                   op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
+                   variant_test_name='polygamma_n_0',
+                   ref=reference_polygamma if TEST_SCIPY else None,
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   sample_inputs_func=sample_inputs_polygamma,
+                   skips=(
+                       DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+                   ),
+                   sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0}),
+                   # polygamma functions have multiple singularities at x having non-positive integer value
+                   reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4),
+                                                            safe_val=1)),
+    *(UnaryUfuncInfo('polygamma',
+                     op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
+                     variant_test_name=f'polygamma_n_{n_}',
+                     ref=reference_polygamma if TEST_SCIPY else None,
+                     dtypes=all_types_and(torch.bool, torch.bfloat16),
+                     dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                     supports_forward_ad=True,
+                     supports_fwgrad_bwgrad=True,
+                     promotes_int_to_float=True,
+                     sample_inputs_func=sample_inputs_polygamma,
+                     decorators=(
+                         DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-3)}), 'TestUnaryUfuncs'),
+                         DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e1, rtol=1e-1),
+                                                         torch.float32: tol(atol=1e-4, rtol=1e-2)}),
+                                      'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                                      active_if=IS_WINDOWS),
+                     ),
+                     skips=(
+                         # Redundant tests
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
+                         # Mismatch: https://github.com/pytorch/pytorch/issues/55357
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'),
+                     ),
+                     sample_kwargs=lambda device, dtype, input: ({'n': n_}, {'n': n_}),
+                     # polygamma functions have multiple singularities at x having non-positive integer value
+                     reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4),
+                                                              safe_val=1))
+      for n_ in (1, 2, 3, 4)),
+    OpInfo('ravel',
+           ref=np.ravel,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_ravel,
+           ),
+    OpInfo('unravel_index',
+           ref=np.unravel_index,
+           dtypes=integral_types_and(),
+           supports_out=False,
+           supports_autograd=False,
+           sample_inputs_func=sample_inputs_unravel_index,
+           ),
+    OpInfo('reshape',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           sample_inputs_func=sample_inputs_view_reshape,
+           reference_inputs_func=reference_inputs_view_reshape,
+           error_inputs_func=error_inputs_view_reshape,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           ),
+    OpInfo('reshape_as',
+           op=lambda x, other: x.reshape_as(other),
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True),
+           reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True),
+           error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+           )),
+    OpInfo('view',
+           op=lambda x, shape: x.view(shape),
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           sample_inputs_func=sample_inputs_view_reshape,
+           reference_inputs_func=reference_inputs_view_reshape,
+           error_inputs_func=error_inputs_view_reshape,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # RuntimeError: view size is not compatible with input tensor's size and stride
+               # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
+               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+           )),
+    OpInfo('view_as',
+           op=lambda x, other: x.view_as(other),
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True),
+           reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True),
+           error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True),
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # RuntimeError: view size is not compatible with input tensor's size and stride
+               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides")
+           )),
+    OpInfo('atleast_1d',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_atleast1d2d3d,
+           skips=(
+               # JIT does not support variadic tensors.
+               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+               # please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
+           ),
+           ),
+    OpInfo('atleast_2d',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
+           ),
+           sample_inputs_func=sample_inputs_atleast1d2d3d,
+           ),
+    OpInfo('atleast_3d',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
+           ),
+           sample_inputs_func=sample_inputs_atleast1d2d3d,
+           ),
+    OpInfo('flatten',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           ref=reference_flatten,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_flatten,
+           reference_inputs_func=reference_inputs_flatten,
+           ),
+    OpInfo('unflatten',
+           op=torch.unflatten,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_unflatten,
+           ),
+    OpInfo('column_stack',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_column_stack,),
+    OpInfo('pinverse',
+           op=torch.pinverse,
+           dtypes=floating_and_complex_types(),
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           supports_out=False,
+           sample_inputs_func=sample_inputs_linalg_invertible,
+           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+           skips=(
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
+                            device_type='mps', dtypes=[torch.float32]),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
+                            device_type='mps', dtypes=[torch.float32]),
+           )),
+    OpInfo('gather',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_gather,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           error_inputs_func=error_inputs_gather,
+           ),
+    OpInfo('index_fill',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
+           inplace_variant=torch.Tensor.index_fill_,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           skips=(
+               # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal!
+               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
+               # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal!
+               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp'),
+           ),
+           sample_inputs_func=sample_inputs_index,
+           reference_inputs_func=partial(sample_inputs_index, reference=True)),
+    OpInfo('index_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_index,
+           reference_inputs_func=partial(sample_inputs_index, reference=True),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
+    OpInfo('index_select',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
+           sample_inputs_func=sample_inputs_index,
+           reference_inputs_func=partial(sample_inputs_index, reference=True),
+           error_inputs_func=error_inputs_index_select,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           assert_jit_shape_analysis=True,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
+    OpInfo('index_add',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           inplace_variant=torch.Tensor.index_add_,
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_index,
+           reference_inputs_func=partial(sample_inputs_index, reference=True),
+           error_inputs_func=error_inputs_index_add,
+           skips=(
+               # boolean alpha not handled properly
+               DecorateInfo(unittest.expectedFailure,
+                            'TestNNCOpInfo',
+                            'test_nnc_correctness',
+                            dtypes=(torch.bool,)),
+           ),
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
+    *(OpInfo('index_reduce',
+             variant_test_name=reduction_type,
+             dtypes=all_types_and(torch.float16, torch.bfloat16),
+             skips=(
+                 DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-3)}),
+                              'TestInductorOpInfo', 'test_comprehensive'),
+             ),
+             supports_out=True,
+             sample_inputs_func=sample_inputs_index_reduce,
+             ) for reduction_type in ('mean', 'prod', 'amin', 'amax')),
+    OpInfo('_unsafe_masked_index',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
+           supports_out=False,
+           supports_inplace_autograd=False,
+           supports_scripting=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs__unsafe_masked_index,
+           skips=(
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+               DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward',
+                            dtypes=(torch.float64,), active_if=IS_WINDOWS),
+           ),),
+    OpInfo('_unsafe_masked_index_put_accumulate',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
+           supports_out=False,
+           supports_inplace_autograd=False,
+           supports_scripting=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           decorators=(
+               DecorateInfo(
+                   toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-2)}),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'
+               ),
+           ),
+           sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate,
+           skips=(
+               DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward',
+                            dtypes=(torch.float64,), active_if=IS_WINDOWS),
+           ),),
+    OpInfo('__getitem__',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_inplace_autograd=False,
+           supports_scripting=False,
+           op=torch.Tensor.__getitem__,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 104448
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),),
+           sample_inputs_func=sample_inputs_getitem),
+    OpInfo('index_put',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           supports_inplace_autograd=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           test_neg_view=False,
+           sample_inputs_func=sample_inputs_index_put,
+           skips=(
+               DecorateInfo(unittest.skip("Skipped"), 'TestBwdGradients', 'test_fn_grad', dtypes=[torch.float64],
+                            device_type='cuda', active_if=(TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)),
+           )),
+    OpInfo('sort',
+           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_sort,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], device_type='cuda', active_if=not TEST_WITH_ROCM),
+           )),
+    OpInfo('unique',
+           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64),
+           sample_inputs_func=sample_inputs_unique,
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip('Output order is undefined when sorted=False'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('unique_consecutive',
+           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_unique_consecutive,
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           )),
+    OpInfo('put',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_forward_grad=False,
+           check_batched_gradgrad=False,  # vmap complains of the sizes
+           sample_inputs_func=sample_inputs_put),
+    OpInfo('take',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           check_batched_grad=False,  # vmap complains of the sizes
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_take,
+           error_inputs_func=error_inputs_take),
+    OpInfo('scatter',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_scatter,
+           error_inputs_func=error_inputs_scatter_and_scatter_add,
+           skips=(
+               # Compiler issue on ROCm. Regression started in ROCm 6.4.
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    UnaryUfuncInfo(
+        'bfloat16',
+        op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        skips=(
+            # autograd tests don't handle operators that change dtype
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
+            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+        )),
+    UnaryUfuncInfo(
+        'bool',
+        op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attributis not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+        )),
+    UnaryUfuncInfo(
+        'byte',
+        op=lambda x, *args, **kwargs: x.byte(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_byte,
+        # The autograd test runner cannot handle functions that change dtype
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )),
+    UnaryUfuncInfo(
+        'char',
+        op=lambda x, *args, **kwargs: x.char(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        # The autograd test runner cannot handle functions that change dtype
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )),
+    UnaryUfuncInfo(
+        'double',
+        op=lambda x, *args, **kwargs: x.double(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+        )),
+    UnaryUfuncInfo(
+        'float',
+        op=lambda x, *args, **kwargs: x.float(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        skips=(
+            # autograd tests don't handle operators that change dtype
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
+            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+        )),
+    UnaryUfuncInfo(
+        'half',
+        op=lambda x, *args, **kwargs: x.half(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        supports_autograd=True,
+        skips=(
+            # autograd tests don't handle operators that change dtype
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
+            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+        )),
+    UnaryUfuncInfo(
+        'int',
+        op=lambda x, *args, **kwargs: x.int(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )),
+    UnaryUfuncInfo(
+        'long',
+        op=lambda x, *args, **kwargs: x.long(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )),
+    UnaryUfuncInfo(
+        'short',
+        op=lambda x, *args, **kwargs: x.short(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )),
+    UnaryUfuncInfo(
+        'cdouble',
+        op=torch.Tensor.cdouble,
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+        )),
+    UnaryUfuncInfo(
+        'cfloat',
+        op=torch.Tensor.cfloat,
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        skips=(
+            # autograd tests don't handle operators that change dtype
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
+            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # RuntimeError: attribute lookup is not defined on builtin
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+        )),
+    UnaryUfuncInfo(
+        'chalf',
+        op=lambda x, *args, **kwargs: x.chalf(*args, **kwargs),
+        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_conversion,
+        skips=(
+            # autograd tests don't handle operators that change dtype
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
+            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
+            # use of lambda doesn't work with test_normalize_operator_exhaustive
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager',
+                         device_type='cpu'),
+            # TypeError: 'int' object is not iterable
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view',
+                         device_type='cpu'),
+            # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view',
+                         device_type='cpu'),
+            # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
+            # RuntimeError: "neg_conj_cuda" not implemented for 'ComplexHalf'
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+        )
+    ),
+    OpInfo('empty_like',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_like_fns,
+           reference_inputs_func=reference_inputs_like_fns,
+           supports_autograd=False,
+           skips=(
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"),
+                            "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
+               DecorateInfo(unittest.skip("Expected: empty_like is not comparable"), 'TestCompositeCompliance',
+                            'test_operator'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('zeros_like',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_like_fns,
+           supports_autograd=False,
+           error_inputs_sparse_func=error_inputs_sparse_like_fns,
+           sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
+           sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
+           sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
+           sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
+           sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
+           skips=(
+           )),
+    OpInfo('ones_like',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_like_fns,
+           supports_autograd=False,
+           skips=(
+           )),
+    OpInfo('randn',
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32),
+           op=lambda *args, **kwargs: wrapper_set_seed(torch.randn, *args, **kwargs),
+           supports_out=True,
+           sample_inputs_func=sample_inputs_randn,
+           supports_autograd=False,
+           skips=(
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+               # CPU randn generates different values based on the strides of out tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
+               # randn fails to warn when resizing its out tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # Tests that assume input tensor has a meaningful effect on output tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
+           )),
+    OpInfo('randn_like',
+           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32),
+           op=lambda inp, *args, **kwargs:
+               wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_like_fns,
+           supports_autograd=False,
+           error_inputs_sparse_func=error_inputs_sparse_like_fns,
+           sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
+           sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
+           sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
+           sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
+           sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('rand_like',
+           dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128),
+           op=lambda inp, *args, **kwargs:
+               wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_like_fns,
+           supports_autograd=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('randint',
+           dtypes=all_types_and(torch.half, torch.bfloat16),
+           op=lambda *args, **kwargs:
+               wrapper_set_seed(torch.randint, *args, **kwargs),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_randint,
+           supports_autograd=False,
+           skips=(
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
+               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
+               # CPU randint generates different values based on the strides of out tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               # randint fails to warn when resizing its out tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # Tests that assume input tensor has a meaningful effect on output tensor
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # Might need to skip until ROCm5.5
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_multiple_devices',
+                            dtypes=[torch.float32, torch.int64], active_if=TEST_WITH_ROCM),
+           )),
+    OpInfo('randint_like',
+           dtypes=all_types_and(torch.half, torch.bfloat16),
+           op=lambda inp, *args, **kwargs:
+               wrapper_set_seed(torch.randint_like, inp, *args, **kwargs),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_randint_like,
+           supports_autograd=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('full_like',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16,
+                                            torch.uint16, torch.uint32),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_full_like,
+           supports_autograd=False,
+           ),
+    OpInfo('new_zeros',
+           op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_new_fns,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+           ),
+           supports_autograd=False),
+    OpInfo('new_ones',
+           op=lambda x, *args, **kwargs: x.new_ones(*args, **kwargs),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_new_fns,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+           ),
+           supports_autograd=False),
+    OpInfo('ones',
+           op=torch.ones,
+           supports_autograd=False,
+           supports_varargs=True,
+           is_factory_function=True,
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=True,
+           sample_inputs_func=sample_inputs_ones_zeros,
+           skips=(
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+           )),
+    OpInfo('zeros',
+           op=torch.zeros,
+           supports_autograd=False,
+           is_factory_function=True,
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=True,
+           sample_inputs_func=sample_inputs_ones_zeros,
+           skips=(
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+           )),
+    OpInfo('full',
+           op=torch.full,
+           supports_autograd=False,
+           is_factory_function=True,
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=True,
+           sample_inputs_func=sample_inputs_full,
+           skips=(
+               # Tests that assume input is a tensor or sequence of tensors
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+               # Same failure as arange: cannot find linspace in captured graph
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # RuntimeError: UNSUPPORTED DTYPE: bool
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)),
+           )),
+    OpInfo('new_empty',
+           op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_new_fns,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
+               DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 'TestCompositeCompliance',
+                            'test_operator'),
+               DecorateInfo(unittest.skip("Expected: new_empty is not comparable"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           ),
+           supports_autograd=False),
+    OpInfo('new_empty_strided',
+           op=lambda x, *args, **kwargs: x.new_empty_strided(*args, **kwargs),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           sample_inputs_func=partial(sample_inputs_new_fns, is_strided=True),
+           supports_autograd=False,
+           skips=(
+               # FX failed to normalize op
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Lazy tensor failures
+               DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestCommon', 'test_noncontiguous_samples'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestMathBits', 'test_neg_conj_view'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestCommon', 'test_non_standard_bool_values'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestCompositeCompliance', 'test_operator'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestDecomp', 'test_comprehensive'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestDecomp', 'test_quick'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
+               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
+                            'TestNNCOpInfo', 'test_nnc_correctness'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('empty_strided',
+           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.empty_strided, inp, *args, **kwargs),
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.half),
+           supports_out=False,
+           supports_autograd=False,
+           sample_inputs_func=sample_inputs_empty_strided,
+           skips=(
+               # FX failed to normalize op - add the op to the op_skip list.
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'),
+               DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', 'test_operator'),
+               # Lazy tensor failures
+               DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestLazyOpInfo'),
+               # RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single
+               # memory location. Please clone() the tensor before performing the operation.
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
+               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
+           )),
+    OpInfo('empty',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           sample_inputs_func=sample_inputs_empty,
+           supports_autograd=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
+               DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance',
+                            'test_operator'),
+               # requires_grad doesn't exist in the jit schema
+               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+               DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                            'TestCommon',
+                            'test_out'),
+               DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                            'TestCommon',
+                            'test_out_warning'),
+               DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                            'TestLazyOpInfo'),
+               DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('eye',
+           dtypes=all_types_complex_float8_and(torch.bool, torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_eye,
+           error_inputs_func=error_inputs_eye,
+           supports_out=True,
+           supports_autograd=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # TODO: same as this?
+               # https://github.com/pytorch/pytorch/issues/81774
+               # also see: arange, new_full
+               # fails to match any schemas despite working in the interpreter
+               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+               # fails to match any schemas despite working in the interpreter
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # skip these tests since we have non tensor input
+               DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
+                            dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
+           )),
+    OpInfo('empty_permuted',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           sample_inputs_func=sample_inputs_empty_permuted,
+           error_inputs_func=error_inputs_empty_permuted,
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 'TestCompositeCompliance',
+                            'test_operator'),
+               # requires_grad doesn't exist in the jit schema
+               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
+                            'TestCommon',
+                            'test_out'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
+                            'TestCommon',
+                            'test_out_warning'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
+                            'TestLazyOpInfo'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
+    OpInfo('scalar_tensor',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           sample_inputs_func=sample_inputs_scalar_tensor,
+           supports_autograd=False,
+           supports_out=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # fails to match any schemas despite working in the interpreter
+               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+               # fails to match any schemas despite working in the interpreter
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # skip these tests since we have non tensor input
+               DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+           )),
+    OpInfo('new_full',
+           op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           sample_inputs_func=sample_inputs_new_full,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+           ),
+           supports_autograd=False),
+    OpInfo('multinomial',
+           op=lambda inp, *args, **kwargs:
+               wrapper_set_seed(torch.multinomial, inp, *args, **kwargs),
+           method_variant=lambda inp, *args, **kwargs:
+               wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs),
+           dtypes=floating_types_and(torch.bfloat16, torch.half),
+           supports_out=True,
+           sample_inputs_func=sample_inputs_multinomial,
+           error_inputs_func=error_inputs_multinomial,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Strides are not the same!
+               # This may not be reproducible in CI
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
+           supports_autograd=False),
+    OpInfo('normal',
+           op=lambda inp, *args, **kwargs:
+               wrapper_set_seed(torch.normal, inp, *args, **kwargs),
+           # The inplace variant (Tensor.normal_) is different from torch.normal
+           inplace_variant=None,
+           dtypes=floating_types_and(torch.bfloat16, torch.half),
+           dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
+           supports_out=True,
+           sample_inputs_func=sample_inputs_normal_tensor_first,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Tensor-likes are not close!
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes
+               DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'),
+               DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+               # RuntimeError: Difference from {dtype} is larger with decomposition
+               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'),
+               # The inplace variant (Tensor.normal_) is different from torch.normal
+               # inplace variant Tensor.normal_ is decomposed using randn_like()
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'))),
+    OpInfo('normal',
+           # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here
+           variant_test_name='number_mean',
+           op=lambda std, mean, *args, **kwargs:
+               wrapper_set_seed(torch.normal, mean, std, *args, **kwargs),
+           # The inplace variant (Tensor.normal_) is different from torch.normal
+           inplace_variant=None,
+           dtypes=floating_types_and(torch.bfloat16, torch.half),
+           dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
+           supports_out=True,
+           sample_inputs_func=sample_inputs_normal_tensor_second,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_compare_cpu'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestEagerFusionOpInfo'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestOperators'),
+               # AssertionError
+               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'),
+               # AssertionError
+               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'),
+               # AssertionError in CUDA variant
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', device_type='cuda'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestDeviceUtils', 'test_device_mode_ops'))),
+    OpInfo('bernoulli',
+           op=lambda inp, *args, **kwargs:
+               wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs),
+           # The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli
+           inplace_variant=None,
+           method_variant=lambda inp, *args, **kwargs:
+               wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs),
+           dtypes=floating_types_and(torch.bfloat16, torch.half),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_bernoulli,
+           error_inputs_func=error_inputs_bernoulli,
+           skips=(
+               # vmap: We do not yet support calling random operations inside of vmap
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # AssertionError: JIT Test does not execute any logic
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # Expected RuntimeError when doing an unsafe cast from a result of
+               # dtype torch.float32 into an out= with dtype torch.lon
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
+    OpInfo('scatter_add',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           inplace_variant=torch.Tensor.scatter_add_,
+           sample_inputs_func=sample_inputs_scatter_add,
+           error_inputs_func=error_inputs_scatter_and_scatter_add,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # Compiler issue on ROCm. Regression started in ROCm 6.4.
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    OpInfo('stack',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_stack,
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # https://github.com/pytorch/pytorch/issues/77046
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+           )),
+    OpInfo('_chunk_cat',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_chunk_cat,
+           error_inputs_func=error_inputs_chunk_cat,
+           supports_autograd=False,
+           supports_out=True,
+           ),
+    OpInfo('hstack',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_hstack_dstack_vstack,
+           error_inputs_func=error_inputs_hstack_dstack_vstack,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           ),
+    BinaryUfuncInfo('hypot',
+                    dtypes=floating_types_and(torch.bfloat16, torch.half),
+                    dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_rhs_python_scalar=False),
+    OpInfo('histogram',
+           dtypes=floating_types(),
+           dtypesIfCUDA=_dispatch_dtypes(),  # histogram is only implemented on CPU
+           sample_inputs_func=sample_inputs_histogram,
+           supports_autograd=False,
+           skips=(
+               # JIT tests don't work with Tensor keyword arguments
+               # https://github.com/pytorch/pytorch/issues/58507
+               # RuntimeError:
+               # undefined value tensor:
+               #   File "", line 3
+               # def the_method(i0):
+               #     return torch.histogram(i0, 1, weight=tensor(-0.5735, dtype=torch.float32), density=False)
+               #                                          ~~~~~~ <--- HERE
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # Not Implemented on XLA.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla'),
+           )),
+    OpInfo('histogramdd',
+           dtypes=floating_types(),
+           dtypesIfCUDA=_dispatch_dtypes(),  # histogramdd is only implemented on CPU
+           sample_inputs_func=sample_inputs_histogramdd,
+           error_inputs_func=error_inputs_histogramdd,
+           supports_autograd=False,
+           skips=(
+               # Not implemented on CUDA
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'),
+               # JIT tests don't work with Tensor keyword arguments
+               # https://github.com/pytorch/pytorch/issues/58507
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           )),
+    OpInfo('histc',
+           dtypes=floating_types_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_types_and(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64),
+           sample_inputs_func=sample_inputs_histc,
+           supports_out=True,
+           supports_autograd=False,
+           skips=(
+               # CUDA histc returns a float tensor but does not correctly warn when passed an integral out tensor
+               # "AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast
+               # from a result of dtype torch.float32 into an out= with dtype torch.long"
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'),
+           )),
+    OpInfo('bincount',
+           dtypes=integral_types_and(),
+           sample_inputs_func=sample_inputs_bincount,
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               # JIT tests don't work with Tensor keyword arguments
+               # https://github.com/pytorch/pytorch/issues/58507
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+           )),
+    OpInfo('bucketize',
+           dtypes=all_types_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
+           sample_inputs_func=sample_inputs_bucketize,
+           reference_inputs_func=reference_inputs_bucketize,
+           error_inputs_func=error_inputs_bucketize,
+           supports_autograd=False,
+           skips=(
+               # JIT tests don't work with Tensor keyword arguments
+               DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'),
+           )),
+    OpInfo('searchsorted',
+           dtypes=all_types_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
+           sample_inputs_func=sample_inputs_searchsorted,
+           supports_autograd=False,
+           ref=reference_searchsorted,
+           skips=(
+               # JIT tests don't work with Tensor keyword arguments
+               # https://github.com/pytorch/pytorch/issues/58507
+               DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'),
+           )),
+    OpInfo('cat',
+           ref=_cat_np,
+           aliases=('concat', 'concatenate'),
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
+           sample_inputs_func=sample_inputs_cat_concat,
+           reference_inputs_func=reference_inputs_cat,
+           error_inputs_func=error_inputs_cat,
+           # https://github.com/pytorch/pytorch/issues/80411
+           gradcheck_fast_mode=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           assert_autodiffed=True,
+           skips=(
+               # https://github.com/pytorch/pytorch/issues/89353
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'),
+               # RuntimeError: Arguments for call not valid.
+               #               Expected a value of type 'List[Tensor]' for argument
+               #               'tensors' but instead found type 'Tensor (inferred)'.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
+               # see https://github.com/pytorch/pytorch/issues/71286
+               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
+               # see https://github.com/pytorch/pytorch/issues/99806
+               # RuntimeError: The size of tensor a (25) must match the size of tensor b (0) at non-singleton dimension 0.
+               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
+           )),
+    OpInfo('unbind',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           ref=reference_unbind,
+           sample_inputs_func=sample_inputs_unbind,
+           error_inputs_func=error_inputs_unbind,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_gradgrad=True,
+           supports_out=False,
+           ),
+    OpInfo('unbind_copy',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           ref=reference_unbind,
+           sample_inputs_func=sample_inputs_unbind,
+           error_inputs_func=error_inputs_unbind,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_gradgrad=True,
+           supports_out=True,
+           check_batched_grad=False,
+           skips=(
+               # Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None
+               # but it returned something else instead.
+               DecorateInfo(
+                   unittest.expectedFailure,
+                   'TestProxyTensorOpInfo',
+                   'test_make_fx_symbolic_exhaustive_out'
+               ),
+           )),
+    OpInfo('vstack',
+           aliases=('row_stack',),
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_hstack_dstack_vstack,
+           error_inputs_func=error_inputs_hstack_dstack_vstack,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # RuntimeError: _fn() Expected a value of type
+               #   'Tensor (inferred)' for argument 't0' but instead found type 'tuple'.
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)),
+    OpInfo('dstack',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_hstack_dstack_vstack,
+           error_inputs_func=error_inputs_hstack_dstack_vstack,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           ),
+    OpInfo('unfold',
+           op=lambda x, *args: x.unfold(*args),
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_gradgrad=False,
+           # See https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Skip operator schema test because this is a functional and not an operator
+               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+           ),
+           sample_inputs_func=sample_inputs_unfold),
+    OpInfo('unfold_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           check_batched_gradgrad=False,
+           # See https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_unfold),
+    OpInfo('msort',
+           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+           check_batched_gradgrad=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_msort),
+    OpInfo('movedim',
+           aliases=('moveaxis',),
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_movedim_moveaxis,
+           reference_inputs_func=reference_movedim_moveaxis,
+           error_inputs_func=error_movedim_moveaxis),
+    OpInfo('renorm',
+           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_renorm,
+           error_inputs_func=error_inputs_renorm,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # RuntimeError: Difference from float64 is larger with decomposition
+               # linalg_vector_norm.default than original on output 0.
+               # Original max diff: 2.560596747969157e-07,
+               # Decomp max diff: 1.8187482915266173e-06
+               DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive',
+                            device_type='cpu', dtypes=(torch.float16,)),
+               DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-4, rtol=3e-6)}),
+                            "TestConsistency", "test_output_match", device_type="mps"),
+           )),
+    ShapeFuncInfo('repeat',
+                  op=lambda x, dims: x.repeat(dims),
+                  ref=np.tile,
+                  dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+                  # https://github.com/pytorch/pytorch/issues/80411
+                  gradcheck_fast_mode=True,
+                  supports_out=False,
+                  supports_forward_ad=True,
+                  supports_fwgrad_bwgrad=True,
+                  sample_inputs_func=sample_repeat_tile,
+                  skips=(
+                      DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+                  )),
+    OpInfo('squeeze',
+           ref=_squeeze_ref,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           assert_autodiffed=True,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           assert_jit_shape_analysis=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_squeeze),
+    OpInfo('squeeze',
+           ref=_squeeze_ref,
+           variant_test_name="multiple",
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           assert_autodiffed=True,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_squeeze_multiple),
+    OpInfo('squeeze_copy',
+           ref=_squeeze_ref,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=True,
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           # https://github.com/pytorch/pytorch/issues/66357
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_squeeze,
+           skips=(
+               DecorateInfo(
+                   unittest.expectedFailure,
+                   'TestJit',
+                   'test_variant_consistency_jit',
+                   dtypes=(torch.float32,),
+               ),
+           )),
+    UnaryUfuncInfo(
+        'fill',
+        ref=_fill_np,
+        method_variant=None,
+        sample_kwargs=_fill_sample_kwargs,
+        sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'value': True}),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # https://github.com/pytorch/pytorch/issues/66357
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+        supports_out=False,
+        skips=(
+            # JIT has issue when op is passed as lambda
+            # AssertionError: JIT Test does not execute any logic
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip("No fill_ op"), 'TestCudaFuserOpInfo'),
+            DecorateInfo(unittest.skip("No fill_ op"), 'TestNNCOpInfo'),
+        )),
+    OpInfo('resize_',
+           op=lambda x, shape: x.clone().resize_(shape),
+           method_variant=None,
+           inplace_variant=torch.Tensor.resize_,
+           # the test fails because resize_ doesn't work with imag views as expected by the test
+           # https://github.com/pytorch/pytorch/issues/65945
+           test_neg_view=False,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               # Cannot resize variables that require grad
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'),
+           ),
+           sample_inputs_func=sample_inputs_resize_ops),
+    OpInfo('resize_as_',
+           op=lambda x, other: torch.resize_as_(x.clone(), other),
+           method_variant=None,
+           inplace_variant=torch.Tensor.resize_as_,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               # Cannot resize variables that require grad
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
+               DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
+           ),
+           sample_inputs_func=sample_inputs_resize_ops),
+    OpInfo('take_along_dim',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_inplace_autograd=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_take_along_dim,
+           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+           decorators=(
+               # RuntimeError: view size is not compatible with input tensor's size and stride
+               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+           )),
+    ShapeFuncInfo('tile',
+                  ref=np.tile,
+                  dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+                  # https://github.com/pytorch/pytorch/issues/80411
+                  gradcheck_fast_mode=True,
+                  supports_out=False,
+                  supports_forward_ad=True,
+                  supports_fwgrad_bwgrad=True,
+                  sample_inputs_func=sample_repeat_tile),
+    OpInfo('trapz',  # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid'
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
+               ),
+           ],
+           sample_inputs_func=sample_trapezoid),
+    OpInfo('trapezoid',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           decorators=[
+               DecorateInfo(
+                   toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}),
+                   'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
+               ),
+           ],
+           sample_inputs_func=sample_trapezoid),
+    OpInfo('cumulative_trapezoid',
+           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           supports_out=False,
+           decorators=(
+               DecorateInfo(
+                   toleranceOverride({torch.float16: tol(atol=4e-3, rtol=4e-3)}),
+                   'TestInductorOpInfo', 'test_comprehensive',
+               ),
+           ),
+           sample_inputs_func=sample_cumulative_trapezoid,),
+    OpInfo('unsqueeze',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           assert_jit_shape_analysis=True,
+           assert_autodiffed=True,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           sample_inputs_func=sample_unsqueeze),
+    OpInfo('unsqueeze_copy',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           assert_jit_shape_analysis=True,
+           assert_autodiffed=True,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           sample_inputs_func=sample_unsqueeze,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
+               DecorateInfo(
+                   unittest.expectedFailure,
+                   'TestJit',
+                   'test_variant_consistency_jit',
+                   dtypes=(torch.float32,),
+               ),
+           )),
+    BinaryUfuncInfo('xlogy',
+                    aliases=('special.xlogy',),
+                    dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                    promotes_int_to_float=True,
+                    supports_forward_ad=True,
+                    supports_fwgrad_bwgrad=True,
+                    supports_one_python_scalar=True,
+                    # We don't test 0 as the gradient will be NaN and it'll break
+                    rhs_make_tensor_kwargs=dict(low=0.01)),
+    OpInfo('zero_',
+           op=lambda x: torch.zero_(x.clone()),
+           method_variant=None,
+           inplace_variant=torch.Tensor.zero_,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           # https://github.com/pytorch/pytorch/issues/80411
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           supports_gradgrad=True,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           sample_inputs_func=sample_inputs_zero_),
+    OpInfo('logsumexp',
+           aliases=('special.logsumexp',),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           assert_autodiffed=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           gradcheck_fast_mode=False,
+           sample_inputs_func=sample_inputs_logsumexp,
+           reference_inputs_func=reference_inputs_logsumexp),
+    OpInfo('trace',
+           dtypes=all_types_and_complex(),
+           dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
+           error_inputs_func=error_inputs_trace,
+           supports_inplace_autograd=False,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_trace),
+    OpInfo('transpose',
+           ref=_numpy_ref_transpose,
+           aliases=('swapdims', 'swapaxes'),
+           assert_jit_shape_analysis=True,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_transpose_swapdims),
+    OpInfo('transpose_copy',
+           assert_jit_shape_analysis=True,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_transpose_swapdims,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
+               DecorateInfo(
+                   unittest.expectedFailure,
+                   'TestJit',
+                   'test_variant_consistency_jit',
+                   dtypes=(torch.float32,)
+               ),
+           )),
+    OpInfo('T',
+           op=lambda x: x.T,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
+           sample_inputs_func=sample_inputs_T,
+           error_inputs_func=error_inputs_T),
+    OpInfo('H',
+           op=lambda x: x.H,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           skips=(
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
+           sample_inputs_func=sample_inputs_T),
+    OpInfo('mT',
+           op=lambda x: x.mT,
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
+           sample_inputs_func=sample_inputs_adjoint),
+    OpInfo('mH',
+           op=lambda x: x.mH,
+           aliases=('adjoint',),
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           skips=(
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
+           sample_inputs_func=sample_inputs_adjoint),
+    OpInfo('tril',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           error_inputs_func=error_inputs_tril_triu,
+           sample_inputs_func=sample_inputs_tril_triu,
+           skips=(
+               # Compiler issue on ROCm. Regression started in ROCm 6.4.
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    OpInfo('triu',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           error_inputs_func=error_inputs_tril_triu,
+           sample_inputs_func=sample_inputs_tril_triu,
+           skips=(
+               # Compiler issue on ROCm. Regression started in ROCm 6.4.
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    OpInfo('triu_indices',
+           dtypes=_dispatch_dtypes((torch.int32, torch.int64)),
+           sample_inputs_func=sample_inputs_trilu_indices,
+           ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.triu_indices(h, ofs, w), dtype=dtype),
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               # skip these tests since we have non tensor input
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
+           )),
+    OpInfo('tril_indices',
+           dtypes=_dispatch_dtypes((torch.int32, torch.int64)),
+           sample_inputs_func=sample_inputs_trilu_indices,
+           ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.tril_indices(h, ofs, w), dtype=dtype),
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               # skip these tests since we have non tensor input
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
+               DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
+           )),
+    OpInfo('kron',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
+           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+           gradcheck_fast_mode=True,
+           supports_inplace_autograd=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           sample_inputs_func=sample_inputs_kron,
+           decorators=(
+               # RuntimeError: view size is not compatible with input tensor's size and stride
+               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+           )),
+    OpInfo('inner',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_inner,
+           ),
+    OpInfo('tensordot',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           sample_inputs_func=sample_inputs_tensordot,
+           skips=(
+               # Skip operator schema test because this is a functional and not an operator.
+               # Reference: https://github.com/pytorch/pytorch/issues/54574
+               DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+           )
+           ),
+    OpInfo('to_sparse',
+           op=lambda x, *args: x.to_sparse(*args),
+           sample_inputs_func=sample_inputs_to_sparse,
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           backward_dtypes=floating_types(),
+           backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_sparse_csr=True,
+           supports_sparse_csc=True,
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           skips=(
+               # NotImplementedError: Could not run 'aten::normal_' with arguments from the 'SparseCPU' backend
+               DecorateInfo(unittest.skip(""), 'TestCommon', 'test_noncontiguous_samples'),
+               # TODO: FIXME: complex inputs requiring grad error in forward
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Allowed exception: sparse tensors don't have strides
+               DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'),
+               DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'),
+               DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'),
+               # TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1.
+               DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"),
+                            'TestSparseCSR', 'test_sparse_csr_consistency'),
+               # Compiler issue on ROCm. Might need to skip until ROCm5.5
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )
+           ),
+    OpInfo('logcumsumexp',
+           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
+           backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
+           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16),
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'),
+               # RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble'
+               # Falling back to non-numerically stabilized exp, causing nan in the results.
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', dtypes=[torch.complex128]),
+               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]),
+               DecorateInfo(
+                   toleranceOverride({
+                       torch.float16: tol(atol=7e-5, rtol=6e-3),
+                   }),
+                   "TestInductorOpInfo",
+                   "test_comprehensive",
+                   device_type="cuda"
+               ),
+           ),
+           sample_inputs_func=sample_inputs_logcumsumexp,
+           error_inputs_func=error_inputs_logcumsumexp),
+    UnaryUfuncInfo('sigmoid',
+                   aliases=('special.expit', 'nn.functional.sigmoid'),
+                   aten_backward_name='sigmoid_backward',
+                   ref=reference_sigmoid if TEST_SCIPY else None,
+                   decorators=(precisionOverride({torch.float16: 1e-2,
+                                                  torch.complex64: 1e-1,
+                                                  torch.bfloat16: 1e-2}),),
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/issues/56012
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.complex64, torch.cdouble], device_type='cuda'),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')),
+                   dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   assert_autodiffed=True,
+                   # sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero
+                   reference_numerics_filter=NumericsFilter(
+                       condition=lambda x: (close_to_int(x / (math.pi * 1j))
+                                            if x.is_complex() else x.new_tensor(False, dtype=torch.bool)),
+                       safe_val=0)),
+    UnaryUfuncInfo('digamma',
+                   ref=scipy.special.digamma if TEST_SCIPY else None,
+                   aliases=('special.psi', 'special.digamma',),
+                   decorators=(precisionOverride({torch.float16: 5e-1}),),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True),
+    UnaryUfuncInfo('erf',
+                   ref=scipy.special.erf if TEST_SCIPY else None,
+                   aliases=('special.erf', ),
+                   decorators=(precisionOverride({torch.float16: 1e-2,
+                                                  torch.bfloat16: 1e-2}),),
+                   skips=(
+                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
+                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
+
+                   ),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   assert_autodiffed=True,
+                   assert_jit_shape_analysis=True,
+                   supports_sparse=True,
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True),
+    UnaryUfuncInfo('erfc',
+                   ref=scipy.special.erfc if TEST_SCIPY else None,
+                   aliases=('special.erfc', ),
+                   decorators=(precisionOverride({torch.float16: 1e-2,
+                                                  torch.bfloat16: 1e-2}),),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   assert_autodiffed=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True),
+    UnaryUfuncInfo('erfinv',
+                   ref=scipy.special.erfinv if TEST_SCIPY else None,
+                   aliases=('special.erfinv', ),
+                   decorators=(precisionOverride({torch.float16: 1e-2,
+                                                  torch.bfloat16: 1e-2,
+                                                  torch.float32: 1e-4}),),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_sparse_csr=True,
+                   supports_sparse_csc=True,
+                   supports_sparse_bsr=True,
+                   supports_sparse_bsc=True,
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   domain=(-1, 1),
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
+                                    active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
+                   )),
+    OpInfo("nn.functional.smooth_l1_loss",
+           ref=reference_smooth_l1_loss,
+           sample_inputs_func=sample_inputs_smooth_l1_loss,
+           dtypes=floating_types_and(torch.float16, torch.bfloat16),
+           backward_dtypes=floating_types_and(torch.bfloat16),
+           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED
+               # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch.
+               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),)),
+    OpInfo(
+        "nn.functional.l1_loss",
+        ref=loss_reference_reduction_wrapper(lambda input, target: np.abs(input - target)),
+        sample_inputs_func=sample_inputs_l1_loss,
+        error_inputs_func=error_inputs_l1_loss,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED
+            # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch.
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=(torch.float32,),
+            ),
+        ),
+    ),
+    UnaryUfuncInfo('lgamma',
+                   ref=reference_lgamma if TEST_SCIPY else None,
+                   aliases=('special.gammaln', ),
+                   decorators=(precisionOverride({torch.float16: 7e-1}),),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   skips=(
+                       # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                                    dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
+                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                                    dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
+                   ),
+                   # lgamma have multiple singularities at x <= 0
+                   reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
+    OpInfo(
+        'logdet',
+        dtypes=floating_and_complex_types(),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
+    # `log_softmax` supports different dtypes based on whether `dtype` argument,
+    # is passed or not. Hence two OpInfo entries, one with dtype and other without.
+    OpInfo(
+        'log_softmax',
+        aliases=('special.log_softmax', 'nn.functional.log_softmax'),
+        supports_out=True,
+        aten_backward_name='_log_softmax_backward_data',
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        sample_inputs_func=sample_inputs_softmax_variant,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=True),
+    OpInfo(
+        'log_softmax',
+        variant_test_name='with_dtype',
+        aliases=('special.log_softmax', 'nn.functional.log_softmax'),
+        supports_out=True,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+        sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=True),
+    UnaryUfuncInfo('logit',
+                   aten_backward_name='logit_backward',
+                   ref=scipy.special.logit if TEST_SCIPY else None,
+                   domain=(0, 1),
+                   aliases=('special.logit', ),
+                   supports_forward_ad=True,
+                   supports_fwgrad_bwgrad=True,
+                   promotes_int_to_float=True,
+                   decorators=(precisionOverride({torch.bfloat16: 5e-1,
+                                                  torch.float16: 5e-1}),),
+                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                   sample_inputs_func=sample_inputs_logit),
+    OpInfo('where',
+           # Currently only the `input` is tested in gradcheck.
+           # If we pass `condition` first, none of the input which supports
+           # autograd will be tested. Hence the following lambda.
+           op=lambda self, condition, other, **kwargs: torch.where(condition, self, other, **kwargs),
+           ref=lambda self, condition, other: np.where(condition, self, other),
+           sample_inputs_func=sample_inputs_where,
+           reference_inputs_func=reference_inputs_where,
+           error_inputs_func=error_inputs_where,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           decorators=(
+               DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),),
+           skips=(
+               # lambda impl
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+           ),
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)),
+    OpInfo('nonzero',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           sample_inputs_func=sample_inputs_nonzero,
+           supports_autograd=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+               # nonzero(): argument 'out' must be Tensor, not tuple
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               # https://github.com/pytorch/pytorch/issues/67458
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+               # nonzero is not raising a warning when the out is resized
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               # Can't find schemas for this operator for some reason
+               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+               # Compiler issue on ROCm. Might need to skip until ROCm5.5
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    OpInfo('nonzero_static',
+           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+           sample_inputs_func=sample_inputs_nonzero_static,
+           supports_out=False,
+           supports_autograd=False,
+           decorators=[onlyCPU],
+           skips=(
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
+               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
+               DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
+               DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values',
+                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+           )),
+    # Following tests are for jiterator's python interface
+    # Jiterator can be used to author elementwise CUDA kernel
+    # jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op
+    # See create_jit_fn in jiterator.py for more information
+    UnaryUfuncInfo(
+        'jiterator_unary',
+        op=torch.cuda.jiterator._create_jit_fn("template  T unary(T x) { return x * x + x; }"),
+        ref=lambda x: x * x + x,
+        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
+        supports_out=False,
+        supports_autograd=False,  # jiterator ops doesn't have backward defined
+        decorators=[
+            onlyCUDA,
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                         'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                         'TestUnaryUfuncs', 'test_reference_numerics_hard'),
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                         'TestUnaryUfuncs', 'test_reference_numerics_normal'),
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                         'TestUnaryUfuncs', 'test_reference_numerics_small'),
+        ],
+        skips=(
+            # Jiterator ops doesn't support neg or conj view
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+            # Jiterator ops doesn't support CompositeCompliantTensor
+            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
+            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
+            # Skip reference_numerics tests for bool type, as the defined function doesn't work for bool
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                         dtypes=[torch.bool]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
+                         dtypes=[torch.bool]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                         dtypes=[torch.bool]),
+            # ROCm generates -inf+infj instead of nan+infj for complex64 for some of the results
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                         dtypes=[torch.complex64], active_if=TEST_WITH_ROCM),
+            # Newer numpy generates -inf+infj instead of nan+infj for complex64 for some of the results
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
+                         dtypes=[torch.complex64], device_type='cuda'),
+            # Expected failure: torch.jiterator_unary is not a valid op
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # Skip Nvfuser
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
+        )
+    ),
+    BinaryUfuncInfo(
+        'jiterator_binary',
+        op=torch.cuda.jiterator._create_jit_fn(
+            "template  T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1),
+        ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
+            else np.add(input, np.multiply(alpha, other)),
+        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
+        sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14),
+        supports_out=False,
+        supports_autograd=False,  # jiterator ops doesn't have backward defined
+        supports_rhs_python_scalar=False,
+        decorators=[onlyCUDA],
+        skips=(
+            # Jiterator ops doesn't support neg or conj view
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+            # Jiterator ops doesn't support CompositeCompliantTensor
+            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
+            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
+            # Expected failure: torch.jiterator_binary is not a valid op
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # Skip Nvfuser
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
+        )
+    ),
+    OpInfo(
+        'jiterator_4inputs_with_extra_args',
+        op=torch.cuda.jiterator._create_jit_fn(
+            "template  T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }",
+            alpha=1, beta=1),
+        ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3,
+        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
+        sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20),
+        supports_out=False,
+        supports_autograd=False,  # jiterator ops doesn't have backward defined
+        decorators=[onlyCUDA],
+        skips=(
+            # Jiterator ops doesn't support neg or conj view
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+            # Jiterator ops doesn't support CompositeCompliantTensor
+            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
+            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
+            # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # Skip Nvfuser
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
+        )
+    ),
+    BinaryUfuncInfo(
+        'jiterator_binary_return_by_ref',
+        op=torch.cuda.jiterator._create_multi_output_jit_fn(
+            """
+            template 
+            void binary_return_by_ref(T i0, T i1, T& out0) {
+                out0 = i0 + i1;
+            }
+            """,
+            num_outputs=1),
+        ref=operator.add,
+        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
+        sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-0.42),
+        supports_out=False,
+        supports_autograd=False,  # jiterator ops doesn't have backward defined
+        supports_rhs_python_scalar=False,
+        decorators=[onlyCUDA],
+        skips=(
+            # Jiterator ops doesn't support neg or conj view
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+            # Jiterator ops doesn't support CompositeCompliantTensor
+            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
+            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
+            # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # Skip Nvfuser
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
+        )
+    ),
+    OpInfo(
+        'jiterator_2inputs_2outputs',
+        op=torch.cuda.jiterator._create_multi_output_jit_fn(
+            """
+            template 
+            void binary_2outputs(T i0, T i1, T& out0, T& out1) {
+                out0 = i0 + i1;
+                out1 = i0 - i1;
+            }
+            """,
+            num_outputs=2),
+        ref=lambda i0, i1, *, alpha=1: (i0 + i1, i0 - i1),
+        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
+        sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2),
+        supports_out=False,
+        supports_autograd=False,  # jiterator ops doesn't have backward defined
+        decorators=[onlyCUDA],
+        skips=(
+            # Jiterator ops doesn't support neg or conj view
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+            # Jiterator ops doesn't support CompositeCompliantTensor
+            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
+            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
+            # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # Skip Nvfuser
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
+        )
+    ),
+    # `torch.norm` has multiple code paths depending on the value of `p`.
+    # These paths have different dtype support. Also JIT supports,
+    # most variants but not all of them. So we split the OpInfo entries,
+    # for `norm` based on the code-paths and JIT support.
+    OpInfo(
+        "norm",
+        sample_inputs_func=sample_inputs_norm,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        # TODO Benchmark again with the new implementation
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        check_batched_forward_grad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # Dispatches in Python to vector_norm. Not sure how to make this test happy
+            # Happens to pass on complex64. Also a mystery
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
+                         dtypes=(torch.float32,)),)
+    ),
+    OpInfo('norm',
+           variant_test_name='nuc',
+           sample_inputs_func=sample_inputs_norm_nuc,
+           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+           check_batched_gradgrad=False,
+           # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
+           # got: Could not allocate memory to change Tensor SizesAndStrides!
+           check_batched_forward_grad=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           dtypes=floating_and_complex_types(),
+           dtypesIfCUDA=floating_and_complex_types(),
+           skips=(
+               # Dispatches in Python to matrix_norm. Not sure how to make this test happy
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.complex64, torch.float32,)),)
+           ),
+    OpInfo('norm',
+           variant_test_name='fro',
+           sample_inputs_func=sample_inputs_norm_fro,
+           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+           supports_forward_ad=True,
+           # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
+           # got: Could not allocate memory to change Tensor SizesAndStrides!
+           check_batched_forward_grad=False,
+           supports_fwgrad_bwgrad=True,
+           skips=(
+               # MPS has some mild accuracy issues for float16. We divide the tolerances by 10
+               DecorateInfo(
+                   toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}),
+                   'TestConsistency',
+                   'test_output_match',
+
+               ),
+               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+               DecorateInfo(
+                   unittest.skip("Skipped!"),
+                   'TestSchemaCheckModeOpInfo',
+                   'test_schema_correctness',
+                   dtypes=(torch.complex64, torch.complex128)),
+               # Dispatches in Python to vector_norm. Not sure how to make this test happy
+               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
+                            dtypes=(torch.complex64, torch.float32,)),)
+           ),
+    OpInfo(
+        "norm",
+        variant_test_name="inf",
+        sample_inputs_func=sample_inputs_norm_inf,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        # fast gradcheck produces NaNs
+        gradcheck_fast_mode=False,
+        skips=(
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}),
+                'TestInductorOpInfo', 'test_comprehensive', device_type='cuda',
+            ),
+            # Dispatches in Python to vector_norm. Not sure how to make this test happy
+            # Happens to pass on complex64. Also a mystery
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
+                         dtypes=(torch.float32,))
+        ),
+    ),
+    OpInfo('t',
+           sample_inputs_func=sample_inputs_t,
+           supports_out=False,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           assert_autodiffed=True,
+           error_inputs_func=error_inputs_t),
+    OpInfo('t_copy',
+           sample_inputs_func=sample_inputs_t,
+           supports_out=True,
+           supports_forward_ad=True,
+           supports_fwgrad_bwgrad=True,
+           # See https://github.com/pytorch/pytorch/pull/78358
+           check_batched_forward_grad=False,
+           # vmap does not support inplace views
+           check_inplace_batched_forward_grad=False,
+           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
+           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           assert_autodiffed=True,
+           error_inputs_func=error_inputs_t),
+    OpInfo(
+        "nn.functional.dropout",
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs),
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # Probably because we have used lambda for the op here
+            # AssertionError: JIT Test does not execute any logic
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # inplace variant dispatches to dropout kernel, while on CUDA
+            # the op dispatches to _fused_dropout (with a few more conditions)
+            # hence, different values and this skip here
+            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # https://github.com/pytorch/pytorch/issues/66357
+        check_batched_forward_grad=False,
+        supports_out=False,
+        sample_inputs_func=sample_inputs_dropout,
+        inplace_variant=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)),
+    OpInfo(
+        "native_dropout_backward",
+        op=torch.ops.aten.native_dropout_backward.default,
+        aten_name="native_dropout_backward",
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_dropout_backward,
+        skips=(
+            DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
+            # Lazy tensor failures
+            DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
+            # These tests fail only when built with ASAN
+            DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN),
+            DecorateInfo(
+                unittest.skip("Fails with ASAN"),
+                'TestLazyOpInfo',
+                'test_correctness_with_reusing_ir',
+                active_if=TEST_WITH_ASAN
+            ),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.dropout2d",
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs),
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        check_batched_forward_grad=False,
+        # As per the docs, valid input dims are (3, 4)
+        sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(3, 4)),
+        inplace_variant=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs, inplace=True)),
+    OpInfo(
+        "nn.functional.dropout3d",
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs),
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        check_batched_forward_grad=False,
+        # As per the docs, valid input dims are (4, 5)
+        sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)),
+        inplace_variant=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)),
+    OpInfo(
+        "nn.functional.alpha_dropout",
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs),
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        gradcheck_wrapper=wrapper_set_seed,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        sample_inputs_func=sample_inputs_dropout,
+        check_batched_forward_grad=False,
+        inplace_variant=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs, inplace=True),
+        skips=(
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # AssertionError: Tensor-likes are not close!
+            # Fails in cuda11.7
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'),
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),),
+    # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype
+    # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases
+    OpInfo(
+        "nn.functional.feature_alpha_dropout",
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs),
+        variant_test_name="with_train",
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
+            # vmap: We do not yet support calling random operations inside of vmap.
+            # Please perform random operations outside of vmap as a workaround
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_forward_mode_AD"),
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_inplace_forward_mode_AD"),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        # As per the docs, valid input dims are (4, 5)
+        sample_inputs_func=partial(sample_inputs_dropout, train=True, valid_input_dim=(4, 5)),
+        inplace_variant=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)),
+    OpInfo(
+        "nn.functional.feature_alpha_dropout",
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs),
+        variant_test_name="without_train",
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),
+        gradcheck_wrapper=wrapper_set_seed,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        sample_inputs_func=partial(sample_inputs_dropout, train=False),
+        inplace_variant=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)),
+    OpInfo(
+        "nn.functional.one_hot",
+        ref=reference_one_hot,
+        supports_out=False,
+        dtypes=_dispatch_dtypes((torch.int64,)),
+        sample_inputs_func=sample_inputs_one_hot,
+    ),
+    OpInfo(
+        "nn.functional.embedding",
+        aten_backward_name="embedding_dense_backward",
+        # We use lambda to reshuffle the positional arguments.
+        # This is because currently only the `input` field of SampleInput
+        # is tested in gradient tests.
+        op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        sample_inputs_func=sample_inputs_embedding,
+        allow_cow_input_materialize_forward=[0],
+        error_inputs_func=error_inputs_embedding,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # Fails on CI https://github.com/pytorch/pytorch/issues/85377
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'),
+            # Reference: https://github.com/pytorch/pytorch/issues/67084
+            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'),
+            # Not a problem: embedding does weird stuff to its input (it renormalizes)
+            DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
+            # Fails due to non-determinism (see issue #74679)
+            # TODO: Investigate why more granular skips in the test don't work in CI
+            DecorateInfo(unittest.skip('Skipped!'),
+                         'TestExpandedWeightFunctional',
+                         'test_expanded_weight_forward'),
+        ),
+        supports_expanded_weight=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "nn.functional.embedding_bag",
+        # We use lambda to reshuffle the positional arguments.
+        # This is because currently only the `input` field of SampleInput
+        # is tested in gradient tests.
+        op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
+        # backward is not supported for mode `max` and dtype `bfloat16`
+        backward_dtypesIfCUDA=floating_types_and(torch.float16),
+        sample_inputs_func=sample_inputs_embedding_bag,
+        skips=(
+            # lambda impl
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
+            # Not a problem: embedding_bag does weird stuff to its input (it renormalizes)
+            DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
+        ),
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        supports_out=False,
+        supports_gradgrad=False,
+        allow_cow_input_materialize_forward=[0],
+    ),
+    OpInfo(
+        "nn.functional.multi_head_attention_forward",
+        op=lambda input, *args, **kwargs:
+            wrapper_set_seed(torch.nn.functional.multi_head_attention_forward, input, *args, **kwargs),
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        sample_inputs_func=sample_inputs_multi_head_attention_forward,
+        skips=(
+            # Tensor-likes are not close
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', dtypes=(torch.float32,)),
+            DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-3, rtol=0)}), 'TestDecomp', 'test_comprehensive'),
+
+            # TODO skip this for now since we can't skip on runtime arch support (taken from scaled_dot_product_attention)
+            DecorateInfo(unittest.skip("Skipped!"), 'TestInductorOpInfo', 'test_comprehensive'),
+            # randomness
+            DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+            # lambda impl
+            # AssertionError: JIT Test does not execute any logic
+            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+            # tests running very slowly break slow tests, so we skip them instead of using `slowTest`.
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
+            DecorateInfo(
+                unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"),
+                'TestDecomp',
+                'test_comprehensive',
+                dtypes=(torch.bfloat16, torch.float16),
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"),
+                'TestDecomp',
+                'test_quick',
+                dtypes=(torch.bfloat16, torch.float16))),
+        supports_out=False,
+        supports_gradgrad=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+    ),
+    UnaryUfuncInfo(
+        "nn.functional.softplus",
+        aten_backward_name='softplus_backward',
+        ref=reference_softplus,
+        sample_kwargs=lambda device, dtype, input: ({'beta': 3, 'threshold': .2}, {'beta': 3, 'threshold': .2}),
+        sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'beta': 3, 'threshold': .2}),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        decorators=(
+            DecorateInfo(
+                toleranceOverride
+                ({
+                    torch.half: tol(atol=1e-2, rtol=1e-2),
+                    torch.bfloat16: tol(atol=1e-2, rtol=1e-2),
+                }),
+                'TestUnaryUfuncs'),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.mse_loss",
+        aten_backward_name='mse_loss_backward',
+        ref=loss_reference_reduction_wrapper(lambda input, target: (input - target) ** 2),
+        sample_inputs_func=sample_inputs_loss,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+            # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
+            # please report a bug to PyTorch.
+            DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.grid_sample",
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_grid_sample,
+        reference_inputs_func=reference_inputs_grid_sample,
+        supports_gradgrad=False,
+        gradcheck_nondet_tol=1e-15),
+    # TODO: delete this OpInfo once we add meta support for grid_sampler_3d
+    OpInfo(
+        "grid_sampler_2d",
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_grid_sampler_2d,
+        supports_gradgrad=False,
+        gradcheck_nondet_tol=1e-15,
+        skips=(
+            DecorateInfo(slowTest, 'TestDecomp', 'test_comprehensive', dtypes=(torch.float32, torch.float64),
+                         active_if=IS_WINDOWS),
+        ),),
+    OpInfo(
+        "argwhere",
+        ref=np.argwhere,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_argwhere,
+        skips=(
+            # Compiler issue on ROCm. Might need to skip until ROCm5.5
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                         dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+        ),
+    ),
+    ReductionOpInfo(
+        'all',
+        identity=True,
+        supports_autograd=False,
+        result_dtype=torch.bool,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.all),
+        skips=(
+            # FIXME: uint8 input returns uint8 instead of bool
+            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
+        ),
+    ),
+    ReductionOpInfo(
+        'any',
+        identity=False,
+        supports_autograd=False,
+        result_dtype=torch.bool,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.any),
+        skips=(
+            # FIXME: uint8 input returns uint8 instead of bool
+            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
+        ),
+    ),
+    ReductionOpInfo(
+        'amax',
+        nan_policy='propagate',
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        ref=reference_reduction_numpy(np.amax),
+        skips=(
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+        ),
+        error_inputs_func=error_inputs_aminmax_amax_amin,
+    ),
+    ReductionOpInfo(
+        'amin',
+        nan_policy='propagate',
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        ref=reference_reduction_numpy(np.amin),
+        skips=(
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+        ),
+        error_inputs_func=error_inputs_aminmax_amax_amin,
+    ),
+    ReductionOpInfo(
+        'argmax',
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        assert_jit_shape_analysis=True,
+        result_dtype=torch.int64,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
+    ),
+    ReductionOpInfo(
+        'argmin',
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        result_dtype=torch.int64,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
+    ),
+    ReductionOpInfo(
+        'count_nonzero',
+        identity=0,
+        supports_out=False,
+        supports_autograd=False,
+        result_dtype=torch.int64,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        sample_inputs_func=sample_inputs_reduction_count_nonzero,
+        ref=reference_reduction_numpy(np.count_nonzero),
+        skips=(
+            # FIXME: count_nonzero does not accept keepdim kwarg
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_unsorted_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_offbounds_keepdim'),
+            # FIXME: dim=[] reduces all dimensions
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+        ),
+    ),
+    ReductionOpInfo(
+        'mean',
+        nan_policy='propagate',
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # FIXME: mean needs 'dim' parameter when using the 'out' overload.
+        # Adding it with 'generate_args_kwargs' does not work, since these also get passed
+        # onto the reference implementations.
+        supports_out=True,
+        assert_autodiffed=True,
+        assert_jit_shape_analysis=True,
+        promotes_int_to_float=True,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.mean),
+        error_inputs_func=error_inputs_mean,
+        skips=(
+            # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result
+            # of dtype torch.float32 into an out= with dtype torch.long
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='cuda', dtypes=[torch.float32]),
+            # FIXME: mean does not support passing keepdim without passing dim
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            # FIXME: mean reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                         dtypes=[torch.float16]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values',
+                         device_type='cuda', dtypes=[torch.complex64]),
+        ),
+    ),
+    ReductionOpInfo(
+        'nanmean',
+        nan_policy='omit',
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
+        sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True),
+        ref=reference_reduction_numpy(np.nanmean),
+        skips=(
+            # AssertionError: False is not true :
+            # Failure in testing nodes' autodifferentiation.
+            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+            # FIXME: prod reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                         dtypes=[torch.float16]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
+                         device_type='cuda', dtypes=[torch.float16]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values',
+                         device_type='cuda', dtypes=[torch.complex64]),
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-5, rtol=4e-2)}),
+                         "TestConsistency", "test_output_match", device_type="mps"),
+        ),
+    ),
+    ReductionOpInfo(
+        'std',
+        nan_policy='propagate',
+        supports_out=True,
+        complex_to_real=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        check_batched_forward_grad=False,
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_std_var,
+        ref=reference_std_var(np.std),
+        generate_args_kwargs=generate_std_var_kwargs,
+        skips=(
+            # FIXME: cannot specify keepdim without dim
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            # FIXME: dim=[] reduces all dimensions
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                         dtypes=(torch.float16,)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
+                         dtypes=(torch.float16,)),
+        ),
+    ),
+    ReductionOpInfo(
+        'std',
+        variant_test_name='unbiased',
+        nan_policy='propagate',
+        supports_out=False,
+        complex_to_real=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        check_batched_forward_grad=False,
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_std_var_unbiased,
+        skips=(
+            # FIXME: dim=[] reduces all dimensions
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+        ),
+    ),
+    ReductionOpInfo(
+        'var',
+        nan_policy='propagate',
+        supports_out=True,
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        complex_to_real=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_std_var,
+        ref=reference_std_var(np.var),
+        generate_args_kwargs=generate_std_var_kwargs,
+        skips=(
+            # FIXME: cannot specify keepdim without dim
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            # FIXME: dim=[] reduces all dimensions
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'),
+            # NumPy is giving NaN for this
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'),
+        ),
+    ),
+    ReductionOpInfo(
+        'var',
+        variant_test_name='unbiased',
+        nan_policy='propagate',
+        supports_out=False,
+        complex_to_real=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        check_batched_forward_grad=False,
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_std_var_unbiased,
+        skips=(
+            # FIXME: dim=[] reduces all dimensions
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+        ),
+    ),
+    ReductionOpInfo(
+        'prod',
+        identity=1,
+        nan_policy='propagate',
+        supports_multiple_dims=False,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        promotes_int_to_int64=True,
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+        dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+        sample_inputs_func=sample_inputs_prod,
+        ref=prod_numpy,
+        skips=(
+            # FIXME: prod does not support passing keepdim without passing dim
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            # FIXME: prod reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: prod does not support passing None to dim
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                         dtypes=[torch.float16, torch.complex64]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
+                         dtypes=[torch.uint8, torch.float16, torch.complex64]),
+            # FIXME: ValueError: The data in MaskedTensor a and Tensor b do not match
+            DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
+                         dtypes=[torch.float16]),
+        ),
+    ),
+    ReductionOpInfo(
+        'sum',
+        identity=0,
+        nan_policy='propagate',
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        promotes_int_to_int64=True,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+        ref=reference_reduction_numpy(np.sum),
+        error_inputs_sparse_func=error_inputs_sparse_reduction_sum,
+        sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_coo),
+        sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csr),
+        sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csc),
+        sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsr),
+        sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsc),
+        skips=(
+            # FIXME: sum does not support passing keepdim without passing dim
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                         dtypes=[torch.float16]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
+                         dtypes=[torch.float16]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
+                         dtypes=[torch.float32]),
+        ),
+    ),
+    ReductionOpInfo(
+        'nansum',
+        identity=0,
+        nan_policy='omit',
+        supports_out=True,
+        promotes_int_to_int64=True,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+        dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+        sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True),
+        ref=reference_reduction_numpy(np.nansum),
+        skips=(
+            # please report a bug to PyTorch.
+            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+            # FIXME: nansum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: flaky test so skipped instead of xfailed
+            # possibly bad low precision reference in numpy
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                         dtypes=[torch.float16]),
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-3, rtol=4e-2)}),
+                         "TestConsistency", "test_output_match", device_type="mps"),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.ctc_loss",
+        dtypes=floating_types(),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_ctc_loss,
+        # gradcheck_wrapper, see https://github.com/pytorch/pytorch/issues/52241
+        gradcheck_wrapper=gradcheck_wrapper_ctc_loss,
+        skips=(
+            # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestBwdGradients",
+                "test_fn_gradgrad",
+                dtypes=(torch.float64,),
+            ),
+            # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=(torch.float32,),
+            ),
+            # Ref: https://github.com/pytorch/pytorch/issues/85231
+            DecorateInfo(unittest.skip("Fails with ASAN"),
+                         'TestProxyTensorOpInfo',
+                         'test_make_fx_fake_exhaustive', active_if=TEST_WITH_ASAN),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.cosine_embedding_loss",
+        dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}),
+                'TestInductorOpInfo', 'test_comprehensive', device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_cosine_embedding_loss,
+    ),
+    OpInfo(
+        "nn.functional.nll_loss",
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        sample_inputs_func=sample_inputs_nll_loss,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        assert_jit_shape_analysis=True,
+        skips=(
+            # RuntimeError:
+            # undefined value tensor:
+            #   File "", line 3
+            # def the_method(i0, i1):
+            #     return torch.nn.functional.nll_loss(i0, i1, weight=tensor([8.4784, 1.7658, 4.3228], dtype=torch.float32))
+            #                                                        ~~~~~~ <--- HERE
+            DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
+            # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120782
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCompositeCompliance",
+                "test_cow_input",
+                device_type='cuda',
+            ),
+            DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"),
+                         dtypes=(torch.half,), device_type="mps"),
+
+        ),
+    ),
+    OpInfo(
+        "nn.functional.gaussian_nll_loss",
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_gaussian_nll_loss,
+        error_inputs_func=error_inputs_gaussian_nll_loss,
+        skips=(
+            # Pre-existing condition (calls .item); needs to be fixed
+            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
+            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
+            # Pre-existing condition (calls .item); needs to be fixed
+            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
+            # JIT does not support variadic tensors.
+            # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+            # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270,
+            # please report a bug to PyTorch.
+            DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}),
+                         "TestConsistency", "test_output_match", device_type="mps"),
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=2e-3)}),
+                         "TestConsistency", "test_output_grad_match", device_type="mps"),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.hinge_embedding_loss",
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_hinge_embedding_loss,
+        error_inputs_func=error_inputs_hinge_embedding_loss,
+        reference_inputs_func=reference_inputs_hinge_embedding_loss,
+    ),
+    OpInfo(
+        "nn.functional.huber_loss",
+        aten_backward_name='huber_loss_backward',
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        sample_inputs_func=sample_inputs_huber_loss,
+        error_inputs_func=error_inputs_huber_loss,
+        skips=(
+            # JIT does not support variadic tensors.
+            # RuntimeError: input->type()->kind() == TypeKind::OptionalType
+            # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270,
+            # please report a bug to PyTorch.
+            DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
+        )
+    ),
+    OpInfo(
+        "nn.functional.pdist",
+        ref=reference_pdist,
+        sample_inputs_func=sample_inputs_pdist,
+        dtypes=floating_types(),
+        supports_out=False,
+        supports_gradgrad=False,
+        skips=(
+            DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
+        )
+    ),
+    OpInfo(
+        "nn.functional.poisson_nll_loss",
+        dtypes=all_types_and(torch.half, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_poisson_nll_loss,
+        error_inputs_func=error_inputs_poisson_nll_loss,
+    ),
+    OpInfo(
+        "argsort",
+        dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+        dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+        sample_inputs_func=sample_inputs_sort,
+        supports_out=False,
+        supports_autograd=False,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=(torch.float32,),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestCommon",
+                "test_non_standard_bool_values",
+                dtypes=[torch.bool],
+                device_type='cuda',
+                active_if=not TEST_WITH_ROCM
+            ),
+        ),
+    ),
+    OpInfo(
+        "repeat_interleave",
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
+        backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
+        sample_inputs_func=sample_inputs_repeat_interleave,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=(torch.float32, torch.complex64),
+            ),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.pairwise_distance",
+        ref=lambda a, b, p=2.0, eps=1e-6, keepdim=False: (
+            np.sum(np.abs(a - b + eps) ** p, axis=-1, keepdims=keepdim) ** (1 / p)
+        ),
+        sample_inputs_func=sample_inputs_pairwise_distance,
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=(torch.float32, torch.complex64),
+            ),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.pixel_shuffle",
+        sample_inputs_func=sample_inputs_pixel_shuffle,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=(torch.float32, torch.complex64),
+            ),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.pixel_unshuffle",
+        sample_inputs_func=sample_inputs_pixel_unshuffle,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=(torch.float32, torch.complex64),
+            ),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.channel_shuffle",
+        sample_inputs_func=sample_inputs_channel_shuffle,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        allow_cow_input_materialize_forward=[0],
+        allow_cow_input_materialize_backward=[0, 'output grad 0'],
+        skips=(
+            # Skip due to NotImplementedError for MPS device.
+            DecorateInfo(unittest.expectedFailure, 'TestConsistency'),
+            DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
+            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
+        ),
+    ),
+    OpInfo(
+        "nn.functional.kl_div",
+        sample_inputs_func=sample_inputs_kl_div,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    OpInfo(
+        "diagflat",
+        ref=lambda input, offset=0: np.diagflat(input, k=offset),
+        sample_inputs_func=sample_inputs_diagflat,
+        dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
+        dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+    ),
+    OpInfo(
+        'scatter_reduce',
+        variant_test_name='sum',
+        inplace_variant=torch.Tensor.scatter_reduce_,
+        # complex not added to dtypes as complex gradients are not properly handled
+        # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_scatter_reduce,
+        skips=(
+            # Compiler issue on ROCm. Regression started in ROCm 6.4.
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
+                         dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
+        ),
+    ),
+    OpInfo(
+        'scatter_reduce',
+        variant_test_name='prod',
+        # complex not added to dtypes as complex gradients are not properly handled
+        # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        sample_inputs_func=sample_inputs_scatter_reduce,
+        skips=(
+            # Not implemented
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'),
+            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
+        ),
+    ),
+    OpInfo(
+        'scatter_reduce',
+        variant_test_name='mean',
+        # complex not added to dtypes as complex gradients are not properly handled
+        # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_scatter_reduce,
+    ),
+    OpInfo(
+        'scatter_reduce',
+        variant_test_name='amin',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_scatter_reduce,
+    ),
+    OpInfo(
+        'scatter_reduce',
+        variant_test_name='amax',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
+        dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_scatter_reduce,
+    ),
+    OpInfo(
+        '_segment_reduce',
+        aten_name='segment_reduce',
+        variant_test_name='lengths',
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
+        supports_gradgrad=False,
+        sample_inputs_func=sample_inputs_segment_reduce,
+        skips=(
+            # FIXME: CUDA driver API confirmed a leak in
+            # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        '_segment_reduce',
+        aten_name='segment_reduce',
+        variant_test_name='offsets',
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
+        supports_gradgrad=False,
+        sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'),
+        skips=(
+            # FIXME: CUDA driver API confirmed a leak in
+            # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cuda",
+            ),
+        ),
+    ),
+]
+op_db += opinfo.definitions.op_db
+
+
+# Separate registry for experimental Python Reference OpInfos.
+python_ref_db = [
+    #
+    # Elementwise Unary OpInfos
+    #
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.abs",
+        torch_opinfo_name="abs",
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/49224
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_small',
+                         dtypes=[torch.int8], active_if=TEST_WITH_ASAN),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.acos",
+        torch_opinfo_name="acos",
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_normal',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            # Failing with wrong imaginary sign on at least some Windows jobs
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_small',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            # Failing with wrong imaginary sign on at least some Windows jobs
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.acosh",
+        torch_opinfo_name="acosh",
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_normal',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            # Failing with wrong imaginary sign on at least some Windows jobs
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_small',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.asin",
+        torch_opinfo_name="asin",
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}),
+                'TestUnaryUfuncs', device_type='cuda'),
+            DecorateInfo(
+                toleranceOverride({torch.complex64: tol(atol=5e-05, rtol=2e-05)}),
+                'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu'
+            ),
+            precisionOverride({torch.bfloat16: 1e-2}),
+        ],
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.asinh",
+        torch_opinfo_name="asinh",
+        decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_small',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_normal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cuda', dtypes=[torch.cdouble],
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.lerp",
+        torch_opinfo_name="lerp",
+    ),
+    PythonRefInfo(
+        "_refs.ones",
+        torch_opinfo_name="ones",
+        skips=(
+            # Tests that assume input is a tensor or sequence of tensors
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.zeros",
+        torch_opinfo_name="zeros",
+        skips=(
+            # Tests that assume input is a tensor or sequence of tensors
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.cauchy",
+        torch_opinfo_name="cauchy",
+        decorators=(
+            # TODO: RuntimeError: no _refs support for torch.rand_like
+            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
+                         'TestCommon',
+                         'test_python_ref'),
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+        )
+    ),
+    PythonRefInfo(
+        "_refs.exponential",
+        torch_opinfo_name="exponential",
+        supports_out=True,
+        decorators=(
+            # dtypes that do not support check_uniform_bounds of rand_like
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
+
+            # TODO: RuntimeError: no _refs support for torch.rand_like
+            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
+                         'TestCommon',
+                         'test_python_ref'),
+
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+        )
+    ),
+    PythonRefInfo(
+        "_refs.geometric",
+        torch_opinfo_name="geometric",
+        supports_out=True,
+        decorators=(
+            # dtypes that do not support check_uniform_bounds of rand_like
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
+
+            # TODO: RuntimeError: no _refs support for torch.rand_like
+            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
+                         'TestCommon',
+                         'test_python_ref'),
+            DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_executor', device_type='cuda'),
+
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+        )
+    ),
+    PythonRefInfo(
+        "_refs.log_normal",
+        torch_opinfo_name="log_normal",
+        supports_out=True,
+        decorators=(
+            # TODO: RuntimeError: no _refs support for torch.rand_like
+            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
+                         'TestCommon',
+                         'test_python_ref'),
+            DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_executor', device_type='cuda'),
+
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+        )
+    ),
+    PythonRefInfo(
+        "_refs.normal",
+        torch_opinfo_name="normal",
+        supports_out=True,
+        decorators=(
+            # TODO: RuntimeError: no _refs support for torch.rand_like
+            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
+                         'TestCommon',
+                         'test_python_ref'),
+
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+            DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+        )
+    ),
+    PythonRefInfo(
+        "_refs.normal",
+        torch_opinfo_name="normal",
+        torch_opinfo_variant_name="number_mean",
+        supports_out=True,
+        decorators=(
+            # TODO: RuntimeError: no _refs support for torch.rand_like
+            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
+                         'TestCommon',
+                         'test_python_ref'),
+
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+            DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+        )
+    ),
+    PythonRefInfo(
+        "_refs.normal_",
+        op=torch.Tensor.normal_,
+        torch_opinfo_name="normal",
+        torch_opinfo_variant_name="in_place",
+        supports_out=False,
+        decorators=(
+            # TODO: RuntimeError: no _refs support for torch.rand_like
+            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
+                         'TestCommon',
+                         'test_python_ref'),
+
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+            DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+        )
+    ),
+    PythonRefInfo(
+        "_refs.arange",
+        torch_opinfo_name="arange",
+        skips=(
+            # Tests that assume input is a tensor or sequence of tensors
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.linspace",
+        torch_opinfo_name="linspace",
+        skips=(
+            # Tests that assume input is a tensor or sequence of tensors
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+
+            # cpu implementation is wrong on some integral types
+            # https://github.com/pytorch/pytorch/issues/81996
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
+
+            # cuda implementation is off-by-one on some inputs due to precision issues
+            # https://github.com/pytorch/pytorch/issues/82230
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.linspace",
+        torch_opinfo_name="linspace",
+        torch_opinfo_variant_name="tensor_overload",
+        skips=(
+            # TypeError: 'int' object is not subscriptable
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+            # cpu implementation is wrong on some integral types
+            # https://github.com/pytorch/pytorch/issues/81996
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
+
+            # cuda implementation is off-by-one on some inputs due to precision issues
+            # https://github.com/pytorch/pytorch/issues/82230
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            # TODO torch.ops.aten.copy is not in _refs
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.float32, torch.float64, torch.float16, torch.complex64, torch.complex128, torch.bfloat16),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.float32, torch.float64, torch.float16, torch.complex64, torch.complex128, torch.bfloat16),
+                         device_type="cpu"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.logspace",
+        torch_opinfo_name="logspace",
+        skips=(
+            # Tests that assume input is a tensor or sequence of tensors
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
+
+            # Off-by-one issue when casting floats to ints
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.logspace",
+        torch_opinfo_name="logspace",
+        torch_opinfo_variant_name="tensor_overload",
+        skips=(
+            # TypeError: 'int' object is not subscriptable
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
+
+            # Off-by-one issue when casting floats to ints
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                         dtypes=(torch.int16, torch.int32, torch.int64),
+                         device_type="cuda"),
+            # TODO copy doesn't have prim refs
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                dtypes=(
+                    torch.float32, torch.float64, torch.float16, torch.complex64,
+                    torch.complex128, torch.bfloat16, torch.int8, torch.uint8
+                ),
+                device_type="cuda"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                dtypes=(
+                    torch.float32, torch.float64, torch.float16,
+                    torch.complex64, torch.complex128, torch.bfloat16,
+                    torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8
+                ),
+                device_type="cpu"),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.meshgrid",
+        torch_opinfo_name="meshgrid",
+        torch_opinfo_variant_name="variadic_tensors",
+    ),
+    PythonRefInfo(
+        "_refs.take_along_dim",
+        torch_opinfo_name="take_along_dim",
+        skips=(
+            DecorateInfo(unittest.expectedFailure,
+                         'TestCommon',
+                         'test_python_ref'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.to",
+        torch_opinfo_name="to",
+    ),
+    PythonRefInfo(
+        "_refs.triu",
+        torch_opinfo_name="triu",
+    ),
+    PythonRefInfo(
+        "_refs.tril",
+        torch_opinfo_name="tril",
+    ),
+    PythonRefInfo(
+        "_refs.triu_indices",
+        torch_opinfo_name="triu_indices",
+        # the implementation uses torch.stack that violates view consistency
+        validate_view_consistency=False,
+        skips=(
+            # skip these tests since we have non tensor input
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
+            DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
+        )),
+    PythonRefInfo(
+        "_refs.tril_indices",
+        torch_opinfo_name="tril_indices",
+        # the implementation uses torch.stack that violates view consistency
+        validate_view_consistency=False,
+        skips=(
+            # skip these tests since we have non tensor input
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
+            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
+            DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
+            DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
+        )),
+    PythonRefInfo(
+        "_refs.meshgrid",
+        torch_opinfo_name="meshgrid",
+        torch_opinfo_variant_name="list_of_tensors",
+    ),
+    PythonRefInfo(
+        "_refs.movedim",
+        aliases=('moveaxis',),
+        torch_opinfo_name="movedim",
+    ),
+    PythonRefInfo(
+        "_refs.bucketize",
+        torch_opinfo_name="bucketize",
+        skips=(
+            # RuntimeError: It appears that you're trying to get value out of a tracing tensor with
+            #  aten._local_scalar_dense.default - erroring out! [...]
+            # triggered by mid_val = boundaries[mid]
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref_executor"),
+        )
+    ),
+    PythonRefInfo(
+        "_refs.equal",
+        torch_opinfo_name="equal",
+        skips=(
+            # RuntimeError: Cannot cast FakeTensor to number
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.atan",
+        torch_opinfo_name="atan",
+        decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_small',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.atanh",
+        torch_opinfo_name="atanh",
+        decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_small',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cuda', dtypes=[torch.cfloat],
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.bitwise_not",
+        torch_opinfo_name="bitwise_not",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.ceil",
+        torch_opinfo_name="ceil",
+        # Fails on int32
+        # https://github.com/pytorch/pytorch/issues/85258
+    ),
+    PythonRefInfo(
+        "_refs.item",
+        torch_opinfo_name="item",
+        skips=(
+            # RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
+            # ValueError: Can't convert a tensor with 10 elements to a number!
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.conj_physical",
+        torch_opinfo_name="conj_physical",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.cos",
+        torch_opinfo_name="cos",
+        decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu',
+                         active_if=IS_WINDOWS),
+            # This fails on CUDA but passes on ROCm
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=(torch.cdouble,), device_type='cuda'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu',
+                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
+            # AssertionError: Tensor-likes are not close!
+            # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed)
+            # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
+            DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cuda',
+                         dtypes=(torch.chalf,), active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.cosh",
+        torch_opinfo_name="cosh",
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/48641
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.int8]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=[torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu',
+                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu',
+                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
+            # AssertionError: Tensor-likes are not close!
+            # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed)
+            # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed)
+            DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cuda',
+                         dtypes=(torch.chalf,), active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.digamma",
+        torch_opinfo_name="digamma",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.erf",
+        torch_opinfo_name="erf",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.erfinv",
+        torch_opinfo_name="erfinv",
+        decorators=(precisionOverride({torch.float16: 1e-2,
+                                       torch.bfloat16: 1e-2,
+                                       torch.float32: 1e-4}),),
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                'test_reference_numerics_extremal',
+                active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                'test_reference_numerics_large',
+                active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                'test_reference_numerics_small',
+                active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.erfc",
+        torch_opinfo_name="erfc",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.exp",
+        torch_opinfo_name="exp",
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/48010
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.expm1",
+        torch_opinfo_name="expm1",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.exp2",
+        torch_opinfo_name="exp2",
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=[torch.cdouble]),
+            # Reference: https://github.com/pytorch/pytorch/issues/48010
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.fill",
+        torch_opinfo_name="fill",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.floor",
+        torch_opinfo_name="floor",
+        # Fails on int32
+        # https://github.com/pytorch/pytorch/issues/85258
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.frexp",
+        torch_opinfo_name="frexp",
+        # Skipped due to numerical failures on Windows CI.
+        # This is also skipped in frexp earlier in the file.
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.frac",
+        torch_opinfo_name="frac",
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                'test_reference_numerics_extremal',
+                dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.imag",
+        torch_opinfo_name="imag",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.isfinite",
+        torch_opinfo_name="isfinite",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.isinf",
+        torch_opinfo_name="isinf",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.isposinf",
+        torch_opinfo_name="isposinf",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.isneginf",
+        torch_opinfo_name="isneginf",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.isnan",
+        torch_opinfo_name="isnan",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.isreal",
+        torch_opinfo_name="isreal",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.i0",
+        torch_opinfo_name="i0",
+        decorators=(precisionOverride({torch.bfloat16: 3e-1,
+                                       torch.float16: 5e-1}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=(torch.int8,)),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.lgamma",
+        torch_opinfo_name="lgamma",
+        decorators=(precisionOverride({torch.float16: 7e-1}),),
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.multigammaln",
+        torch_opinfo_name="mvlgamma",
+        torch_opinfo_variant_name="mvlgamma_p_1",
+        skips=skips_mvlgamma(),
+        decorators=(
+            DecorateInfo(torch.testing._internal.common_utils.markDynamoStrictTest, 'TestUnaryUfuncs',
+                         'test_reference_numerics_large'),
+            DecorateInfo(torch.testing._internal.common_utils.xfailIfTorchDynamo, 'TestUnaryUfuncs',
+                         'test_reference_numerics_large'),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.multigammaln",
+        torch_opinfo_name="mvlgamma",
+        torch_opinfo_variant_name="mvlgamma_p_3",
+        skips=skips_mvlgamma(),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.multigammaln",
+        torch_opinfo_name="mvlgamma",
+        torch_opinfo_variant_name="mvlgamma_p_5",
+        skips=skips_mvlgamma(),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.log",
+        torch_opinfo_name="log",
+        decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.log1p",
+        torch_opinfo_name="log1p",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.log10",
+        torch_opinfo_name="log10",
+        decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.log2",
+        torch_opinfo_name="log2",
+        decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=[torch.cfloat, torch.cdouble]),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.logsumexp",
+        torch_opinfo_name="logsumexp",
+        # When keepdim=False logsumexp function uses squeeze operation
+        # that is not yet exposed in nvFuser's Python API.
+    ),
+    PythonRefInfo(
+        "_refs.log_softmax",
+        torch_opinfo_name="log_softmax",
+        torch_opinfo_variant_name="with_dtype",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nan_to_num",
+        torch_opinfo_name="nan_to_num",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.neg",
+        torch_opinfo_name="neg",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.positive",
+        torch_opinfo_name="positive",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.real",
+        torch_opinfo_name="real",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.reciprocal",
+        torch_opinfo_name="reciprocal",
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/45690
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=[torch.cfloat, torch.cdouble]),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.round",
+        torch_opinfo_name="round",
+        # Fails on int32
+        # https://github.com/pytorch/pytorch/issues/85258
+        skips=(
+            DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
+                         "TestUnaryUfuncs", "test_reference_numerics_extremal",
+                         device_type="cuda"),
+            DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
+                         "TestUnaryUfuncs", "test_reference_numerics_normal",
+                         device_type="cuda"),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.rsqrt",
+        torch_opinfo_name="rsqrt",
+        decorators=(precisionOverride({torch.half: 5e-2}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=(torch.cfloat, torch.cdouble)),
+            # AssertionError: Tensor-likes are not close!
+            # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed)
+            # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
+            DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=(torch.chalf,)),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.sigmoid",
+        torch_opinfo_name="sigmoid",
+        aliases=('_refs.special.expit',),
+        # Reference: https://github.com/pytorch/pytorch/issues/56012
+        handles_complex_extremal_values=False,
+        handles_large_floats=False,
+        decorators=(precisionOverride({torch.float16: 1e-2,
+                                       torch.complex64: 1e-1,
+                                       torch.bfloat16: 1e-2}),),
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/56012
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=[torch.complex64, torch.cdouble], device_type='cuda'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.sign",
+        torch_opinfo_name="sign",
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/41245
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=[torch.bfloat16, torch.float16, torch.float32,
+                                 torch.float64]),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.sgn",
+        torch_opinfo_name="sgn",
+        # This is an issue with the vectorised abs on CPU
+        handles_complex_extremal_values=False,
+        handles_large_floats=False,
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/41245
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=[torch.bfloat16, torch.float16, torch.float32,
+                                 torch.float64]),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.signbit",
+        torch_opinfo_name="signbit",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.sin",
+        torch_opinfo_name="sin",
+        decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
+        skips=(
+            # Fails on CUDA but passes on ROCm
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=(torch.cdouble,), device_type='cuda'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu',
+                         active_if=IS_WINDOWS),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu',
+                         active_if=IS_WINDOWS),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.sinc",
+        torch_opinfo_name="sinc",
+        decorators=(precisionOverride({torch.bfloat16: 1e-2,
+                                       torch.float16: 1e-2}),),
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/49133
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_small',
+                         dtypes=[torch.cfloat]),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.sinh",
+        torch_opinfo_name="sinh",
+        decorators=(precisionOverride({torch.float16: 1e-2}),),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=(IS_MACOS or IS_WINDOWS)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=(IS_MACOS or IS_WINDOWS)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=(torch.cdouble,)),
+            # Reference: https://github.com/pytorch/pytorch/issues/48641
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.int8]),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.softmax",
+        torch_opinfo_name="softmax",
+        torch_opinfo_variant_name="with_dtype",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.sqrt",
+        torch_opinfo_name="sqrt",
+        decorators=(
+            precisionOverride({torch.bfloat16: 7e-2}),
+            DecorateInfo(
+                toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
+                'TestUnaryUfuncs', 'test_reference_numerics_large'),
+        ),
+        skips=(
+            # Reference: https://github.com/pytorch/pytorch/issues/47358
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=(torch.cfloat, torch.cdouble),
+                         active_if=IS_MACOS),
+            # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=(torch.bfloat16,)),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.square",
+        torch_opinfo_name="square",
+        decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),),
+        skips=(
+            # AssertionError: Reference result was farther (2.2417024338305655e-07) from the precise computation
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', dtypes=(torch.complex64,)),
+            # Reference: https://github.com/pytorch/pytorch/issues/52549
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.tan",
+        torch_opinfo_name="tan",
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}),
+                'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'),
+        ],
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=(IS_MACOS or IS_WINDOWS)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=(IS_MACOS or IS_WINDOWS)),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.tanh",
+        torch_opinfo_name="tanh",
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}),
+                'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'),
+        ],
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_extremal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=(IS_MACOS or IS_WINDOWS)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_large',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
+                         active_if=(IS_MACOS or IS_WINDOWS)),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.trunc",
+        torch_opinfo_name="trunc",
+        # Fails on int32
+        # https://github.com/pytorch/pytorch/issues/85258
+    ),
+    PythonRefInfo(
+        "_refs.special.log_softmax",
+        torch_opinfo_name="log_softmax",  # alias
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
+    PythonRefInfo(
+        "_refs.special.softmax",
+        torch_opinfo_name="softmax",  # alias
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
+    #
+    # Elementwise Unary Special OpInfos
+    #
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.logit",
+        torch_opinfo_name="logit",
+    ),
+    #
+    # Elementwise Unary nn.functional OpInfos
+    #
+    PythonRefInfo(
+        "_refs.nn.functional.alpha_dropout",
+        torch_opinfo_name="nn.functional.alpha_dropout",
+        decorators=(
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestCommon',
+                         'test_python_ref'),
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_executor', device_type='cuda'),
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestMathBits',
+                         'test_neg_view'),
+            # AssertionError: Tensor-likes are not close!
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestCommon',
+                         'test_compare_cpu'),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.celu",
+        torch_opinfo_name="nn.functional.celu",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.channel_shuffle",
+        torch_opinfo_name="nn.functional.channel_shuffle",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.threshold",
+        torch_opinfo_name="nn.functional.threshold",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.dropout",
+        torch_opinfo_name="nn.functional.dropout",
+        decorators=(
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestCommon',
+                         'test_python_ref'),
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestMathBits',
+                         'test_conj_view'),
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestMathBits',
+                         'test_neg_conj_view'),
+            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
+                         'TestMathBits',
+                         'test_neg_view'),
+            # dropout is not comparable
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.elu",
+        torch_opinfo_name="nn.functional.elu",
+        supports_out=True,
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({
+                    torch.float16: tol(atol=1e-03, rtol=1.2e-03),
+                    torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
+                }),
+                'TestUnaryUfuncs', device_type='cuda',
+            ), ],
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.hardtanh",
+        torch_opinfo_name="nn.functional.hardtanh",
+        supports_out=True,
+    ),
+    PythonRefInfo(  # TODO: Port this to an UnaryOpInfo
+        "_refs.nn.functional.gelu",
+        torch_opinfo_name="nn.functional.gelu",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.layer_norm",
+        torch_opinfo_name="nn.functional.layer_norm",
+        skips=(
+            # Reference result was farther (3.5762786809723224e-07) from the precise computation
+            # than the torch result was (2.5068410824946596e-07)!
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.float32,), device_type='cpu'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.glu",
+        torch_opinfo_name="nn.functional.glu",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.pairwise_distance",
+        torch_opinfo_name="nn.functional.pairwise_distance",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.pdist",
+        torch_opinfo_name="nn.functional.pdist",
+        supports_out=True,
+        skips=(
+            # RunTimeError: no _refs support for torch.Tensor.index_select
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+            # Reference result was farther (1.946091651916504e-05) from the precise
+            # computation than the torch result was (1.1920928955078125e-06)!
+            DecorateInfo(
+                unittest.expectedFailure,
+                'TestCommon',
+                'test_python_ref_torch_fallback',
+                dtypes=(torch.float32,),
+                device_type='cpu',
+            ),
+        )),
+    PythonRefInfo(
+        "_refs.nn.functional.leaky_relu",
+        torch_opinfo_name="nn.functional.leaky_relu",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.log_softmax",
+        torch_opinfo_name="log_softmax",  # alias
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.pixel_shuffle",
+        torch_opinfo_name="nn.functional.pixel_shuffle",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.pixel_unshuffle",
+        torch_opinfo_name="nn.functional.pixel_unshuffle",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.poisson_nll_loss",
+        torch_opinfo_name="nn.functional.poisson_nll_loss",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.prelu",
+        torch_opinfo_name="nn.functional.prelu",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.relu",
+        torch_opinfo_name="nn.functional.relu",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.relu6",
+        torch_opinfo_name="nn.functional.relu6",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.mish",
+        torch_opinfo_name="nn.functional.mish",
+        supports_out=True,
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}),
+                'TestUnaryUfuncs',), ],
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.selu",
+        torch_opinfo_name="nn.functional.selu",
+        supports_out=True,
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({
+                    torch.float16: tol(atol=1e-2, rtol=1.8e-2),
+                    torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2)
+                }),
+                'TestUnaryUfuncs', device_type='cuda',
+            ), ],
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.softmax",
+        torch_opinfo_name="softmax",  # alias
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.softmin",
+        torch_opinfo_name="nn.functional.softmin",
+        torch_opinfo_variant_name="with_dtype",
+        supports_out=False,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.softplus",
+        torch_opinfo_name="nn.functional.softplus",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.l1_loss",
+        torch_opinfo_name="nn.functional.l1_loss",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.margin_ranking_loss",
+        torch_opinfo_name="nn.functional.margin_ranking_loss",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.mse_loss",
+        torch_opinfo_name="nn.functional.mse_loss",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.smooth_l1_loss",
+        torch_opinfo_name="nn.functional.smooth_l1_loss",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.hinge_embedding_loss",
+        torch_opinfo_name="nn.functional.hinge_embedding_loss"
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.nll_loss",
+        torch_opinfo_name="nn.functional.nll_loss",
+        # The corresponding PyTorch op doesn't support out.  But the ref is
+        # registered as a decomp and ATen has an out variant.
+        supports_out=True,
+        # For simpler indexing, we flatten target indices, then reshape the result tensor.
+        # This creates inconsistent view state with reference impl.
+        validate_view_consistency=False,
+        skips=(
+            # RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda"
+            ),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.huber_loss",
+        torch_opinfo_name="nn.functional.huber_loss",
+        # The corresponding PyTorch op doesn't support out.  But the ref is
+        # registered as a decomp and ATen has an out variant.
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.tanhshrink",
+        torch_opinfo_name="nn.functional.tanhshrink",
+        decorators=[
+            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
+                         'test_reference_numerics_normal',
+                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
+            DecorateInfo(
+                toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02),
+                                   torch.complex64: tol(atol=6e-04, rtol=1e-05)}),
+                'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'),
+        ],
+        skips=(
+            # in each case, pytorch will produce a nan while numpy will not
+            DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
+                         'TestUnaryUfuncs', "test_reference_numerics_large",
+                         dtypes=(torch.complex64, torch.complex128),
+                         active_if=(IS_MACOS)),
+            DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
+                         'TestUnaryUfuncs', "test_reference_numerics_extremal",
+                         dtypes=(torch.complex64, torch.complex128),
+                         device_type='cpu',
+                         active_if=(IS_MACOS or IS_WINDOWS)),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.hardshrink",
+        torch_opinfo_name="nn.functional.hardshrink",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.nn.functional.softshrink",
+        torch_opinfo_name="nn.functional.softshrink",
+    ),
+    #
+    # Elementwise Binary Reference OpInfos
+    #
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.add",
+        torch_opinfo_name="add",
+        # https://github.com/pytorch/pytorch/issues/76944
+        supports_two_python_scalars=True,
+        supports_one_python_scalar=True,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
+                'TestBinaryUfuncs', 'test_reference_numerics'),
+        ),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestBinaryUfuncs',
+                         'test_reference_numerics_extremal_values',
+                         dtypes=(torch.complex64, torch.complex128)),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.atan2",
+        torch_opinfo_name="atan2",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.bitwise_and",
+        torch_opinfo_name="bitwise_and",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.bitwise_left_shift",
+        torch_opinfo_name="bitwise_left_shift",
+        skips=(
+            # https://github.com/pytorch/pytorch/issues/70904
+            DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.bitwise_right_shift",
+        torch_opinfo_name="bitwise_right_shift",
+        skips=(
+            # # https://github.com/pytorch/pytorch/issues/70904
+            DecorateInfo(unittest.skip("Skipped some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.bitwise_or",
+        torch_opinfo_name="bitwise_or",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.bitwise_xor",
+        torch_opinfo_name="bitwise_xor",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.copysign",
+        torch_opinfo_name="copysign",
+        skips=(
+            # RuntimeError: Expected divisor (b) to be on the same device (cuda:0) as dividend (a), but it is found on cpu!
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
+            # FIXME output 0: meta disagrees with real impl
+            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
+        )
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.div",
+        torch_opinfo_name="div",
+        torch_opinfo_variant_name="no_rounding_mode",
+        # https://github.com/pytorch/pytorch/issues/76944
+        supports_two_python_scalars=True,
+        supports_one_python_scalar=True,
+        skips=(
+            # NotImplementedError: argument of type: 
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor',
+                dtypes=(torch.complex32, torch.complex64, torch.complex128,)
+            ),
+            # Reference result was farther (0.7433461727239705) from the precise
+            # computation than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                dtypes=(torch.complex32,), device_type="cuda"
+            ),
+            # Reference result was farther (0.7433461727239705) from the precise
+            # computation than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                dtypes=(torch.complex32,), device_type="cuda"
+            ),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.div",
+        torch_opinfo_name="div",
+        torch_opinfo_variant_name="trunc_rounding",
+        # https://github.com/pytorch/pytorch/issues/76944
+        supports_two_python_scalars=True,
+        supports_one_python_scalar=True,
+        decorators=(
+            # See https://github.com/pytorch/pytorch/issues/111126
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.div",
+        torch_opinfo_name="div",
+        torch_opinfo_variant_name="floor_rounding",
+        # https://github.com/pytorch/pytorch/issues/76944
+        supports_two_python_scalars=True,
+        supports_one_python_scalar=True,
+        decorators=(
+            # See https://github.com/pytorch/pytorch/issues/111126
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+            # Reference result was farther (nan) from the precise computation than the
+            # torch result was (inf)!
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestCommon",
+                "test_python_ref",
+                dtypes=(torch.bfloat16,),
+                device_type="cpu",
+                active_if=not IS_S390X,
+            ),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.eq",
+        torch_opinfo_name="eq",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.float_power",
+        torch_opinfo_name="float_power",
+        skips=(
+            # Test doesn't account for float -> double type promotion
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+            # Complex values error with: Greatest absolute difference: nan at index
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics_small_values',
+                         dtypes=[torch.complex64, torch.complex128]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics_large_values',
+                         dtypes=[torch.complex64, torch.complex128]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics_extremal_values',
+                         dtypes=[torch.complex64, torch.complex128]),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.logaddexp",
+        torch_opinfo_name="logaddexp",
+        skips=(
+            # failure due to mismatch in edge cases, which boils down to what torch.exp(inf + infj) should be
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='cpu',
+                         dtypes=(torch.complex64, torch.complex128)),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='cpu',
+                         dtypes=(torch.complex64, torch.complex128)),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.logaddexp2",
+        torch_opinfo_name="logaddexp2",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.floor_divide",
+        torch_opinfo_name="floor_divide",
+        rhs_make_tensor_kwargs=dict(exclude_zero=True),
+        # https://github.com/pytorch/pytorch/issues/76944
+        supports_two_python_scalars=True,
+        supports_one_python_scalar=True,
+        # bfloat16 floor_divide compared with a float32 reference works inconsistently
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.bfloat16,)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.bfloat16,)),
+            # bfloat16 floor_divide compared with a float32 reference works inconsistently
+            DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
+                         dtypes=(torch.bfloat16,)),
+            # int8 floor divide has different results for -128 // -1 vs. NumPy
+            DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
+                         'test_reference_numerics_small_values',
+                         dtypes=(torch.int8,)),
+            # The following tests fails on some jobs
+            DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
+                         'test_reference_numerics_extremal_values',
+                         dtypes=(torch.float16,)),
+            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}),
+                         'TestBinaryUfuncs', 'test_reference_numerics'),
+            # FIXME output 0: meta disagrees with real impl
+            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.fmax",
+        torch_opinfo_name="fmax",
+        supports_rhs_python_scalar=False,
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.fmin",
+        torch_opinfo_name="fmin",
+        supports_rhs_python_scalar=False,
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.fmod",
+        torch_opinfo_name="fmod",
+        rhs_make_tensor_kwargs={'exclude_zero': True},
+        supports_rhs_python_scalar=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.bfloat16,), device_type='cpu'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.bfloat16,), device_type='cpu'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_contig_vs_every_other',
+                         dtypes=(torch.bfloat16,)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_non_contig',
+                         dtypes=(torch.bfloat16,)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics',
+                         dtypes=(torch.bfloat16,)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics_small_values',
+                         dtypes=(torch.uint8,)),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.gcd",
+        torch_opinfo_name="gcd",
+        skips=(
+            DecorateInfo(unittest.expectedFailure,
+                         'TestBinaryUfuncs',
+                         'test_reference_numerics_small_values',
+                         dtypes=(torch.int8,)),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.ge",
+        torch_opinfo_name="ge",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.gt",
+        torch_opinfo_name="gt",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.heaviside",
+        torch_opinfo_name="heaviside",
+        supports_rhs_python_scalar=False,
+        skips=(
+            # PyTorch's heaviside does not appear to propagate NaNs
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestBinaryUfuncs',
+                         'test_reference_numerics_extremal_values'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.hypot",
+        torch_opinfo_name="hypot",
+        supports_rhs_python_scalar=False,
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.igamma",
+        torch_opinfo_name="igamma",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.igammac",
+        torch_opinfo_name="igammac",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.isclose",
+        torch_opinfo_name="isclose",
+        skips=(
+            # Intentional xfail -- isclose does not type promote
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestBinaryUfuncs',
+                         'test_reference_numerics_extremal_values'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.lcm",
+        torch_opinfo_name="lcm",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.le",
+        torch_opinfo_name="le",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.logical_and",
+        torch_opinfo_name="logical_and",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.logical_not",
+        torch_opinfo_name="logical_not",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.logical_or",
+        torch_opinfo_name="logical_or",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.logical_xor",
+        torch_opinfo_name="logical_xor",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.lt",
+        torch_opinfo_name="lt",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.maximum",
+        torch_opinfo_name="maximum",
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.minimum",
+        torch_opinfo_name="minimum",
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.mul",
+        torch_opinfo_name="mul",
+        # https://github.com/pytorch/pytorch/issues/76944
+        supports_two_python_scalars=True,
+        supports_one_python_scalar=True,
+        skips=(
+            # Reference result was farther (0.0) from the precise computation
+            # than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                dtypes=(torch.complex32,),
+            ),
+            # Reference result was farther (0.0) from the precise computation
+            # than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                dtypes=(torch.complex32,), device_type='cuda'
+            ),
+            # Reference result was farther (0.0) from the precise computation
+            # than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                dtypes=(torch.complex32,), device_type='cuda'
+            ),
+        )
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.ne",
+        torch_opinfo_name="ne",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.nextafter",
+        torch_opinfo_name="nextafter",
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.pow",
+        torch_opinfo_name="pow",
+        decorators=(
+            DecorateInfo(
+                toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}),
+                'TestBinaryUfuncs', 'test_reference_numerics'),
+            DecorateInfo(
+                toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05),
+                                   torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}),
+                'TestBinaryUfuncs', 'test_scalar_support'),
+        ),
+        skips=(
+            # Reference result was farther (inf) from the precise
+            # computation than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                dtypes=(torch.complex32,),
+            ),
+            # Reference result was farther (inf) from the precise
+            # computation than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                dtypes=(torch.complex32,), device_type="cuda"
+            ),
+            # Reference result was farther (inf) from the precise
+            # computation than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                dtypes=(torch.complex32,), device_type="cuda"
+            ),
+            # Skipping integers because they are being raised to negative powers causing an error
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs',
+                         'test_reference_numerics_small_values',
+                         dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]),
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs',
+                         'test_reference_numerics_large_values',
+                         dtypes=[torch.int16, torch.int32, torch.int64]),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics',
+                         dtypes=(torch.complex32,)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics_small_values',
+                         dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics_large_values',
+                         dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics_extremal_values',
+                         dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.remainder",
+        torch_opinfo_name="remainder",
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.bfloat16,), device_type='cpu'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.bfloat16,), device_type='cpu'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics',
+                         dtypes=(torch.bfloat16,)),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
+                         'test_reference_numerics_small_values',
+                         dtypes=(torch.uint8,)),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.rsub",
+        torch_opinfo_name="rsub",
+        # https://github.com/pytorch/pytorch/issues/76944
+        skips=(
+            # Reference result was farther (nan) from the precise computation than
+            # the torch result was (nan)!
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.chalf,), device_type='cpu'),
+            # Reference result was farther (nan) from the precise computation than
+            # the torch result was (nan)!
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.chalf,), device_type='cpu'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.sub",
+        torch_opinfo_name="sub",
+        # https://github.com/pytorch/pytorch/issues/76944
+        supports_two_python_scalars=True,
+        supports_one_python_scalar=True,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0),
+                                   torch.bfloat16: tol(atol=1e-5, rtol=5e-3),
+                                   torch.complex32: tol(atol=1e-5, rtol=1e-3)}),
+                'TestBinaryUfuncs', 'test_reference_numerics'),
+            DecorateInfo(
+                toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
+                'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'),
+            DecorateInfo(
+                toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
+                'TestDecomp', 'test_comprehensive', device_type='cpu'),
+            DecorateInfo(
+                toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
+                'TestDecomp', 'test_quick', device_type='cpu'),
+        ),
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestBinaryUfuncs',
+                         'test_reference_numerics',
+                         dtypes=(torch.uint8,)),
+            DecorateInfo(unittest.skip("Skipped!"),
+                         'TestBinaryUfuncs',
+                         'test_reference_numerics_small_values',
+                         dtypes=(torch.uint8,)),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.true_divide",
+        torch_opinfo_name="true_divide",
+        # https://github.com/pytorch/pytorch/issues/76944
+        supports_two_python_scalars=True,
+        supports_one_python_scalar=True,
+        skips=(
+            # Reference result was farther (0.7433461727239705) from the precise
+            # computation than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
+                dtypes=(torch.complex32,),
+            ),
+            # Reference result was farther (0.7433461727239705) from the precise
+            # computation than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
+                dtypes=(torch.complex32,), device_type="cuda"
+            ),
+            # Reference result was farther (0.7433461727239705) from the precise
+            # computation than the torch result was (nan)!
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
+                dtypes=(torch.complex32,), device_type="cuda"
+            ),
+        ),
+    ),
+    #
+    # Elementwise Ternary Reference OpInfos
+    #
+    PythonRefInfo(
+        "_refs.addcdiv",
+        torch_opinfo_name="addcdiv",
+    ),
+    PythonRefInfo(
+        "_refs.addcmul",
+        torch_opinfo_name="addcmul",
+        skips=(
+            # Reference result was farther (1.3343989849090576e-05)
+            # from the precise computation than the torch result
+            # was (9.592622518539429e-06)!
+            # FIXME: enable dtype-based tolerances in test_ops.py:TestCommon._ref_test_helper
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.float16,), device_type="cpu"),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
+                         dtypes=(torch.float16,), device_type="cpu"),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.clamp_min",
+        torch_opinfo_name="clamp_min",
+        skips=(
+            # test error disabled since rhs non-tensor python scalar is supported
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.clamp_max",
+        torch_opinfo_name="clamp_max",
+        skips=(
+            # test error disabled since rhs non-tensor python scalar is supported
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.clamp",
+        torch_opinfo_name="clamp",
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.triplet_margin_loss",
+        torch_opinfo_name="nn.functional.triplet_margin_loss",
+        supports_out=False,
+        # TODO: Uses minimum and clamp
+        skips=(
+            # AssertionError: Tensor-likes are not close!
+            # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed)
+            # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed)
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.uint8,), device_type="cpu"),
+        )
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.xlogy",
+        torch_opinfo_name="xlogy",
+        supports_one_python_scalar=True,
+    ),
+    #
+    # Elementwise Binary Special OpInfos
+    #
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.special.xlog1py",
+        torch_opinfo_name="special.xlog1py",
+        supports_one_python_scalar=True,
+    ),
+    #
+    # Data Conversion & Data Movement Opinfos
+    #
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.bfloat16",
+        torch_opinfo_name="bfloat16",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.bool",
+        torch_opinfo_name="bool",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.byte",
+        torch_opinfo_name="byte",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+        skips=(
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.char",
+        torch_opinfo_name="char",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+        skips=(
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs._conversions.complex",
+        torch_opinfo_name="complex",
+        error_inputs_func=partial(error_inputs_complex, is_ref=True),
+        skips=(
+            # Tests don't account for complex's type promotion semantics
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
+        )
+    ),
+    ElementwiseBinaryPythonRefInfo(
+        "_refs._conversions.polar",
+        torch_opinfo_name="polar",
+        skips=(
+            # Tests don't account for complex's type promotion semantics
+            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
+            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.double",
+        torch_opinfo_name="double",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.float",
+        torch_opinfo_name="float",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.half",
+        torch_opinfo_name="half",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.int",
+        torch_opinfo_name="int",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+        skips=(
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.long",
+        torch_opinfo_name="long",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+        skips=(
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.short",
+        torch_opinfo_name="short",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+        skips=(
+            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
+        )
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.chalf",
+        torch_opinfo_name="chalf",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.cfloat",
+        torch_opinfo_name="cfloat",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs._conversions.cdouble",
+        torch_opinfo_name="cdouble",
+        # TODO: If self already has the correct dtype and device, then self is
+        # returned ignoring memory_format.
+        # https://github.com/pytorch/pytorch/issues/86558
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.clone",
+        torch_opinfo_name="clone",
+    ),
+    #
+    # View & Shape OpInfos
+    #
+    PythonRefInfo(
+        "_refs.alias_copy",
+        torch_opinfo_name="alias_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.atleast_1d",
+        torch_opinfo_name="atleast_1d",
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.atleast_2d",
+        torch_opinfo_name="atleast_2d",
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.atleast_3d",
+        torch_opinfo_name="atleast_3d",
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.as_strided",
+        torch_opinfo_name="as_strided",
+        # FIXME: doesn't support chalf
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.as_strided_copy",
+        torch_opinfo_name="as_strided_copy",
+        supports_out=True,
+        # FIXME: doesn't support chalf
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
+            # The view function this decompose into does not have a ref
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.as_strided",
+        torch_opinfo_name="as_strided",
+        torch_opinfo_variant_name="partial_views",
+        # FIXME: doesn't support chalf
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.as_strided_scatter",
+        torch_opinfo_name="as_strided_scatter",
+        # returns a view of an intermediate tensor (as_strided)
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.block_diag",
+        torch_opinfo_name="block_diag",
+    ),
+    PythonRefInfo(
+        "_refs.broadcast_shapes",
+        torch_opinfo_name="broadcast_shapes",
+    ),
+    PythonRefInfo(
+        "_refs.broadcast_tensors",
+        torch_opinfo_name="broadcast_tensors",
+    ),
+    PythonRefInfo(
+        "_refs.broadcast_to",
+        torch_opinfo_name="broadcast_to",
+    ),
+    PythonRefInfo(
+        "_refs.cat",
+        torch_opinfo_name="cat",
+        skips=(
+            # FIXME: AssertionError: RuntimeError not raised
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.chunk",
+        torch_opinfo_name="chunk",
+    ),
+    PythonRefInfo(
+        "_refs.column_stack",
+        torch_opinfo_name="column_stack",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.conj",
+        torch_opinfo_name="conj",
+    ),
+    PythonRefInfo(
+        "_refs.constant_pad_nd",
+        torch_opinfo_name="constant_pad_nd",
+    ),
+    PythonRefInfo(
+        "_refs.contiguous",
+        torch_opinfo_name="contiguous",
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.deg2rad",
+        torch_opinfo_name="deg2rad",
+        decorators=(precisionOverride({torch.bfloat16: 7e-1,
+                                       torch.float16: 7e-1}),),
+    ),
+    PythonRefInfo(
+        "_refs.dsplit",
+        torch_opinfo_name="dsplit",
+    ),
+    PythonRefInfo(
+        "_refs.diag",
+        torch_opinfo_name="diag",
+    ),
+    PythonRefInfo(
+        "_refs.diagonal",
+        torch_opinfo_name="diagonal",
+    ),
+    PythonRefInfo(
+        "_refs.diagonal_copy",
+        torch_opinfo_name="diagonal_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.diagonal_scatter",
+        torch_opinfo_name="diagonal_scatter",
+        supports_out=True,
+        # returns a view of an intermediate tensor (as_strided)
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.diag_embed",
+        torch_opinfo_name="diag_embed",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.dstack",
+        torch_opinfo_name="dstack",
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.expand",
+        torch_opinfo_name="expand",
+    ),
+    PythonRefInfo(
+        "_refs.expand_as",
+        torch_opinfo_name="expand_as",
+    ),
+    PythonRefInfo(
+        "_refs.expand_copy",
+        torch_opinfo_name="expand_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.flatten",
+        torch_opinfo_name="flatten",
+    ),
+    PythonRefInfo(
+        "_refs.flip",
+        torch_opinfo_name="flip",
+    ),
+    PythonRefInfo(
+        "_refs.fliplr",
+        torch_opinfo_name="fliplr",
+    ),
+    PythonRefInfo(
+        "_refs.flipud",
+        torch_opinfo_name="flipud",
+    ),
+    PythonRefInfo(
+        "_refs.hstack",
+        torch_opinfo_name="hstack",
+        skips=(
+            # https://github.com/pytorch/pytorch/issues/78613
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.narrow",
+        torch_opinfo_name="narrow",
+        error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True),
+    ),
+    PythonRefInfo(
+        "_refs.narrow_copy",
+        torch_opinfo_name="narrow_copy",
+        supports_out=True,
+        error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True),
+        skips=(
+            # The view function this decompose into does not have a ref
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.nn.functional.group_norm",
+        torch_opinfo_name="nn.functional.group_norm",
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.native_layer_norm",
+        torch_opinfo_name="native_layer_norm",
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref",
+                         device_type="cpu", dtypes=(torch.float32,)),
+            DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback",
+                         device_type="cpu", dtypes=(torch.float32,)),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.permute",
+        torch_opinfo_name="permute",
+    ),
+    PythonRefInfo(
+        "_refs.permute_copy",
+        torch_opinfo_name="permute_copy",
+        supports_out=True,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.rad2deg",
+        torch_opinfo_name="rad2deg",
+        decorators=(precisionOverride({torch.bfloat16: 7e-1,
+                                       torch.float16: 7e-1}),),
+    ),
+    PythonRefInfo(
+        "_refs.ravel",
+        torch_opinfo_name="ravel",
+    ),
+    PythonRefInfo(
+        "_refs.renorm",
+        torch_opinfo_name="renorm",
+    ),
+    PythonRefInfo(
+        "_refs.repeat",
+        torch_opinfo_name="repeat",
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.reshape",
+        torch_opinfo_name="reshape",
+    ),
+    PythonRefInfo(
+        "_refs.reshape_as",
+        torch_opinfo_name="reshape_as",
+    ),
+    PythonRefInfo(
+        "_refs.roll",
+        torch_opinfo_name="roll",
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.rot90",
+        torch_opinfo_name="rot90",
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.select_scatter",
+        torch_opinfo_name="select_scatter",
+    ),
+    PythonRefInfo(
+        "_refs.stack",
+        torch_opinfo_name="stack",
+        validate_view_consistency=False,
+    ),
+    PythonRefInfo(
+        "_refs.squeeze",
+        torch_opinfo_name="squeeze",
+    ),
+    PythonRefInfo(
+        "_refs.squeeze_copy",
+        torch_opinfo_name="squeeze_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.squeeze",
+        torch_opinfo_name="squeeze",
+        torch_opinfo_variant_name="multiple",
+    ),
+    PythonRefInfo(
+        "_refs.tensor_split",
+        torch_opinfo_name="tensor_split",
+        skips=(
+            # RuntimeError: no _refs support for torch.Tensor.tolist
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.hsplit",
+        torch_opinfo_name="hsplit",
+    ),
+    PythonRefInfo(
+        "_refs.vsplit",
+        torch_opinfo_name="vsplit",
+    ),
+    PythonRefInfo(
+        "_refs.dot",
+        torch_opinfo_name="dot",
+        error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True),
+        # .conj() does not set ._is_view() correctly in ATen
+        validate_view_consistency=False,
+        skips=(
+            # RuntimeError: no _refs support for torch.Tensor.is_conj
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.vdot",
+        torch_opinfo_name="vdot",
+        error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True),
+        # .conj() does not set ._is_view() correctly in ATen
+        validate_view_consistency=False,
+        skips=(
+            # RuntimeError: no _refs support for torch.Tensor.is_conj
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.transpose",
+        torch_opinfo_name="transpose",
+    ),
+    PythonRefInfo(
+        "_refs.transpose_copy",
+        torch_opinfo_name="transpose_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.t",
+        torch_opinfo_name="t",
+    ),
+    PythonRefInfo(
+        "_refs.t_copy",
+        torch_opinfo_name="t_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.T",
+        torch_opinfo_name="T",
+        error_inputs_func=partial(error_inputs_T, has_ndims_error=True),
+    ),
+    PythonRefInfo(
+        "_refs.unbind_copy",
+        torch_opinfo_name="unbind_copy",
+    ),
+    PythonRefInfo(
+        "_refs.unfold",
+        torch_opinfo_name="unfold",
+    ),
+    PythonRefInfo(
+        "_refs.unfold_copy",
+        torch_opinfo_name="unfold_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.unsqueeze",
+        torch_opinfo_name="unsqueeze",
+    ),
+    PythonRefInfo(
+        "_refs.unsqueeze_copy",
+        torch_opinfo_name="unsqueeze_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.view",
+        torch_opinfo_name="view",
+    ),
+    PythonRefInfo(
+        "_refs.view_as",
+        torch_opinfo_name="view_as",
+    ),
+    PythonRefInfo(
+        "_refs.view_copy",
+        torch_opinfo_name="view_copy",
+        supports_out=True,
+    ),
+    PythonRefInfo(
+        "_refs.vstack",
+        torch_opinfo_name="vstack",
+        skips=(
+            # https://github.com/pytorch/pytorch/issues/78613
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.unflatten",
+        torch_opinfo_name="unflatten",
+    ),
+    PythonRefInfo(
+        "_refs.unbind",
+        torch_opinfo_name="unbind",
+    ),
+    #
+    # Reduction Reference OpInfos
+    #
+    ReductionPythonRefInfo(
+        "_refs.all",
+        torch_opinfo_name="all",
+        skips=(
+            # FIXME: uint8 input returns uint8 instead of bool
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_result_dtype',
+                dtypes=[torch.uint8]),
+        ),
+    ),
+    ReductionPythonRefInfo(
+        "_refs.amax",
+        torch_opinfo_name="amax",
+        error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True),
+        skips=(
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+        ),
+    ),
+    ReductionPythonRefInfo(
+        "_refs.amin",
+        torch_opinfo_name="amin",
+        error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True),
+        skips=(
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+        ),
+    ),
+    ReductionPythonRefInfo(
+        "_refs.any",
+        torch_opinfo_name="any",
+        skips=(
+            # FIXME: uint8 input returns uint8 instead of bool
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_result_dtype',
+                dtypes=[torch.uint8]),
+        ),
+    ),
+    ReductionPythonRefInfo(
+        "_refs.count_nonzero",
+        torch_opinfo_name="count_nonzero",
+        skips=(
+            # FIXME: count_nonzero does not accept keepdim kwarg
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions',
+                'test_dim_default_keepdim'),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions',
+                'test_dim_multi_unsorted_keepdim'),
+            # FIXME: dim=[] reduces all dimensions
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+        ),
+    ),
+    ReductionPythonRefInfo(
+        "_refs.mean",
+        torch_opinfo_name="mean",
+        supports_out=True,
+        error_inputs_func=partial(error_inputs_mean, is_ref=True),
+        skips=(
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+        ),
+    ),
+    ReductionPythonRefInfo(
+        "_refs.std",
+        torch_opinfo_name="std",
+        supports_out=True,
+        skips=(
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                dtypes=(torch.float16,)),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions',
+                'test_ref_duplicate_values',
+                dtypes=(torch.float16,)),
+        ),
+    ),
+    # std_mean and var_mean are not ReductionInfos
+    PythonRefInfo(
+        "_refs.std_mean",
+        torch_opinfo_name="std_mean",
+    ),
+    ReductionPythonRefInfo(
+        "_refs.sum",
+        torch_opinfo_name="sum",
+        supports_out=True,
+        skips=(
+            # FIXME: doesn't test out behavior properly for this operator
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+            # FIXME: mean reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                dtypes=[torch.float16]),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions',
+                'test_ref_duplicate_values',
+                dtypes=[torch.float16]),
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
+                dtypes=[torch.float32]),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.cumsum",
+        torch_opinfo_name="cumsum",
+        supports_out=True,
+        skips=(
+            # doesn't test out behavior properly for this operator
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.cumprod",
+        torch_opinfo_name="cumprod",
+        supports_out=True,
+        skips=(
+            # doesn't test out behavior properly for this operator
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.sum_to_size",
+        torch_opinfo_name="sum_to_size",
+        validate_view_consistency=False,
+    ),
+    ReductionPythonRefInfo(
+        "_refs.prod",
+        torch_opinfo_name="prod",
+        supports_out=True,
+        supports_multiple_dims=True,
+        skips=(
+            # FIXME: doesn't test out behavior properly for this operator
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
+                dtypes=[torch.float16, torch.complex64]),
+        ),
+    ),
+    ReductionPythonRefInfo(
+        "_refs.var",
+        torch_opinfo_name="var",
+        supports_out=True,
+        skips=(
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(
+                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: improve precision
+            DecorateInfo(
+                unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.var_mean",
+        torch_opinfo_name="var_mean",
+        validate_view_consistency=False,
+    ),
+    #
+    # Linear Algebra Operators
+    #
+    PythonRefInfo(
+        "_refs.addr",
+        torch_opinfo_name="addr",
+        decorators=(
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.trace",
+        torch_opinfo_name="trace",
+    ),
+    PythonRefInfo(
+        "_refs.norm",
+        torch_opinfo_name="norm",
+        supports_out=True,
+        # Uses vector_norm inside and vector_norm is affected by
+        # https://github.com/pytorch/pytorch/issues/77216
+        validate_view_consistency=False,
+    ),
+    #
+    # Tensor Creation Reference OpInfos
+    #
+    PythonRefInfo(
+        "_refs.empty",
+        torch_opinfo_name="empty",
+        skips=(
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_python_ref'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_neg_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_neg_view'),
+            # FIXME: shouldn't check empty results
+            DecorateInfo(unittest.skip("Can't check result for empty"), 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.empty_like",
+        torch_opinfo_name="empty_like",
+        skips=(
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_python_ref'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_neg_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_neg_view'),
+            # FIXME: should not compare results of empty_like
+            DecorateInfo(unittest.skip("Can't check result for empty_like"), 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.randn",
+        torch_opinfo_name="randn",
+        op=lambda *args, **kwargs: wrapper_set_seed(refs.randn, *args, **kwargs),
+        skips=(
+            # see https://github.com/pytorch/pytorch/issues/85121
+            DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"),
+                         'TestCommon',
+                         'test_python_ref_executor'),
+            # These tests expect the input to be a tensor or a sequence of tensors
+            DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
+            DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_view'),
+            DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_conj_view'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.eye",
+        torch_opinfo_name="eye",
+        skips=(
+            # skip these tests since we have non tensor input
+            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.new_empty",
+        torch_opinfo_name="new_empty",
+        skips=(
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_python_ref'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_out'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestCommon',
+                         'test_out_warning'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_neg_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
+                         'TestMathBits',
+                         'test_neg_view'),
+            # FIXME: should not compare results of empty_like
+            DecorateInfo(unittest.skip("Can't check result for new_empty"), 'TestCommon', 'test_python_ref_executor'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.new_empty_strided",
+        torch_opinfo_name="new_empty_strided",
+        skips=(
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestCommon',
+                         'test_python_ref'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestMathBits',
+                         'test_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestMathBits',
+                         'test_neg_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestMathBits',
+                         'test_neg_view'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_executor'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.empty_strided",
+        torch_opinfo_name="empty_strided",
+        skips=(
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestCommon',
+                         'test_python_ref'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_torch_fallback'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestMathBits',
+                         'test_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestMathBits',
+                         'test_neg_conj_view'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestMathBits',
+                         'test_neg_view'),
+            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
+                         'TestCommon',
+                         'test_python_ref_executor'),
+            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.new_full",
+        torch_opinfo_name="new_full",
+    ),
+    PythonRefInfo(
+        "_refs.new_ones",
+        torch_opinfo_name="new_ones",
+    ),
+    PythonRefInfo(
+        "_refs.new_zeros",
+        torch_opinfo_name="new_zeros",
+    ),
+    #
+    # Conditional Reference OpInfos
+    #
+    PythonRefInfo(
+        "_refs.masked_fill",
+        torch_opinfo_name="masked_fill",
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.where",
+        torch_opinfo_name="where",
+        op=lambda self, condition, other: refs.where(condition, self, other),
+        supports_out=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.index_select",
+        torch_opinfo_name="index_select",
+        # empty_strided
+        skips=(
+            # no _refs support for Tensor.__setitem__
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+            # Sample out= with a stride of zero. This _out operation checks that the input has no
+            # inner overlap
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),)
+    ),
+    PythonRefInfo(
+        "_refs.index_copy",
+        torch_opinfo_name="index_copy",
+        # empty_strided
+        skips=(
+            # no _refs support for Tensor.__setitem__
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.index_add",
+        torch_opinfo_name="index_add",
+        # empty_strided
+        skips=(
+            # no _refs support for Tensor.__setitem__
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.index_fill",
+        torch_opinfo_name="index_fill",
+        # empty_strided
+        skips=(
+            # no _refs support for Tensor.__setitem__
+            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),)
+    ),
+    #
+    # Test-related functions
+    #
+    PythonRefInfo(
+        "_refs.allclose",
+        torch_opinfo_name="allclose",
+    ),
+    #
+    # Misc functions
+    #
+    PythonRefInfo(
+        "_refs.stft",
+        torch_opinfo_name="stft",
+        skips=[
+            # RuntimeError: no _refs support for aten.pad
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref'
+            ),
+        ],
+    ),
+    PythonRefInfo(
+        "_refs.istft",
+        torch_opinfo_name="istft",
+        skips=[
+            # RuntimeError: no _refs support for aten.unfold_backward
+            DecorateInfo(
+                unittest.expectedFailure, 'TestCommon', 'test_python_ref'
+            ),
+            DecorateInfo(
+                unittest.skip("Expected: unfold_backward() got an unexpected keyword argument 'input_sizes'"),
+                'TestCommon',
+                'test_python_ref_executor',
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+        ],
+    ),
+    PythonRefInfo(
+        "_refs.view_as_complex",
+        torch_opinfo_name="view_as_complex",
+    ),
+    PythonRefInfo(
+        "_refs.split_with_sizes",
+        torch_opinfo_name="split_with_sizes",
+    ),
+]
+python_ref_db += opinfo.definitions.python_ref_db
+
+# Common operator groupings
+ops_and_refs = op_db + python_ref_db
+unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)]
+binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)]
+binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo))
+spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)]
+sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse]
+sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr]
+sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse]
+shape_funcs = [op for op in ops_and_refs if isinstance(op, ShapeFuncInfo)]
+reduction_ops = [op for op in ops_and_refs if isinstance(op, ReductionOpInfo)]
+reference_filtered_ops = [op for op in reduction_ops if op.ref is not None]
+reference_masked_ops = [op for op in reference_filtered_ops if op.name.startswith('masked.')]
+sparse_masked_reduction_ops = [op for op in sparse_reduction_ops if op.name.startswith('masked.')]
+
+def index_variable(shape, max_indices, device=torch.device('cpu')):
+    if not isinstance(shape, tuple):
+        shape = (shape,)
+    return torch.testing.make_tensor(*shape, dtype=torch.long, device=device, low=0, high=max_indices)
+
+def gather_variable(shape, index_dim, max_indices, duplicate=False, device=torch.device('cpu')):
+    assert len(shape) == 2
+    assert index_dim < 2
+    batch_dim = 1 - index_dim
+    index = torch.zeros(*shape, dtype=torch.long, device=device)
+    for i in range(shape[index_dim]):
+        index.select(index_dim, i).copy_(
+            torch.randperm(max_indices, device=device)[:shape[batch_dim]])
+    if duplicate:
+        index.select(batch_dim, 0).copy_(index.select(batch_dim, 1))
+    return index
+
+def bernoulli_scalar():
+    return torch.tensor(0, dtype=torch.bool).bernoulli_()
+
+def mask_not_all_zeros(shape):
+    assert len(shape) > 0
+    while True:
+        result = torch.randn(shape).gt(0)
+        if result.sum() > 0:
+            return result
+
+# Copied from functorch
+def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
+    return (op_name, variant_name, device_type, dtypes, True)
+
+
+def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
+    return (op_name, variant_name, device_type, dtypes, False)
+
+
+def skipOps(test_case_name, base_test_name, to_skip):
+    all_opinfos = op_db
+    for xfail in to_skip:
+        op_name, variant_name, device_type, dtypes, expected_failure = xfail
+        matching_opinfos = [o for o in all_opinfos
+                            if o.name == op_name and o.variant_test_name == variant_name]
+        assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
+        for op in matching_opinfos:
+            decorators = list(op.decorators)
+            if expected_failure:
+                decorator = DecorateInfo(unittest.expectedFailure,
+                                         test_case_name, base_test_name,
+                                         device_type=device_type, dtypes=dtypes)
+                decorators.append(decorator)
+            else:
+                decorator = DecorateInfo(unittest.skip("Skipped!"),
+                                         test_case_name, base_test_name,
+                                         device_type=device_type, dtypes=dtypes)
+                decorators.append(decorator)
+            op.decorators = tuple(decorators)
+
+    # This decorator doesn't modify fn in any way
+    def wrapped(fn):
+        return fn
+    return wrapped
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a05cf807ae8f1634ac80b20348095d55c76a52
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mkldnn.py
@@ -0,0 +1,77 @@
+# mypy: ignore-errors
+
+import contextlib
+import functools
+import inspect
+
+import torch
+
+
+# Test whether hardware BF32 math mode enabled. It is enabled only on:
+# - MKLDNN is available
+# - BF16 is supported by MKLDNN
+def bf32_is_not_fp32():
+    if not torch.backends.mkldnn.is_available():
+        return False
+    if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
+        return False
+    return True
+
+
+@contextlib.contextmanager
+def bf32_off():
+    old_matmul_precision = torch.get_float32_matmul_precision()
+    try:
+        torch.set_float32_matmul_precision("highest")
+        yield
+    finally:
+        torch.set_float32_matmul_precision(old_matmul_precision)
+
+
+@contextlib.contextmanager
+def bf32_on(self, bf32_precision=1e-5):
+    old_matmul_precision = torch.get_float32_matmul_precision()
+    old_precision = self.precision
+    try:
+        torch.set_float32_matmul_precision("medium")
+        self.precision = bf32_precision
+        yield
+    finally:
+        torch.set_float32_matmul_precision(old_matmul_precision)
+        self.precision = old_precision
+
+
+# This is a wrapper that wraps a test to run this test twice, one with
+# allow_bf32=True, another with allow_bf32=False. When running with
+# allow_bf32=True, it will use reduced precision as specified by the
+# argument
+def bf32_on_and_off(bf32_precision=1e-5):
+    def with_bf32_disabled(self, function_call):
+        with bf32_off():
+            function_call()
+
+    def with_bf32_enabled(self, function_call):
+        with bf32_on(self, bf32_precision):
+            function_call()
+
+    def wrapper(f):
+        params = inspect.signature(f).parameters
+        arg_names = tuple(params.keys())
+
+        @functools.wraps(f)
+        def wrapped(*args, **kwargs):
+            kwargs.update(zip(arg_names, args))
+            cond = bf32_is_not_fp32()
+            if "device" in kwargs:
+                cond = cond and (torch.device(kwargs["device"]).type == "cpu")
+            if "dtype" in kwargs:
+                cond = cond and (kwargs["dtype"] == torch.float)
+            if cond:
+                with_bf32_disabled(kwargs["self"], lambda: f(**kwargs))
+                with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
+            else:
+                f(**kwargs)
+
+        return wrapped
+
+    return wrapper
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..d713f7b5535d5f4de365027dc00e3343cbbafdf9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_modules.py
@@ -0,0 +1,4420 @@
+# mypy: ignore-errors
+
+import torch
+import unittest
+from copy import deepcopy
+from enum import Enum
+from functools import wraps, partial
+from itertools import chain, product
+import itertools
+import math
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pack_padded_sequence
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import TEST_CUDNN
+from torch.testing._internal.common_dtype import (
+    floating_types, floating_and_complex_types_and, get_all_fp_dtypes)
+from torch.testing._internal.common_device_type import (
+    _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol,
+    skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS)
+from torch.testing._internal.common_methods_invocations import DecorateInfo
+from torch.testing._internal.common_nn import (
+    cosineembeddingloss_reference, cross_entropy_loss_reference, ctcloss_reference,
+    hingeembeddingloss_reference, huberloss_reference, kldivloss_reference,
+    marginrankingloss_reference, multimarginloss_reference, multilabelmarginloss_reference,
+    nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction)
+from torch.testing._internal.common_utils import (
+    freeze_rng_state, skipIfMPS, skipIfMPSOnMacOS13, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS,
+    skipIfTorchDynamo)
+from types import ModuleType
+import operator
+
+# List of all namespaces containing modules to test.
+MODULE_NAMESPACES: list[ModuleType] = [
+    torch.nn.modules,
+    torch.ao.nn.qat.modules,
+    torch.ao.nn.quantizable.modules,
+    torch.ao.nn.quantized.modules,
+    torch.ao.nn.quantized.modules,
+]
+
+# Modules that shouldn't be tested for one reason or another.
+MODULES_TO_SKIP: set[type] = {
+    torch.nn.Module,  # abstract base class
+    torch.nn.Container,  # deprecated
+    torch.nn.NLLLoss2d,  # deprecated
+    torch.ao.nn.quantized.MaxPool2d,  # aliases to nn.MaxPool2d
+    torch.ao.nn.quantized.MaxPool2d,  # aliases to nn.MaxPool2d
+}
+
+# List of all module classes to test.
+MODULE_CLASSES: list[type] = [*chain.from_iterable([
+    [getattr(namespace, module_name) for module_name in namespace.__all__]  # type: ignore[attr-defined]
+    for namespace in MODULE_NAMESPACES])]
+MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP]
+
+# Dict of module class -> common name. Useful for making test names more intuitive.
+# Example: torch.nn.modules.linear.Linear -> "nn.Linear"
+MODULE_CLASS_NAMES: dict[type, str] = {}
+for namespace in MODULE_NAMESPACES:
+    for module_name in namespace.__all__:  # type: ignore[attr-defined]
+        module_cls = getattr(namespace, module_name)
+        namespace_name = namespace.__name__.replace('torch.', '').replace('.modules', '')
+
+        # Deal with any aliases by preferring earlier names.
+        if module_cls not in MODULE_CLASS_NAMES:
+            MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}'
+
+
+# Specifies the modes (i.e. train, eval) to test over.
+TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval'))
+
+
+class modules(_TestParametrizer):
+    """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
+
+    def __init__(self, module_info_iterable, allowed_dtypes=None,
+                 train_eval_mode=TrainEvalMode.train_and_eval, skip_if_dynamo=True):
+        self.module_info_list = list(module_info_iterable)
+        self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
+        self.train_eval_mode = train_eval_mode
+        self.skip_if_dynamo = skip_if_dynamo
+
+    def _get_training_flags(self, module_info):
+        training_flags = []
+        if (self.train_eval_mode == TrainEvalMode.train_only or
+                self.train_eval_mode == TrainEvalMode.train_and_eval):
+            training_flags.append(True)
+
+        if (self.train_eval_mode == TrainEvalMode.eval_only or
+                self.train_eval_mode == TrainEvalMode.train_and_eval):
+            training_flags.append(False)
+
+        # If train and eval modes don't differ for the module, don't bother using more than one.
+        if not module_info.train_and_eval_differ:
+            training_flags = training_flags[:1]
+
+        return training_flags
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if device_cls is None:
+            raise RuntimeError('The @modules decorator is only intended to be used in a device-specific '
+                               'context; use it with instantiate_device_type_tests() instead of '
+                               'instantiate_parametrized_tests()')
+
+        for module_info in self.module_info_list:
+            dtypes = set(module_info.supported_dtypes(device_cls.device_type))
+            if self.allowed_dtypes is not None:
+                dtypes = dtypes.intersection(self.allowed_dtypes)
+
+            training_flags = self._get_training_flags(module_info)
+            for (training, dtype) in product(training_flags, dtypes):
+                # Construct the test name; device / dtype parts are handled outside.
+                # See [Note: device and dtype suffix placement]
+                test_name = module_info.formatted_name
+                if len(training_flags) > 1:
+                    test_name += f"_{'train_mode' if training else 'eval_mode'}"
+
+                # Construct parameter kwargs to pass to the test.
+                param_kwargs = {'module_info': module_info}
+                _update_param_kwargs(param_kwargs, 'dtype', dtype)
+                _update_param_kwargs(param_kwargs, 'training', training)
+
+                try:
+
+                    @wraps(test)
+                    def test_wrapper(*args, **kwargs):
+                        return test(*args, **kwargs)
+
+                    if self.skip_if_dynamo and not torch.testing._internal.common_utils.TEST_WITH_TORCHINDUCTOR:
+                        test_wrapper = skipIfTorchDynamo("Policy: we don't run ModuleInfo tests w/ Dynamo")(test_wrapper)
+
+                    decorator_fn = partial(module_info.get_decorators, generic_cls.__name__,
+                                           test.__name__, device_cls.device_type, dtype)
+
+                    yield (test_wrapper, test_name, param_kwargs, decorator_fn)
+                except Exception as ex:
+                    # Provides an error message for debugging before rethrowing the exception
+                    print(f"Failed to instantiate {test_name} for module {module_info.name}!")
+                    raise ex
+
+
+def get_module_common_name(module_cls):
+    if module_cls in MODULE_CLASS_NAMES:
+        # Example: "nn.Linear"
+        return MODULE_CLASS_NAMES[module_cls]
+    else:
+        return module_cls.__name__
+
+
+class FunctionInput:
+    """ Contains args and kwargs to pass as input to a function. """
+    __slots__ = ['args', 'kwargs']
+
+    def __init__(self, *args, **kwargs):
+        self.args = args
+        self.kwargs = kwargs
+
+
+class ModuleInput:
+    """ Contains args / kwargs for module instantiation + forward pass. """
+    __slots__ = ['constructor_input', 'forward_input', 'desc', 'reference_fn']
+
+    def __init__(self, constructor_input, forward_input=None, desc='', reference_fn=None):
+        self.constructor_input = constructor_input  # Inputs to pass during construction
+        self.forward_input = forward_input  # Inputs to pass to forward()
+        self.desc = desc  # Description for this set of inputs
+        self.reference_fn = reference_fn  # Reference with signature: reference_fn(module, parameters, *args, **kwargs)
+
+        if reference_fn is not None:
+
+            @wraps(reference_fn)
+            def copy_reference_fn(m, *args, **kwargs):
+                # Copy inputs to avoid undesired side effects from calling the reference.
+                args, kwargs = deepcopy(args), deepcopy(kwargs)
+
+                # Note that module parameters are passed in for convenience.
+                return reference_fn(m, list(m.parameters()), *args, **kwargs)
+
+            self.reference_fn = copy_reference_fn
+
+class ModuleErrorEnum(Enum):
+    """ Enumerates when error is raised when testing modules. """
+    CONSTRUCTION_ERROR = 0
+    FORWARD_ERROR = 1
+
+class ErrorModuleInput:
+    """
+    A ModuleInput that will cause the operation to throw an error plus information
+    about the resulting error.
+    """
+
+    __slots__ = ["module_error_input", "error_on", "error_type", "error_regex"]
+
+    def __init__(self,
+                 module_error_input,
+                 *,
+                 error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+                 error_type=RuntimeError,
+                 error_regex):
+        self.module_error_input = module_error_input
+        self.error_on = error_on
+        self.error_type = error_type
+        self.error_regex = error_regex
+
+
+class ModuleInfo:
+    """ Module information to be used in testing. """
+
+    def __init__(self,
+                 module_cls,  # Class object for the module under test
+                 *,
+                 module_inputs_func,  # Function to generate module inputs
+                 skips=(),  # Indicates which tests to skip
+                 decorators=None,  # Additional decorators to apply to generated tests
+                 dtypes=floating_types(),  # dtypes this function is expected to work with
+                 dtypesIfMPS=(torch.float16, torch.float32,),  # dtypes this function is expected to work with on MPS
+                 dtypesIfHpu=(torch.bfloat16, torch.float32,),
+                 supports_gradgrad=True,  # whether the op supports second order gradients
+                 gradcheck_nondet_tol=0.0,  # tolerance for nondeterminism while performing gradcheck
+                 module_memformat_affects_out=False,  # whether converting module to channels last will generate
+                                                      # channels last output
+                 train_and_eval_differ=False,  # whether the module has differing behavior between train and eval
+                 module_error_inputs_func=None,  # Function to generate module inputs that error
+                 gradcheck_fast_mode=None,  # Whether to use the fast implementation for gradcheck/gradgradcheck.
+                                            # When set to None, defers to the default value provided by the wrapper
+                                            # function around gradcheck (testing._internal.common_utils.gradcheck)
+                 ):
+        self.module_cls = module_cls
+        self.module_inputs_func = module_inputs_func
+        self.decorators = (*(decorators if decorators else []), *(skips if skips else []))
+        self.dtypes = dtypes
+        self.dtypesIfMPS = dtypesIfMPS
+        self.dtypesIfHpu = dtypesIfHpu
+        self.supports_gradgrad = supports_gradgrad
+        self.gradcheck_nondet_tol = gradcheck_nondet_tol
+        self.module_memformat_affects_out = module_memformat_affects_out
+        self.train_and_eval_differ = train_and_eval_differ
+        self.module_error_inputs_func = module_error_inputs_func
+        self.gradcheck_fast_mode = gradcheck_fast_mode
+        self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin)
+
+    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
+        result = []
+        for decorator in self.decorators:
+            if isinstance(decorator, DecorateInfo):
+                if decorator.is_active(test_class, test_name, device, dtype, param_kwargs):
+                    result.extend(decorator.decorators)
+            else:
+                result.append(decorator)
+        return result
+
+    def supported_dtypes(self, device_type):
+        if device_type == 'mps':
+            return self.dtypesIfMPS
+        elif device_type == 'hpu':
+            return self.dtypesIfHpu
+        else:
+            return self.dtypes
+
+    @property
+    def name(self):
+        return get_module_common_name(self.module_cls)
+
+    @property
+    def formatted_name(self):
+        return self.name.replace('.', '_')
+
+# Start of module inputs functions.
+
+def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    module_inputs = [
+        ModuleInput(constructor_input=FunctionInput(10, 8),
+                    forward_input=FunctionInput(input=make_input((4, 10))),
+                    reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
+        ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='no_bias',
+                    reference_fn=lambda m, p, i: torch.mm(i, p[0].t())),
+        ModuleInput(constructor_input=FunctionInput(3, 5),
+                    forward_input=FunctionInput(make_input(3)),
+                    desc='no_batch_dim',
+                    reference_fn=lambda m, p, i: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1])
+    ]
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def bilinear_reference_fn(m, p, x1, x2, bias=True):
+        result = torch.einsum('bn,anm,bm->ba', x1, p[0], x2)
+        if bias:
+            if x1.shape[0] == 1:
+                result = result.view(-1) + p[1]
+            else:
+                result = result + p[1].view(1, -1).expand(x1.shape[0], p[0].shape[0])
+        return result
+
+    module_inputs = [
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4),
+                    forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
+                    reference_fn=bilinear_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4, bias=False),
+                    forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
+                    desc='no_bias',
+                    reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1, x2, bias=False)),
+        ModuleInput(constructor_input=FunctionInput(2, 3, 4),
+                    forward_input=FunctionInput(make_input(2), make_input(3)),
+                    desc='no_batch_dim',
+                    reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1.view(1, -1), x2.view(1, -1))),
+    ]
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_KLDivLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_batchmean', {'reduction': 'batchmean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('log_target', {'log_target': True})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return kldivloss_reference(i, t, **constructor_kwargs)
+
+        input = make_input((10, 10)).log()
+        target = make_input((10, 10)) if kwargs.get('log_target', False) else make_input((10, 10)).log()
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(input, target),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+        scalar_input = make_input(()).log()
+        # FIXME(rec): scalar_target is unused, perhaps should be argument to FunctionInput?
+        scalar_target = (  # noqa: F841
+            make_input(()) if kwargs.get('log_target', False) else make_input(()).log()
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(scalar_input, scalar_input),
+                        desc='scalar_' + desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    def make_input(shape, device=device, dtype=dtype, requires_grad=requires_grad):
+        return make_tensor(shape, device=device, dtype=dtype,
+                           requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('ignore_index', {'ignore_index': 2}),
+        ('weights', {'weight': make_weight(4).abs()}),
+        ('weights_ignore_index', {'weight': make_weight(4).abs(), 'ignore_index': 2}),
+        ('weights_ignore_index_neg', {'weight': make_weight(4).abs(), 'ignore_index': -1})
+    ]
+
+    # TODO: Uncomment when negative weights is supported.
+    # negative_weight = make_weight(10)
+    # negative_weight[0] = -1
+    # cases.append(('weights_negative', {'weight': negative_weight}))
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return nllloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 4)),
+                                                    torch.empty(15, device=device).uniform_().mul(4).floor().long()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+        def nd_reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return nlllossNd_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5, 5)),
+                            torch.empty(2, 5, 5, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"nd_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5, 5, 2, 2)),
+                            torch.empty(2, 5, 5, 2, 2, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"higher_dim_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(
+                            make_input((2, 4, 5)),
+                            torch.empty(2, 5, device=device).uniform_().mul(4).floor().long()),
+                        desc=f"3d_{desc}",
+                        reference_fn=nd_reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(3),
+                                                    make_target(3),
+                                                    make_input(1).abs()),
+                        desc=desc,
+                        reference_fn=no_batch_dim_reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_PoissonNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('full', {'full': True}),
+        ('no_log_input', {'log_input': False}),
+        ('full_no_log_input', {'full': True, 'log_input': False}),
+    ]
+
+    def poissonnllloss_reference_fn(i, t, log_input=True, full=False, reduction='mean', eps=1e-8):
+        if log_input:
+            result = i.exp() - t.mul(i)
+        else:
+            result = i - t.mul((i + eps).log())
+
+        if full:
+            result += (t.mul(t.log()) - t + 0.5 * (2. * math.pi * t).log()).masked_fill(t <= 1, 0)
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return poissonnllloss_reference_fn(i, t, **constructor_kwargs)
+
+        log_input = constructor_kwargs.get('log_input', True)
+        input = make_input((2, 3, 4, 5)) if log_input else make_input((2, 3, 4, 5)).abs().add(0.001)
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(input,
+                                                    make_target((2, 3, 4, 5)).floor_().abs_()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MSELoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    def mse_loss_reference_fn(m, p, i, t, reduction='mean'):
+        if reduction == 'none':
+            return (i - t).pow(2)
+        elif reduction == 'mean':
+            return (i - t).pow(2).sum() / i.numel()
+        else:
+            return (i - t).pow(2).sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 4, 5)),
+                                                    make_target((2, 3, 4, 5))),
+                        desc=desc,
+                        reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs))
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_target(())),
+                        desc=f'{desc}_scalar',
+                        reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def no_batch_dim_reference_fn(m, p, *args, **kwargs):
+    """Reference function for modules supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+
+    Currently it only supports modules which return a single Tensor as output.
+    You can bind the following kwargs.
+    Kwargs:
+        batch_first[bool] : If True, all the Tensors in `args` while be unsqueezed at dim `0` .
+                        and output will be squeezed at dim `0` else dim `1` for both.
+        kwargs_to_batchify[dict] : Dictionary specifying the name of the argument and dimension to unsqueeze.
+                               Useful if there are few arguments whose batch dimension are different
+                               from the ones selected by `batch_first`.
+        is_criterion[bool] : Specify if the module is a criterion and handle the reduction for output accordingly.
+    """
+    def get_and_pop(key, default):
+        v = kwargs.get(key, default)
+        if key in kwargs:
+            kwargs.pop(key)
+        return v
+
+    batch_dim = 0 if get_and_pop('batch_first', True) else 1
+    kwargs_to_batchify = get_and_pop('kwargs_to_batchify', None)
+    is_criterion = get_and_pop('is_criterion', False)
+
+    if kwargs_to_batchify is not None:
+        assert isinstance(kwargs_to_batchify, dict)
+        for k, v in kwargs.items():
+            if k in kwargs_to_batchify and v is not None:
+                bdim = kwargs_to_batchify[k]
+                kwargs[k] = v.unsqueeze(bdim)
+
+    single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs).squeeze(batch_dim)
+
+    if is_criterion:
+        reduction = get_reduction(m)
+        if reduction == 'none':
+            return output.squeeze(0)
+    return output
+
+
+def no_batch_dim_reference_mha(m, p, *args, **kwargs):
+    """Reference function for MultiheadAttention supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    batch_dim = 0 if kwargs.get('batch_first', True) else 1
+    if 'batch_first' in kwargs:
+        kwargs.pop('batch_first')
+    if 'key_padding_mask' in kwargs and kwargs['key_padding_mask'] is not None:
+        kwargs['key_padding_mask'] = kwargs['key_padding_mask'].unsqueeze(0)
+    single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), output[1].squeeze(0))
+
+
+def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs):
+    """Reference function for RNN and GRU supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    if len(args) == 1:
+        inp, = args
+        h = None
+    elif len(args) == 2:
+        inp, h = args
+        h = h.unsqueeze(1)
+
+    batch_dim = 0 if kwargs['batch_first'] else 1
+    kwargs.pop('batch_first')
+    inp = inp.unsqueeze(batch_dim)
+    single_batch_input_args = (inp, h)
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), output[1].squeeze(1))
+
+
+def no_batch_dim_reference_lstm(m, p, *args, **kwargs):
+    """Reference function for LSTM supporting no batch dimensions.
+
+    Unbatched inputs are unsqueezed to form a
+    single batch input before passing them to the module.
+    The output is squeezed to compare with the
+    output of unbatched input to the module.
+    """
+    if len(args) == 1:
+        inp, = args
+        h = None
+    elif len(args) == 2:
+        inp, h = args
+        h = (h[0].unsqueeze(1), h[1].unsqueeze(1))
+
+    batch_dim = 0 if kwargs['batch_first'] else 1
+    kwargs.pop('batch_first')
+    inp = inp.unsqueeze(batch_dim)
+    single_batch_input_args = (inp, h)
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1)))
+
+
+def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
+    """Reference function for LSTMCell supporting no batch dimensions.
+
+    The module is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    inp, (h, c) = args
+    single_batch_input_args = (inp.unsqueeze(0), (h.unsqueeze(0), c.unsqueeze(0)))
+    with freeze_rng_state():
+        output = m(*single_batch_input_args, **kwargs)
+        return (output[0].squeeze(0), output[1].squeeze(0))
+
+
+def generate_regression_criterion_inputs(make_input):
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(reduction=reduction),
+            forward_input=FunctionInput(make_input((4, )), make_input(4,)),
+            reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True),
+            desc=f'no_batch_dim_{reduction}'
+        ) for reduction in ['none', 'mean', 'sum']]
+
+
+def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(kernel_size=2),
+                    forward_input=FunctionInput(make_input((3, 6))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(2),
+                    forward_input=FunctionInput(make_input((2, 3, 6)))),
+        ModuleInput(constructor_input=FunctionInput((2,), (2,)),
+                    forward_input=FunctionInput(make_input((2, 3, 6))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, 1),
+                    forward_input=FunctionInput(make_input((2, 3, 6))),
+                    desc='stride_pad')]
+
+
+def module_inputs_torch_nn_AvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput((2, 2)),
+                    forward_input=FunctionInput(make_input((3, 6, 6))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput((2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6)))),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='stride_pad'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor_stride'),
+        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='divisor_stride_pad')]
+
+
+
+def module_inputs_torch_nn_AvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2)),
+                    forward_input=FunctionInput(make_input((3, 4, 4, 4))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))),
+        ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride_pad'),
+        ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='stride_pad_gpu_fixedkw_output'),
+        ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2)),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))),
+                    desc='stride_pad_gpu_general_output'),
+        ModuleInput(constructor_input=FunctionInput(3, 1, 0),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='stride1_pad0_gpu_input'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='stride_pad_gpu_input_nooverlap'),
+        ModuleInput(constructor_input=FunctionInput((2, 2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor'),
+        ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride_pad'),
+        ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+                    desc='divisor_stride_pad_gpu_fixedkw_output'),
+        ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))),
+                    desc='divisor_stride_pad_gpu_general_output'),
+        ModuleInput(constructor_input=FunctionInput(3, 1, 0, divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor_stride1_pad0_gpu_input'),
+        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='divisor_stride_pad_gpu_input_nooverlap')]
+
+
+
+def module_inputs_torch_nn_AdaptiveAvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='one_output')]
+
+
+def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single_1x1output'),
+        ModuleInput(constructor_input=FunctionInput((3, 4)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple_none')]
+
+def module_inputs_torch_nn_AdaptiveAvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 2, 7))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 2, 7))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((None, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))),
+                    desc='tuple_none'),
+        ModuleInput(constructor_input=FunctionInput((3, 2, 2)),
+                    forward_input=FunctionInput(make_input((1, 1, 3, 2, 6))),
+                    desc='last_dim')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None)),
+                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
+                    desc='tuple_none')]
+
+
+def module_inputs_torch_nn_AdaptiveMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='single'),
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((3, 5, 6, 7))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='tuple'),
+        ModuleInput(constructor_input=FunctionInput((3, None, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
+                    desc='tuple_none'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 12, 9, 3))),
+                    desc='single_nonatomic'),
+        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 4, 10))),
+                    desc='tuple_nonatomic')]
+
+
+def module_inputs_torch_nn_BatchNorm1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(10,),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='affine'),
+        ModuleInput(constructor_input=FunctionInput(5,),
+                    forward_input=FunctionInput(make_input((4, 5, 3))),
+                    desc='3d_input'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, None),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='affine_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, True, False),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((4, 5, 3))),
+                    desc='3d_input_not_affine'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 9))),
+                    desc='zero_batch')]
+
+
+def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6)))),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, None),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='2d_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='momentum'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, False),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, True, False),
+                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 2, 2))),
+                    desc='zero_batch')]
+
+
+def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(3,),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, None),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='3d_simple_average'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='momentum'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, False),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='not_affine'),
+        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, True, False),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
+                    desc='not_tracking_stats'),
+        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
+                    forward_input=FunctionInput(make_input((0, 5, 2, 2, 2))),
+                    desc='zero_batch')]
+
+
+def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
+    N = kwargs['N']
+    lazy = kwargs.get('lazy', False)
+    transposed = kwargs.get('transposed', False)
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}]
+    kernel_size, C_in, C_out = 3, 4, 5
+    input_no_batch_shape = (C_in,) + tuple(i + 3 for i in range(N))
+    input_batch_shape = (2,) + input_no_batch_shape
+    return [
+        ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else
+                                       FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)),
+                    forward_input=FunctionInput(make_input(
+                        input_batch_shape if with_batch else input_no_batch_shape)),
+                    desc=('' if with_batch else 'no_batch_dim'),
+                    reference_fn=(None if with_batch else no_batch_dim_reference_fn))
+        for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list)
+    ]
+
+
+def module_inputs_torch_nn_CosineEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.7})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs):
+            return cosineembeddingloss_reference(i1, i2, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10)), make_input((15, 10)),
+                                                    make_target((15,)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((2, 3, 2, 5))),
+                    desc='4d_input')]
+
+
+def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_GLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6)))),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    desc='dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((4,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_GELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput('none'),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput('none'),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='channels_last_mem_format'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+                    desc='channels_last_3d_mem_format')]
+
+
+def module_inputs_torch_nn_ReLU6(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='channels_last_mem_format'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+                    desc='channels_last_3d_mem_format')]
+
+
+def module_inputs_torch_nn_LeakyReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(0.5),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    desc='with_negval'),
+        ModuleInput(constructor_input=FunctionInput(0.0),
+                    forward_input=FunctionInput(make_input((10, 10))),
+                    desc='with_zero_negval'),
+        ModuleInput(constructor_input=FunctionInput(0.5),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='with_negval_scalar')]
+
+
+def module_inputs_torch_nn_PReLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='1d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='1d_multiparam'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='2d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='2d_multiparam'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='3d'),
+        ModuleInput(constructor_input=FunctionInput(3),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))),
+                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
+                    desc='3d_multiparam')]
+
+
+def module_inputs_torch_nn_SELU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar')]
+
+
+def module_inputs_torch_nn_SiLU(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x))]
+
+
+def module_inputs_torch_nn_Softmax(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20))),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(0, True)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softmax2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((1, 3, 10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, False))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_LogSoftmax(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_()),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((1, 3, 10, 20))),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
+                    desc='multiparam'),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
+                    desc='multiparam_scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((4, 5))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softmin(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((10, 20)))),
+        ModuleInput(constructor_input=FunctionInput(1),
+                    forward_input=FunctionInput(make_input((2, 3, 5, 10))),
+                    desc='multidim'),
+        ModuleInput(constructor_input=FunctionInput(0),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(-1),
+                    forward_input=FunctionInput(make_input((3, 4, 10))),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softplus(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: torch.log1p(torch.exp(i))),
+        ModuleInput(constructor_input=FunctionInput(2),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=lambda m, p, i: 1. / 2. * torch.log1p(torch.exp(2 * i)),
+                    desc='beta'),
+        ModuleInput(constructor_input=FunctionInput(2, -100),
+                    forward_input=FunctionInput(make_input((10, 20))),
+                    reference_fn=(
+                        lambda m, p, i: ((i * 2) > -100).type_as(i) * i
+                        + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
+                    desc='beta_threshold'),
+        ModuleInput(constructor_input=FunctionInput(2, -100),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=(
+                        lambda m, p, i: ((i * 2) > -100).type_as(i) * i
+                        + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
+                    desc='beta_threshold_scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5)))),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    desc='lambda'),
+        ModuleInput(constructor_input=FunctionInput(1,),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='lambda_scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Softsign(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((3, 2, 5))),
+                    reference_fn=lambda m, p, i: i.div(1 + torch.abs(i))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: i.div(1 + torch.abs(i)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Tanh(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+
+def module_inputs_torch_nn_Tanhshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Threshold(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='threshold_value'),
+        ModuleInput(constructor_input=FunctionInput(2., 10.),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+                    desc='large_value'),
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input(())),
+                    desc='threshold_value_scalar'),
+        ModuleInput(constructor_input=FunctionInput(2., 1.),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_Mish(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((5, 6, 7))),
+                    reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(())),
+                    reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    reference_fn=no_batch_dim_reference_fn,
+                    desc='no_batch_dim')]
+
+
+def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4)),
+                                                make_input((2, 3, 4))),
+                    reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
+                                                                         for a, b in zip(i, t))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(()), make_input(())),
+                    reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
+                    desc='scalar')] + generate_regression_criterion_inputs(make_input)
+
+
+def module_inputs_torch_nn_SmoothL1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return smoothl1loss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_input((5, 10))),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_input(())),
+                        desc=f'scalar_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+
+def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weights', {'weight': make_weight((10,))}),
+    ]
+
+    def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        result = -(t * i.log() + (1 - t) * (1 - i).log())
+
+        if weight is not None:
+            result = result * weight
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
+                                                    make_target((15, 10)).gt(0).to(dtype)),
+                        desc=desc,
+                        reference_fn=partial(bce_loss_reference_fn, **constructor_kwargs))
+        )
+
+    scalar_weight = make_weight(())
+    module_inputs.append(
+        ModuleInput(constructor_input=FunctionInput(weight=scalar_weight),
+                    forward_input=FunctionInput(make_input((), low=1e-2, high=1 - 1e-2),
+                                                make_target(()).gt(0).to(dtype)),
+                    desc='scalar_weight',
+                    reference_fn=partial(bce_loss_reference_fn, weight=scalar_weight))
+    )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weights', {'weight': make_weight((10,))}),
+        ('scalar_weights', {'weight': make_weight(())})
+    ]
+
+    def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        # TODO: add pos_weight to the definition here and corresponding SampleInputs
+        max_val = (-i).clamp(min=0)
+        result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_())
+
+        if weight is not None:
+            result = result * weight
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.sum() / i.numel()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
+                                                    make_target((15, 10)).gt(0).to(dtype)),
+                        desc=desc,
+                        reference_fn=partial(bce_withlogitsloss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    reductions: list[str] = ['mean', 'sum', 'none']
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('weights', {'weight': make_weight((3,))}),
+        ('ignore_index', {'ignore_index': 1}),
+        ('label_smoothing', {'label_smoothing': 0.15}),
+        ('ignore_index_label_smoothing', {'ignore_index': 1, 'label_smoothing': 0.15})
+    ]
+
+    module_inputs = []
+    for reduction, (desc, constructor_kwargs) in product(reductions, cases):
+        def reference_fn(m, p, i, t, reduction=reduction, constructor_kwargs=constructor_kwargs):
+            return cross_entropy_loss_reference(i, t, reduction=reduction, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5, 5)),
+                                                    make_target((2, 5, 5), low=0, high=3)),
+                        desc=f"4d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5)),
+                                                    make_target((2, 5), low=0, high=3)),
+                        desc=f"3d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3)),
+                                                    make_target((2), low=0, high=3)),
+                        desc=f"2d_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                        forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)),
+                                                    make_target((2, 5, 5, 2, 2), low=0, high=3)),
+                        desc=f"higher_dim_{desc}_{reduction}",
+                        reference_fn=reference_fn)
+        )
+
+        if constructor_kwargs.get('ignore_index', None) is None:
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3, 4, 2)),
+                                                        make_input((5, 3, 4, 2)).softmax(dim=1)),
+                            desc=f"4d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3, 4)),
+                                                        make_input((5, 3, 4)).softmax(dim=1)),
+                            desc=f"3d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((5, 3)),
+                                                        make_input((5, 3)).softmax(dim=1)),
+                            desc=f"2d_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)),
+                                                        make_input((2, 3, 5, 5, 2, 2)).softmax(dim=1)),
+                            desc=f"higher_dim_prob_target_{desc}_{reduction}",
+                            reference_fn=reference_fn)
+            )
+            module_inputs.append(
+                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
+                            forward_input=FunctionInput(make_input((3,)),
+                                                        make_target((), low=0, high=3)),
+                            desc=f"no_batch_dim_{desc}_{reduction}",
+                            reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True))
+            )
+
+    return module_inputs
+
+
+
+def module_inputs_torch_nn_CTCLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('blank', {'blank': 14})
+    ]
+    target_dtypes = [torch.int, torch.long]
+
+    module_inputs = []
+    for target_dtype, (desc, constructor_kwargs) in product(target_dtypes, cases):
+        def reference_fn(m, p, i, t, il, tl, constructor_kwargs=constructor_kwargs):
+            return ctcloss_reference(i, t, il, tl, **constructor_kwargs)
+
+        blank = constructor_kwargs.get('blank', 0)
+        low = 0 if blank == 14 else 1
+        high = 14 if blank == 14 else 15
+
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((3, 30), dtype=target_dtype, low=low, high=high),
+                                            (50, 50, 50), (30, 25, 20)),
+                desc=f'{desc}_lengths_intlists',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((3, 30), dtype=target_dtype, low=low, high=high),
+                                            torch.tensor((50, 50, 50), device=device),
+                                            torch.tensor((30, 25, 20), device=device)),
+                desc=f'{desc}_lengths_tensors',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high),
+                                            (50, 50, 50), (30, 25, 20)),
+                desc=f'{desc}_1d_target_lengths_intlists',
+                reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**constructor_kwargs),
+                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
+                                            make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high),
+                                            torch.tensor((50, 50, 50), device=device),
+                                            torch.tensor((30, 25, 20), device=device)),
+                desc=f'{desc}_1d_target_lengths_tensors',
+                reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(3, 6, 1e-3),
+            forward_input=FunctionInput(make_input((4, 6, 5))),
+            desc='1d_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 12, 1e-3),
+            forward_input=FunctionInput(make_input((4, 12))),
+            desc='1d_affine_GN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 6, 1e-3),
+            forward_input=FunctionInput(make_input((150, 6))),
+            desc='1d_affine_large_batch'),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 5, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_affine_IN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 10, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 10))),
+            desc='1d_no_affine_LN'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 6, 1e-3),
+            forward_input=FunctionInput(make_input((4, 6, 2, 3))),
+            desc='2d_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput(3, 3, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 3, 2, 3))),
+            desc='2d_no_affine_IN'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3, 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 3, 2, 3))),
+            desc='2d_no_affine_LN'),
+    ]
+
+
+def module_inputs_torch_nn_Hardshrink(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2.),
+            forward_input=FunctionInput(make_input((4, 3, 2, 4))),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(2.),
+            forward_input=FunctionInput(make_input(())),
+            desc='scalar',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        )
+    ]
+
+
+def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 2, 5))),
+            desc='4d_input')
+    ]
+
+
+def module_inputs_torch_nn_Hardtanh(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((3, 2, 5))),
+            reference_fn=lambda m, p, i: i.clamp(-1, 1),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            reference_fn=lambda m, p, i: i.clamp(-1, 1),
+            desc='scalar',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        )
+    ]
+
+
+def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.5})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return hingeembeddingloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((10,)),
+                                                    make_target((10,)).gt(0).to(dtype).mul_(2).sub_(1)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input(()),
+                                                    make_target(()).gt(0).to(dtype).mul_(2).sub_(1)),
+                        desc=f'scalar_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_HuberLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return huberloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_input((5, 10))),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_InstanceNormNd(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    lazy = kwargs.get('lazy', False)
+    N = kwargs['N']
+    num_features, eps, momentum, affine, track_running_stats = 3, 1e-3, 0.3, False, True
+    input_no_batch_shape_dict = {1: (3, 15), 2: (3, 6, 6), 3: (3, 4, 4, 4)}
+    input_no_batch_shape = input_no_batch_shape_dict[N]
+    input_batch_shape = (4,) + input_no_batch_shape
+
+    return [
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum)
+            ),
+            forward_input=FunctionInput(make_input(input_batch_shape))),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum, affine, track_running_stats) if lazy else
+                FunctionInput(num_features, eps, momentum, affine, track_running_stats)
+            ),
+            forward_input=FunctionInput(make_input(input_batch_shape)),
+            desc='tracking_stats'),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum)
+            ),
+            forward_input=FunctionInput(make_input(input_no_batch_shape)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='tracking_stats_no_batch_dim'),
+        ModuleInput(
+            constructor_input=(
+                FunctionInput(eps, momentum, affine, track_running_stats) if lazy else
+                FunctionInput(num_features, eps, momentum, affine, track_running_stats)
+            ),
+            forward_input=FunctionInput(make_input(input_no_batch_shape)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim')
+    ]
+
+def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((128, 5, 5))),
+            desc='1d_elementwise_affine_large_batch'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_no_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((0, 5))),
+            desc='1d_empty_elementwise_affine'),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, elementwise_affine=True, bias=False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine_no_bias'),
+    ]
+
+def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def rms_norm_reference_fn(m, p, i):
+        eps = m.eps
+        if eps is None:
+            eps = torch.finfo(i.dtype).eps
+        ndim = i.ndim
+        normalized_shape = m.normalized_shape
+        weight = m.weight
+        dims = [ndim - i - 1 for i in range(len(normalized_shape))]
+        upcasted_i = i.float()
+        result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
+        if weight is not None:
+            result *= weight
+        return result.type_as(i)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((128, 5, 5))),
+            desc='1d_elementwise_affine_large_batch',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 5, 5))),
+            desc='1d_no_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
+            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
+            desc='3d_no_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+        ModuleInput(
+            constructor_input=FunctionInput([5], 1e-3),
+            forward_input=FunctionInput(make_input((0, 5))),
+            desc='1d_empty_elementwise_affine',
+            reference_fn=rms_norm_reference_fn),
+    ]
+
+
+def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(3,),
+            forward_input=FunctionInput(make_input((1, 5, 7))),
+            desc='1d'),
+        ModuleInput(
+            constructor_input=FunctionInput(2,),
+            forward_input=FunctionInput(make_input((1, 5, 7, 7))),
+            desc='2d_uneven_pad'),
+        ModuleInput(
+            constructor_input=FunctionInput(1, 1., 0.5, 2.),
+            forward_input=FunctionInput(make_input((1, 5, 7, 7, 7))),
+            desc='3d_custom_params'),
+    ]
+
+
+def module_inputs_torch_nn_LPPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7))),
+            desc='norm'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 3),
+            forward_input=FunctionInput(make_input((1, 3, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 3),
+            forward_input=FunctionInput(make_input((3, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+    ]
+
+
+
+def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((3, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='norm'),
+    ]
+
+
+def module_inputs_torch_nn_LPPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7, 7)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, 2),
+            forward_input=FunctionInput(make_input((3, 7, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput(1.5, 2),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7, 7))),
+            desc='norm'),
+    ]
+
+
+def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(4),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='3d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 4),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='stride'),
+        ModuleInput(
+            constructor_input=FunctionInput(4, return_indices=True),
+            forward_input=FunctionInput(make_input((2, 10, 4))),
+            desc='return_indices'),
+    ]
+
+
+def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
+            forward_input=FunctionInput(make_input((3, 7, 7))),
+            desc='3d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='4d_input'),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1), return_indices=True),
+            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
+            desc='return_indices'),
+    ]
+
+def module_inputs_torch_nn_MaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5)))),
+        ModuleInput(
+            constructor_input=FunctionInput(2, (2, 2, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='stride'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, (1, 1, 1)),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='stride_padding'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, 2, (1, 1, 1), return_indices=True),
+            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
+            desc='return_indices'),
+    ]
+
+
+def module_inputs_torch_nn_FractionalMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_random_samples():
+        return torch.empty((1, 3, 2), dtype=torch.double, device=device).uniform_()
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((1, 3, 5, 7))),
+            desc='ratio'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((1, 3, 7, 6))),
+            desc='size'),
+        ModuleInput(
+            constructor_input=FunctionInput(
+                2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True
+            ),
+            forward_input=FunctionInput(make_input((1, 3, 5, 7))),
+            desc='ratio_return_indices'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((3, 5, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='ratio_no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((3, 7, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='size_no_batch_dim'),
+    ]
+
+
+def module_inputs_torch_nn_FractionalMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def make_random_samples():
+        return torch.empty((2, 4, 3), dtype=torch.double, device=device).uniform_()
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))),
+            desc='ratio'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 7, 7, 7))),
+            desc='size'),
+        ModuleInput(
+            constructor_input=FunctionInput((4, 2, 3), output_size=(10, 3, 2), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((2, 4, 16, 7, 5))),
+            desc='asymsize'),
+        ModuleInput(
+            constructor_input=FunctionInput(
+                2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True
+            ),
+            forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))),
+            desc='ratio_return_indices'),
+        ModuleInput(
+            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((4, 5, 5, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='ratio_no_batch_dim'),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()),
+            forward_input=FunctionInput(make_input((4, 7, 7, 7))),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='size_no_batch_dim'),
+    ]
+
+
+def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            desc='scalar'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            desc='channels_last_mem_format'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
+            desc='channels_last_3d_mem_format'
+        )
+    ]
+
+
+def module_inputs_torch_nn_LogSigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(())),
+            reference_fn=lambda m, p, i: i.sigmoid().log(),
+            desc='scalar'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input((2, 3, 4))),
+            reference_fn=lambda m, p, i: i.sigmoid().log(),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        ),
+    ]
+
+
+def module_inputs_torch_nn_MarginRankingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('margin', {'margin': 0.5})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs):
+            return marginrankingloss_reference(i1, i2, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((50,)), make_input((50,)),
+                                                    make_target((50,)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiLabelMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return multilabelmarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((10,)),
+                                                    make_target((10), low=0, high=10)),
+                        desc=f'1d_{desc}',
+                        reference_fn=reference_fn)
+        )
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5, 10), low=0, high=10)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('p', {'p': 2}),
+        ('margin', {'margin': 0.5}),
+        ('weights', {'weight': make_weight(10)})
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return multimarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5), low=0, high=10)),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_MultiLabelSoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+        ('weight', {'weight': make_weight(10)}),
+    ]
+
+    def multilabelsoftmargin_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
+        result = t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()
+        if weight is not None:
+            result *= weight
+        result = (-result).sum(i.dim() - 1) / i.size(-1)
+
+        if reduction == 'none':
+            return result
+        elif reduction == 'mean':
+            return result.mean()
+        else:
+            return result.sum()
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 10)),
+                                                    make_target((5, 10), low=0, high=2)),
+                        desc=desc,
+                        reference_fn=partial(multilabelsoftmargin_loss_reference_fn, **constructor_kwargs))
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_SoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
+
+    cases: list[tuple[str, dict]] = [
+        ('', {}),
+        ('reduction_sum', {'reduction': 'sum'}),
+        ('reduction_mean', {'reduction': 'mean'}),
+        ('reduction_none', {'reduction': 'none'}),
+    ]
+
+    module_inputs = []
+    for desc, constructor_kwargs in cases:
+        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
+            return softmarginloss_reference(i, t, **constructor_kwargs)
+
+        module_inputs.append(
+            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
+                        forward_input=FunctionInput(make_input((5, 5)),
+                                                    make_target((5, 5)).sign()),
+                        desc=desc,
+                        reference_fn=reference_fn)
+        )
+
+    return module_inputs
+
+
+def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Reuse the TransformerEncoderLayer samples since the forward args are nearly the same.
+    samples = []
+    for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer(
+            None, device, dtype, requires_grad, training):
+        # Construct a TransformerEncoderLayer object to pass to TransformerEncoder.
+        l_args, l_kwargs = (layer_module_input.constructor_input.args,
+                            layer_module_input.constructor_input.kwargs)
+        l_kwargs['device'] = device
+        l_kwargs['dtype'] = dtype
+        encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs)
+        num_layers = 2
+        # Note: TransformerEncoderLayer takes a "src_mask" while
+        # TransformerEncoder takes a "mask"; rename kwarg appropriately.
+        forward_input = layer_module_input.forward_input
+        if 'src_mask' in forward_input.kwargs:
+            forward_input.kwargs['mask'] = forward_input.kwargs['src_mask']
+            del forward_input.kwargs['src_mask']
+        samples.append(ModuleInput(
+            constructor_input=FunctionInput(encoder_layer, num_layers),
+            forward_input=forward_input,
+            desc=layer_module_input.desc
+        ))
+    return samples
+
+def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 16, 0.0),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='relu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='gelu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4))
+            ),
+            desc='no_bias'
+        ), ]
+
+    # Samples below are for validating the no-batch-dim support.
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for src_mask, src_key_padding_mask, norm_first, batch_first, bias in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first, kwargs_to_batchify={'src_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+
+    # Samples below where we pass reference_fn are for validating the fast path,
+    # since the fast path requires no_grad mode, we run the fast path in .eval()
+    # and no_grad() in the reference_fn and verify that against the results in train mode.
+    def fast_path_reference_fn(module, parameters, *args, **kwargs):
+        assert module.training
+        module.train(False)
+        with torch.no_grad():
+            output = module(*args, **kwargs)
+        module.train(True)
+        return output
+
+    if training:
+        for norm_first, bias in itertools.product((True, False), (True, False)):
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(
+                        4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias
+                    ),
+                    forward_input=FunctionInput(
+                        make_input((2, 3, 4)),
+                    ),
+                    # fastpath doesn't run when bias=False
+                    reference_fn=fast_path_reference_fn if bias else None,
+                    desc=f'fastpath_{bias}_norm_first_{norm_first}'
+                )
+            )
+
+    return samples
+
+
+def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 16, 0.0),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='relu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='gelu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
+            forward_input=FunctionInput(
+                make_input((2, 3, 4)), make_input((2, 3, 4))
+            ),
+            desc='no_bias'
+        ), ]
+
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for tgt_mask, tgt_key_padding_mask, norm_first, bias, batch_first in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        # Using same mask for tgt and memory
+        memory_mask = tgt_mask
+        memory_key_padding_mask = tgt_key_padding_mask
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first,
+                                     kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+        src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
+        if not batch_first:
+            src, tgt = src.transpose(0, 1), tgt.transpose(0, 1)
+        if tgt_key_padding_mask is not None:
+            memory_key_padding_mask, tgt_key_padding_mask = (tgt_key_padding_mask.expand(2, 3),) * 2
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                dropout=0.0, batch_first=batch_first,
+                                                norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    src, tgt, tgt_mask=tgt_mask, memory_mask=memory_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
+                ),
+                desc=f'norm_first_{norm_first}_batch_first_{batch_first}_bias_{bias}'
+            ))
+
+    return samples
+
+
+def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = []
+    # Samples below are for validating the no-batch-dim support.
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
+    for mask, key_padding_mask, norm_first, bias, batch_first in \
+            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
+        # Using same mask for tgt and memory
+        src_mask , tgt_mask = (mask,) * 2
+        src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                num_encoder_layers=1, num_decoder_layers=1,
+                                                dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+                reference_fn=partial(no_batch_dim_reference_fn,
+                                     batch_first=batch_first,
+                                     kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}),
+                desc=f'no_batch_dim_batch_first_{batch_first}'
+            ))
+
+        src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
+        if not batch_first:
+            src = src.transpose(0, 1)
+            tgt = tgt.transpose(0, 1)
+        if key_padding_mask is not None:
+            src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask.expand(2, 3),) * 2
+
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
+                                                num_encoder_layers=1, num_decoder_layers=1,
+                                                dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
+                forward_input=FunctionInput(
+                    src, tgt, tgt_mask=tgt_mask, src_mask=src_mask,
+                    tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
+                ),
+            ))
+    return samples
+
+
+def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+            forward_input=FunctionInput(make_empty(2, 3).random_(4))
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+            forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
+            desc='discontiguous'
+        ),
+    ]
+
+
+def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = []
+    bool_vals = (True, False)
+    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
+    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3, 3)))
+    products = itertools.product(bool_vals, bool_vals, bool_vals, key_padding_masks, attn_masks)
+    for bias, add_bias_kv, add_zero_attn, key_padding_mask, attn_mask in products:
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=True,
+                                                bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
+                forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
+                                            key_padding_mask=key_padding_mask, attn_mask=attn_mask),
+                reference_fn=no_batch_dim_reference_mha,
+            )
+        )
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=False,
+                                                bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
+                forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
+                                            key_padding_mask=key_padding_mask, attn_mask=attn_mask),
+                reference_fn=partial(no_batch_dim_reference_mha, batch_first=False),
+            )
+        )
+
+    return samples
+
+
+def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10),
+            forward_input=FunctionInput(make_input(5), make_input(10)),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10, bias=True),
+            forward_input=FunctionInput(make_input(5), make_input(10)),
+            reference_fn=no_batch_dim_reference_fn,
+        )
+    ]
+
+    is_rnn = kwargs.get('is_rnn', False)
+    if is_rnn:
+        # RNN also supports `nonlinearity` argument.
+        # `tanh` is the default, so we check with `relu`
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(5, 10, bias=True, nonlinearity='relu'),
+                forward_input=FunctionInput(make_input(5), make_input(10)),
+                reference_fn=no_batch_dim_reference_fn,
+            )
+        )
+
+    return samples
+
+
+def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = (
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10),
+            forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
+            reference_fn=no_batch_dim_reference_lstmcell,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(5, 10, bias=True),
+            forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
+            reference_fn=no_batch_dim_reference_lstmcell,
+        ),
+    )
+
+    return samples
+
+def make_packed_sequence(inp, batch_sizes):
+    required_grad = inp.requires_grad
+    inp.requires_grad_(False)  # user won't have access to inp so won't be able to get its grads
+    seq = pack_padded_sequence(inp, batch_sizes)
+    seq.data.requires_grad_(required_grad)
+    return seq
+
+
+def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    is_rnn = kwargs['is_rnn']
+    nonlinearity = ('relu', 'tanh')
+    bias = (False, True)
+    batch_first = (False, True)
+    bidirectional = (False, True)
+
+    samples = []
+    if is_rnn:
+        prod_gen = product(nonlinearity, bias, batch_first, bidirectional)
+    else:
+        prod_gen = product(bias, batch_first, bidirectional)
+
+    for args in prod_gen:
+        if is_rnn:
+            nl, b, b_f, bidir = args
+        else:
+            b, b_f, bidir = args
+
+        cons_args = {'input_size': 2, 'hidden_size': 2, 'num_layers': 2,
+                     'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+        cons_args_hidden = {'input_size': 2, 'hidden_size': 3, 'num_layers': 2,
+                            'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+
+        if is_rnn:
+            cons_args['nonlinearity'] = nl
+            cons_args_hidden['nonlinearity'] = nl
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args),
+                forward_input=FunctionInput(make_input((3, 2))),
+                reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+            )
+        )
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args_hidden),
+                forward_input=FunctionInput(make_input((3, 2)), make_input((4 if bidir else 2, 3))),
+                reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+            )
+        )
+        if with_packed_sequence:
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(**cons_args),
+                    forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))),
+                    reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+                )
+            )
+            samples.append(
+                ModuleInput(
+                    constructor_input=FunctionInput(**cons_args),
+                    forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))),
+                    reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
+                )
+            )
+
+    return samples
+
+
+def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs):
+    # Currently all samples below are for validating the no-batch-dim support.
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    bias = (False, True)
+    batch_first = (False, True)
+    bidirectional = (False, True)
+    proj_sizes = (0, 2)
+
+    samples = []
+    prod_gen = product(bias, batch_first, bidirectional, proj_sizes)
+
+    for args in prod_gen:
+        b, b_f, bidir, proj_size = args
+        hidden_size = 3
+        cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
+                     'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+        cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
+                            'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
+
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args),
+                forward_input=FunctionInput(make_input((2, 2))),
+                reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
+            )
+        )
+
+        h_out = proj_size if proj_size > 0 else hidden_size
+        hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size)))
+        samples.append(
+            ModuleInput(
+                constructor_input=FunctionInput(**cons_args_hidden),
+                forward_input=FunctionInput(make_input((3, 2)), hx),
+                reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
+            )
+        )
+
+
+    return samples
+
+
+
+def module_inputs_torch_nn_ReflectionPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((2, 3))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReflectionPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReflectionPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)),
+            forward_input=FunctionInput(make_input((3, 3, 3, 3, 3))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ReplicationPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6, 7))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4)),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ZeroPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 2),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2), 3),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4), 5),
+            forward_input=FunctionInput(make_input((1, 2, 3, 4))),
+        ),
+    ]
+
+def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1, 3),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6), 7),
+            forward_input=FunctionInput(make_input((1, 2, 1, 2, 1))),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def padding1d_circular_ref(inp, pad):
+        r""" input:
+                [[[0., 1., 2.],
+                  [3., 4., 5.]]]
+                pad: (1, 2)
+                output:
+                    [[[2., 0., 1., 2., 0., 1.],
+                      [5., 3., 4., 5., 3., 4.]]]
+            """
+        return torch.cat([inp[:, :, -pad[0]:], inp, inp[:, :, :pad[1]]], dim=2)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4))),
+            reference_fn=no_batch_dim_reference_fn
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 1)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3)),
+            forward_input=FunctionInput(make_input((1, 2, 3))),
+            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    def padding2d_circular_ref(inp, pad):
+        r"""input:
+                [[[[0., 1., 2],
+                   [3., 4., 5.]]]]
+                pad: (1, 2, 2, 1)
+        output:
+            [[[[2., 0., 1., 2., 0., 1.],
+               [5., 3., 4., 5., 3., 4.],
+               [2., 0., 1., 2., 0., 1.],
+               [5., 3., 4., 5., 3., 4.],
+               [2., 0., 1., 2., 0., 1.]]]]
+        """
+        inp = torch.cat([inp[:, :, -pad[2]:], inp, inp[:, :, :pad[3]]], dim=2)
+        return torch.cat([inp[:, :, :, -pad[0]:], inp, inp[:, :, :, :pad[1]]], dim=3)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 2, 1)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((2, 3, 2, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3, 3, 1)),
+            forward_input=FunctionInput(make_input((1, 1, 3, 3))),
+            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
+        ),
+    ]
+
+def module_inputs_torch_nn_CircularPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+
+    def padding3d_circular_ref(inp, pad):
+        r"""input:
+                [[[[[ 0.,  1.,  2.],
+                    [ 3.,  4.,  5.]],
+                   [[ 6.,  7.,  8.],
+                    [ 9., 10., 11.]]]]]
+            pad: (1, 2, 2, 1, 1, 2)
+            output: [[[[[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]],
+
+                       [[ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.]],
+
+                       [[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]],
+
+                       [[ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.],
+                        [ 5.,  3.,  4.,  5.,  3.,  4.],
+                        [ 2.,  0.,  1.,  2.,  0.,  1.]],
+
+                       [[ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.],
+                        [11.,  9., 10., 11.,  9., 10.],
+                        [ 8.,  6.,  7.,  8.,  6.,  7.]]]]]
+        """
+        inp = torch.cat([inp[:, :, -pad[4]:], inp, inp[:, :, :pad[5]]], dim=2)
+        inp = torch.cat([inp[:, :, :, -pad[2]:], inp, inp[:, :, :, :pad[3]]], dim=3)
+        return torch.cat([inp[:, :, :, :, -pad[0]:], inp, inp[:, :, :, :, :pad[1]]], dim=4)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(1),
+            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
+            reference_fn=no_batch_dim_reference_fn,
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 2, 2, 1, 1, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput((3, 3, 2, 1, 2, 2)),
+            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
+            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
+        ),
+    ]
+
+
+# All these operators share similar issues on cuDNN and MIOpen
+rnn_gru_lstm_module_info_decorators = (
+    # RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
+    # We could not generate a fallback
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_grad",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
+    # Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_gradgrad",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # CUDNN GRU doesn't accept non-contiguous hx
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
+        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
+    ),
+    # MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
+    DecorateInfo(
+        unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
+        active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
+    )
+)
+
+# Start of module error inputs functions.
+
+def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 11), make_input(3, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="input has inconsistent input_size: got 11 expected 10"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(5, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="Input batch size 3 doesn't match hidden0 batch size 5"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 1, 1, 20)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="Expected hidden to be 1D or 2D, got 4D instead"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20, 'relu'),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20, 'tanh'),
+                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+    ]
+    return samples
+
+def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 11), (make_input(3, 20), make_input(3, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="input has inconsistent input_size: got 11 expected 10"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(3, 21), make_input(3, 21))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(5, 20), make_input(5, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=RuntimeError,
+            error_regex="Input batch size 3 doesn't match hidden0 batch size 5"
+        ),
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(10, 20),
+                forward_input=FunctionInput(make_input(3, 10), (make_input(3, 1, 1, 20), make_input(3, 1, 1, 20))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex="Expected hx\\[0\\] to be 1D or 2D, got 4D instead"
+        ),
+    ]
+    return samples
+
+
+def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
+    samples = [
+        ErrorModuleInput(
+            ModuleInput(constructor_input=FunctionInput(10, 0, 1)),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex="hidden_size must be greater than zero"
+        ),
+        ErrorModuleInput(
+            ModuleInput(constructor_input=FunctionInput(10, 10, 0)),
+            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
+            error_type=ValueError,
+            error_regex="num_layers must be greater than zero"
+        ),
+    ]
+    return samples
+
+def module_error_inputs_torch_nn_Pad1d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3, 4, 5))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 2D or 3D input \(got 4D input\)",
+
+        ),
+    ]
+
+def module_error_inputs_torch_nn_Pad2d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 3D or 4D input \(got 2D input\)",
+
+        ),
+    ]
+
+def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad, training, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    is_constant = kwargs.get('is_constant', False)
+
+    return [
+        ErrorModuleInput(
+            ModuleInput(
+                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
+                forward_input=FunctionInput(make_input((2, 3))),
+            ),
+            error_on=ModuleErrorEnum.FORWARD_ERROR,
+            error_type=ValueError,
+            error_regex=r"expected 4D or 5D input \(got 2D input\)",
+
+        ),
+    ]
+
+
+_macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_macos_or_newer(15, 0)
+
+
+# Database of ModuleInfo entries in alphabetical order.
+module_db: list[ModuleInfo] = [
+    ModuleInfo(torch.nn.AdaptiveAvgPool1d,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d,
+               skips=(
+                   # Fails on MPS backend if input/output sizes are not divisible
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveAvgPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d,
+               skips=(
+                   # Fails on MPS backend if input/output sizes are not divisible
+                   DecorateInfo(skipMPS),
+                   # Fails on backward check if output size is 1x1
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                   ),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveAvgPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool1d,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool1d,
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool2d,
+               ),
+    ModuleInfo(torch.nn.AdaptiveMaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.AvgPool1d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool1d,
+               ),
+    ModuleInfo(torch.nn.AvgPool2d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool2d,
+               skips=(
+                   # The difference between channels last backward and
+                   # channels first backward of AvgPool2d on CUDA is too large
+                   # See https://github.com/pytorch/pytorch/issues/107201
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='cuda',
+                   ),
+                   # error: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible
+                   DecorateInfo(skipIfMPSOnMacOS13, 'TestModule', dtypes=[torch.float16], device_type='mps',),),
+               ),
+    ModuleInfo(torch.nn.AvgPool3d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # No channels_last support for AvgPool1d as it does not take 4D inputs
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.BatchNorm1d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
+               skips=(
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ))
+               ),
+    ModuleInfo(torch.nn.BatchNorm2d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
+               skips=(
+                   # See https://github.com/pytorch/pytorch/issues/134580
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')),
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),)
+               ),
+    ModuleInfo(torch.nn.BatchNorm3d,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
+                   # RuntimeError: tried to get Double out of SymInt
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_symbolic_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),
+                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
+                   DecorateInfo(
+                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
+                       'test_aot_autograd_module_exhaustive',
+                       active_if=operator.itemgetter('training')
+                   ),)
+               ),
+    ModuleInfo(torch.nn.CELU,
+               module_inputs_func=module_inputs_torch_nn_CELU,
+               # not MPS specific, will be xfailed for all devices in next PR
+               skips=(
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace',
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.Conv1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
+                   # xfail does not work due to Fatal Python error: Aborted
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16]),
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors",
+                                device_type='mps', dtypes=[torch.float16]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Conv2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='cuda', dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float32, torch.float16]),
+                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
+                   # xfail does not work due to Fatal Python error: Aborted
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16]),
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors",
+                                device_type='mps', dtypes=[torch.float16]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Conv3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Conv3d is not supported on MPS backend
+                   DecorateInfo(skipMPS, device_type="mps"),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
+                   # xfail does not work due to Fatal Python error: Aborted
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16]),
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors",
+                                device_type='mps', dtypes=[torch.float16]),),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Fails on backward check because ViewAsRealBackward apply contiguous for grad
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
+                                dtypes=(torch.complex32, torch.complex64, torch.complex128)),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
+                                dtypes=[torch.float64, torch.complex128]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16, torch.float32]),
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
+                   # xfail does not work due to Fatal Python error: Aborted
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16]),
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors",
+                                device_type='mps', dtypes=[torch.float16]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.ConvTranspose3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
+               dtypes=floating_and_complex_types_and(torch.chalf),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # ConvTranspose3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+                   # These fail only on ROCm
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
+                                dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM),
+                   # Not implemented for chalf on CPU
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
+                                dtypes=(torch.chalf,), device_type='cuda'),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.CosineEmbeddingLoss,
+               module_inputs_func=module_inputs_torch_nn_CosineEmbeddingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.ELU,
+               module_inputs_func=module_inputs_torch_nn_ELU,
+               # not MPS specific, will be xfailed for all devices in next PR
+               skips=(
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace',
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.FractionalMaxPool2d,
+               module_inputs_func=module_inputs_torch_nn_FractionalMaxPool2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.FractionalMaxPool3d,
+               module_inputs_func=module_inputs_torch_nn_FractionalMaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.L1Loss,
+               module_inputs_func=module_inputs_torch_nn_L1Loss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.SmoothL1Loss,
+               module_inputs_func=module_inputs_torch_nn_SmoothL1Loss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible
+                   # NS: Still fails on MacOS15.1
+                   DecorateInfo(skipIfMPS, 'TestModule', 'test_non_contiguous_tensors',
+                                dtypes=[torch.float16], device_type='mps'),),
+               ),
+    ModuleInfo(torch.nn.LazyConv1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
+                   # xfail does not work due to Fatal Python error: Aborted
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16]),
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors",
+                                device_type='mps', dtypes=[torch.float16]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConv2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='cuda', dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float32, torch.float16]),
+                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
+                   # xfail does not work due to Fatal Python error: Aborted
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16]),
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors",
+                                device_type='mps', dtypes=[torch.float16]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConv3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # LazyConv3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose1d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
+                   # xfail does not work due to Fatal Python error: Aborted
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16]),
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors",
+                                device_type='mps', dtypes=[torch.float16]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose2d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
+                                dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float32, torch.float16]),
+                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
+                   # xfail does not work due to Fatal Python error: Aborted
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float16]),
+                   DecorateInfo(skipIfMPSOnMacOS13, "TestModule", "test_non_contiguous_tensors",
+                                device_type='mps', dtypes=[torch.float16]),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.LazyConvTranspose3d,
+               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True),
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               module_memformat_affects_out=True,
+               skips=(
+                   # Failure on ROCM for float32 issue #70125
+                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
+                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
+                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
+                   DecorateInfo(skipMeta),
+                   # LazyConvTranspose3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
+                   # This was wrongly being skipped before and needs investigation.
+                   # See https://github.com/pytorch/pytorch/issues/80247
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
+               ),
+               decorators=(
+                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
+               )),
+    ModuleInfo(torch.nn.Linear,
+               module_inputs_func=module_inputs_torch_nn_Linear,
+               skips=(
+                   # No channels_last support for Linear currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Bilinear,
+               module_inputs_func=module_inputs_torch_nn_Bilinear,
+               decorators=[
+                   DecorateInfo(
+                       toleranceOverride({
+                           torch.float32: tol(atol=1e-4, rtol=1e-4),
+                           torch.float64: tol(atol=1e-4, rtol=1e-4)}),
+                       'TestModule', 'test_forward', device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for Bilinear currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.LPPool1d,
+               module_inputs_func=module_inputs_torch_nn_LPPool1d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.LPPool2d,
+               module_inputs_func=module_inputs_torch_nn_LPPool2d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training') and not _macos15_or_newer,
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LPPool3d,
+               module_inputs_func=module_inputs_torch_nn_LPPool3d,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   DecorateInfo(skipIfMPS, device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.MaxPool1d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool1d,
+               ),
+    ModuleInfo(torch.nn.MaxPool2d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool2d,
+               ),
+    ModuleInfo(torch.nn.MaxPool3d,
+               module_inputs_func=module_inputs_torch_nn_MaxPool3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipIfMPS, device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.KLDivLoss,
+               module_inputs_func=module_inputs_torch_nn_KLDivLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # https://github.com/pytorch/pytorch/issues/115588
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.MSELoss,
+               module_inputs_func=module_inputs_torch_nn_MSELoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible
+                   DecorateInfo(skipIfMPSOnMacOS13, 'TestModule', 'test_non_contiguous_tensors',
+                                device_type='mps', dtypes=[torch.float16],),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.MarginRankingLoss,
+               module_inputs_func=module_inputs_torch_nn_MarginRankingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.MultiLabelMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiLabelMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # 'aten::multilabel_margin_loss_forward' is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps'),
+                   # derivative for aten::multilabel_margin_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.MultiMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # 'aten::multi_margin_loss' is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps'),
+                   # RuntimeError: derivative for aten::multi_margin_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
+               ),
+    ModuleInfo(torch.nn.SoftMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_SoftMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.MultiLabelSoftMarginLoss,
+               module_inputs_func=module_inputs_torch_nn_MultiLabelSoftMarginLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.NLLLoss,
+               module_inputs_func=module_inputs_torch_nn_NLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.GaussianNLLLoss,
+               module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
+    ModuleInfo(torch.nn.PoissonNLLLoss,
+               module_inputs_func=module_inputs_torch_nn_PoissonNLLLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
+    ModuleInfo(torch.nn.HingeEmbeddingLoss,
+               module_inputs_func=module_inputs_torch_nn_HingeEmbeddingLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.HuberLoss,
+               module_inputs_func=module_inputs_torch_nn_HuberLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: seemingly incorrect output dtype
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.BCELoss,
+               module_inputs_func=module_inputs_torch_nn_BCELoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # error: input types 'tensor' and 'tensor<15x10xf16>' are not broadcast compatible
+                   DecorateInfo(skipIfMPS, 'TestModule', dtypes=[torch.float16], device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.BCEWithLogitsLoss,
+               module_inputs_func=module_inputs_torch_nn_BCEWithLogitsLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # see #119108: tolerance issue
+                   DecorateInfo(skipIfMPS, 'TestModule', dtypes=[torch.float16], device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.CrossEntropyLoss,
+               module_inputs_func=module_inputs_torch_nn_CrossEntropyLoss,
+               dtypes=get_all_fp_dtypes(include_half=True, include_bfloat16=False),
+               decorators=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'),
+                   DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule",
+                                "test_forward", dtypes=[torch.float16], device_type='cpu'),
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16],
+                                device_type='cuda'),),
+               ),
+    ModuleInfo(torch.nn.CTCLoss,
+               module_inputs_func=module_inputs_torch_nn_CTCLoss,
+               skips=(
+                   # No channels_last support for loss functions.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # The operator aten::_ctc_loss is not currently implemented for the MPS device.
+                   DecorateInfo(skipIfMPS, 'TestModule', device_type='mps',),
+                   # derivative for aten::_ctc_loss_backward is not implemented
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
+                   # https://github.com/pytorch/pytorch/issues/115585
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_non_contiguous_tensors'),)
+               ),
+    ModuleInfo(torch.nn.GELU,
+               module_inputs_func=module_inputs_torch_nn_GELU,
+               skips=(
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
+                                device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.GLU,
+               module_inputs_func=module_inputs_torch_nn_GLU,
+               ),
+    ModuleInfo(torch.nn.GroupNorm,
+               module_inputs_func=module_inputs_torch_nn_GroupNorm,
+               dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True),
+               skips=(
+                   # Tracking at https://github.com/pytorch/pytorch/issues/98089
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_memory_format', device_type='cpu'),
+                   # No channels_last support for GroupNorm currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='mps'),
+                   DecorateInfo(unittest.skip("Skipped!"), "TestModule", "test_grad",
+                                active_if=TEST_WITH_ROCM, device_type='cuda'),)
+               ),
+    ModuleInfo(torch.nn.Hardshrink,
+               module_inputs_func=module_inputs_torch_nn_Hardshrink,
+               ),
+    ModuleInfo(torch.nn.Hardswish,
+               module_inputs_func=module_inputs_torch_nn_Hardswish,
+               supports_gradgrad=False),
+    ModuleInfo(torch.nn.Hardtanh,
+               module_inputs_func=module_inputs_torch_nn_Hardtanh,
+               ),
+    ModuleInfo(torch.nn.InstanceNorm1d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=1),
+               train_and_eval_differ=True,
+               skips=(
+                   # No channels_last support for InstanceNorm1d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.InstanceNorm2d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=2),
+               train_and_eval_differ=True,
+               skips=(
+                   # No channels_last support for InstanceNorm2d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.InstanceNorm3d,
+               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=3),
+               train_and_eval_differ=True,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_memory_format'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_non_contiguous_tensors'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_forward'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_non_contiguous'),
+                   DecorateInfo(expectedFailureMPS, 'TestModuleMPS', 'test_save_load'),
+                   # No channels_last support for InstanceNorm3d currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.LocalResponseNorm,
+               module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
+               skips=(
+                   # uses avg_pool3d which is not supported on MPS backend
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'),
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous_tensors'),
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_forward'),
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_if_train_and_eval_modes_differ'),
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous'),
+                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_save_load'),)
+               ),
+    ModuleInfo(torch.nn.LayerNorm,
+               module_inputs_func=module_inputs_torch_nn_LayerNorm,
+               skips=(
+                   # No channels_last support for LayerNorm currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.RMSNorm,
+               module_inputs_func=module_inputs_torch_nn_RMSNorm,
+               ),
+    # TransformerEncoder takes the same inputs as TransformerEncoderLayer
+    ModuleInfo(torch.nn.TransformerEncoder,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_TransformerEncoder,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerEncoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # Doesn't support device / dtype kwargs directly because it is just a
+                   # container of TransformerEncoderLayers.
+                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_factory_kwargs'),)
+               ),
+    ModuleInfo(torch.nn.TransformerEncoderLayer,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
+               decorators=[
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_non_contiguous_tensors',
+                                device_type='cpu', active_if=IS_WINDOWS),
+                   DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}),
+                                'TestModule', 'test_forward',
+                                device_type='mps'),
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerEncoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.TransformerDecoderLayer,
+               module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for TransformerDecoderLayer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Transformer,
+               module_inputs_func=module_inputs_torch_nn_Transformer,
+               # Inputs are too large to run with slow gradcheck
+               # https://github.com/pytorch/pytorch/issues/117140
+               gradcheck_fast_mode=True,
+               decorators=[
+                   # Not implemented for SDPA backward derivative
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
+                                device_type='cpu'),
+               ],
+               skips=(
+                   # No channels_last support for Transformer currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.MultiheadAttention,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_MultiheadAttention,
+               skips=(
+                   # No channels_last support for MultiheadAttention currently.
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Embedding,
+               module_inputs_func=module_inputs_torch_nn_Embedding,
+               decorators=[
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_non_contiguous_tensors',
+                                device_type='mps')],
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.ReLU,
+               module_inputs_func=module_inputs_torch_nn_ReLU,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LeakyReLU,
+               module_inputs_func=module_inputs_torch_nn_LeakyReLU,
+               ),
+    ModuleInfo(torch.nn.ReLU6,
+               module_inputs_func=module_inputs_torch_nn_ReLU6,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.PReLU,
+               module_inputs_func=module_inputs_torch_nn_PReLU,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.RNNCell,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
+               ),
+    ModuleInfo(torch.nn.GRUCell,
+               module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell,
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
+               ),
+    ModuleInfo(torch.nn.LSTMCell,
+               module_inputs_func=module_inputs_torch_nn_LSTMCell,
+               module_error_inputs_func=module_error_inputs_torch_nn_LSTMCell,
+               ),
+    ModuleInfo(torch.nn.Sigmoid,
+               module_inputs_func=module_inputs_torch_nn_Sigmoid,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.LogSigmoid,
+               module_inputs_func=module_inputs_torch_nn_LogSigmoid,
+               skips=(
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.SiLU,
+               module_inputs_func=module_inputs_torch_nn_SiLU,
+               ),
+    ModuleInfo(torch.nn.Softmax,
+               module_inputs_func=module_inputs_torch_nn_Softmax,
+               ),
+    ModuleInfo(torch.nn.Softmax2d,
+               module_inputs_func=module_inputs_torch_nn_Softmax2d,
+               skips=(
+                   # no channels last support for Softmax2d currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: tolerance issue
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.LogSoftmax,
+               module_inputs_func=module_inputs_torch_nn_LogSoftmax,
+               skips=(
+                   # no channels last support for LogSoftmax currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
+                   # See #119108: inf nan error
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
+               ),
+    ModuleInfo(torch.nn.Softmin,
+               module_inputs_func=module_inputs_torch_nn_Softmin,
+               skips=(
+                   # no channels last support for Softmin currently
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
+               ),
+    ModuleInfo(torch.nn.Softplus,
+               module_inputs_func=module_inputs_torch_nn_Softplus,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Softshrink,
+               module_inputs_func=module_inputs_torch_nn_Softshrink,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Softsign,
+               module_inputs_func=module_inputs_torch_nn_Softsign,
+               ),
+    ModuleInfo(torch.nn.Tanh,
+               module_inputs_func=module_inputs_torch_nn_Tanh,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.Tanhshrink,
+               module_inputs_func=module_inputs_torch_nn_Tanhshrink,
+               skips=None if _macos15_or_newer else (
+                   # Fails on backward check on MPS
+                   # See https://github.com/pytorch/pytorch/issues/107214
+                   DecorateInfo(
+                       unittest.expectedFailure,
+                       'TestModule',
+                       'test_memory_format',
+                       active_if=operator.itemgetter('training'),
+                       device_type='mps',
+                   ),)
+               ),
+    ModuleInfo(torch.nn.Threshold,
+               module_inputs_func=module_inputs_torch_nn_Threshold,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.Mish,
+               module_inputs_func=module_inputs_torch_nn_Mish,
+               skips=(
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.RNN,
+               train_and_eval_differ=True,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               decorators=rnn_gru_lstm_module_info_decorators
+               ),
+    ModuleInfo(torch.nn.GRU,
+               train_and_eval_differ=True,
+               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               decorators=rnn_gru_lstm_module_info_decorators),
+    ModuleInfo(torch.nn.LSTM,
+               train_and_eval_differ=True,
+               module_inputs_func=module_inputs_torch_nn_LSTM,
+               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
+               skips=(
+                   # LSTM with projections is not currently supported with MPS
+                   DecorateInfo(skipMPS),),
+               decorators=rnn_gru_lstm_module_info_decorators),
+    ModuleInfo(torch.nn.ReflectionPad1d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
+               ),
+    ModuleInfo(torch.nn.ReflectionPad2d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReflectionPad3d,
+               module_inputs_func=module_inputs_torch_nn_ReflectionPad3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReplicationPad1d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad1d,
+               ),
+    ModuleInfo(torch.nn.ReplicationPad2d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad2d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ReplicationPad3d,
+               module_inputs_func=module_inputs_torch_nn_ReplicationPad3d,
+               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+               skips=(
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='cuda'),
+                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
+                                device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.SELU,
+               module_inputs_func=module_inputs_torch_nn_SELU,
+               skips=(
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
+               ),
+    ModuleInfo(torch.nn.ZeroPad1d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad1d,
+               ),
+    ModuleInfo(torch.nn.ZeroPad2d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad2d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ZeroPad3d,
+               module_inputs_func=module_inputs_torch_nn_ZeroPad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.CircularPad1d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad1d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad1d,
+               ),
+    ModuleInfo(torch.nn.CircularPad2d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad2d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad2d,
+               ),
+    ModuleInfo(torch.nn.CircularPad3d,
+               module_inputs_func=module_inputs_torch_nn_CircularPad3d,
+               module_error_inputs_func=module_error_inputs_torch_nn_Pad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),)
+               ),
+    ModuleInfo(torch.nn.ConstantPad1d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad1d,
+               ),
+    ModuleInfo(torch.nn.ConstantPad2d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad2d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               ),
+    ModuleInfo(torch.nn.ConstantPad3d,
+               module_inputs_func=module_inputs_torch_nn_ConstantPad3d,
+               skips=(
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
+               )
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e93a3552dff8a1f6bad87d86340c277151456cb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_mps.py
@@ -0,0 +1,1006 @@
+import unittest
+from collections.abc import Sequence
+from typing import Optional
+
+import torch
+
+from .common_utils import MACOS_VERSION
+from .opinfo.core import DecorateInfo, OpInfo
+
+
+if torch.backends.mps.is_available():
+
+    def mps_ops_modifier(
+        ops: Sequence[OpInfo],
+        device_type: Optional[str] = None,
+        xfail_exclusion: Optional[list[str]] = None,
+    ) -> Sequence[OpInfo]:
+        if xfail_exclusion is None:
+            xfail_exclusion = []
+
+        # Supported complex OPS
+        SUPPORTED_COMPLEX_OPS = {
+            "__radd__",
+            "__rmul__",
+            "__rsub__",
+            "__getitem__",
+            "_unsafe_masked_index",
+            "abs",
+            "add",
+            "alias_copy",
+            "argwhere",
+            "atleast_1d",
+            "atleast_2d",
+            "atleast_3d",
+            "as_strided",
+            "as_strided_copy",
+            "as_strided_scatter",
+            "asin",
+            "acos",
+            "atan",
+            "broadcast_tensors",
+            "broadcast_to",
+            "chalf",
+            "cfloat",
+            "chunk",
+            "clone",
+            "conj",
+            "conj_physical",
+            "contiguous",
+            "cos",
+            "cosh",
+            "diag",
+            "diag_embed",
+            "diagflat",
+            "diagonal",
+            "diagonal_copy",
+            "diagonal_scatter",
+            "divno_rounding_mode",
+            "dsplit",
+            "empty",
+            "empty_permuted",
+            "empty_strided",
+            "exp",
+            "expm1",
+            "exp2",
+            "expand",
+            "expand_as",
+            "expand_copy",
+            "flatten",
+            "fill",
+            "full",
+            "full_like",
+            "H",
+            "hsplit",
+            "imag",
+            "index_copy",
+            "index_select",
+            "isfinite",
+            "isinf",
+            "isreal",
+            "item",
+            "kron",
+            "linalg.diagonal",
+            "linalg.svd",
+            "log10",
+            "log1p",
+            "log2",
+            "log",
+            "mH",
+            "mT",
+            "masked_fill",
+            "masked_scatter",
+            "masked_select",
+            "meshgridlist_of_tensors",
+            "meshgridvariadic_tensors",
+            "movedim",
+            "mul",
+            "narrow",
+            "narrow_copy",
+            "neg",
+            "new_full",
+            "new_ones",
+            "new_zeros",
+            "nn.functional.conv1d",
+            "nn.functional.conv2d",
+            "nn.functional.conv_transpose1d",
+            "nn.functional.conv_transpose2d",
+            "nn.functional.conv_transpose3d",
+            "nn.functional.feature_alpha_dropoutwithout_train",
+            "nn.functional.padcircular",
+            "nn.functional.softsign",
+            "nn.functional.tanhshrink",
+            "nn.functional.unfold",
+            "nonzero",
+            "ones",
+            "ones_like",
+            "outer",
+            "permute",
+            "permute_copy",
+            "positive",
+            "randn",
+            "ravel",
+            "real",
+            "repeat_interleave",
+            "reshape_as",
+            "reshape",
+            "resolve_conj",
+            "resolve_neg",
+            "rsqrt",
+            "rsub",
+            "scalar_tensor",
+            "select",
+            "sgn",
+            "sigmoid",
+            "sin",
+            "sinc",
+            "sinh",
+            "slice",
+            "special.spherical_bessel_j0",
+            "special.entr",
+            "special.xlog1py",
+            "special.zeta",
+            "split",
+            "split_with_sizes",
+            "split_with_sizes_copy",
+            "splitlist_args",
+            "sqrt",
+            "squeeze",
+            "squeeze_copy",
+            "squeezemultiple",
+            "sub",
+            "svd",
+            "t",
+            "t_copy",
+            "tanh",
+            "tan",
+            "tensor_split",
+            "transpose",
+            "transpose_copy",
+            "tril",
+            "triu",
+            "true_divide",
+            "T",
+            "unbind",
+            "unbind_copy",
+            "unflatten",
+            "unfold",
+            "unfold_copy",
+            "unsafe_chunk",
+            "unsafe_split",
+            "unsqueeze",
+            "unsqueeze_copy",
+            "view_as",
+            "view_as_real",
+            "view",
+            "view_copy",
+            "vsplit",
+            "zero_",
+            "zeros",
+            "zeros_like",
+        }
+
+        AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
+            "__rdiv__",
+            "__rmatmul__",
+            "_chunk_cat",
+            "acosh",
+            "all",
+            "allclose",
+            "angle",
+            "any",
+            "addcdiv",
+            "addcmul",
+            "addmmdecomposed",
+            "addmv",
+            "atanh",
+            "bfloat16",
+            "bmm",
+            "bool",
+            "cartesian_prod",
+            "cat",
+            "char",
+            "column_stack",
+            "combinations",
+            "corrcoef",
+            "constant_pad_nd",
+            "cov",
+            "count_nonzero",
+            "diff",
+            "div",
+            "dot",
+            "dstack",
+            "einsum",
+            "eq",
+            "equal",
+            "eye",
+            "fft.fft",
+            "fft.fft2",
+            "fft.fftn",
+            "fft.fftshift",
+            "fft.ifft",
+            "fft.ifft2",
+            "fft.ifftn",
+            "fft.ifftshift",
+            "fft.irfftn",
+            "fft.irfft2",
+            "fft.irfft",
+            "fft.hfftn",
+            "fft.hfft2",
+            "fft.hfft",
+            "flip",
+            "fliplr",
+            "flipud",
+            "float",
+            "gradient",
+            "half",
+            "hstack",
+            "inner",
+            "int",
+            "isclose",
+            "isnan",
+            "ldexp",
+            "lerp",
+            "linalg.multi_dot",
+            "linalg.pinv",
+            "linspace",
+            "linspacetensor_overload",
+            "logical_and",
+            "logical_not",
+            "logical_or",
+            "logical_xor",
+            "logsumexp",
+            "long",
+            "masked.mean",
+            "masked.prod",
+            "masked.std",
+            "masked.sum",
+            "masked.var",
+            "masked.logsumexp",
+            "matmul",
+            "mean",
+            "mm",
+            "mv",
+            "ne",
+            "nn.functional.padconstant",
+            "nn.functional.padreflect",
+            "nn.functional.padreplicate",
+            "nn.functional.pixel_shuffle",
+            "nn.functional.pixel_unshuffle",
+            "nn.functional.rms_norm",
+            "pinverse",
+            "prod",
+            "reciprocal",
+            "roll",
+            "rot90",
+            "short",
+            "sinh",
+            "sqrt",
+            "square",
+            "stack",
+            "stft",
+            "sum",
+            "sum_to_size",
+            "tensordot",
+            "trace",
+            "trapz",
+            "trapezoid",
+            "vstack",
+            "where",
+            "byte",
+        }
+        # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
+        MACOS_BEFORE_13_3_XFAILLIST = {
+            # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
+            "cdist": [torch.float32],
+            # CPU Error: cpu not giving nan for x/0.0
+            "atan2": [
+                torch.bool,
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            # test blow pass on macOS 12 as it falls back to cpu
+            # Argsort case using duplicate indices (undefined behaviour):
+            #  - CPU output: tensor([2546, 6917, 3181,  ..., 7128, 5133,   30], device='cpu')
+            #  - MPS output: tensor([2546, 6917, 3181,  ..., 7128,   30, 5133], device='mps:0')
+            # Elements from index 30 and 5133 are both equal.
+            # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
+            "argsort": [torch.float16, torch.int8, torch.uint8, torch.bool],
+            # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
+            # The values of the sorted tensor match the CPU,
+            # but in case of the returned indices this results in undefined behaviour.
+            "sort": [torch.int8, torch.uint8, torch.bool, torch.float16],
+            # Unsupported dtypes
+            "cumsum": [torch.int64],
+            "cumprod": [torch.int64],
+            "cumulative_trapezoid": [torch.int64],
+            "masked.cumsum": [torch.int64],
+            "masked.cumprod": [torch.int64],
+            "linalg.vander": [torch.int64],
+            # Fail with `Expected 1.0 but got nan.` for empty tensors
+            # Caused by sample input at index 23: SampleInput(
+            #     input=Tensor[size=(), device="mps:0", dtype=torch.float32],
+            #     args=(0),
+            #     kwargs={'mask': 'Tensor[size=(), device="mps:0", dtype=torch.bool]'},
+            #     broadcasts_input=False, name='')
+            "masked.softmin": [torch.float32, torch.float16],
+            "masked.softmax": [torch.float32, torch.float16],
+            "masked.log_softmax": [torch.float32, torch.float16],
+        }
+
+        MACOS_AFTER_13_1_XFAILLIST = {
+            # before macOS 13.2 it falls back to cpu and pass the forward pass
+            "grid_sampler_2d": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],  # Unsupported Border padding mode
+        }
+
+        MACOS_13_3_XFAILLIST = {
+            # Failure due to precision issue for fp16
+            # on both cpu and mps there are test cases that might produce inf result
+            # 'nn.functional.pairwise_distance': [torch.float16],
+            # test blow pass on macOS 12 as it falls back to cpu
+            # Argsort case using duplicate indices (undefined behaviour):
+            #  - CPU output: tensor([2546, 6917, 3181,  ..., 7128, 5133,   30], device='cpu')
+            #  - MPS output: tensor([2546, 6917, 3181,  ..., 7128,   30, 5133], device='mps:0')
+            # Elements from index 30 and 5133 are both equal.
+            # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
+            "argsort": [
+                torch.float16,
+                torch.int8,
+                torch.uint8,
+                torch.bool,
+                torch.bfloat16,
+            ],
+            # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
+            # The values of the sorted tensor match the CPU,
+            # but in case of the returned indices this results in undefined behaviour.
+            "sort": [
+                torch.int8,
+                torch.uint8,
+                torch.bool,
+                torch.float16,
+                torch.bfloat16,
+            ],
+        }
+
+        MACOS_BEFORE_14_4_XFAILLIST = {
+            # These ops work fine in 14.4 but fail in 14.2 or 13.x
+            "fft.hfft2": [torch.complex64],
+        }
+
+        # Those ops are not expected to work
+        UNIMPLEMENTED_XFAILLIST = {
+            # Failures due to lack of op implementation on MPS backend
+            "logspace": None,
+            "logspacetensor_overload": None,
+            "linalg.eig": None,
+            "linalg.eigvals": None,
+            "put": None,
+            "cauchy_": None,
+            "cauchy": None,
+            "cholesky_inverse": None,
+            "cholesky_solve": None,
+            "frexp": None,
+            "gcd": None,
+            "geqrf": None,
+            "nn.functional.grid_sample": None,  # Unsupported Border padding mode
+            "heaviside": None,
+            "igamma": None,
+            "igammac": None,
+            "index_reduceprod": None,
+            "index_reducemean": None,
+            "index_reduceamax": None,
+            "index_reduceamin": None,
+            "kthvalue": None,
+            "lcm": None,
+            "linalg.cond": None,
+            "linalg.eigh": None,
+            "linalg.eigvalsh": None,
+            "linalg.householder_product": None,
+            "linalg.ldl_factor": None,
+            "linalg.ldl_factor_ex": None,
+            "linalg.ldl_solve": None,
+            "linalg.lstsq": None,
+            "linalg.lstsqgrad_oriented": None,
+            "linalg.lu": None,
+            "linalg.lu_solve": None,
+            "linalg.matrix_norm": [torch.float32],
+            "linalg.norm": [torch.float32],
+            "linalg.normsubgradients_at_zero": [torch.float32],
+            "linalg.qr": None,
+            "linalg.svdvals": None,
+            "linalg.vecdot": None,
+            "logcumsumexp": None,
+            "lu_solve": None,
+            "masked.median": None,
+            "matrix_exp": None,
+            "mode": None,
+            "native_dropout_backward": None,
+            "normnuc": None,
+            "nn.functional.fractional_max_pool2d": None,
+            "nn.functional.fractional_max_pool3d": None,
+            "nn.functional.adaptive_avg_pool3d": None,
+            "nn.functional.adaptive_max_pool3d": None,
+            "nn.functional.interpolatearea": None,
+            "nn.functional.interpolatebicubic": [torch.uint8],
+            "nn.functional.max_unpool1dgrad": None,
+            "nn.functional.max_unpool2dgrad": None,
+            "nn.functional.max_unpool3dgrad": None,
+            "nn.functional.avg_pool3d": None,
+            "nn.functional.ctc_loss": None,
+            "nn.functional.embedding_bag": None,
+            "nn.functional.max_pool3d": None,
+            "nn.functional.max_unpool1d": None,
+            "nn.functional.max_unpool2d": None,
+            "nn.functional.max_unpool3d": None,
+            "nn.functional.multi_margin_loss": None,
+            "nn.functional.multilabel_margin_loss": None,
+            "nn.functional.pdist": None,
+            "nn.functional.rrelu": None,
+            "nn.functional.norm": None,
+            "ormqr": None,
+            "pca_lowrank": None,
+            "qr": None,
+            "scatter_reduceamax": [torch.int32, torch.int64]
+            if MACOS_VERSION < 15.0
+            else [torch.int64],
+            "scatter_reduceamin": [torch.int32, torch.int64]
+            if MACOS_VERSION < 15.0
+            else [torch.int64],
+            "segment_reduce": None,
+            "_segment.reduce": None,
+            "segment.reduce": None,
+            "segment_reduce_offsets": None,
+            "_segment_reduce_offsets": None,
+            "_segment_reduce_lengths": None,
+            "_segment_reducelengths": None,
+            "_segment_reduceoffsets": None,
+            "sparse.mm": None,
+            "sparse.sampled_addmm": None,
+            "sparse.mmreduce": None,
+            "special.airy_ai": None,
+            "special.erfcx": None,
+            "special.laguerre_polynomial_l": None,
+            "special.log_ndtr": None,
+            "special.ndtri": None,
+            "svd_lowrank": None,
+            "symeig": None,
+            "take": None,
+            "to": None,
+            "to_sparse": None,
+            "unique": None,
+            "vdot": None,
+            "segment_reduce_": None,
+            "_upsample_bilinear2d_aa": [torch.uint8],  # uint8 is for CPU only
+            "_upsample_bicubic2d_aa": [torch.uint8],  # uint8 is for CPU only
+            "geometric": None,
+            "geometric_": None,
+            "log_normal_": None,
+            "log_normal": None,
+            "cdouble": None,
+            "double": None,
+            "nn.functional.softminwith_dtype": None,
+            "log_softmaxwith_dtype": None,
+            "softmaxwith_dtype": None,
+            "float_power": None,
+            "linalg.matrix_rankhermitian": None,
+            "linalg.pinvhermitian": None,
+            "nonzero_static": None,
+            # MPS: input sizes must be divisible by output sizes
+            "nn.functional.adaptive_avg_pool1d": None,
+            "nn.functional.adaptive_avg_pool2d": None,
+            # Convolution for integral types is not supported on MPS
+            "nn.functional.conv1d": [torch.int64],
+            "nn.functional.conv2d": [torch.int64],
+            "nn.functional.conv3d": [torch.int64],
+            "nn.functional.conv_transpose1d": [torch.int64],
+            "nn.functional.conv_transpose2d": [torch.int64, torch.bfloat16],
+            "nn.functional.conv_transpose3d": [
+                torch.int64,
+                torch.bfloat16,
+                torch.float16,
+            ],
+            # Unsupported dtypes
+            "dot": [torch.int64] if MACOS_VERSION < 14.0 else [],
+            "histc": [torch.float16, torch.bfloat16],
+            "index_add": [torch.int64],
+            # GEMM on MPS is not supported for integral types
+            "nn.functional.linear": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            "addmmdecomposed": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            "addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
+            "matmul": [torch.int64] if MACOS_VERSION < 14.0 else [],
+            "__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [],
+            # returned output on CPU is float64
+            "bincount": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            # round not working properly for float16 and bfloat16
+            "round": [torch.float16, torch.bfloat16],
+            "rounddecimals_0": [torch.bfloat16],
+            # atomic operations not supported
+            "_unsafe_masked_index_put_accumulate": [
+                torch.int8,
+                torch.uint8,
+                torch.int16,
+                torch.int64,
+            ],
+        }
+
+        if MACOS_VERSION < 14.0:
+            # FFT and BFloat16 support was added in MacOS 14
+            UNIMPLEMENTED_XFAILLIST.update(
+                {
+                    "bfloat16": None,
+                    "fft.fft": None,
+                    "fft.fft2": None,
+                    "fft.fftn": None,
+                    "fft.hfft": None,
+                    "fft.hfft2": None,
+                    "fft.hfftn": None,
+                    "fft.ifft": None,
+                    "fft.ifft2": None,
+                    "fft.ifftn": None,
+                    "fft.ihfft": None,
+                    "fft.ihfft2": None,
+                    "fft.ihfftn": None,
+                    "fft.irfft": None,
+                    "fft.irfft2": None,
+                    "fft.irfftn": None,
+                    "fft.rfft": None,
+                    "fft.rfft2": None,
+                    "fft.rfftn": None,
+                    "stft": None,
+                    # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers,
+                    # not reproducible in later OS. Added assert to op if used in < 14.0
+                    "isin": [
+                        torch.int64,
+                        torch.int32,
+                        torch.int16,
+                        torch.uint8,
+                        torch.int8,
+                    ],
+                    "nn.functional.max_pool2d": [torch.uint8],
+                }
+            )
+
+        if MACOS_VERSION < 15.0:
+            UNIMPLEMENTED_XFAILLIST.update(
+                {
+                    "quantile": None,
+                    "nanquantile": None,
+                }
+            )
+
+        UNDEFINED_XFAILLIST = {
+            # Top 60 operators
+            # topk fails with duplicate indices
+            "topk": [
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],
+            # Failures due to random output that they generate using
+            # Philox engine causing mismatch with CPU results
+            "multinomial": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],  # random results
+            "uniform": [torch.float16, torch.float32, torch.bfloat16],
+            "rand_like": [torch.float16, torch.float32, torch.bfloat16],
+            "randint": None,
+            "randint_like": None,
+            "randn": None,
+            "randn_like": None,
+            "bernoulli": [torch.float16, torch.float32, torch.bfloat16],
+            "exponential": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.feature_alpha_dropoutwith_train": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],
+            "normal": [torch.float16, torch.float32, torch.bfloat16],
+            "normalin_place": [torch.float16, torch.float32, torch.bfloat16],
+            "normalnumber_mean": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.alpha_dropout": [
+                torch.float16,
+                torch.float32,
+                torch.bfloat16,
+            ],
+            "nn.functional.dropout": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.dropout2d": [torch.float16, torch.float32, torch.bfloat16],
+            "nn.functional.dropout3d": [torch.float16, torch.float32, torch.bfloat16],
+            # See https://github.com/pytorch/pytorch/issues/111479
+            "nn.functional.multi_head_attention_forward": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],
+            "index_put": [
+                torch.uint8,
+                torch.int8,
+                torch.int16,
+                torch.int64,
+            ],
+            # zero to negative integer powers are undefined
+            "__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
+            "resize_": [torch.float16, torch.float32, torch.bfloat16],
+            "resize_as_": [torch.float16, torch.float32, torch.bfloat16],
+            # CPU Errors:
+            "addr": [
+                torch.bool,
+                torch.int16,
+                torch.int32,
+                torch.int64,
+                torch.uint8,
+                torch.int8,
+            ],  # "addmv_impl_cpu" not implemented for 'Half'
+            "as_stridedpartial_views": None,  # cpu result off, showing random values
+            # random results
+            # mps vs cpu:
+            # Mismatched elements: 40 / 96 (41.7%)
+            # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
+            # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
+            # cuda(2.0.0.dev20230301+cu117) vs cpu:
+            # Mismatched elements: 56 / 96 (58.3%)
+            # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
+            # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
+            "nn.functional.scaled_dot_product_attention": [
+                torch.float32,
+                torch.float16,
+                torch.bfloat16,
+            ],
+        }
+
+        ON_MPS_XFAILLIST = {
+            # Failures due to lack of implementation of downstream functions on MPS backend
+            # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+            "linalg.matrix_rank": None,
+            # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")`
+            "arange": [torch.uint8],
+        }
+
+        EMPTY_OPS_SKIPLIST = {
+            # Fill tensors with uninitialized data, causing mismatch with CPU.
+            # They occasionally match, thus skipping them.
+            # See https://github.com/pytorch/pytorch/issues/100175
+            "new_empty": None,
+            "new_empty_strided": None,
+            "empty_strided": None,
+            # CPU: empty is returning all 0's and there is a mismatch with MPS
+            # allocation (MacOS 13). According to
+            # https://pytorch.org/docs/2.0/generated/torch.empty.html
+            "empty": None,
+            "empty_like": None,
+            "empty_permuted": None,
+        }
+
+        SKIPLIST = {
+            # Unsupported
+            # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
+            "nn.functional.conv3d": None,
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            if device_type is not None:
+                d.device_type = device_type
+
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            if key in EMPTY_OPS_SKIPLIST:
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.skip("Skipping empty ops."),
+                        dtypes=EMPTY_OPS_SKIPLIST[key],
+                    ),
+                )
+            if key in SKIPLIST:
+                addDecorator(
+                    op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key])
+                )
+            for xfaillist in [
+                UNIMPLEMENTED_XFAILLIST,
+                UNDEFINED_XFAILLIST,
+                ON_MPS_XFAILLIST,
+            ]:
+                if key in xfaillist and key not in xfail_exclusion:
+                    addDecorator(
+                        op,
+                        DecorateInfo(unittest.expectedFailure, dtypes=xfaillist[key]),
+                    )
+
+            if (
+                key in MACOS_BEFORE_14_4_XFAILLIST
+                and key not in xfail_exclusion
+                and (MACOS_VERSION < 14.4)
+            ):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure,
+                        dtypes=MACOS_BEFORE_14_4_XFAILLIST[key],
+                    ),
+                )
+
+            if (
+                key in MACOS_BEFORE_13_3_XFAILLIST
+                and key not in xfail_exclusion
+                and (torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3)
+            ):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure,
+                        dtypes=MACOS_BEFORE_13_3_XFAILLIST[key],
+                    ),
+                )
+
+            if (
+                key in MACOS_AFTER_13_1_XFAILLIST
+                and key not in xfail_exclusion
+                and torch.backends.mps.is_macos13_or_newer(2)
+            ):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure, dtypes=MACOS_AFTER_13_1_XFAILLIST[key]
+                    ),
+                )
+
+            if (
+                key in MACOS_13_3_XFAILLIST
+                and key not in xfail_exclusion
+                and (MACOS_VERSION >= 13.3)
+            ):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST[key]
+                    ),
+                )
+
+            # If ops is not supported for complex types, expect it to fail
+            if key not in SUPPORTED_COMPLEX_OPS and (
+                key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS
+                or MACOS_VERSION < 14.0
+            ):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure,
+                        dtypes=[torch.complex32, torch.complex64],
+                    ),
+                )
+
+        return ops
+
+    def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
+        XFAILLIST_GRAD = {
+            # Unimplemented ops
+            "_segment_reduce": [torch.float16, torch.float32],
+            "_chunk_cat": [torch.float16, torch.float32],
+            "_upsample_bilinear2d_aa": None,  # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
+            "_upsample_bicubic2d_aa": None,  # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
+            "sparse.mmreduce": [torch.float32],  # csr not supported
+            "unique_consecutive": [torch.float16, torch.float32],
+            "scalar_tensor": [torch.float16, torch.float32],
+            "cdist": [torch.float32],
+            "masked.scatter": [torch.float16, torch.float32],
+            "index_fill": [torch.float16, torch.float32],  # missing `aten::_unique`.
+            "linalg.solve": [torch.float16, torch.float32],  # missing `aten::lu_solve`.
+            "linalg.solve_ex": [
+                torch.float16,
+                torch.float32,
+            ],  # missing `aten::lu_solve`.
+            "linalg.tensorsolve": [
+                torch.float16,
+                torch.float32,
+            ],  # missing `aten::lu_solve`.
+            "linalg.det": [torch.float16, torch.float32],  # missing aten::lu_solve.out
+            "linalg.slogdet": [
+                torch.float16,
+                torch.float32,
+            ],  # missing aten::lu_solve.out
+            "logdet": [torch.float16, torch.float32],  # missing aten::lu_solve.out
+            "aminmax": [torch.float32, torch.float16],
+            "special.i1": [torch.float16],  # "i1_backward" not implemented for 'Half'
+            "special.i1e": [torch.float16],  # "i1e_backward" not implemented for 'Half'
+            # Correctness issues
+            "atanh": [torch.float32],
+            # Random output
+            "exponential": [torch.float16, torch.float32],
+            # CPU errors
+            # derivative for zeta is not implemented
+            "special.zeta": None,
+            # derivative for aten::nextafter is not implemented on CPU
+            "nextafter": None,
+            # derivative for aten::floor_divide is not implemented on CPU
+            "floor_divide": [torch.float16, torch.float32],
+            # derivative for aten::narrow_copy is not implemented on CPU
+            "narrow_copy": [torch.float16, torch.float32],
+            # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
+            "histogramdd": [torch.float16, torch.float32],
+            # derivative for aten::histogram is not implemented
+            "histogram": [torch.float16, torch.float32],
+            # 'bool' object is not iterable
+            "allclose": [torch.float16, torch.float32],
+            "equal": [torch.float16, torch.float32],
+            # 'float' object is not iterable
+            "item": [torch.float16, torch.float32],
+            # "smooth_l1_backward_cpu_out" not implemented for 'Half'
+            "nn.functional.smooth_l1_loss": [torch.float16],
+            # cpu error: grad requires non-empty inputs
+            "randn": [torch.float16, torch.float32],
+            "signal.windows.bartlett": [torch.float32],
+            "signal.windows.blackman": [torch.float32],
+            "signal.windows.cosine": [torch.float32],
+            "signal.windows.exponential": [torch.float32],
+            "signal.windows.gaussian": [torch.float32],
+            "signal.windows.general_cosine": [torch.float32],
+            "signal.windows.general_hamming": [torch.float32],
+            "signal.windows.hamming": [torch.float32],
+            "signal.windows.hann": [torch.float32],
+            "signal.windows.kaiser": [torch.float32],
+            "signal.windows.nuttall": [torch.float32],
+            "eye": [torch.float16, torch.float32],
+            # round not working properly for float16
+            "round": [torch.float16],
+            # topk fails with duplicate indices
+            "topk": [torch.float16],
+        }
+
+        MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
+            # Failures due to precision issues (may be fast-math). These has been fixed in MacOS 14
+            "masked.softmin": [torch.float32, torch.float16],
+            "masked.softmax": [torch.float32, torch.float16],
+            "masked.log_softmax": [torch.float32, torch.float16],
+            "atanh": [torch.float16],
+            "triangular_solve": [torch.float32],
+            # Unsupported Border padding mode, forward pass success as fallback to cpu
+            "grid_sampler_2d": [torch.float32, torch.float16, torch.bfloat16],
+            # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
+            # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
+            # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
+            # Running `msort` with stable `sort` passes.
+            "msort": [torch.float16],
+        }
+
+        SKIPLIST_GRAD = {
+            "nn.functional.pairwise_distance": [torch.float16],
+            # failed assertion `destination datatype must be fp32'
+            "nn.functional.conv1d": [torch.float16],
+            "nn.functional.conv2d": [torch.float16],
+            "nn.functional.conv3d": [torch.float16],
+            "nn.functional.conv_transpose1d": [torch.float16],
+            "nn.functional.conv_transpose2d": [torch.float16],
+            "nn.functional.conv_transpose3d": [torch.float16],
+        }
+
+        MACOS_13_3_XFAILLIST_GRAD = {
+            # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
+            # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
+            # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
+            # Running `msort` with stable `sort` passes.
+            "msort": [torch.float16],
+        }
+
+        ON_MPS_XFAILLIST = {
+            # Failures due to lack of implementation of downstream functions on MPS backend
+            # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
+            "linalg.matrix_rank": None,
+            # Exception: Caused by sample input at index 3 on MPS
+            "nn.functional.conv3d": [torch.float32],
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            if key in XFAILLIST_GRAD:
+                addDecorator(
+                    op,
+                    DecorateInfo(unittest.expectedFailure, dtypes=XFAILLIST_GRAD[key]),
+                )
+
+            if key in SKIPLIST_GRAD:
+                addDecorator(op, DecorateInfo(unittest.skip, dtypes=SKIPLIST_GRAD[key]))
+
+            if key in ON_MPS_XFAILLIST:
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure, dtypes=ON_MPS_XFAILLIST[key]
+                    ),
+                )
+
+            if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (
+                torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3
+            ):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure,
+                        dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key],
+                    ),
+                )
+
+            if key in MACOS_13_3_XFAILLIST_GRAD and (MACOS_VERSION >= 13.3):
+                addDecorator(
+                    op,
+                    DecorateInfo(
+                        unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST_GRAD[key]
+                    ),
+                )
+        return ops
+
+    def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
+        # Error input samples do not take a dtype argument.
+        XFAILLIST = {
+            # Exceptions are not raised
+            "__rmod__",
+            "__rsub__",
+            "__rpow__",
+            "bernoulli",
+            "clamp_max",
+            "clamp_min",
+            "masked_scatter",
+            # unsupported float64 dtype
+            "cat",
+            "complex",
+            "multinomial",
+            "nn.functional.conv1d",
+            "nn.functional.conv2d",
+            "nn.functional.conv3d",
+            "gather",
+            "scatter",
+            "scatter_add",
+            # MPS does not support tensor dimensions > 16
+            "amax",
+            "amin",
+            "aminmax",
+            # memory overlapping checks
+            "index_select",
+            # unimplemented
+            "logcumsumexp",
+        }
+
+        def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
+            op.decorators = op.decorators + (d,)
+
+        for op in ops:
+            key = op.name + op.variant_test_name
+            if key in XFAILLIST:
+                addDecorator(op, DecorateInfo(unittest.expectedFailure))
+
+        return ops
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..135cc6a7bd66d3b1bd3c2fde6cb07a2c0af7c46a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_nn.py
@@ -0,0 +1,3993 @@
+# mypy: ignore-errors
+
+from abc import abstractmethod
+import tempfile
+import unittest
+
+from copy import deepcopy
+from functools import reduce, partial
+from itertools import product
+from operator import mul
+
+
+import torch
+import torch.cuda
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import _reduction as _Reduction
+from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
+    gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
+from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
+from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
+from torch.autograd import Variable
+from torch.types import _TensorOrTensors
+import torch.backends.cudnn
+
+from typing import Callable, Union, Any
+from collections.abc import Sequence
+
+TemporaryFile = tempfile.TemporaryFile
+PRECISION = 1e-5
+
+
+def get_reduction(m):
+    result = getattr(m, 'reduction', None)
+    if result is None:
+        result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
+    assert result is not None
+    return result
+
+
+def get_weight(m):
+    result = getattr(m, 'weight', None)
+    if result is not None:
+        return result
+    return getattr(m, 'weights', None)
+
+# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
+#
+# The way to check API parity is to add parity tests for the NN module / functional of interest.
+# Here are the detailed steps:
+#
+# For NN module:
+# 1. Make sure you already have a test dict with the module configuration you want to test.
+# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
+#    the Python module constructor arguments. For example, if in the test dict we pass
+#    `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
+#    as the corresponding C++ constructor argument to `torch::nn::Linear`.
+# 3. If in the process of performing the above step you referenced any variables
+#    in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
+#    to the test dict to make sure that those variables are populated with the right Python values.
+#    For example, if the Python constructor call is
+#    `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
+#    the corresponding C++ constructor argument is
+#    `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
+#    and the `cpp_var_map` entry must be
+#    `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
+#    used in the C++ constructor argument with the Python tensor value `random_samples`.
+#
+# For NN functional:
+# 1. Make sure you already have a test dict with the functional configuration you want to test.
+# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
+#    then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
+#    functional optional arguments. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
+#    then the `cpp_options_args` entry should be
+#    "F::InterpolateFuncOptions().size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)".
+# 3. Otherwise, if the test dict's `constructor` entry looks like
+#    `wrap_functional(lambda i: F.some_functional_name(...))`,
+#    then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
+#    functional function call. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
+#    then the `cpp_function_call` entry should be
+#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
+# 4. If in the process of performing the above two steps you referenced any variables
+#    in the `cpp_options_args` or `cpp_function_call` entry, you must
+#    add `cpp_var_map` entry to the test dict to make sure that those variables
+#    are populated with the right Python values. For example, if the test dict's `constructor` entry is
+#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
+#    then the `cpp_function_call` entry should be
+#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
+#    Notice that there are two variables `i` and `t` that need to have their values provided,
+#    and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
+#    (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
+#    and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
+#
+# There are also a few optional flags in the test dict to control the C++ parity test behavior:
+#
+# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
+# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
+
+
+module_tests = [
+    dict(
+        module_name='Linear',
+        constructor_args=(10, 8),
+        cpp_constructor_args='torch::nn::LinearOptions(10, 8)',
+        input_size=(4, 10),
+        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
+        with_tf32=True,
+        tf32_precision=0.005,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='Linear',
+        constructor_args=(10, 8, False),
+        cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)',
+        input_size=(4, 10),
+        desc='no_bias',
+        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
+        with_tf32=True,
+        tf32_precision=0.05 if TEST_WITH_ROCM else 0.005,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='RReLU',
+        input_size=(1, 2, 2),
+        test_cuda=False,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='RReLU',
+        constructor_args=(0.1, 0.9),
+        cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
+        input_size=(4, 4, 5),
+        desc='with_up_down',
+        test_cuda=False,
+        default_dtype=torch.double,
+    ),
+    dict(
+        module_name='Flatten',
+        input_size=(2, 3, 4, 5),
+        reference_fn=lambda i, *_: torch.flatten(i, 1),
+        default_dtype=torch.double,
+    ),
+    # TODO: reference function
+    dict(
+        module_name='CrossMapLRN2d',
+        constructor_args=(5, 5e-3, 1e-3, 2),
+        cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)',
+        input_size=(2, 3, 6, 6),
+        check_gradgrad=False,
+        # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched"
+        check_batched_grad=False,
+        default_dtype=torch.double,
+    ),
+]
+
+
+# Generates rand tensor with non-equal values. This ensures that duplicate
+# values won't be causing test failure for modules like MaxPooling.
+# size should be small, otherwise randperm fails / long overflows.
+def _rand_tensor_non_equal(*size):
+    total = reduce(mul, size, 1)
+    return torch.randperm(total).view(*size).double()
+
+
+def wrap_functional(fn, **kwargs):
+    class FunctionalModule(nn.Module):
+        def forward(self, *args):
+            return fn(*args, **kwargs)
+    return FunctionalModule
+
+
+def poissonnllloss_no_reduce_test():
+    t = torch.randn(10, 10)
+    return dict(
+        fullname='PoissonNLLLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::poisson_nll_loss('
+                          'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: i.exp() - t.mul(i),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def bceloss_no_reduce_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    return dict(
+        fullname='BCELoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
+        pickle=False,
+        precision=7e-4,
+        default_dtype=torch.double)
+
+
+def bceloss_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    return dict(
+        fullname='BCELoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def bceloss_weights_no_reduce_test():
+    t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double))
+    weights = torch.rand(10, dtype=torch.double)
+    return dict(
+        fullname='BCELoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i),
+                                             weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='F::binary_cross_entropy('
+                          'i, t.to(i.options()), '
+                          'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
+        pickle=False,
+        precision=3e-4,
+        default_dtype=torch.double,
+    )
+
+
+def bceloss_weights_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    weights = torch.rand((), dtype=torch.double)
+    return dict(
+        fullname='BCELoss_weights_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy(i, t.type_as(i),
+                                             weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy(
+            i, t.to(i.options()),
+            F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_legacy_enum_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_legacy_enum',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_no_reduce_test():
+    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def bce_with_logistic_no_reduce_scalar_test():
+    t = torch.randn(()).gt(0).to(torch.double)
+    sigmoid = nn.Sigmoid()
+    return dict(
+        fullname='BCEWithLogitsLoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::binary_cross_entropy_with_logits(
+            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_with_target_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_with_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_no_reduce_scalar_test():
+    t = torch.rand((), dtype=torch.double)
+    return dict(
+        fullname='KLDivLoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.rand(()).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_with_log_target_no_reduce_test():
+    t = torch.rand(10, 10, dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_with_log_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def kldivloss_no_reduce_log_target_test():
+    t = torch.rand(10, 10, dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_no_reduce_log_target',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(10, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double,
+    )
+
+
+def kldivloss_no_reduce_scalar_log_target_test():
+    t = torch.rand((), dtype=torch.double).log()
+    return dict(
+        fullname='KLDivLoss_no_reduce_scalar_log_target',
+        constructor=wrap_functional(
+            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
+        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
+        input_fn=lambda: torch.rand(()).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def l1loss_no_reduce_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='L1Loss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def l1loss_no_reduce_complex_test():
+    t = torch.randn(2, 3, 4, dtype=torch.cdouble)
+    return dict(
+        fullname='L1Loss_no_reduce_complex',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False)
+
+
+def l1loss_no_reduce_scalar_test():
+    t = torch.randn((), dtype=torch.double)
+    return dict(
+        fullname='L1Loss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def mseloss_no_reduce_test():
+    input_size = (2, 3, 4, 5)
+    target = torch.randn(*input_size, dtype=torch.double)
+    return dict(
+        fullname='MSELoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
+        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
+        input_size=input_size,
+        cpp_var_map={'i': '_get_input()', 'target': target},
+        reference_fn=lambda i, *_: (i - target).pow(2),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def mseloss_no_reduce_scalar_test():
+    input_size = ()
+    target = torch.randn(input_size, dtype=torch.double)
+    return dict(
+        fullname='MSELoss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
+        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
+        input_size=input_size,
+        cpp_var_map={'i': '_get_input()', 'target': target},
+        reference_fn=lambda i, *_: (i - target).pow(2),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_ignore_index_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_ignore_index_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none',
+                'ignore_index': 2}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''',
+        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss_no_reduce_weights_ignore_index_neg_test():
+    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
+    weight = torch.rand(10)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none',
+                'ignore_index': -1}
+
+    return dict(
+        fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''',
+        input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss2d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_ignore_index_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLoss2d_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nllloss2d_no_reduce_weights_test():
+    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
+    weight = torch.rand(3)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLoss2d_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    kwargs = {'reduction': 'none'}
+    return dict(
+        fullname='NLLLossNd_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_ignore_index_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
+    return dict(
+        fullname='NLLLossNd_no_reduce_ignore_index',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
+                                 reduction=str(kwargs['reduction']))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def nlllossNd_no_reduce_weights_test():
+    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
+    weight = torch.rand(3)
+
+    def kwargs(i):
+        return {'weight': weight.type_as(i), 'reduction': 'none'}
+
+    return dict(
+        fullname='NLLLossNd_no_reduce_weights',
+        constructor=wrap_functional(
+            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
+        cpp_function_call='''F::nll_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_no_reduce_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_no_reduce_scalar_test():
+    t = torch.randn((), dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_no_reduce_scalar',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_beta_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_beta',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def smoothl1loss_zero_beta_test():
+    t = torch.randn(2, 3, 4, dtype=torch.double)
+    return dict(
+        fullname='SmoothL1Loss_zero_beta',
+        constructor=wrap_functional(
+            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)),
+        cpp_function_call='''F::smooth_l1_loss(
+            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def huberloss_delta_test():
+    t = torch.randn(2, 3, 4)
+    return dict(
+        fullname='HuberLoss_delta',
+        constructor=wrap_functional(
+            lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)),
+        cpp_function_call='''F::huber_loss(
+            i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''',
+        input_fn=lambda: torch.randn(2, 3, 4),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_0d_no_reduce_test():
+    t = torch.zeros(()).long()
+    return dict(
+        fullname='MultiLabelMarginLoss_0d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(()),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False)
+
+
+def multilabelmarginloss_1d_no_reduce_test():
+    t = Variable(torch.rand(10).mul(10).floor().long())
+    return dict(
+        fullname='MultiLabelMarginLoss_1d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_index_neg_test():
+    t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
+    return dict(
+        fullname='MultiLabelMarginLoss_index_neg',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelmarginloss_no_reduce_test():
+    t = Variable(torch.rand(5, 10).mul(10).floor().long())
+    return dict(
+        fullname='MultiLabelMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multilabel_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def hingeembeddingloss_no_reduce_test():
+    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
+    return dict(
+        fullname='HingeEmbeddingLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::hinge_embedding_loss(
+            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
+        check_sum_reduction=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def hingeembeddingloss_margin_no_reduce_test():
+    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
+    return dict(
+        fullname='HingeEmbeddingLoss_margin_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
+        cpp_function_call='''F::hinge_embedding_loss(
+            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
+        check_sum_reduction=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def softmarginloss_no_reduce_test():
+    t = torch.randn(5, 5, dtype=torch.double)
+    return dict(
+        fullname='SoftMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::soft_margin_loss(
+            i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 5),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
+        supports_forward_ad=True,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelsoftmarginloss_no_reduce_test():
+    t = torch.rand(5, 10).mul(2).floor()
+    return dict(
+        fullname='MultiLabelSoftMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
+        cpp_function_call='''F::multilabel_soft_margin_loss(
+            i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multilabelsoftmarginloss_weights_no_reduce_test():
+    t = torch.rand(5, 10).mul(2).floor()
+    weights = torch.rand(10)
+    return dict(
+        fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
+                                                    weight=weights.type_as(i), reduction='none')),
+        cpp_function_call='''F::multilabel_soft_margin_loss(
+            i, t.to(i.options()),
+            F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, *_:
+            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_1d_no_reduce_test():
+    t = torch.rand(1).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_1d_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_1d_input_0d_target_no_reduce_test():
+    t = torch.rand(()).mul(8).floor().long()
+    return dict(
+        fullname='multimarginloss_1d_input_0d_target_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_p_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_p_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_margin_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    return dict(
+        fullname='MultiMarginLoss_margin_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
+                                                  margin=0.5, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def multimarginloss_weights_no_reduce_test():
+    t = torch.rand(5).mul(8).floor().long()
+    weights = torch.rand(10, dtype=torch.double)
+    return dict(
+        fullname='MultiMarginLoss_weights_no_reduce',
+        constructor=wrap_functional(
+            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
+                                          reduction='none')),
+        cpp_function_call='''F::multi_margin_loss(
+            i, t.to(i.options()).to(torch::kLong),
+            F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
+        input_fn=lambda: torch.randn(5, 10),
+        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
+        reference_fn=lambda i, *_:
+            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
+                                                  weight=weights, reduction='none'),
+        check_sum_reduction=True,
+        check_gradgrad=False,
+        pickle=False,
+        default_dtype=torch.double)
+
+
+def single_batch_reference_fn(input, parameters, module):
+    """Reference function for modules supporting no batch dimensions.
+
+    The module is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    def unsqueeze_inp(inp):
+        if isinstance(inp, (list, tuple)):
+            return [t.unsqueeze(0) for t in inp]
+        return inp.unsqueeze(0)
+
+    single_batch_input = unsqueeze_inp(input)
+    single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
+    with freeze_rng_state():
+        return module(*single_batch_input).squeeze(0)
+
+
+def get_new_module_tests():
+    new_module_tests = [
+        poissonnllloss_no_reduce_test(),
+        bceloss_no_reduce_test(),
+        bceloss_weights_no_reduce_test(),
+        bce_with_logistic_legacy_enum_test(),
+        bce_with_logistic_no_reduce_test(),
+        bceloss_no_reduce_scalar_test(),
+        bceloss_weights_no_reduce_scalar_test(),
+        bce_with_logistic_no_reduce_scalar_test(),
+        kldivloss_with_target_no_reduce_test(),
+        kldivloss_no_reduce_test(),
+        kldivloss_no_reduce_scalar_test(),
+        kldivloss_with_log_target_no_reduce_test(),
+        kldivloss_no_reduce_log_target_test(),
+        kldivloss_no_reduce_scalar_log_target_test(),
+        l1loss_no_reduce_test(),
+        l1loss_no_reduce_complex_test(),
+        l1loss_no_reduce_scalar_test(),
+        mseloss_no_reduce_test(),
+        mseloss_no_reduce_scalar_test(),
+        nllloss_no_reduce_test(),
+        nllloss_no_reduce_ignore_index_test(),
+        nllloss_no_reduce_weights_test(),
+        nllloss_no_reduce_weights_ignore_index_test(),
+        nllloss_no_reduce_weights_ignore_index_neg_test(),
+        nllloss2d_no_reduce_test(),
+        nllloss2d_no_reduce_weights_test(),
+        nllloss2d_no_reduce_ignore_index_test(),
+        nlllossNd_no_reduce_test(),
+        nlllossNd_no_reduce_weights_test(),
+        nlllossNd_no_reduce_ignore_index_test(),
+        smoothl1loss_no_reduce_test(),
+        smoothl1loss_no_reduce_scalar_test(),
+        smoothl1loss_beta_test(),
+        smoothl1loss_zero_beta_test(),
+        huberloss_delta_test(),
+        multilabelmarginloss_0d_no_reduce_test(),
+        multilabelmarginloss_1d_no_reduce_test(),
+        multilabelmarginloss_index_neg_test(),
+        multilabelmarginloss_no_reduce_test(),
+        hingeembeddingloss_no_reduce_test(),
+        hingeembeddingloss_margin_no_reduce_test(),
+        softmarginloss_no_reduce_test(),
+        multilabelsoftmarginloss_no_reduce_test(),
+        multilabelsoftmarginloss_weights_no_reduce_test(),
+        multimarginloss_no_reduce_test(),
+        multimarginloss_1d_no_reduce_test(),
+        multimarginloss_1d_input_0d_target_no_reduce_test(),
+        multimarginloss_p_no_reduce_test(),
+        multimarginloss_margin_no_reduce_test(),
+        multimarginloss_weights_no_reduce_test(),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='stride',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3, 1, 1),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='pad1',
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 5, 1, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            desc='pad2',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 4, 3, 1, 1),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)',
+            input_size=(1, 4, 1),
+            cudnn=True,
+            desc='pad1size1',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 4, 5, 1, 2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)',
+            input_size=(1, 4, 1),
+            cudnn=True,
+            desc='pad2size1',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv1d',
+            constructor_args=(4, 5, 3),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
+            input_size=(0, 4, 10),
+            cudnn=True,
+            desc='zero_batch',
+            with_tf32=True,
+            tf32_precision=0.005,
+        ),
+        dict(
+            fullname='Conv1d_dilated',
+            constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
+            input_size=(2, 4, 10),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_groups',
+            constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)',
+            input_size=(2, 4, 6),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_valid',
+            constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same',
+            constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same2',
+            constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv1d_pad_same_dilated',
+            constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)',
+            input_size=(2, 4, 10),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose1d',
+            constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
+            cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)',
+            cudnn=True,
+            input_size=(1, 3, 7),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose1d',
+            constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
+                                    .stride(2).padding(1).output_padding(1).groups(1).bias(false)''',
+            input_size=(1, 3, 6),
+            cudnn=True,
+            desc='no_bias',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose1d',
+            constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
+                                    .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''',
+            input_size=(1, 3, 6),
+            cudnn=True,
+            desc='dilated',
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose1d_groups',
+            constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
+            cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3)
+                                    .stride(3).padding(1).output_padding(1).groups(2)''',
+            cudnn=True,
+            input_size=(2, 4, 7),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
+            input_size=(2, 3, 7, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 3), (2, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})',
+            input_size=(2, 3, 6, 6),
+            cudnn=True,
+            desc='strided',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})',
+            input_size=(2, 3, 6, 6),
+            cudnn=True,
+            desc='padding',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})',
+            input_size=(2, 3, 8, 8),
+            cudnn=True,
+            desc='dilated',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(2, 3, 6, 5),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv2d',
+            constructor_args=(3, 4, (3, 2)),
+            cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
+            input_size=(0, 3, 7, 5),
+            cudnn=True,
+            desc='zero_batch',
+            check_with_long_tensor=True,
+            with_tf32=True,
+        ),
+        dict(
+            fullname='Conv2d_groups',
+            constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
+            input_size=(2, 4, 6, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_groups_thnn',
+            constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
+            input_size=(2, 4, 6, 5),
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.015,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_valid',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_same',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_pad_same_dilated',
+            constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)',
+            input_size=(2, 2, 6, 5),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({3, 2}).padding(1).output_padding({1, 1})''',
+            cudnn=True,
+            input_size=(1, 3, 7, 6),
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({2, 3})
+                                    .padding(1)
+                                    .output_padding({1, 1})
+                                    .groups(1)
+                                    .bias(false)
+                                    .dilation({2, 2})''',
+            input_size=(1, 3, 6, 7),
+            cudnn=True,
+            desc='dilated',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose2d',
+            constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
+            cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
+                                    .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''',
+            input_size=(1, 3, 6, 7),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='ConvTranspose2d_groups',
+            constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
+            cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)',
+            input_size=(1, 2, 4, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.01,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_with_multiplier',
+            constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_strided',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_padded',
+            constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)',
+            input_size=(2, 4, 6, 6),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv2d_depthwise_dilated',
+            constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
+            cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
+            input_size=(2, 4, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (2, 3, 2)),
+            cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})',
+            input_size=(1, 2, 4, 5, 4),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            cudnn=True,
+            desc='no_bias',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False),
+            cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
+                                    .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            cudnn=True,
+            desc='1x1x1_no_bias',
+            check_with_long_tensor=False,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, 2, 2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)',
+            input_size=(2, 3, 5, 5, 5),
+            cudnn=True,
+            desc='stride',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, 2, 2, 1),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)',
+            input_size=(2, 3, 5, 5, 5),
+            cudnn=True,
+            desc='stride_padding',
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Conv3d',
+            constructor_args=(3, 4, (2, 3, 4)),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})',
+            input_size=(0, 3, 3, 4, 5),
+            cudnn=True,
+            check_with_long_tensor=True,
+            desc='zero_batch',
+            with_tf32=True,
+        ),
+        dict(
+            fullname='Conv3d_groups',
+            constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)',
+            input_size=(1, 2, 4, 5, 4),
+            cudnn=True,
+            check_with_long_tensor=True,
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_dilated',
+            constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
+            input_size=(2, 3, 5, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_dilated_strided',
+            constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
+            input_size=(2, 3, 5, 5, 5),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_valid',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_same',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Conv3d_pad_same_dilated',
+            constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2),
+            cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)',
+            input_size=(2, 3, 6, 5, 4),
+            cudnn=True,
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose3d',
+            constructor_args=(2, 3, (2, 3, 2)),
+            cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})',
+            cudnn=True,
+            input_size=(1, 2, 4, 5, 4),
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ConvTranspose3d',
+            constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
+            cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
+                                    .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''',
+            cudnn=True,
+            input_size=(1, 2, 4, 5, 4),
+            desc='dilated',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_size=(2, 3, 2, 2, 2),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_size=(3, 2, 2, 2),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='ReplicationPad3d',
+            constructor_args=((1, 2, 3, 3, 2, 1),),
+            cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
+            input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True),
+            skip_half=True,
+            desc='complex'
+        ),
+        dict(
+            module_name='Embedding',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+            decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
+        ),
+        dict(
+            module_name='Embedding',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
+            input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+            check_gradgrad=False,
+            desc='discontiguous',
+            default_dtype=torch.double,
+            decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='mean',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
+            input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
+            check_gradgrad=False,
+            desc='discontiguous',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3, None, 2., False, 'sum'),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='sum',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='EmbeddingBag',
+            constructor_args=(4, 3, None, 2., False, 'max'),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''',
+            input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
+            check_gradgrad=False,
+            desc='max',
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_mean_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1),
+            cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_sum_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_max_padding_idx',
+            constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''',
+            input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='EmbeddingBag_sparse',
+            constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
+            cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
+                                    .sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))''',
+            input_fn=lambda: torch.randperm(2).repeat(1, 2),
+            check_gradgrad=False,
+            has_sparse_gradients=True,
+        ),
+        dict(
+            constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
+            cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
+            input_fn=lambda: torch.randperm(2).repeat(1, 2),
+            fullname='Embedding_sparse',
+            check_gradgrad=False,
+            has_sparse_gradients=True,
+        ),
+        dict(
+            module_name='PixelShuffle',
+            constructor_args=(3,),
+            cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
+            input_size=(1, 9, 4, 4),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PixelUnshuffle',
+            constructor_args=(3,),
+            cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
+            input_size=(1, 1, 12, 12),
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_nearest_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(0, 2, 4),
+            fullname='interpolate_nearest_1d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
+            input_size=(1, 2, 3),
+            fullname='interpolate_nearest_tuple_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt).scale_factor(std::vector({4.})).mode(torch::kNearest)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_nearest_scale_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 3),
+            fullname='interpolate_linear_tuple_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4.}))
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_scale_1d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4),
+            fullname='interpolate_linear_1d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kLinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_1d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4.}))
+                                .mode(torch::kLinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4),
+            fullname='interpolate_linear_scale_1d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({2, 2}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 128, 1, 1),
+            fullname='interpolate_nearest_2d_launch_configs',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_nearest_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 16}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 3, 4),
+            fullname='interpolate_nearest_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_nearest_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_nearest_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_bilinear_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3),
+            fullname='interpolate_bilinear_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4.,
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 2.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_shared_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_skewed_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_tuple_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4),
+            fullname='interpolate_bicubic_2d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3),
+            fullname='interpolate_bicubic_tuple_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 2.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_shared_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bicubic', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_skewed_2d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kBicubic)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_tuple_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
+                                        mode='bicubic', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({2., 1.}))
+                                .mode(torch::kBicubic)
+                                .align_corners(true)''',
+            input_size=(1, 2, 4, 4),
+            fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_nearest_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(0, 2, 4, 4, 4),
+            fullname='interpolate_nearest_3d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 16, 16}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 3, 4, 4),
+            fullname='interpolate_nearest_tuple_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({4., 4., 4.}))
+                                .mode(torch::kNearest)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_nearest_scale_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 4, 4, 4),
+            fullname='interpolate_trilinear_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({12, 12, 12}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(0, 2, 4, 4, 4),
+            fullname='interpolate_trilinear_3d_zero_dim',
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
+                                        scale_factor=None, mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 2, 3, 3),
+            fullname='interpolate_trilinear_tuple_3d',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({3., 3., 3.}))
+                                .mode(torch::kTrilinear)
+                                .align_corners(false)''',
+            input_size=(1, 2, 3, 4, 5),
+            fullname='interpolate_trilinear_scale_3d',
+            # See https://github.com/pytorch/pytorch/issues/5006
+            precision=3e-4,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None,
+                                        mode='trilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::vector({4, 6, 6}))
+                                .scale_factor(std::nullopt)
+                                .mode(torch::kTrilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 2, 3, 3),
+            fullname='interpolate_trilinear_tuple_3d_align_corners',
+            pickle=False,
+            default_dtype=torch.double
+        ),
+        dict(
+            constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True),
+            cpp_options_args='''F::InterpolateFuncOptions()
+                                .size(std::nullopt)
+                                .scale_factor(std::vector({3., 3., 3.}))
+                                .mode(torch::kTrilinear)
+                                .align_corners(true)''',
+            input_size=(1, 2, 3, 4, 4),
+            fullname='interpolate_trilinear_scale_3d_align_corners',
+            # See https://github.com/pytorch/pytorch/issues/5006
+            precision=3e-4,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=-1),
+            cpp_options_args='F::SoftmaxFuncOptions(-1)',
+            input_size=(2, 128),  # trigger the last-dim algo in CUDA
+            fullname='softmax_lastdim',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
+            cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
+            input_size=(2, 128),
+            fullname='softmax_lastdim_dtype',
+            pickle=False,
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1),
+            cpp_options_args='F::SoftmaxFuncOptions(1)',
+            input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
+            fullname='softmax_spatial_special',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1),
+            cpp_options_args='F::SoftmaxFuncOptions(1)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='softmax_spatial',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
+            cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='softmax_spatial_dtype',
+            pickle=False,
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=0),
+            cpp_options_args='F::SoftmaxFuncOptions(0)',
+            input_size=(2, 3, 4, 5),
+            fullname='softmax_functional_dim0',
+            test_cuda=False,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=3),
+            cpp_options_args='F::SoftmaxFuncOptions(3)',
+            input_size=(2, 3, 4, 5),
+            fullname='softmax_functional_dim3',
+            test_cuda=False,
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.softmax, dim=-1),
+            cpp_options_args='F::SoftmaxFuncOptions(-1)',
+            input_size=(),
+            fullname='softmax_functional_scalar',
+            test_cuda=False,
+            pickle=False,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=-1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(-1)',
+            input_size=(2, 128),  # trigger the last-dim algo in CUDA
+            fullname='log_softmax_lastdim',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(1)',
+            input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
+            fullname='log_softmax_spatial_special',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=1),
+            cpp_options_args='F::LogSoftmaxFuncOptions(1)',
+            input_size=(2, 2, 4, 4),  # regular spatial algorithm
+            fullname='log_softmax_spatial',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=0),
+            cpp_options_args='F::LogSoftmaxFuncOptions(0)',
+            input_size=(2, 3, 4, 5),
+            fullname='log_softmax_dim0',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=3),
+            cpp_options_args='F::LogSoftmaxFuncOptions(3)',
+            input_size=(2, 3, 4, 5),
+            fullname='log_softmax_dim3',
+            pickle=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            constructor=wrap_functional(F.log_softmax, dim=0),
+            cpp_options_args='F::LogSoftmaxFuncOptions(0)',
+            input_size=(),
+            fullname='log_softmax_scalar',
+            pickle=False,
+        ),
+        dict(
+            fullname='Unfold',
+            constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(2, 4, 3, 3),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold',
+            constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(2, 16, 4),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_no_batch_dim_input',
+            constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
+            cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
+            input_size=(16, 4),
+            check_gradgrad=False,
+            ref=single_batch_reference_fn,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Unfold_int_input',
+            constructor=lambda: nn.Unfold(2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)',
+            input_size=(2, 4, 3, 3),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_int_input',
+            constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
+            input_size=(2, 16, 4),
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            fullname='Fold_no_batch_dim_int_input',
+            constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
+            cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
+            input_size=(16, 4),
+            ref=single_batch_reference_fn,
+            check_gradgrad=False,
+            test_cuda=True,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='RReLU',
+            constructor_args=(0.1, 0.9),
+            cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
+            input_size=(),
+            desc='with_up_down_scalar',
+            test_cuda=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
+            desc='broadcast_lhs',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
+            desc='broadcast_rhs',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            constructor_args=(1.5, 1e-05, True),
+            cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
+            input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
+            desc='with_non_default_args',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='PairwiseDistance',
+            input_fn=lambda: (torch.randn(8), torch.randn(8)),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerEncoderLayer',
+            constructor_args=(4, 2, 16, 0.0),
+            cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
+                                    .dim_feedforward(16)
+                                    .dropout(0.0)''',
+            input_size=(2, 3, 4),
+            desc='relu_activation',
+            with_tf32=True,
+            tf32_precision=0.1,
+            # TODO(#50743): figure out the error
+            # RuntimeError: The size of tensor a (6) must match the size of tensor b (4)
+            # at non-singleton dimension 2
+            check_batched_grad=False,
+            check_gradgrad=False,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerEncoderLayer',
+            constructor_args=(4, 2, 8, 0.0, F.gelu),
+            cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kGELU)''',
+            input_size=(2, 3, 4),
+            check_gradgrad=False,
+            desc='gelu_activation',
+            with_tf32=True,
+            tf32_precision=0.08 if SM90OrLater else 0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerDecoderLayer',
+            constructor_args=(4, 2, 8, 0.0),
+            cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
+            check_gradgrad=False,
+            desc='relu_activation',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='TransformerDecoderLayer',
+            constructor_args=(4, 2, 8, 0.0, F.gelu),
+            cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kGELU)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
+            check_gradgrad=False,
+            desc='gelu_activation',
+            with_tf32=True,
+            tf32_precision=0.05,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Transformer',
+            constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu),
+            cpp_constructor_args='''torch::nn::TransformerOptions()
+                                    .d_model(4)
+                                    .nhead(2)
+                                    .num_encoder_layers(2)
+                                    .num_decoder_layers(2)
+                                    .dim_feedforward(8)
+                                    .dropout(0.0)
+                                    .activation(torch::kReLU)''',
+            input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
+            check_gradgrad=False,
+            desc='multilayer_coder',
+            with_tf32=True,
+            tf32_precision=0.05 if SM90OrLater else 0.03,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Linear',
+            constructor_args=(3, 5),
+            cpp_constructor_args='torch::nn::LinearOptions(3, 5)',
+            input_fn=lambda: torch.rand(3),
+            reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1],
+            desc="no_batch_dim",
+            with_tf32=True,
+            tf32_precision=0.005,
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Flatten',
+            cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)',
+            constructor_args=(-3, -1),
+            input_size=(3, 4, 5),
+            reference_fn=single_batch_reference_fn,
+            desc="no_batch_dim",
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='Unflatten',
+            cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})',
+            constructor_args=(-2, torch.Size([2, 2])),
+            input_size=(3, 4, 5),
+            reference_fn=single_batch_reference_fn,
+            desc="no_batch_dim",
+            default_dtype=torch.double,
+        ),
+        dict(
+            module_name='LayerNorm',
+            constructor_args=([56, 56, 56], 1e-5, False),
+            cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)',
+            input_size=(4, 56, 56, 56),
+            cudnn=True,
+            check_eval=True,
+            gradcheck_fast_mode=True,
+            check_half=True,
+            desc='3d_no_affine_large_feature',
+        ),
+    ]
+
+    # add conv padding mode tests:
+    for padding_mode, cpp_padding_mode in zip(
+            ['reflect', 'circular', 'replicate', 'zeros'],
+            ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']):
+        # conv signature:
+        #     in_channels, out_channels, kernel_size, stride=1,
+        #     padding=0, dilation=1, groups=1,
+        #     bias=True, padding_mode='zeros'
+        for d in (1, 2, 3):
+            if d == 3 and padding_mode == 'reflect':
+                # FIXME: remove after implementing reflection pad 3d
+                #        https://github.com/pytorch/pytorch/issues/27655
+                continue
+            padding = tuple(range(1, d + 1))
+            cpp_padding = '{' + ', '.join(map(str, padding)) + '}'
+            input_size = (2, 2) + (4,) * d
+            output_size = (2, 3) + tuple(p + 1 for p in padding)  # simplified from `(4 + 2 * p - 3) // 2 + 1`
+            new_module_tests.append(
+                dict(
+                    module_name=f'Conv{d}d',
+                    constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode),
+                    cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3)
+                                            .stride(2)
+                                            .padding({cpp_padding})
+                                            .dilation(1)
+                                            .groups(1)
+                                            .bias(true)
+                                            .padding_mode({cpp_padding_mode})''',
+                    input_size=input_size,
+                    output_size=output_size,
+                    cudnn=True,
+                    desc=f'{padding_mode}_stride2_pad2',
+                    with_tf32=True,
+                    tf32_precision=0.05,
+                    default_dtype=torch.double,
+                ),
+            )
+
+    # Check that non linear activations work with no batch dimensions
+    non_linear_activations_no_batch = [
+        'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU',
+        'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU',
+        'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
+        'Tanhshrink', 'Threshold'
+    ]
+    non_linear_activations_extra_info: dict[str, dict] = {
+        'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
+        'Threshold': {'constructor_args': (2., 1.)},
+        'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
+        'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
+        # For RRelu, test that compare CPU and GPU results fail because RNG
+        # is different between CPU and GPU
+        'RReLU': {'test_cuda': False, 'default_dtype': torch.double},
+        'ELU': {'default_dtype': torch.double},
+        'GELU': {'default_dtype': torch.double},
+        'GLU': {'default_dtype': torch.double},
+        'Hardshrink': {'default_dtype': torch.double},
+        'Hardtanh': {'default_dtype': torch.double},
+        'LeakyReLU': {'default_dtype': torch.double},
+        'LogSigmoid': {'default_dtype': torch.double},
+        'Mish': {'default_dtype': torch.double},
+        'PReLU': {'default_dtype': torch.double},
+        'ReLU6': {'default_dtype': torch.double},
+        'ReLU': {'default_dtype': torch.double},
+        'SELU': {'default_dtype': torch.double},
+        'SiLU': {'default_dtype': torch.double},
+        'Sigmoid': {'default_dtype': torch.double},
+        'Softplus': {'default_dtype': torch.double},
+        'Softshrink': {'default_dtype': torch.double},
+        'Softsign': {'default_dtype': torch.double},
+        'Tanh': {'default_dtype': torch.double},
+        'Tanhshrink': {'default_dtype': torch.double},
+    }
+    for non_linear_activation in non_linear_activations_no_batch:
+        activation_test_info = dict(
+            module_name=non_linear_activation,
+            input_size=(4,),
+            reference_fn=single_batch_reference_fn,
+            desc='no_batch_dim',
+            test_cpp_api_parity=False,
+        )
+        extra_info = non_linear_activations_extra_info.get(non_linear_activation, {})
+        activation_test_info.update(extra_info)
+        new_module_tests.append(activation_test_info)
+
+
+    return new_module_tests
+
+
+def kldivloss_reference(input, target, reduction='mean', log_target=False):
+    if log_target:
+        result = torch.exp(target) * (target - input)
+    else:
+        result = target * (target.log() - input)
+    if reduction == 'mean':
+        return result.mean()
+    elif reduction == 'sum':
+        return result.sum()
+    elif reduction == 'batchmean' and result.dim() != 0:
+        return result.sum() / result.size(0)
+    return result
+
+
+def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
+                        reduction='mean'):
+    assert input.dim() >= 3
+    N = input.size(0)
+    C = input.size(1)
+    out_size = (N,) + input.size()[2:]
+    output = torch.zeros(out_size).type_as(input)
+
+    if weight is None:
+        weight = torch.ones(C).type_as(input)
+    total_weight = 0
+    for tup in product(*[range(size) for size in out_size]):
+        t_nx = target[tup]
+        norm = 0. if ignore_index == t_nx else weight[t_nx].item()
+        input_index = list(tup)
+        input_index.insert(1, t_nx)
+        output[tup] = -input[tuple(input_index)] * norm
+        total_weight += norm
+
+    if reduction == 'mean':
+        return output.sum() / total_weight
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean',
+                                             label_smoothing=0.0):
+    assert input.dim() >= 2
+
+    input = torch.log_softmax(input, 1)
+    C = input.size(1)
+    if weight is None:
+        weight = torch.ones(C).type_as(input)
+    weight = weight.view(1, C, *(1 for _ in input.shape[2:]))
+
+    if label_smoothing > 0.0:
+        assert label_smoothing <= 1.0
+        target = (target * (1 - label_smoothing) + label_smoothing / C)
+
+    output = -(input * target * weight).sum(dim=1)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100,
+                                                reduction='mean', label_smoothing=0.0):
+    log_softmax_input = torch.log_softmax(input, 1)
+    nllloss = F.nll_loss(
+        log_softmax_input,
+        target,
+        weight,
+        ignore_index=ignore_index,
+        reduction=reduction)
+
+    if label_smoothing == 0.0:
+        return nllloss
+
+    assert 0.0 < label_smoothing <= 1.0
+
+    input = torch.log_softmax(input, 1)
+    C = input.size(1)
+    if weight is not None:
+        input = input * weight.view(1, C, *(1 for _ in input.shape[2:]))
+
+    smooth_loss = -torch.sum(input, 1)
+
+    ignore_mask = target == ignore_index
+    smooth_loss.masked_fill_(ignore_mask, 0.0)
+
+    if reduction == 'mean':
+        if weight is not None:
+            # TODO: This code can path can be removed if #61309 is resolved
+            # loss is normalized by the weights to be consistent with nll_loss_nd
+            ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum()
+        else:
+            ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not()))
+    elif reduction == 'sum':
+        ret = torch.sum(smooth_loss)
+    else:
+        ret = smooth_loss
+
+    return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C)
+
+
+def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean',
+                                 label_smoothing=0.0):
+    if input.shape == target.shape:
+        return cross_entropy_loss_prob_target_reference(
+            input,
+            target,
+            weight=weight,
+            reduction=reduction,
+            label_smoothing=label_smoothing)
+    else:
+        return cross_entropy_loss_indices_target_reference(
+            input, target, weight=weight, reduction=reduction,
+            ignore_index=ignore_index, label_smoothing=label_smoothing
+        )
+
+
+def nllloss_reference(input, target, weight=None, ignore_index=-100,
+                      reduction='mean'):
+
+    def nll_loss_helper(input, target, weight, ignore_index):
+        if target == ignore_index:
+            return (0, 0)
+        norm = 1 if weight is None else weight[target]
+        result = -input[target] * norm
+        return (result, norm)
+
+    losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
+                          for i, t in zip(input, target)]
+    losses, weights = zip(*losses_and_weights)
+    losses_tensor = input.new_tensor(losses)
+    if reduction == 'mean':
+        return sum(losses_tensor) / sum(weights)
+    elif reduction == 'sum':
+        return sum(losses_tensor)
+    else:
+        return losses_tensor
+
+
+def smoothl1loss_reference(input, target, reduction='mean', beta=1.0):
+    abs_diff = (input - target).abs()
+    ge_beta_mask = (abs_diff >= beta).type_as(abs_diff)
+    lt_beta_mask = (abs_diff < beta).type_as(abs_diff)
+    # when beta <= 0 we should just use l1_loss
+    if beta == 0:
+        output = abs_diff
+    else:
+        output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def huberloss_reference(input, target, reduction='mean', delta=1.0):
+    abs_diff = (input - target).abs()
+    ge_delta_mask = (abs_diff >= delta)
+    lt_delta_mask = (abs_diff < delta)
+    output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def _multilabelmarginloss_reference(input, target):
+    targets = []
+    for target_index in target:
+        if target_index < 0:
+            break
+        targets.append(target_index)
+
+    sum = 0
+    for target_index in targets:
+        for i in range(0, len(input)):
+            if i not in targets:
+                sum += max(0, 1 - input[target_index] + input[i])
+
+    return sum
+
+
+def multilabelmarginloss_reference(input, target, reduction='mean'):
+    # make everything 2-dimensional
+    input_dim = input.dim()
+    if input.dim() < 2:
+        assert target.dim() < 2
+        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
+        target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0)
+
+    n = input.size(0)
+    dim = input.size(1)
+    output = input.new(n).zero_()
+    for i in range(0, n):
+        output[i] = _multilabelmarginloss_reference(input[i], target[i])
+
+    if reduction == 'mean':
+        return output.mean() / dim
+    elif reduction == 'sum':
+        return output.sum() / dim
+    elif input_dim < 2:
+        # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
+        # back to correct dimensionality
+        return output.squeeze() / dim
+    else:
+        return output / dim
+
+
+def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
+    margin_clamp = (margin - input).clamp(min=0).type_as(input)
+    output = torch.where(target == 1, input, margin_clamp)
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def softmarginloss_reference(input, target, reduction='mean'):
+    output = (1 + (-input * target).exp()).log()
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def _multimarginloss_reference(input, target_idx, p, margin, weight):
+    if weight is None:
+        weight = input.new(len(input)).fill_(1)
+
+    output = 0
+    for i in range(0, len(input)):
+        if i != target_idx:
+            output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p)
+    return output
+
+
+def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
+    if input.dim() < 2:
+        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
+
+    target_dim = target.dim()
+    if target.dim() == 0:
+        target = target.unsqueeze(0)
+
+    n = input.size(0)
+    dim = input.size(1)
+    output = input.new(n)
+    for x in range(0, n):
+        output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
+
+    if reduction == 'mean':
+        return output.mean() / dim
+    elif reduction == 'sum':
+        return output.sum() / dim
+    elif target_dim == 0:
+        return output.squeeze(0) / dim
+    return output / dim
+
+
+def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
+    def _cos(a, b):
+        cos = a.new(a.size(0))
+        for i in range(0, a.size(0)):
+            cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
+        return cos
+
+    output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
+
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
+                                reduction='mean'):
+    d_p = torch.pairwise_distance(anchor, positive, p, eps)
+    d_n = torch.pairwise_distance(anchor, negative, p, eps)
+    if swap:
+        d_s = torch.pairwise_distance(positive, negative, p, eps)
+        d_n = torch.min(d_n, d_s)
+
+    output = torch.clamp(margin + d_p - d_n, min=0.0)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
+    output = (-target * (input1 - input2) + margin).clamp(min=0)
+    if reduction == 'mean':
+        return output.mean()
+    elif reduction == 'sum':
+        return output.sum()
+    return output
+
+
+# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space
+def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
+    input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
+    target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
+    dt = log_probs.dtype
+    log_probs = log_probs.double()  # we need the accuracy as we are not in logspace
+    targets = targets.long()
+    cum_target_lengths = target_lengths.cumsum(0)
+    losses = []
+    for i in range(log_probs.size(1)):
+        input_length = input_lengths[i].item()
+        target_length = target_lengths[i].item()
+        cum_target_length = cum_target_lengths[i].item()
+        targets_prime = targets.new_full((2 * target_length + 1,), blank)
+        if targets.dim() == 2:
+            targets_prime[1::2] = targets[i, :target_length]
+        else:
+            targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
+        probs = log_probs[:input_length, i].exp()
+        alpha = log_probs.new_zeros((target_length * 2 + 1,))
+        alpha[0] = probs[0, blank]
+        alpha[1] = probs[0, targets_prime[1]]
+        mask_third = (targets_prime[:-2] != targets_prime[2:])
+        for t in range(1, input_length):
+            alpha_next = alpha.clone()
+            alpha_next[1:] += alpha[:-1]
+            alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
+            alpha = probs[t, targets_prime] * alpha_next
+        losses.append(-alpha[-2:].sum().log()[None])
+    output = torch.cat(losses, 0)
+    if reduction == 'mean':
+        output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
+    elif reduction == 'sum':
+        output = output.sum()
+    output = output.to(dt)
+    return output
+
+
+loss_reference_fns: dict['str', Callable] = {
+    'KLDivLoss': kldivloss_reference,
+    'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True),
+    'NLLLoss': nllloss_reference,
+    'NLLLossNd': nlllossNd_reference,
+    'SmoothL1Loss': smoothl1loss_reference,
+    'HuberLoss': huberloss_reference,
+    'MultiLabelMarginLoss': multilabelmarginloss_reference,
+    'HingeEmbeddingLoss': hingeembeddingloss_reference,
+    'SoftMarginLoss': softmarginloss_reference,
+    'MultiMarginLoss': multimarginloss_reference,
+    'CosineEmbeddingLoss': cosineembeddingloss_reference,
+    'TripletMarginLoss': tripletmarginloss_reference,
+    'MarginRankingLoss': marginrankingloss_reference,
+    'CTCLoss': ctcloss_reference,
+    'CrossEntropyLoss': cross_entropy_loss_reference
+}
+
+
+criterion_tests = []
+
+
+def single_batch_reference_criterion_fn(*args):
+    """Reference function for criterion supporting no batch dimensions.
+
+    The criterion is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    criterion = args[-1]
+
+    def unsqueeze_inp(inp):
+        if isinstance(inp, (list, tuple)):
+            return [t.unsqueeze(0) for t in inp]
+        return inp.unsqueeze(0)
+
+    def flatten(xs):
+        result = []
+        if isinstance(xs, (list, tuple)):
+            for x in xs:
+                result.extend(flatten(x))
+        else:
+            result.append(xs)
+        return result
+
+    single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])
+
+    output = criterion(*single_batch_input_args)
+    reduction = get_reduction(criterion)
+
+    if reduction == 'none':
+        return output.squeeze(0)
+    # reduction is 'sum' or 'mean' which results in a scalar
+    return output
+
+
+# Check that regression criterion work with no batch dimensions
+regression_criterion_no_batch = [
+    'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
+]
+reductions = ['none', 'mean', 'sum']
+for name, reduction in product(regression_criterion_no_batch, reductions):
+    regression_test_info = dict(
+        fullname=f"{name}_no_batch_dim_{reduction}",
+        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
+        input_size=(3, ),
+        target_size=(3, ),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=False,
+        default_dtype=torch.double,
+    )
+    criterion_tests.append(regression_test_info)
+
+
+for reduction in reductions:
+    regression_test_info = dict(
+        fullname=f"KLDivLoss_no_batch_dim_{reduction}",
+        constructor=lambda: nn.KLDivLoss(reduction=reduction),
+        input_fn=lambda: torch.rand((3,)).log(),
+        target_fn=lambda: torch.rand((3,)),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=False,
+        default_dtype=torch.double,
+    )
+    criterion_tests.append(regression_test_info)
+
+
+# Check that classification criterion work with no batch dimensions
+# List of tuples of (name, input_fn, target_fn)
+classification_criterion_no_batch = [
+    (
+        'BCELoss',
+        lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)),
+        lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double)
+    ),
+    ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)),
+    ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
+    ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])),
+    ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
+    ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)),
+    (
+        'CosineEmbeddingLoss',
+        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
+        lambda: torch.tensor(1, dtype=torch.double)
+    ),
+    # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target
+    ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()),
+    # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
+    (
+        'TripletMarginLoss',
+        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
+        lambda: torch.randn(9, dtype=torch.double)
+    ),
+    ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)),
+]
+classification_criterion_no_batch_extra_info: dict[str, dict] = {
+    'MultiLabelMarginLoss': {'check_gradgrad': False},
+}
+# TODO : Fix these discrepancies
+classification_cpp_parity = {
+    'BCELoss': False,
+    'BCEWithLogitsLoss': False,
+    'HingeEmbeddingLoss': False,
+    'NLLLoss': False,
+    'SoftMarginLoss': False,
+}
+reductions = ['none', 'mean', 'sum']
+for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch,
+                                                      reductions):
+    classification_test_info = dict(
+        fullname=f"{name}_no_batch_dim_{reduction}",
+        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
+        input_fn=lambda f=input_fn: f(),
+        target_fn=lambda f=target_fn: f(),
+        reference_fn=single_batch_reference_criterion_fn,
+        test_cpp_api_parity=True,
+        has_parity=classification_cpp_parity.get(name, True)
+    )
+    extra_info = classification_criterion_no_batch_extra_info.get(name, {})
+    classification_test_info.update(extra_info)
+    criterion_tests.append(classification_test_info)
+
+
+class NNTestCase(TestCase):
+
+    # _forward is defined in classes inheriting from NNTestCase
+    @abstractmethod
+    def _forward(self, *args, **kwargs):
+        raise NotImplementedError
+
+    @abstractmethod
+    def _get_parameters(self, module: nn.Module) -> tuple[list[nn.Parameter], list[nn.Parameter]]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _zero_grad_parameters(self, module: nn.Module) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def _backward(self, module: nn.Module,
+                  input: _TensorOrTensors, output: torch.Tensor,
+                  grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
+                  create_graph: bool = False):
+        raise NotImplementedError
+
+    def _jacobian(self, input, num_out):
+        if isinstance(input, tuple):
+            return tuple(self._jacobian(elem, num_out) for elem in input)
+        elif isinstance(input, list):
+            return [self._jacobian(elem, num_out) for elem in input]
+        else:
+            return torch.zeros(input.nelement(), num_out)
+
+    def _flatten_tensors(self, x):
+        if isinstance(x, torch.Tensor):
+            if x.is_sparse:
+                return x.to_dense().view(-1)
+            else:
+                return x.view(-1)
+        else:
+            return tuple(self._flatten_tensors(a) for a in x)
+
+    def _zero_grad_input(self, input):
+        if isinstance(input, torch.Tensor):
+            if input.requires_grad and input.grad is not None:
+                input.grad.zero_()
+                input.grad.detach_()
+        else:
+            for i in input:
+                self._zero_grad_input(i)
+
+    def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
+        output = self._forward(module, input)
+        output_size = output.nelement()
+
+        if jacobian_input:
+            jacobian_inp = self._jacobian(input, output_size)
+            flat_jacobian_input = list(_iter_tensors(jacobian_inp))
+
+        if jacobian_parameters:
+            num_param = sum(p.numel() for p in self._get_parameters(module)[0])
+            jacobian_param = torch.zeros(num_param, output_size)
+
+        for i in range(output_size):
+            param, d_param = self._get_parameters(module)
+            # make non grad zeros
+            d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)]
+
+            d_out = torch.zeros_like(output)
+            flat_d_out = d_out.view(-1)
+            flat_d_out[i] = 1
+
+            if jacobian_parameters:
+                self._zero_grad_parameters(module)
+            # Tensors will accumulate gradient from multiple steps
+            if jacobian_input:
+                self._zero_grad_input(input)
+            d_input = self._backward(module, input, output, d_out)
+
+            if jacobian_input:
+                for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input)):
+                    jacobian_x[:, i] = d_x.contiguous().view(-1)
+            if jacobian_parameters:
+                jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
+
+        res: tuple[torch.Tensor, ...] = ()
+        if jacobian_input:
+            res += jacobian_inp,
+        if jacobian_parameters:
+            res += jacobian_param,
+
+        return res
+
+    def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
+        def fw(*input):
+            return self._forward(module, input).detach()
+
+        res: tuple[torch.Tensor, ...] = ()
+        if jacobian_input:
+            res += _get_numerical_jacobian(fw, input, eps=1e-6),
+        if jacobian_parameters:
+            param, _ = self._get_parameters(module)
+            to_cat = []
+            for p in param:
+                jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6)
+                # get_numerical_jacobian returns a list of tuples but we require a tensor
+                to_cat.append(jacobian[0][0])
+            res += (torch.cat(to_cat, 0),)
+        return res
+
+    def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True):
+        jacobian_parameters = bool(self._get_parameters(module)[0])
+        analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
+        numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
+        analytical_t = list(_iter_tensors(analytical))
+        numerical_t = list(_iter_tensors(numerical))
+
+        differences = []
+        for a, n in zip(analytical_t, numerical_t):
+            if a.numel() != 0:
+                differences.append(a.add(n, alpha=-1).abs().max())
+            # TODO: compare structure (ensure analytic jacobian has correct shape)
+        if len(differences) > 0:
+            self.assertLessEqual(max(differences), PRECISION)  # type: ignore[type-var]
+
+
+class TestBase:
+
+    _required_arg_names = {'constructor_args', 'input', 'extra_args'}
+
+    def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
+        self.desc = desc
+        self.fullname = fullname
+        self.constructor = constructor
+        self.reference_fn = reference_fn
+        for name in self._required_arg_names:
+            if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
+                if name in {'constructor_args', 'extra_args'}:
+                    kwargs[name] = ()
+                else:
+                    raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!")
+        self._extra_kwargs = kwargs
+        self._arg_cache = {}
+
+    def get_name(self):
+        if self.fullname is not None:
+            return 'test_' + self.fullname
+
+        test_name = 'test_' + self.constructor.__name__
+        if self.desc:
+            test_name += '_' + self.desc
+        return test_name
+
+    def _unpack(self, value):
+        if isinstance(value, torch.Tensor):
+            return value
+        elif is_iterable(value):
+            return type(value)(self._unpack(v) for v in value)
+        else:
+            return value
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', True)
+
+    @property
+    def extra_args(self):
+        return self._get_arg('extra_args', True)
+
+    def _get_arg(self, name, unpack):
+        assert name in self._required_arg_names
+
+        if name not in self._arg_cache:
+            fn_name = name + '_fn'
+            size_name = name + '_size'
+
+            if name in self._extra_kwargs:
+                self._arg_cache[name] = self._extra_kwargs[name]
+            elif fn_name in self._extra_kwargs:
+                self._arg_cache[name] = self._extra_kwargs[fn_name]()
+            else:
+                assert size_name in self._extra_kwargs, \
+                    f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}"
+
+                def map_tensor_sizes(sizes):
+                    if isinstance(sizes, list):
+                        return [map_tensor_sizes(s) for s in sizes]
+                    elif isinstance(sizes, torch.Tensor):
+                        return sizes.double()
+                    else:
+                        return torch.randn(sizes)
+
+                self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
+
+        return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
+
+    def _get_input(self, unpack=True):
+        return self._get_arg('input', unpack)
+
+    def __call__(self, test_case):
+        raise NotImplementedError
+
+
+class ModuleTest(TestBase):
+
+    @abstractmethod
+    def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
+        raise NotImplementedError
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.jacobian_input = kwargs.get('jacobian_input', True)
+        self.should_test_cuda = kwargs.get('test_cuda', True)
+        self.should_test_pickle = kwargs.get('pickle', True)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.FIXME_no_cuda_gradgrad_comparison = \
+            kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
+        self.precision = kwargs.get('precision', 2e-4)
+        self.check_forward_only = kwargs.get('check_forward_only', False)
+        self.default_dtype = kwargs.get('default_dtype', None)
+        if self.default_dtype is None:
+            self.default_dtype = torch.get_default_dtype()
+
+    def __call__(self, test_case):
+        with set_default_dtype(self.default_dtype):
+            module = self.constructor(*self.constructor_args)
+            input = self._get_input()
+
+            if self.reference_fn is not None:
+                out = test_case._forward(module, input)
+                ref_input = deepcopy(input)
+                ref_module = deepcopy(module)
+                expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module)
+                test_case.assertEqual(out, expected_out, exact_dtype=False)
+            if self.check_forward_only:
+                return
+            self.test_noncontig(test_case, module, input)
+
+            if self.should_test_pickle:
+                # TODO: do this with in-memory files as soon as torch.save will support it
+                with tempfile.TemporaryFile() as f:
+                    test_case._forward(module, input)
+                    torch.save(module, f)
+                    f.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    module_copy = torch.load(f, weights_only=False)
+                    test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
+
+            self._do_test(test_case, module, input)
+
+    def noncontiguize(self, obj):
+        if isinstance(obj, list):
+            return [self.noncontiguize(o) for o in obj]
+        elif isinstance(obj, tuple):
+            return tuple(self.noncontiguize(o) for o in obj)
+        tensor = obj
+        ndim = tensor.dim()
+        # Always making only the last dimension noncontiguous is easy to hide
+        # bugs because .view(-1) will still work. So try to find a dim with size
+        # > 1 and make that non-contiguous, i.e., stack + select on the
+        # dimension directly after that.
+        dim = ndim
+        for d in range(ndim):
+            if tensor.size(d) > 1:
+                dim = d + 1
+                break
+        noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
+        assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous()
+        noncontig.requires_grad = tensor.requires_grad
+        return noncontig
+
+    def test_noncontig(self, test_case, module, input):
+        # check no scalars, can't make non-contig
+        if isinstance(input, torch.Tensor) and input.dim() == 0:
+            return
+        if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
+            return
+
+        test_case._zero_grad_parameters(module)
+        test_case._zero_grad_input(input)
+        with freeze_rng_state():
+            output = test_case._forward(module, input)
+            if getattr(module, "return_indices", False):
+                output = output[0]
+            grad_output = output.new(output.shape).normal_()
+            output = output.clone()
+            d_input = deepcopy(test_case._backward(module, input, output, grad_output))
+            d_param = deepcopy(test_case._get_parameters(module)[1])
+
+        nc_input = self.noncontiguize(input)
+        nc_grad_output = self.noncontiguize(grad_output)
+        for contig_i, contig_g in product((True, False), repeat=2):
+            i = input if contig_i else nc_input
+            # Some ops, e.g., nn.Flatten, return gradient that shares
+            # storage with the grad_output. Hence we copy here.
+            go = deepcopy(grad_output if contig_g else nc_grad_output)
+            test_case._zero_grad_parameters(module)
+            test_case._zero_grad_input(i)
+            with freeze_rng_state():
+                out = test_case._forward(module, i)
+                if getattr(module, "return_indices", False):
+                    out = out[0]
+                grad = test_case._backward(module, i, out, go)
+
+                test_case.assertEqual(out, output)
+                test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0)
+                test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
+
+    def test_cuda(self, test_case):
+        if not TEST_CUDA or not self.should_test_cuda:
+            raise unittest.SkipTest('Excluded from CUDA tests')
+
+        with set_default_dtype(self.default_dtype):
+            cpu_input = self._get_input()
+
+            type_map = {torch.double: torch.float}
+            cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
+
+            is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple)
+
+            gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
+
+            cpu_module = self.constructor(*self.constructor_args)
+            gpu_module = self.constructor(*self.constructor_args).float().cuda()
+            cpu_param = test_case._get_parameters(cpu_module)
+            gpu_param = test_case._get_parameters(gpu_module)
+            for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
+                gpu_p.data.copy_(cpu_p)
+
+            test_case._zero_grad_input(cpu_input_tuple)
+            test_case._zero_grad_input(gpu_input_tuple)
+            test_case._zero_grad_parameters(cpu_module)
+            test_case._zero_grad_parameters(gpu_module)
+            cpu_output = test_case._forward(cpu_module, cpu_input_tuple)
+            gpu_output = test_case._forward(gpu_module, gpu_input_tuple)
+            if getattr(cpu_module, "return_indices", False):
+                cpu_output = cpu_output[0]
+                gpu_output = gpu_output[0]
+            test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False)
+
+            # Run backwards on CPU and GPU and compare results
+            for _ in range(5):
+                cpu_gradOutput = cpu_output.clone().normal_()
+                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output)
+                cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput)
+                gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput)
+                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
+                for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
+                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)
+
+            # Run double-backwards on CPU and GPU and compare results
+            if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
+                cpu_output = cpu_module(*cpu_input_tuple)
+                gpu_output = gpu_module(*gpu_input_tuple)
+                if getattr(cpu_module, "return_indices", False):
+                    cpu_output = cpu_output[0]
+                    gpu_output = gpu_output[0]
+
+                cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
+                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
+                gpu_gradOutput.requires_grad = True
+
+                cpu_gradInputs = torch.autograd.grad(
+                    cpu_output,
+                    cpu_input_tuple + tuple(cpu_module.parameters()),
+                    cpu_gradOutput,
+                    create_graph=True)
+                gpu_gradInputs = torch.autograd.grad(
+                    gpu_output,
+                    gpu_input_tuple + tuple(gpu_module.parameters()),
+                    gpu_gradOutput,
+                    create_graph=True)
+
+                for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
+                    test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False)
+
+                # We mix output into the second backwards computation so that
+                # torch.autograd.grad doesn't complain that some inputs
+                # are unreachable (which can happen if you differentiate
+                # only on the gradient.
+                if is_any_input_complex:
+                    outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs)
+                    outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs)
+                else:
+                    outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs)
+                    outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs)
+
+                cpu_gg = torch.autograd.grad(
+                    outputs_cpu,
+                    cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
+                    retain_graph=True)
+                gpu_gg = torch.autograd.grad(
+                    outputs_gpu,
+                    gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
+                    retain_graph=True)
+                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
+                for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
+                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False)
+
+            self.test_noncontig(test_case, gpu_module, gpu_input_tuple)
+
+
+class InputVariableMixin:
+    def _get_input(self):
+        input = TestBase._get_input(self, False)  # type: ignore[arg-type]
+
+        def map_variables(i):
+            if isinstance(i, torch.Tensor):
+                if i.is_floating_point() or i.is_complex():
+                    i.requires_grad = True
+                return i
+            else:
+                return type(i)(map_variables(elem) for elem in i)
+
+        return map_variables(input)
+
+
+class NewModuleTest(InputVariableMixin, ModuleTest):  # type: ignore[misc]
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.cudnn = kwargs.get('cudnn', False)
+        self.check_inplace = kwargs.get('check_inplace', False)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.skip_double = kwargs.get('skip_double', False)
+        self.skip_half = kwargs.get('skip_half', False)
+        self.with_tf32 = kwargs.get('with_tf32', False)
+        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
+        self.test_cpu = kwargs.get('test_cpu', True)
+        self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
+        self.check_batched_grad = kwargs.get('check_batched_grad', True)
+        self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None)
+        self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
+        self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)
+
+    def _check_gradients(self, test_case, module, input_tuple):
+        params = tuple(x for x in module.parameters())
+        num_inputs = len(input_tuple)
+
+        def fn_to_gradcheck(*inputs_and_params, **kwargs):
+            assert not kwargs
+            return test_case._forward(module, inputs_and_params[:num_inputs])
+
+        # gradcheck doesn't support operators that take in dense inputs but
+        # return sparse parameters. This only happens in the case of nn.Embedding
+        # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which
+        # is a slightly different version of gradcheck that can handle this.
+        if self.has_sparse_gradients:
+            assert num_inputs == 1
+            test_input_jacobian = torch.is_floating_point(input_tuple[0])
+            test_case.check_jacobian(module, input_tuple[0], test_input_jacobian)
+        else:
+            test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params,
+                                           check_batched_grad=self.check_batched_grad,
+                                           fast_mode=self.gradcheck_fast_mode,
+                                           check_forward_ad=self.supports_forward_ad))
+
+        if self.check_gradgrad:
+            test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params,
+                                               check_batched_grad=self.check_batched_grad,
+                                               fast_mode=self.gradcheck_fast_mode,
+                                               check_fwd_over_rev=self.supports_fwgrad_bwgrad))
+
+    def _do_test(self, test_case, module, input):
+        num_threads = torch.get_num_threads()
+        torch.set_num_threads(1)
+        input_tuple = input if isinstance(input, tuple) else (input,)
+
+        self._check_gradients(test_case, module, input_tuple)
+
+        # check if module can be printed
+        module.__repr__()
+
+        if self.check_inplace:
+            # check if the inplace variant of the module gives the same result
+            # as the out-of-place
+
+            # check_inplace doesn't support multiple input tensors, since we don't have any modules
+            # that modify the inputs in-place and that accept more than one input
+            assert len(input_tuple) == 1
+            input = input_tuple[0]
+
+            module_ip = self.constructor(*self.constructor_args, inplace=True)
+
+            input_version = input._version
+            with freeze_rng_state():
+                output = module(input)
+            test_case.assertEqual(input._version, input_version)
+
+            input_ip = deepcopy(input)
+            input_ip_clone = input_ip.clone()
+            with freeze_rng_state():
+                output_ip = module_ip(input_ip_clone)
+            test_case.assertNotEqual(input_ip_clone._version, input_version)
+            test_case.assertEqual(output, output_ip)
+            grad = output.data.clone().normal_()
+            if input.grad is not None:
+                with torch.no_grad():
+                    input.grad.zero_()
+            if input_ip.grad is not None:
+                with torch.no_grad():
+                    input_ip.grad.zero_()
+            output.backward(grad)
+            output_ip.backward(grad)
+            test_case.assertEqual(input.grad, input_ip.grad)
+
+        def assert_module_parameters_are(tensor_type, device_id=None):
+            for p in module.parameters():
+                test_case.assertIsInstance(p, tensor_type)
+                if device_id is not None:
+                    test_case.assertEqual(p.get_device(), device_id)
+
+        if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA:
+            # check that cuda() moves module parameters to correct GPU device,
+            # and that float() casts parameters correctly
+            input_tuple = tuple(t.cuda() for t in input_tuple)
+            module.float().cuda()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+            if torch.cuda.device_count() > 1:
+                input_tuple = tuple(t.cuda(1) for t in input_tuple)
+                module.cuda(1)
+                with torch.cuda.device(1):
+                    module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
+        else:
+            # check that float()/double() casters work correctly
+            def to_type(tensor, real, complex):
+                if tensor.is_complex():
+                    return tensor.to(complex)
+                elif tensor.is_floating_point():
+                    return tensor.to(real)
+                else:
+                    return tensor
+
+            def to_half(x):
+                # TODO: torch.complex32 when properly supported
+                return to_type(x, torch.float16, None)
+
+            def to_single(x):
+                return to_type(x, torch.float32, torch.complex64)
+
+            def to_double(x):
+                return to_type(x, torch.float64, torch.complex128)
+
+            # to float
+            input_tuple = tuple(to_single(t) for t in input_tuple)
+            module.float()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.FloatTensor)
+
+            # and back to double
+            input_tuple = tuple(to_double(t) for t in input_tuple)
+            module.double()
+            module(*input_tuple)
+            assert_module_parameters_are(torch.DoubleTensor)
+
+            if TEST_CUDA and self.should_test_cuda:
+                # check that cuda() moves module parameters to correct GPU device,
+                # and that float() casts parameters correctly
+
+                # to GPU0
+                input_tuple = tuple(to_single(t).cuda() for t in input_tuple)
+                module.float().cuda()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                # to CPU
+                input_tuple = tuple(t.cpu() for t in input_tuple)
+                module.cpu()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.FloatTensor)
+
+                # back to GPU0
+                input_tuple = tuple(t.cuda() for t in input_tuple)
+                module.cuda()
+                module(*input_tuple)
+                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                # test that forwards of module runs correctly without cuDNN
+                if self.cudnn:
+                    with torch.backends.cudnn.flags(enabled=False):
+                        module(*input_tuple)
+                        assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
+
+                if torch.cuda.device_count() >= 2:
+                    # test cross-GPU transfer works
+                    # to GPU1
+                    input_tuple = tuple(t.cuda(1) for t in input_tuple)
+                    module.cuda(1)
+                    with torch.cuda.device(1):
+                        module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
+
+                if not self.skip_double:
+                    # test double()
+                    input_tuple = tuple(to_double(t).cuda() for t in input_tuple)
+                    module.double().cuda()
+                    module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.DoubleTensor, 0)  # type: ignore[attr-defined]
+
+                # test half()
+                if not self.skip_half:
+                    input_tuple = tuple(to_half(t).cuda() for t in input_tuple)
+                    module.half().cuda()
+                    module(*input_tuple)
+                    assert_module_parameters_are(torch.cuda.HalfTensor, 0)  # type: ignore[attr-defined]
+        torch.set_num_threads(num_threads)
+
+    def _get_target(self):
+        return self._get_arg('target', False)
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', False)
+
+
+class CriterionTest(InputVariableMixin, TestBase):  # type: ignore[misc]
+    # TODO: check that criterions don't ignore grad_output
+
+    _required_arg_names = TestBase._required_arg_names.union({'target'})
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.should_test_cuda = kwargs.get('test_cuda', True)
+        self.check_forward_only = kwargs.get('check_forward_only', False)
+        self.check_gradgrad = kwargs.get('check_gradgrad', True)
+        self.check_half = kwargs.get('check_half', True)
+        self.check_bfloat16 = kwargs.get('check_bfloat16', False)
+        self.check_complex = kwargs.get('check_complex', False)
+        self.test_cpu = kwargs.get('test_cpu', True)
+        self.with_tf32 = kwargs.get('with_tf32', True)
+        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
+        self.check_batched_grad = kwargs.get('check_batched_grad', True)
+        self.default_dtype = kwargs.get('default_dtype', None)
+        if self.default_dtype is None:
+            self.default_dtype = torch.get_default_dtype()
+
+    def __call__(self, test_case):
+        with set_default_dtype(self.default_dtype):
+            module = self.constructor(*self.constructor_args)
+            input = self._get_input()
+
+            # Check that these methods don't raise errors
+            module.__repr__()
+            str(module)
+
+            target = self._get_target()
+
+            if self.reference_fn is not None:
+                out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
+                ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
+                expected_out = self.reference_fn(*ref_args)
+                test_case.assertEqual(out, expected_out)
+
+            if self.check_forward_only:
+                return
+
+            params = tuple(x for x in module.parameters())
+            if not isinstance(input, tuple):
+                inputs = (input,) + params + (target,)
+
+                def apply_fn(input, target, *params):
+                    return module(input, target)
+            else:
+                inputs = input + params + (target,)
+
+                def apply_fn(input1, input2, target, *params):  # type: ignore[misc]
+                    return module(input1, input2, target)
+
+            gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
+
+            if self.check_gradgrad:
+                gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
+
+    def test_cuda(self, test_case, dtype, extra_args=None):
+        def convert_dtype(obj, dtype, requires_grad=False):
+            if isinstance(obj, torch.Tensor):
+                return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
+            elif isinstance(obj, tuple):
+                return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
+            else:
+                return obj
+
+        if not TEST_CUDA or not self.should_test_cuda:
+            raise unittest.SkipTest('Excluded from CUDA tests')
+
+        with set_default_dtype(self.default_dtype):
+            cpu_input = self._get_input()
+            cpu_target = self._get_target()
+            cpu_module = self.constructor(*self.constructor_args)
+            gpu_module = self.constructor(*self.constructor_args)
+
+            # Convert input, target and module parameters to dtype
+            cpu_input = convert_dtype(cpu_input, dtype, True)
+            if cpu_target.is_floating_point() or cpu_target.is_complex():
+                cpu_target = convert_dtype(cpu_target, dtype)
+            cpu_module.type(dtype)
+            gpu_module.type(dtype)
+
+            # GPU setup
+            gpu_input = to_gpu(cpu_input)
+            gpu_target = to_gpu(cpu_target)
+            gpu_module.cuda()
+
+            # torch.HalfTensor doesn't support most operations, converting back to default
+            if dtype in {torch.half, torch.bfloat16}:
+                cpu_input = self._get_input()
+                cpu_target = self._get_target()
+                # Loss modules with weights require consistent input/module weight types
+                cpu_module = self.constructor(*self.constructor_args)
+
+            cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
+            gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
+            # dtype used to be able to be None, so set precision in this way instead of a precision map
+            test_case.assertEqual(cpu_output, gpu_output,
+                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
+
+            cpu_gradInput = test_case._backward_criterion(
+                cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
+            gpu_gradInput = test_case._backward_criterion(
+                gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
+            # dtype used to be able to be None, so set precision in this way instead of a precision map
+            test_case.assertEqual(cpu_gradInput, gpu_gradInput,
+                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
+
+    def _get_target(self):
+        return self._get_arg('target', False)
+
+    @property
+    def constructor_args(self):
+        return self._get_arg('constructor_args', False)
+
+    @property
+    def extra_args(self):
+        return self._get_arg('extra_args', False)
+
+
+def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None):
+    # fp32 compute
+    input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True)
+    if scale_factor is not None:
+        input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_()
+    out1 = op(input1)
+    grad_input1 = torch.randn_like(out1, device=device)
+    out1.backward(grad_input1)
+
+    # bfloat16 compute
+    op_bfp16 = op.bfloat16()
+    input2 = input1.detach().bfloat16().requires_grad_()
+    grad_input2 = grad_input1.bfloat16()
+    out2 = op_bfp16(input2)
+    out2.backward(grad_input2)
+
+    test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False)
+    test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False)
+
+def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False):
+    if not inference:
+        inp.requires_grad_(True)
+    out = module(inp)
+    if not inference:
+        gO = torch.rand_like(out)
+        out.backward(gO)
+    if check_size:
+        test_case.assertEqual(out.size(), inp.size())
+    if not inference:
+        for p in module.parameters():
+            if p.requires_grad:
+                test_case.assertEqual(p.grad, torch.zeros_like(p.grad))
+        test_case.assertEqual(inp.grad, torch.zeros_like(inp))
+
+
+def _create_basic_net():
+    class Layer(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
+            self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7))
+
+    class Net(nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.l1 = Layer()
+            self.dummy_param = nn.Parameter(torch.empty(3, 5))
+            self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1))
+
+    l = Layer()
+    n = Net()
+    s = nn.Sequential(n, n)
+
+    return l, n, s
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..780514e6743970e4ade78368029f5cd403adfca7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_optimizers.py
@@ -0,0 +1,2211 @@
+# mypy: ignore-errors
+
+import functools
+import itertools
+import sys
+import unittest
+from copy import deepcopy
+from enum import Enum
+from typing import Any, Union
+
+import torch
+from torch import Tensor
+from torch.nn import Parameter
+from torch.optim import (
+    Adadelta,
+    Adafactor,
+    Adagrad,
+    Adam,
+    Adamax,
+    AdamW,
+    ASGD,
+    LBFGS,
+    NAdam,
+    Optimizer,
+    RAdam,
+    RMSprop,
+    Rprop,
+    SGD,
+    SparseAdam,
+)
+from torch.optim.lr_scheduler import (
+    ConstantLR,
+    ExponentialLR,
+    LinearLR,
+    PolynomialLR,
+    ReduceLROnPlateau,
+    StepLR,
+)
+from torch.testing._internal.common_device_type import tol, toleranceOverride
+from torch.testing._internal.common_methods_invocations import DecorateInfo
+from torch.testing._internal.common_utils import (
+    _TestParametrizer,
+    skipIfMPS,
+    skipIfTorchDynamo,
+    skipIfXpu,
+    TEST_WITH_TORCHDYNAMO,
+)
+from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
+
+
+class OptimizerInput:
+    """Contains args / kwargs to be passed to an optimizer constructor."""
+
+    __slots__ = ["params", "kwargs", "desc"]
+
+    def __init__(
+        self,
+        params: Union[
+            list[Parameter], list[Tensor], dict[Any, Any], list[dict[str, Any]]
+        ],
+        kwargs: dict[str, Any],
+        desc: str = "",
+    ):
+        # params can be a list of Tensors OR param_groups OR None
+        self.params = params
+        self.kwargs = kwargs
+        self.desc = desc
+
+    def __repr__(self):
+        return f"params={self.params}, kwargs={self.kwargs}, desc={self.desc}"
+
+
+class OptimizerErrorEnum(Enum):
+    """Enumerates when an error is raised when testing optimizers."""
+
+    CONSTRUCTION_ERROR = 0
+    STEP_ERROR = 1
+
+
+class ErrorOptimizerInput:
+    """
+    An OptimizerInput that will cause the optimizer to throw an error when constructed.
+    Includes the type and string of the resulting error.
+    """
+
+    __slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"]
+
+    def __init__(
+        self,
+        optimizer_error_input,
+        *,
+        error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
+        error_type=RuntimeError,
+        error_regex="",
+    ):
+        self.optimizer_error_input = optimizer_error_input
+        self.error_on = error_on
+        self.error_type = error_type
+        self.error_regex = error_regex
+
+
+class OptimizerInfo:
+    """Optimizer information to be used in testing."""
+
+    def __init__(
+        self,
+        optim_cls: Optimizer,  # Class object for the Optimizer under test
+        *,
+        # Function to generate optimizer inputs EXCLUDING params. We delegate params responsibility
+        # to the test using the OptimizerInfo. OptimizerInput.params is likely None.
+        # Can optionally take in device to filter out certain unsupported configs
+        optim_inputs_func,
+        # Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the
+        # LRScheduler tests like test_forloop_goes_right_direction with_lrsched.
+        # We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every
+        # LRScheduler configuration will be included. See test_lrscheduler.py for that instead.
+        # A few optimizers like SGD and Adam will test more LRSchedulers.
+        scheduler_inputs=(
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        # A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
+        # supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
+        supported_impls: tuple[str, ...] = ("foreach", "differentiable"),
+        # A subset of all flags, signifying which ones were only supported after the
+        # original optimizer had already been released. aka impls where we need to check BC.
+        not_og_supported_flags: tuple[str, ...] = (
+            "foreach",
+            "differentiable",
+            "maximize",
+            "capturable",
+        ),
+        # the optim supports passing in sparse gradients as well as dense grads
+        supports_sparse: bool = False,
+        # the optimizer constructor supports passing in capturable as a kwarg
+        has_capturable_arg: bool = False,
+        # the optim only supports one config: sparse grads w/ dense params, see SparseAdam
+        only_supports_sparse_grads: bool = False,
+        # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
+        # with especially tuned hyperparameters. These only apply if the optimizer supports
+        # sparse parameters or grads.
+        metadata_for_sparse=({}, []),
+        # the optim supports complex parameters
+        supports_complex: bool = True,
+        # whether the optimizer.step() function requires a closure to be passed
+        step_requires_closure: bool = False,
+        # whether the optimizer supports per-param options with parameter groups
+        supports_param_groups: bool = True,
+        # whether the optimizer supports parameters on multiple devices
+        supports_multiple_devices: bool = True,
+        skips=(),  # Indicates which tests to skip
+        decorators=None,  # Additional decorators to apply to generated tests
+        optim_error_inputs_func=None,  # Function to generate optim inputs that error
+        supports_fused_on: tuple[str, ...] = (),
+    ):
+        self.optim_cls = optim_cls
+        self.optim_inputs_func = optim_inputs_func
+        self.scheduler_inputs = scheduler_inputs
+        self.supported_impls = supported_impls
+        self.not_og_supported_flags = not_og_supported_flags
+        self.supports_sparse = supports_sparse
+        self.has_capturable_arg = has_capturable_arg
+        self.metadata_for_sparse = metadata_for_sparse
+        self.only_supports_sparse_grads = only_supports_sparse_grads
+        self.supports_complex = supports_complex
+        self.step_requires_closure = step_requires_closure
+        self.supports_param_groups = supports_param_groups
+        self.supports_multiple_devices = supports_multiple_devices
+        self.decorators = (
+            *(decorators if decorators else []),
+            *(skips if skips else []),
+        )
+        self.optim_error_inputs_func = optim_error_inputs_func
+        self.supports_fused_on = supports_fused_on
+
+    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
+        result = []
+        for decorator in self.decorators:
+            if isinstance(decorator, DecorateInfo):
+                if decorator.is_active(
+                    test_class, test_name, device, dtype, param_kwargs
+                ):
+                    result.extend(decorator.decorators)
+            else:
+                result.append(decorator)
+        return result
+
+    @property
+    def name(self):
+        return self.optim_cls.__name__
+
+
+class optims(_TestParametrizer):
+    """Decorator for specifying a list of optimizers over which to run a test."""
+
+    def __init__(self, optim_info_iterable, dtypes=None):
+        self.optim_info_list = list(optim_info_iterable)
+
+        # optimizers aren't limited to be one dtype as parameters can have different dtypes
+        # We default to torch.float32, but dtypes should be specified through passed in
+        # parameters.
+        self.dtypes = dtypes if dtypes is not None else [torch.float32]
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if device_cls is None:
+            raise RuntimeError(
+                "The @optims decorator is only intended to be used in a device-specific "
+                "context; use it with instantiate_device_type_tests() instead of "
+                "instantiate_parametrized_tests()"
+            )
+
+        for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes):
+            # Construct the test name; device / dtype parts are handled outside.
+            # See [Note: device and dtype suffix placement]
+            test_name = optim_info.name
+
+            # Construct parameter kwargs to pass to the test.
+            param_kwargs = {"optim_info": optim_info, "dtype": dtype}
+
+            try:
+
+                @functools.wraps(test)
+                def test_wrapper(*args, **kwargs):
+                    return test(*args, **kwargs)
+
+                decorator_fn = functools.partial(
+                    optim_info.get_decorators,
+                    generic_cls.__name__,
+                    test.__name__,
+                    device_cls.device_type,
+                    dtype,
+                )
+
+                yield (test_wrapper, test_name, param_kwargs, decorator_fn)
+            except Exception as ex:
+                # Provides an error message for debugging before rethrowing the exception
+                print(
+                    f"Failed to instantiate {test_name} for module {optim_info.name}!"
+                )
+                raise ex
+
+
+# Helper function for generating error inputs for all optimizers, used below.
+def get_error_inputs_for_all_optims(device, dtype):
+    if _get_device_type(device) == "cpu":
+        sample_param = Parameter(torch.randn(1, device=device, dtype=dtype))
+        sample_param2 = Parameter(torch.randn(1, device=device, dtype=dtype))
+        return [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=sample_param,
+                    kwargs={},
+                    desc="invalid param type",
+                ),
+                error_type=TypeError,
+                error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_param, sample_param],
+                    kwargs={},
+                    desc="a param group cannot have duplicate parameters",
+                ),
+                error_type=UserWarning,
+                error_regex=".*a parameter group with duplicate parameters.*",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[{"params": sample_param}, {"params": sample_param}],
+                    kwargs={},
+                    desc="duplicate parameters should not occur across param groups either",
+                ),
+                error_type=ValueError,
+                error_regex="some parameters appear in more than one parameter group",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=torch.tensor([0.001, 0.001])),
+                    desc="Tensor lr must be 1-element",
+                ),
+                error_type=ValueError,
+                error_regex="Tensor lr must be 1-element",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[("weight", sample_param), sample_param2],
+                    kwargs={},
+                    desc="all optimizer params should be with/without names",
+                ),
+                error_type=ValueError,
+                error_regex="all optimizer params should be with/without names. Some param names are missing",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        {"params": [sample_param], "lr": 1e-2},
+                        {"params": [("weight", sample_param2)]},
+                    ],
+                    kwargs={},
+                    desc="all optimizer param groups should be with/without names.",
+                ),
+                error_type=ValueError,
+                error_regex="all optimizer param groups should be with/without names. "
+                "cannot add param group with names to the optimizer",
+            ),
+        ]
+    else:
+        return []
+
+
+# ------------------------------------------------------------------------------------------
+# NOTE: [optimizer kwarg categories]
+# We categorize optimizer kwargs as 3 types:
+#  1. optimizer-specific flags are like amsgrad or rho or beta, flags that are specific to
+#     algorithms and thus only show up for certain optimizers. There are many of these, so I
+#     do not bother gathering them all and listing them here. The converse to these would be
+#     global flags that every optimizer ideally _should_ support. We break global flags into
+#     2 further categories and list them all below.
+#  2. global-friendly = ["lr", "weight_decay", "maximize", "capturable"]
+#     global-friendly flags are global flags who play nicely with all other global flags,
+#     i.e., are mutually exclusive in function. This means that any pair of the following
+#     flags can be toggled at once (e.g., maximize and weight_decay). Furthermore, any of the
+#     following flags theoretically can be enabled with ANY other global flag, including the
+#     cliquey ones (e.g, capturable and foreach).
+#  3. global-cliquey = ["foreach", "fused", "differentiable"]
+#     global-cliquey flags are global flags that do NOT coexist with other cliquey flags,
+#     usually because they contradict each other in function. For example, one should not flip
+#     both foreach AND fused to True, because they are two differing performance optimizations
+#     in which you can only opt into one.
+#
+# The following optim_inputs_func_* sampling functions only return constructor combinations of
+# optimizer-specific and global-friendly flags. This is because we are confident they would mesh
+# well with additional kwargs. On the flip side of the same coin, we reserve setting the
+# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
+
+
+def optim_inputs_func_adadelta(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "capturable": True},
+            desc="capturable with weight decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+
+
+def optim_error_inputs_func_adadelta(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, rho=1.1),
+                    desc="rho should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid rho value: 1.1",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adafactor(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "lr": 0.01},
+            desc="nonzero weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"beta2_decay": -1.0},
+            desc="non-default beta2_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"d": 1.5},
+            desc="non-default clipping threshold d",
+        ),
+    ]
+
+
+def optim_error_inputs_func_adafactor(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
+        complex_param.grad = torch.rand_like(complex_param)
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(eps=(-1e-30, 1e-3)),
+                    desc="epsilon1 should be >= 0",
+                ),
+                error_type=ValueError,
+                error_regex="epsilon1 should be >= 0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(d=0.0),
+                    desc="invalid d",
+                ),
+                error_type=ValueError,
+                error_regex="Clipping threshold d should be >= 1",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(beta2_decay=0.8),
+                    desc="invalid beta2_decay",
+                ),
+                error_type=ValueError,
+                error_regex="beta2_decay should be <= 0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[complex_param],
+                    kwargs=dict(),
+                    desc="does not support complex parameters",
+                ),
+                error_type=RuntimeError,
+                error_regex="Adafactor does not support complex parameters",
+                error_on=OptimizerErrorEnum.STEP_ERROR,
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adagrad(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
+        OptimizerInput(
+            params=None,
+            kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1},
+            desc="initial_accumulator_value",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1},
+            desc="lr_decay",
+        ),  # TODO: Move out to testing in param_group?
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001)},
+            desc="Tensor lr",
+        ),
+    ]
+
+
+def optim_error_inputs_func_adagrad(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, lr_decay=-0.5),
+                    desc="lr_decay must be bigger than 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid lr_decay value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
+# with all implementation code paths...
+def optim_inputs_func_adam(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True},
+            desc="capturable, amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True},
+            desc="Tensor lr with capturable and amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "betas": (torch.tensor(0.9), torch.tensor(0.99)),
+                "amsgrad": True,
+                "capturable": True,
+            },
+            desc="Tensor lr, Tensor betas, with capturable and amsgrad",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "betas": (torch.tensor(0.9), torch.tensor(0.99)),
+                "amsgrad": False,
+                "capturable": True,
+            },
+            desc="Tensor lr, Tensor betas, with capturable",
+        ),
+    ]
+    mps_supported_configs = [
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.01)}, desc="Tensor lr"
+        ),
+    ]
+
+    total = (
+        [
+            OptimizerInput(params=None, kwargs={}, desc="default"),
+            OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+            OptimizerInput(
+                params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+            ),
+            OptimizerInput(
+                params=None,
+                kwargs={"weight_decay": 0.1, "maximize": True},
+                desc="maximize",
+            ),
+            OptimizerInput(
+                params=None,
+                kwargs={"weight_decay": 0.1, "amsgrad": True},
+                desc="amsgrad",
+            ),
+        ]
+        + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+        + (mps_supported_configs if _get_device_type(device) == "mps" else [])
+    )
+    if dtype in (torch.float16,):
+        for input in total:
+            """
+            Too small eps will make denom to be zero for low precision dtype
+            denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
+            For example,
+            >>> a
+            tensor([0.], dtype=torch.float16)
+            >>> a + 1e-8
+            tensor([0.], dtype=torch.float16)
+            """
+            input.kwargs["eps"] = 0.1
+    return total
+
+
+def optim_error_inputs_func_adam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-1),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -1",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=torch.tensor(0.001), foreach=True),
+                    desc="lr as Tensor doesn't work with foreach & not capturable",
+                ),
+                error_type=ValueError,
+                error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(0.9, torch.tensor(0.99))),
+                    desc="betas must be either both floats or both Tensors",
+                ),
+                error_type=ValueError,
+                error_regex="betas must be either both floats or both Tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(torch.tensor(0.9), 0.99)),
+                    desc="betas must be either both floats or both Tensors",
+                ),
+                error_type=ValueError,
+                error_regex="betas must be either both floats or both Tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(
+                        lr=1e-2,
+                        betas=(torch.tensor(0.9), torch.tensor(0.99)),
+                        foreach=True,
+                    ),
+                    desc=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
+                ),
+                error_type=ValueError,
+                error_regex=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
+            ),
+        ]
+    if _get_device_type(device) == "cuda":
+        sample_tensor = torch.empty((), device=device, dtype=dtype)
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_tensor],
+                    kwargs={"foreach": True, "fused": True},
+                    desc="`fused` and `foreach` cannot be `True` together",
+                ),
+                error_type=RuntimeError,
+                error_regex="`fused` and `foreach` cannot be `True` together",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[sample_tensor],
+                    kwargs={"fused": True, "differentiable": True},
+                    desc="`fused` does not support `differentiable`",
+                ),
+                error_type=RuntimeError,
+                error_regex="`fused` does not support `differentiable`",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adamax(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "maximize": True, "capturable": True},
+            desc="capturable, maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0, "maximize": True, "capturable": True},
+            desc="capturable, maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "maximize": False, "capturable": True},
+            desc="capturable, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.9,
+                "maximize": False,
+                "capturable": True,
+            },
+            desc="capturable, weight_decay, tensor LR",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, weight_decay",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+
+
+def optim_error_inputs_func_adamax(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
+                    desc="beta2 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 1: 1.0",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_adamw(device, dtype=None):
+    return optim_inputs_func_adam(device, dtype)
+
+
+def optim_error_inputs_func_adamw(device, dtype):
+    return optim_error_inputs_func_adam(device, dtype)
+
+
+def optim_inputs_func_asgd(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"maximize": True, "capturable": True},
+            desc="maximize, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "capturable": True},
+            desc="weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
+            desc="maximize, weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.1,
+                "maximize": True,
+                "capturable": True,
+            },
+            desc="maximize, weight_decay, capturable, tensor LR",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
+        OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
+        OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize, nonzero weight_decay",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+
+
+def optim_error_inputs_func_asgd(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-0.5),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_lbfgs(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
+        ),
+        OptimizerInput(
+            params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"line_search_fn": "strong_wolfe"},
+            desc="strong_wolfe",
+        ),
+    ]
+
+
+def optim_error_inputs_func_lbfgs(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    return error_inputs
+
+
+def optim_inputs_func_nadam(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
+            desc="weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.9,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+                "capturable": True,
+            },
+            desc="decoupled_weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "weight_decay": 0.9,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+                "capturable": True,
+            },
+            desc="decoupled_weight_decay, capturable",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum_decay": 6e-3},
+            desc="non-zero momentum_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+            },
+            desc="weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
+            desc="weight_decay, momentum_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+            },
+            desc="decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+
+
+def optim_error_inputs_func_nadam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum_decay=-0.2),
+                    desc="momentum_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum_decay value: -0.2",
+            ),
+        ]
+    return error_inputs
+
+
+# Weird story bro, NAdam and RAdam do not have maximize.
+def optim_inputs_func_radam(device=None, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "capturable": True,
+                "weight_decay": 0.1,
+            },
+            desc="capturable, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "capturable": True,
+                "weight_decay": 0.1,
+                "decoupled_weight_decay": True,
+            },
+            desc="capturable, weight_decay, decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "lr": torch.tensor(0.001),
+                "capturable": True,
+                "weight_decay": 0.1,
+                "decoupled_weight_decay": True,
+            },
+            desc="capturable, weight_decay, decoupled_weight_decay, tensor LR",
+        ),
+    ]
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
+        OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
+            desc="decoupled_weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+
+
+def optim_error_inputs_func_radam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, weight_decay=-1),
+                    desc="weight_decay should > 0",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid weight_decay value: -1",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_rmsprop(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
+            desc="capturable, maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "maximize": True,
+            },
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "centered": True},
+            desc="centered",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "maximize": True,
+                "weight_decay": 0.1,
+            },
+            desc="maximize, weight_decay",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
+            desc="momentum",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.1,
+                "centered": True,
+                "momentum": 0.1,
+                "maximize": True,
+            },
+            desc="maximize, centered, weight_decay, w/ momentum",
+        ),
+    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+
+
+def optim_error_inputs_func_rmsprop(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum=-1.0),
+                    desc="momentum should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum value: -1.0",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_rprop(device, dtype=None):
+    cuda_supported_configs = [
+        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"lr": torch.tensor(0.001), "capturable": True},
+            desc="Tensor lr with capturable",
+        ),
+    ]
+
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"etas": (0.5, 1.5)}, desc="non-default etas"
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"step_sizes": (2e-6, 100)},
+            desc="non-default step_sizes",
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+
+
+def optim_error_inputs_func_rprop(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
+                    desc="0 < eta1 < 1 < eta2",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid eta values: 1.0, 0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_sgd(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
+        ),
+        OptimizerInput(
+            params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
+        ),
+        OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.1, "maximize": True},
+            desc="maximize",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "dampening": 0.5},
+            desc="dampening",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "weight_decay": 0.1},
+            desc="weight_decay w/ momentum",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
+            desc="nesterov",
+        ),
+    ]
+
+
+def optim_error_inputs_func_sgd(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, momentum=-0.5),
+                    desc="momentum should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid momentum value: -0.5",
+            ),
+        ]
+    return error_inputs
+
+
+def optim_inputs_func_sparseadam(device, dtype=None):
+    return [
+        OptimizerInput(params=None, kwargs={}, desc="default"),
+        OptimizerInput(
+            params=None, kwargs={"lr": 0.01}, desc="non-default lr"
+        ),  # TODO: Move out to testing in param_group?
+        OptimizerInput(
+            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
+        ),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
+    ]
+
+
+def optim_error_inputs_func_sparseadam(device, dtype):
+    error_inputs = get_error_inputs_for_all_optims(device, dtype)
+
+    if _get_device_type(device) == "cpu":
+        error_inputs += [
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=None,
+                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
+                    desc="beta1 should be between 0 and 1",
+                ),
+                error_type=ValueError,
+                error_regex="Invalid beta parameter at index 0: 1.0",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        torch.zeros(
+                            3, layout=torch.sparse_coo, device=device, dtype=dtype
+                        )
+                    ],
+                    kwargs={},
+                    desc="dense params required",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam requires dense parameter tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[
+                        {
+                            "params": [
+                                torch.zeros(
+                                    3,
+                                    layout=torch.sparse_coo,
+                                    device=device,
+                                    dtype=dtype,
+                                )
+                            ]
+                        }
+                    ],
+                    kwargs={},
+                    desc="dense params required in param_groups",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam requires dense parameter tensors",
+            ),
+            ErrorOptimizerInput(
+                OptimizerInput(
+                    params=[torch.rand(2, 3, device=device, dtype=torch.complex64)],
+                    kwargs={},
+                    desc="complex not supported",
+                ),
+                error_type=ValueError,
+                error_regex="SparseAdam does not support complex parameters",
+            ),
+        ]
+    return error_inputs
+
+
+def _get_device_type(device: Union[str, torch.device]) -> str:
+    # Returns the device type as a string, e.g., "cpu" or "cuda"
+    if isinstance(device, torch.device):
+        device = str(device.type)
+    assert isinstance(device, str)
+    return device.split(":")[0]
+
+
+def _get_optim_inputs_including_global_cliquey_kwargs(
+    device, dtype, optim_info, skip=()
+) -> list[OptimizerInput]:
+    """
+    Return a list of all configs for a given optimizer as a list of OptimizerInputs,
+    including configs that have supported global cliquey kwargs (foreach, fused,
+    differentiable) based on optim_info.supported_impls.
+
+    The configs (optim_inputs) returned by optim_info.optim_inputs_func(...)
+    intentionally do NOT include global cliquey kwargs to give flexibility to tests.
+    For example, testing correctness between toggling foreach on and off is now
+    trivial. That said, we sometimes want to test for all possible configs on an
+    optimizer including all supported flags, so this helper returns all optim inputs.
+    """
+    assert all(
+        x in ["foreach", "fused", "differentiable"] for x in skip
+    ), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
+
+    optim_inputs = optim_info.optim_inputs_func(device)
+
+    supported_impls = tuple(
+        x
+        for x in optim_info.supported_impls
+        if x not in skip
+        and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused")
+        and (
+            _get_device_type(device) in _get_foreach_kernels_supported_devices()
+            or x != "foreach"
+        )
+    )
+
+    all_optim_inputs = []
+    for optim_input in optim_inputs:
+        # Add the base config where all the flags are False
+        base_kwargs = deepcopy(optim_input.kwargs)
+        if len(supported_impls) != 0:
+            for flag in supported_impls:
+                base_kwargs[flag] = False
+            all_optim_inputs.append(
+                OptimizerInput(params=None, kwargs=base_kwargs, desc=optim_input.desc)
+            )
+        else:
+            all_optim_inputs.append(optim_input)
+        # Add a config for when each of the global cliquey kwargs is True
+        # Note that in [optimizer kwarg categories], these kwargs are mutually
+        # exclusive, so we do not need to product them together.
+        for flag in supported_impls:
+            new_kwargs = deepcopy(base_kwargs)
+            new_kwargs[flag] = True
+            all_optim_inputs.append(
+                OptimizerInput(
+                    params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}"
+                )
+            )
+    return all_optim_inputs
+
+
+# Database of OptimizerInfo entries in alphabetical order.
+optim_db: list[OptimizerInfo] = [
+    OptimizerInfo(
+        Adadelta,
+        optim_inputs_func=optim_inputs_func_adadelta,
+        optim_error_inputs_func=optim_error_inputs_func_adadelta,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            # Note on tolerances:
+            # test_correctness_Adadelta_cuda_float32
+            # Mismatched elements: 10 / 100 (10.0%)
+            # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed)
+            # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed)
+            # This is due to floating point ordering error + usage of sqrt
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(
+                            rtol=5.5e-4,
+                            atol=5e-5,
+                        )
+                    }
+                ),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adafactor,
+        optim_inputs_func=optim_inputs_func_adafactor,
+        optim_error_inputs_func=optim_error_inputs_func_adafactor,
+        supported_impls=("foreach",),
+        not_og_supported_flags=("foreach",),
+        supports_complex=False,
+        skips=(
+            DecorateInfo(
+                unittest.skip("See #133268 regarding dtype being None"),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                device_type="cuda",
+                active_if=lambda kwargs: kwargs.get("use_closure", False),
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_foreach_large_tensor",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_foreach_matches_forloop",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_mixed_device_dtype",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_peak_memory_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_save_load_equality_with_weights_only",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028 regarding copy not supported"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_state_dict_deterministic",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                unittest.skip("See #133268 regarding dtype being None"),
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_save_load_equality_with_weights_only",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_state_dict_deterministic",
+                device_type="xpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("See #133268 regarding dtype being None"),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+                device_type="xpu",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adagrad,
+        optim_inputs_func=optim_inputs_func_adagrad,
+        optim_error_inputs_func=optim_error_inputs_func_adagrad,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu",),
+        supports_sparse=True,
+        metadata_for_sparse=(
+            {"lr": 0.1, "weight_decay": 0, "lr_decay": 0},
+            [
+                lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
+                lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
+            ],
+        ),
+        decorators=(
+            DecorateInfo(
+                #  Note on tolerances:
+                #  difference comes from the fact that the non fused kernel have
+                #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                #  to make sure there is no discrepancies between cuda fused kernel
+                #  and cpu fused kernel
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adam,
+        optim_inputs_func=optim_inputs_func_adam,
+        scheduler_inputs=(
+            [lambda opt: ExponentialLR(opt, gamma=0.9)],
+            [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
+            [
+                lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
+                lambda opt: ExponentialLR(opt, gamma=0.9),
+            ],
+            [
+                lambda opt: ExponentialLR(opt, gamma=0.9),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
+            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        optim_error_inputs_func=optim_error_inputs_func_adam,
+        supported_impls=("foreach", "differentiable", "fused"),
+        has_capturable_arg=True,
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu", "cuda", "mps"),
+        decorators=(
+            # Expected floating point error between fused and compiled forloop
+            DecorateInfo(
+                toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+                active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
+                and kwargs["dtype"] == torch.float64,
+            ),
+            DecorateInfo(
+                #  Note on tolerances:
+                #  difference comes from the fact that the non fused kernel have
+                #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                #  to make sure there is no discrepancies between cuda fused kernel
+                #  and cpu fused kernel
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+            DecorateInfo(
+                # Note on tolerances:
+                # Tracking through #127000
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=3e-5, rtol=1.3e-06),
+                    }
+                ),
+                "TestCudaOptims",
+                "test_grad_scaling_autocast_fused_optimizers",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Adamax,
+        optim_inputs_func=optim_inputs_func_adamax,
+        optim_error_inputs_func=optim_error_inputs_func_adamax,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                unittest.skip("Uses too much memory, even for H100, surprisingly."),
+                "TestOptimRenewed",
+                "test_foreach_large_tensor",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        AdamW,
+        optim_inputs_func=optim_inputs_func_adamw,
+        optim_error_inputs_func=optim_error_inputs_func_adamw,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_fused_on=("cpu", "cuda", "mps"),
+        has_capturable_arg=True,
+        decorators=(
+            # Expected error between compiled forloop and fused optimizers
+            DecorateInfo(
+                toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+                active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
+                and kwargs["dtype"] == torch.float64,
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    #  Note on tolerances:
+                    #  difference comes from the fact that the non fused kernel have
+                    #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
+                    #  to make sure there is no discrepancies between cuda fused kernel
+                    #  and cpu fused kernel
+                    {
+                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_fused_matches_forloop",
+            ),
+            # Note on tolerances:
+            # Tracking through #127000
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(
+                            atol=3e-5,
+                            rtol=1.3e-06,
+                        )
+                    }
+                ),
+                "TestCudaOptims",
+                "test_grad_scaling_autocast_fused_optimizers",
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        ASGD,
+        optim_inputs_func=optim_inputs_func_asgd,
+        optim_error_inputs_func=optim_error_inputs_func_asgd,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1.5e-5, rtol=1e-5),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+            DecorateInfo(
+                unittest.skip(
+                    "ASGD internally changes the weights even with zero grad"
+                ),
+                "TestOptimRenewed",
+                "test_step_is_noop_for_zero_grads",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        LBFGS,
+        optim_inputs_func=optim_inputs_func_lbfgs,
+        optim_error_inputs_func=optim_error_inputs_func_lbfgs,
+        supported_impls=(),
+        step_requires_closure=True,
+        supports_param_groups=False,
+        supports_multiple_devices=False,
+        skips=(
+            # Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094
+            DecorateInfo(
+                skipIfMPS,
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+                device_type="mps",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.complex64: tol(
+                            rtol=4.5e-5,
+                            atol=5e-5,
+                        )
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_groups_weight_decay",
+            ),
+            DecorateInfo(
+                unittest.skip("LBFGS doesn't support multidevice"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction_multigpu",
+            ),
+            DecorateInfo(
+                unittest.skip("Does not support param groups"),
+                "TestOptimRenewed",
+                "test_param_group_with_lrscheduler_goes_right_direction",
+            ),
+            # https://github.com/pytorch/pytorch/issues/131398
+            DecorateInfo(
+                unittest.expectedFailure,
+                "CompiledOptimizerParityTests",
+                "test_correctness",
+                active_if=lambda kwargs: sys.platform == "darwin"
+                and kwargs["use_closure"],
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        NAdam,
+        optim_inputs_func=optim_inputs_func_nadam,
+        optim_error_inputs_func=optim_error_inputs_func_nadam,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors, https://github.com/pytorch/pytorch/issues/117150"
+                ),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        RAdam,
+        optim_inputs_func=optim_inputs_func_radam,
+        optim_error_inputs_func=optim_error_inputs_func_radam,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        # previously atol=1e-7, rtol=1e-7
+                        torch.float64: tol(atol=1.5e-7, rtol=1.1e-7)
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_foreach_matches_forloop",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        RMSprop,
+        optim_inputs_func=optim_inputs_func_rmsprop,
+        optim_error_inputs_func=optim_error_inputs_func_rmsprop,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {  # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
+                        torch.float32: tol(atol=5e-04, rtol=0.01),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_mixed_device_dtype",
+                active_if=TEST_WITH_TORCHDYNAMO,
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        Rprop,
+        optim_inputs_func=optim_inputs_func_rprop,
+        optim_error_inputs_func=optim_error_inputs_func_rprop,
+        supported_impls=("foreach", "differentiable"),
+        has_capturable_arg=True,
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo("See #116028"),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        SGD,
+        optim_inputs_func=optim_inputs_func_sgd,
+        scheduler_inputs=(
+            [lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
+            [
+                lambda opt: LinearLR(
+                    opt, start_factor=0.4, end_factor=0.8, total_iters=4
+                )
+            ],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: LinearLR(
+                    opt, start_factor=0.4, end_factor=0.6, total_iters=4
+                ),
+            ],
+            [
+                lambda opt: StepLR(opt, gamma=0.99, step_size=10),
+                lambda opt: ExponentialLR(opt, gamma=0.99),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
+            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
+            [
+                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+                lambda opt: ReduceLROnPlateau(opt),
+            ],
+        ),
+        optim_error_inputs_func=optim_error_inputs_func_sgd,
+        supported_impls=("foreach", "differentiable", "fused"),
+        not_og_supported_flags=(
+            "foreach",
+            "differentiable",
+            "fused",
+            "maximize",
+            "capturable",
+        ),
+        supports_sparse=True,
+        metadata_for_sparse=(
+            {
+                "lr": 4.8e-3,
+                "maximize": False,
+                "momentum": 0,
+                "nesterov": False,
+                "weight_decay": 0,
+            },
+            [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
+        ),
+        supports_fused_on=(
+            "cpu",
+            "cuda",
+            "mps",
+        ),
+        skips=(
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
+                ),
+                "TestOptimRenewed",
+                "test_set_default_dtype_works_with_foreach",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
+                ),
+                "TestOptimRenewed",
+                "test_complex_2d",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {  # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
+                        torch.float32: tol(atol=5e-04, rtol=0.007),
+                    }
+                ),
+                "TestOptimRenewed",
+                "test_mixed_device_dtype",
+                active_if=TEST_WITH_TORCHDYNAMO,
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo(
+                    "This test uses mocks, which dynamo does not support"
+                ),
+                "TestOptimRenewed",
+                "test_defaults_changed_to_foreach",
+            ),
+        ),
+    ),
+    OptimizerInfo(
+        SparseAdam,
+        optim_inputs_func=optim_inputs_func_sparseadam,
+        optim_error_inputs_func=optim_error_inputs_func_sparseadam,
+        supported_impls=(),
+        only_supports_sparse_grads=True,
+        metadata_for_sparse=({"lr": 4e-2}, []),
+        supports_complex=False,  # Missing complex support, see #118153
+        skips=(
+            DecorateInfo(
+                skipIfMPS,  # SparseAdam does not support MPS
+                "TestOptimRenewed",
+                device_type="mps",
+            ),
+            DecorateInfo(
+                skipIfXpu(msg="SparseAdam is not yet supported on the XPU stack"),
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_param_groups_lr",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_tensor_lr",
+            ),
+            DecorateInfo(
+                unittest.skip(
+                    "SparseAdam does not support dense gradients, see #116507"
+                ),
+                "TestOptimRenewed",
+                "test_can_load_older_state_dict",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_load_nontensor_step",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_forloop_goes_right_direction_multigpu",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_param_group_with_lrscheduler_goes_right_direction",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_state_dict_with_cuda_params",
+            ),
+            DecorateInfo(
+                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
+                "TestOptimRenewed",
+                "test_deepcopy_copies_all_public_attrs",
+            ),
+        ),
+    ),
+]
+
+
+class TensorTracker:
+    """
+    A utility to track tensor clones in a list, with the expectation of popping them later (in
+    order) to make fair comparisons between two multi-step computation. The intended use case is
+    usually when comparing two supposed equal computations, such as an optimizer step that each
+    individually consists of multiple steps, where numerical deviation could multiply.
+
+    The goal is to be able to compare and align numbers at every milestone so as to minimize
+    numerical discrepancies, and so when the test fails, it is likely a real problem.
+    """
+
+    def __init__(self, assert_eq_kwargs=None):
+        if assert_eq_kwargs is None:
+            assert_eq_kwargs = {}
+        self.assert_eq_kwargs = assert_eq_kwargs
+        self.tensors = []
+
+    def add(self, tensor):
+        """
+        Add a detach().clone()'d version of the tensor
+        """
+        self.tensors.append(tensor.detach().clone())
+
+    # pops from beginning, like a queue and not a stack!
+    def pop_check_set(self, tensor_to_set, testcase):
+        """
+        Pop the first element in the tensor tracker, assert equality between the popped tensor and
+        the input tensor, and then set the input tensor to have the same values as the popped tensor
+        (with copy_).
+        """
+        testcase.assertGreater(len(self.tensors), 0, "no tensors to pop")
+        ref = self.tensors.pop(0)
+
+        testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}")
+        testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs)
+
+        with torch.no_grad():
+            tensor_to_set.copy_(ref)
+
+    def all_popped(self):
+        return len(self.tensors) == 0
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py
new file mode 100644
index 0000000000000000000000000000000000000000..13cd86e05bd6f7b4e9515cf102cc1e6d3b49781d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_pruning.py
@@ -0,0 +1,385 @@
+# Owner(s): ["module: unknown"]
+
+from typing import Any
+from torch.ao.pruning import BaseSparsifier
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class ImplementedSparsifier(BaseSparsifier):
+    def __init__(self, **kwargs: dict[str, Any]) -> None:
+        super().__init__(defaults=kwargs)
+
+    def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: dict[str, Any]) -> None:
+        module.parametrizations.weight[0].mask[0] = 0  # type: ignore[index, union-attr]
+        linear_state = self.state['linear1.weight']
+        linear_state['step_count'] = linear_state.get('step_count', 0) + 1
+
+
+class MockSparseLinear(nn.Linear):
+    """
+    This class is a MockSparseLinear class to check convert functionality.
+    It is the same as a normal Linear layer, except with a different type, as
+    well as an additional from_dense method.
+    """
+    @classmethod
+    def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear':
+        """
+        """
+        linear = cls(mod.in_features,
+                     mod.out_features)
+        return linear
+
+
+def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool:
+    """
+    Checks to see if all rows in subset tensor are present in the superset tensor
+    """
+    i = 0
+    for row in subset_tensor:
+        while i < len(superset_tensor):
+            if not torch.equal(row, superset_tensor[i]):
+                i += 1
+            else:
+                break
+        else:
+            return False
+    return True
+
+
+class SimpleLinear(nn.Module):
+    r"""Model with only Linear layers without biases, some wrapped in a Sequential,
+    some following the Sequential. Used to test basic pruned Linear-Linear fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Linear(7, 5, bias=False),
+            nn.Linear(5, 6, bias=False),
+            nn.Linear(6, 4, bias=False),
+        )
+        self.linear1 = nn.Linear(4, 4, bias=False)
+        self.linear2 = nn.Linear(4, 10, bias=False)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.linear1(x)
+        x = self.linear2(x)
+        return x
+
+
+class LinearBias(nn.Module):
+    r"""Model with only Linear layers, alternating layers with biases,
+    wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Linear(7, 5, bias=True),
+            nn.Linear(5, 6, bias=False),
+            nn.Linear(6, 3, bias=True),
+            nn.Linear(3, 3, bias=True),
+            nn.Linear(3, 10, bias=False),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        return x
+
+
+class LinearActivation(nn.Module):
+    r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
+    Activation functions modules in between each Linear in the Sequential, and each outside layer.
+    Used to test pruned Linear(Bias)-Activation-Linear fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Linear(7, 5, bias=True),
+            nn.ReLU(),
+            nn.Linear(5, 6, bias=False),
+            nn.Tanh(),
+            nn.Linear(6, 4, bias=True),
+        )
+        self.linear1 = nn.Linear(4, 3, bias=True)
+        self.act1 = nn.ReLU()
+        self.linear2 = nn.Linear(3, 10, bias=False)
+        self.act2 = nn.Tanh()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.linear1(x)
+        x = self.act1(x)
+        x = self.linear2(x)
+        x = self.act2(x)
+        return x
+
+
+class LinearActivationFunctional(nn.Module):
+    r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
+    Activation functions modules in between each Linear in the Sequential, and functional
+    activationals are called in between each outside layer.
+    Used to test pruned Linear(Bias)-Activation-Linear fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Linear(7, 5, bias=True),
+            nn.ReLU(),
+            nn.Linear(5, 6, bias=False),
+            nn.ReLU(),
+            nn.Linear(6, 4, bias=True),
+        )
+        self.linear1 = nn.Linear(4, 3, bias=True)
+        self.linear2 = nn.Linear(3, 8, bias=False)
+        self.linear3 = nn.Linear(8, 10, bias=False)
+        self.act1 = nn.ReLU()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.linear1(x)
+        x = F.relu(x)
+        x = self.linear2(x)
+        x = F.relu(x)
+        x = self.linear3(x)
+        x = F.relu(x)
+        return x
+
+
+class SimpleConv2d(nn.Module):
+    r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following.
+    Used to test pruned Conv2d-Conv2d fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Conv2d(1, 32, 3, 1, bias=False),
+            nn.Conv2d(32, 64, 3, 1, bias=False),
+        )
+        self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
+        self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.conv2d1(x)
+        x = self.conv2d2(x)
+        return x
+
+
+class Conv2dBias(nn.Module):
+    r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside.
+    Used to test pruned Conv2d-Bias-Conv2d fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Conv2d(1, 32, 3, 1, bias=True),
+            nn.Conv2d(32, 32, 3, 1, bias=True),
+            nn.Conv2d(32, 64, 3, 1, bias=False),
+        )
+        self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True)
+        self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.conv2d1(x)
+        x = self.conv2d2(x)
+        return x
+
+
+class Conv2dActivation(nn.Module):
+    r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following.
+    Activation function modules in between each Sequential layer, functional activations called
+    in-between each outside layer.
+    Used to test pruned Conv2d-Bias-Activation-Conv2d fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Conv2d(1, 32, 3, 1, bias=True),
+            nn.ReLU(),
+            nn.Conv2d(32, 64, 3, 1, bias=True),
+            nn.Tanh(),
+            nn.Conv2d(64, 64, 3, 1, bias=False),
+            nn.ReLU(),
+        )
+        self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
+        self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.conv2d1(x)
+        x = F.relu(x)
+        x = self.conv2d2(x)
+        x = F.hardtanh(x)
+        return x
+
+
+class Conv2dPadBias(nn.Module):
+    r"""Model with only Conv2d layers, all with bias and some with padding > 0,
+    some in a Sequential and some following. Activation function modules in between each layer.
+    Used to test that bias is propagated correctly in the special case of
+    pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Conv2d(1, 32, 3, 1, padding=1, bias=True),
+            nn.ReLU(),
+            nn.Conv2d(32, 32, 3, 1, bias=False),
+            nn.ReLU(),
+            nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
+            nn.ReLU(),
+            nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
+            nn.ReLU(),
+            nn.Conv2d(32, 64, 3, 1, bias=True),
+            nn.Tanh(),
+        )
+        self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True)
+        self.act1 = nn.ReLU()
+        self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True)
+        self.act2 = nn.Tanh()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.conv2d1(x)
+        x = self.act1(x)
+        x = self.conv2d2(x)
+        x = self.act2(x)
+        return x
+
+
+class Conv2dPool(nn.Module):
+    r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following.
+    Activation function modules in between each layer, Pool2d modules in between each layer.
+    Used to test pruned Conv2d-Pool2d-Conv2d fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True),
+            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
+            nn.Tanh(),
+            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
+        )
+        self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
+        self.af1 = nn.ReLU()
+        self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True)
+        self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.conv2d1(x)
+        x = self.maxpool(x)
+        x = self.af1(x)
+        x = self.conv2d2(x)
+        x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1)
+        x = F.relu(x)
+        x = self.conv2d3(x)
+        return x
+
+
+class Conv2dPoolFlattenFunctional(nn.Module):
+    r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
+    and a functional Flatten followed by a Linear layer.
+    Activation functions and Pool2ds in between each layer also.
+    Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
+            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
+            nn.Tanh(),
+            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
+        )
+        self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
+        self.af1 = nn.ReLU()
+        self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
+        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(11, 13, bias=True)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.conv2d1(x)
+        x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
+        x = self.af1(x)
+        x = self.conv2d2(x)
+        x = self.avg_pool(x)
+        x = torch.flatten(x, 1)  # test functional flatten
+        x = self.fc(x)
+        return x
+
+
+class Conv2dPoolFlatten(nn.Module):
+    r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
+    and a Flatten module followed by a Linear layer.
+    Activation functions and Pool2ds in between each layer also.
+    Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.seq = nn.Sequential(
+            nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
+            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
+            nn.Tanh(),
+            nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
+        )
+        self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
+        self.af1 = nn.ReLU()
+        self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
+        self.avg_pool = nn.AdaptiveAvgPool2d((2, 2))
+        self.flatten = nn.Flatten()
+        self.fc = nn.Linear(44, 13, bias=True)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.seq(x)
+        x = self.conv2d1(x)
+        x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
+        x = self.af1(x)
+        x = self.conv2d2(x)
+        x = self.avg_pool(x)
+        x = self.flatten(x)
+        x = self.fc(x)
+        return x
+
+
+class LSTMLinearModel(nn.Module):
+    """Container module with an encoder, a recurrent module, and a linear."""
+
+    def __init__(
+        self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
+    ) -> None:
+        super().__init__()
+        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
+        self.linear = nn.Linear(hidden_dim, output_dim)
+
+    def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        output, _hidden = self.lstm(input)
+        decoded = self.linear(output)
+        return decoded, output
+
+
+class LSTMLayerNormLinearModel(nn.Module):
+    """Container module with an LSTM, a LayerNorm, and a linear."""
+
+    def __init__(
+        self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
+    ) -> None:
+        super().__init__()
+        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
+        self.norm = nn.LayerNorm(hidden_dim)
+        self.linear = nn.Linear(hidden_dim, output_dim)
+
+    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        x, state = self.lstm(x)
+        x = self.norm(x)
+        x = self.linear(x)
+        return x, state
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..211b282c4fc4a81214c894afd00cf050a92e4ebe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantization.py
@@ -0,0 +1,3411 @@
+# mypy: ignore-errors
+
+r"""Importing this file includes common utility methods and base classes for
+checking quantization api and properties of resulting modules.
+"""
+
+import torch
+import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
+import torch.ao.nn.quantized as nnq
+import torch.ao.nn.quantized.dynamic as nnqd
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from functorch.experimental import control_flow
+from torch.ao.nn.intrinsic import _FusedModule
+from torch.ao.quantization import (
+    convert,
+    default_dynamic_qat_qconfig,
+    default_dynamic_qconfig,
+    default_dynamic_quant_observer,
+    default_embedding_qat_qconfig,
+    default_observer,
+    default_per_channel_qconfig,
+    default_qconfig,
+    default_symmetric_qnnpack_qat_qconfig,
+    default_weight_observer,
+    DeQuantStub,
+    float_qparams_weight_only_qconfig,
+    get_default_qat_qconfig,
+    get_default_qat_qconfig_mapping,
+    get_default_qconfig,
+    get_default_qconfig_mapping,
+    PerChannelMinMaxObserver,
+    propagate_qconfig_,
+    QConfig,
+    QConfigMapping,
+    quantize,
+    quantize_dynamic_jit,
+    quantize_jit,
+    QuantStub,
+    QuantType,
+    QuantWrapper,
+)
+from torch.ao.quantization.backend_config import get_executorch_backend_config
+from torch.ao.quantization.quantization_mappings import (
+    get_default_dynamic_quant_module_mappings,
+    get_default_qat_module_mappings,
+    get_default_qconfig_propagation_list,
+)
+from torch.ao.quantization.quantize_pt2e import (
+    _convert_to_reference_decomposed_fx,
+    convert_pt2e,
+    prepare_pt2e,
+    prepare_qat_pt2e,
+)
+from torch.ao.quantization.quantizer.xnnpack_quantizer import (
+    get_symmetric_quantization_config,
+    XNNPACKQuantizer,
+)
+
+from torch.export import export_for_training
+from torch.jit.mobile import _load_for_lite_interpreter
+from torch.testing._internal.common_quantized import override_quantized_engine
+from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase
+
+try:
+    from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
+
+    # graph mode quantization based on fx
+    from torch.ao.quantization.quantize_fx import (
+        convert_fx,
+        convert_to_reference_fx,
+        prepare_fx,
+        prepare_qat_fx,
+    )
+    from torch.fx import GraphModule
+    from torch.fx.graph import Node
+
+    HAS_FX = True
+except ImportError:
+    HAS_FX = False
+
+import contextlib
+import copy
+import functools
+import io
+import os
+
+import unittest
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch._dynamo as torchdynamo
+import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
+import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq
+from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
+from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer
+from torch.testing import FileCheck
+
+
+class NodeSpec:
+    """Used for checking GraphModule Node"""
+
+    def __init__(self, op, target):
+        """
+        op: call_function | call_module
+        target:
+          for call_function, target would be a function
+          for call_module, target would be the type of PyTorch module
+        """
+        self.op = op
+        self.target = target
+
+    @classmethod
+    def call_function(cls, target):
+        return NodeSpec("call_function", target)
+
+    @classmethod
+    def call_method(cls, target):
+        return NodeSpec("call_method", target)
+
+    @classmethod
+    def call_module(cls, target):
+        return NodeSpec("call_module", target)
+
+    def __hash__(self):
+        return hash((self.op, self.target))
+
+    def __eq__(self, other):
+        if not isinstance(other, NodeSpec):
+            return NotImplemented
+
+        return self.op == other.op and self.target == other.target
+
+    def __repr__(self):
+        return repr(self.op) + " " + repr(self.target)
+
+
+def get_supported_device_types():
+    return (
+        ["cpu", "cuda"] if torch.cuda.is_available() and not TEST_WITH_ROCM else ["cpu"]
+    )
+
+
+def test_only_eval_fn(model, calib_data):
+    r"""
+    Default evaluation function takes a torch.utils.data.Dataset or a list of
+    input Tensors and run the model on the dataset
+    """
+    for inp in calib_data:
+        model(*inp)
+
+
+_default_loss_fn = torch.nn.CrossEntropyLoss()
+
+
+def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn):
+    r"""
+    Default train function takes a torch.utils.data.Dataset and train the model
+    on the dataset
+    """
+    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+    train_loss, correct, total = 0, 0, 0
+    for _ in range(10):
+        model.train()
+
+        for data, target in train_data:
+            optimizer.zero_grad()
+            output = model(data)
+            loss = loss_fn(output, target)
+            loss.backward()
+            optimizer.step()
+            train_loss += loss.item()
+            _, predicted = torch.max(output, 1)
+            total += target.size(0)
+            correct += (predicted == target).sum().item()
+    return train_loss, correct, total
+
+
+class AverageMeter:
+    """Computes and stores the average and current value"""
+
+    def __init__(self, name, fmt=":f"):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
+        return fmtstr.format(**self.__dict__)
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the accuracy over the k top predictions for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+
+def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
+    model.train()
+    for cnt, (image, target) in enumerate(data_loader, start=1):
+        print(".", end="")
+        image, target = image.to(device), target.to(device)
+        output = model(image)
+        loss = criterion(output, target)
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+        accuracy(output, target, topk=(1, 5))
+        if cnt >= ntrain_batches:
+            return
+    return
+
+
+def ddp_setup(rank, world_size):
+    os.environ["MASTER_ADDR"] = "localhost"
+    os.environ["MASTER_PORT"] = "12355"
+
+    # initialize the process group
+    dist.init_process_group("gloo", rank=rank, world_size=world_size)
+
+
+def ddp_cleanup():
+    dist.destroy_process_group()
+
+
+def run_ddp(rank, world_size, prepared):
+    ddp_setup(rank, world_size)
+    prepared.cuda()
+    prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank])
+    prepared.to(rank)
+    model_with_ddp = prepared
+    optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001)
+    train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1)  # noqa: F821
+    ddp_cleanup()
+
+
+def convert_dynamic(module):
+    convert(module, get_default_dynamic_quant_module_mappings(), inplace=True)
+
+
+def prepare_dynamic(model, qconfig_dict=None):
+    propagate_qconfig_(model, qconfig_dict)
+
+
+def _make_conv_test_input(
+    batch_size,
+    in_channels_per_group,
+    input_feature_map_size,
+    out_channels_per_group,
+    groups,
+    kernel_size,
+    X_scale,
+    X_zero_point,
+    W_scale,
+    W_zero_point,
+    use_bias,
+    use_channelwise,
+):
+    in_channels = in_channels_per_group * groups
+    out_channels = out_channels_per_group * groups
+
+    (X_value_min, X_value_max) = (0, 4)
+    X_init = torch.randint(
+        X_value_min,
+        X_value_max,
+        (
+            batch_size,
+            in_channels,
+        )
+        + input_feature_map_size,
+    )
+    X = X_scale * (X_init - X_zero_point).float()
+    X_q = torch.quantize_per_tensor(
+        X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8
+    )
+
+    W_scale = W_scale * out_channels
+    W_zero_point = W_zero_point * out_channels
+    # Resize W_scale and W_zero_points arrays equal to out_channels
+    W_scale = W_scale[:out_channels]
+    W_zero_point = W_zero_point[:out_channels]
+    # For testing, we use small values for weights and for activations so that
+    # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in
+    # qconv implementation and if there is no overflow.
+    # In reference we can't exactly match the results with reference.
+    # Please see the comment in qconv implementation file
+    #   aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
+    (W_value_min, W_value_max) = (-5, 5)
+    # The operator expects them in the format
+    # (out_channels, in_channels/groups,) + kernel_size
+    W_init = torch.randint(
+        W_value_min,
+        W_value_max,
+        (
+            out_channels,
+            in_channels_per_group,
+        )
+        + kernel_size,
+    )
+    b_init = torch.randint(0, 10, (out_channels,))
+
+    if use_channelwise:
+        W_shape = (-1, 1) + (1,) * len(kernel_size)
+        W_scales_tensor = torch.tensor(W_scale, dtype=torch.float)
+        W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float)
+        W = (
+            W_scales_tensor.reshape(*W_shape)
+            * (W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
+        )
+        b = X_scale * W_scales_tensor * b_init.float()
+        W_q = torch.quantize_per_channel(
+            W,
+            W_scales_tensor.double(),
+            W_zero_points_tensor.long(),
+            0,
+            dtype=torch.qint8,
+        )
+    else:
+        W = W_scale[0] * (W_init - W_zero_point[0]).float()
+        b = X_scale * W_scale[0] * b_init.float()
+        W_q = torch.quantize_per_tensor(
+            W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8
+        )
+
+    return (X, X_q, W, W_q, b if use_bias else None)
+
+
+def _make_conv_add_extra_input_tensor(scale, zero_point, sizes):
+    (X_value_min, X_value_max) = (0, 4)
+    X_init = torch.randint(
+        X_value_min,
+        X_value_max,
+        sizes,  # Infer the size of tensor to do the add
+    )
+    X = scale * (X_init - zero_point).float()
+    X_q = torch.quantize_per_tensor(
+        X, scale=scale, zero_point=zero_point, dtype=torch.quint8
+    )
+    return X, X_q
+
+
+def skipIfNoFBGEMM(fn):
+    reason = "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer."
+    if isinstance(fn, type):
+        if "fbgemm" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "fbgemm" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoQNNPACK(fn):
+    reason = "Quantized operations require QNNPACK."
+    if isinstance(fn, type):
+        if "qnnpack" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "qnnpack" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def withQNNPACKBackend(fn):
+    # TODO(future PR): consider combining with skipIfNoQNNPACK,
+    # will require testing of existing callsites
+    reason = "Quantized operations require QNNPACK."
+    if isinstance(fn, type):
+        if "qnnpack" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "qnnpack" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        with override_quantized_engine("qnnpack"):
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoONEDNN(fn):
+    reason = "Quantized operations require ONEDNN."
+    if isinstance(fn, type):
+        if "onednn" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "onednn" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoONEDNNBF16(fn):
+    reason = "Quantized operations require BF16 support."
+    if isinstance(fn, type):
+        if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoX86(fn):
+    reason = "Quantized operations require X86."
+    if isinstance(fn, type):
+        if "x86" not in torch.backends.quantized.supported_engines:
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if "x86" not in torch.backends.quantized.supported_engines:
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoDynamoSupport(fn):
+    reason = "dynamo doesn't support."
+    if isinstance(fn, type):
+        if not torchdynamo.is_dynamo_supported():
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torchdynamo.is_dynamo_supported():
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+def skipIfNoInductorSupport(fn):
+    reason = "inductor doesn't support."
+    if isinstance(fn, type):
+        if not torchdynamo.is_inductor_supported():
+            fn.__unittest_skip__ = True
+            fn.__unittest_skip_why__ = reason
+        return fn
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torchdynamo.is_inductor_supported():
+            raise unittest.SkipTest(reason)
+        else:
+            fn(*args, **kwargs)
+
+    return wrapper
+
+
+try:
+    import torchvision  # noqa: F401
+
+    HAS_TORCHVISION = True
+except ImportError:
+    HAS_TORCHVISION = False
+skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
+
+
+def get_script_module(model, tracing, data):
+    return torch.jit.trace(model, data) if tracing else torch.jit.script(model)
+
+
+def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True):
+    """
+    Convert lengths to offsets for embedding_bag
+    """
+    tt = np.zeros((t.shape[0] + 1,), dtype=offset_type)
+    tt[1:] = t
+    tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type))
+    if use_begin_offset:
+        return tt[:-1]
+    return tt[1:]
+
+
+def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
+    assert w.dim() == 2
+    w = w.transpose(0, 1).contiguous()
+    assert q_group_size > 1
+    assert w.shape[-1] % q_group_size == 0
+
+    to_quant = w.reshape(-1, q_group_size)
+    assert torch.isnan(to_quant).sum() == 0
+
+    max_val = to_quant.amax(dim=1, keepdim=True)
+    min_val = to_quant.amin(dim=1, keepdim=True)
+    max_int = 2**n_bit - 1
+    min_int = 0
+    scales = (max_val - min_val).clamp(min=1e-6) / max_int
+    assert torch.isnan(scales).sum() == 0
+
+    zeros = min_val + scales * (2 ** (n_bit - 1))
+    assert torch.isnan(zeros).sum() == 0
+
+    out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
+    assert torch.isnan(out).sum() == 0
+
+    out = out.to(dtype=torch.int32).reshape(w.shape)
+    if out.device != torch.device("cpu"):
+        out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8)
+
+    # Scales and zeros for the same q-group should be contiguous, so we can
+    # load as a 32-bit word
+    scales = scales.view(w.shape[0], -1)
+    zeros = zeros.view(w.shape[0], -1)
+    scales_and_zeros = (
+        torch.cat(
+            [
+                scales.reshape(scales.size(0), scales.size(1), 1),
+                zeros.reshape(zeros.size(0), zeros.size(1), 1),
+            ],
+            2,
+        )
+        .transpose(0, 1)
+        .contiguous()
+    )
+
+    return out, scales_and_zeros
+
+
+def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32):
+    # W is of shape [K x N]
+    # We transpose W as Quantization is applied on [N x K]
+    w = w.transpose(0, 1).contiguous()
+    assert w.dim() == 2
+    assert groupsize > 1
+    assert w.shape[-1] % groupsize == 0
+    # Calculate scale and zeros
+    to_quant = w.reshape(-1, groupsize)
+    max_val = to_quant.abs().amax(dim=1, keepdim=True)
+    eps = torch.finfo(max_val.dtype).eps
+    max_int = 2 ** (n_bit - 1) - 1  # For 4-bit, this is 7
+    scales = max_val.clamp(min=eps) / max_int
+    zeros = torch.zeros_like(scales)
+
+    # Quantize the weight
+    scales = scales.to(torch.float32).reshape(w.shape[0], -1)
+    zeros = zeros.to(torch.float32).reshape(w.shape[0], -1)
+    scales = scales.reshape(-1, 1)
+    zeros = zeros.reshape(-1, 1)
+    max_int = 2**n_bit - 1
+    w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int)
+    # We pack 2 signed int4 values in unsigned uint8 container.
+    # This reduces the weight size by half and improves load perf
+    out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8)
+
+    scales_and_zeros = scales.squeeze().contiguous()
+
+    return out_uint8, scales_and_zeros
+
+
+def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
+    # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
+    # default setup for affine quantization of activations
+    x_dtype = x.dtype
+    x = x.float()
+    eps = torch.finfo(torch.float32).eps
+
+    # get min and max
+    min_val, max_val = torch.aminmax(x, dim=1)
+
+    # calculate scales and zero_points based on min and max
+    # reference: https://fburl.com/code/srbiybme
+    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
+    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
+    device = min_val_neg.device
+
+    # reference: https://fburl.com/code/4wll53rk
+    max_val_pos = torch.max(-min_val_neg, max_val_pos)
+    scales = max_val_pos / (float(quant_max - quant_min) / 2)
+    # ensure scales is the same dtype as the original tensor
+    scales = torch.clamp(scales, min=eps).to(x.dtype)
+    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
+
+    # quantize based on qmin/qmax/scales/zp
+    x_div = x / scales.unsqueeze(-1)
+    x_round = torch.round(x_div)
+    x_zp = x_round + zero_points.unsqueeze(-1)
+    quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
+
+    return quant, scales.to(x_dtype), zero_points
+
+
+# QuantizationTestCase used as a base class for testing quantization on modules
+class QuantizationTestCase(TestCase):
+    def setUp(self):
+        super().setUp()
+        self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)]
+        self.train_data = [
+            [
+                torch.rand(2, 5, dtype=torch.float),
+                torch.randint(0, 1, (2,), dtype=torch.long),
+            ]
+            for _ in range(2)
+        ]
+        self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] for _ in range(2)]
+        self.img_data_2d = [
+            [torch.rand(1, 3, 10, 10, dtype=torch.float)] for _ in range(2)
+        ]
+        self.img_data_3d = [
+            [torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] for _ in range(2)
+        ]
+        self.img_data_1d_train = [
+            [
+                torch.rand(2, 3, 10, dtype=torch.float),
+                torch.randint(0, 1, (1,), dtype=torch.long),
+            ]
+            for _ in range(2)
+        ]
+        self.img_data_2d_train = [
+            [
+                torch.rand(1, 3, 10, 10, dtype=torch.float),
+                torch.randint(0, 1, (1,), dtype=torch.long),
+            ]
+            for _ in range(2)
+        ]
+        self.img_data_3d_train = [
+            [
+                torch.rand(1, 3, 5, 5, 5, dtype=torch.float),
+                torch.randint(0, 1, (1,), dtype=torch.long),
+            ]
+            for _ in range(2)
+        ]
+
+        self.img_data_dict = {
+            1: self.img_data_1d,
+            2: self.img_data_2d,
+            3: self.img_data_3d,
+        }
+
+        # Quant types that produce statically quantized ops
+        self.static_quant_types = [QuantType.STATIC, QuantType.QAT]
+        # All quant types for (fx based) graph mode quantization
+        self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT]
+
+    def checkNoPrepModules(self, module):
+        r"""Checks the module does not contain child
+        modules for quantization preparation, e.g.
+        quant, dequant and observer
+        """
+        self.assertFalse(hasattr(module, "quant"))
+        self.assertFalse(hasattr(module, "dequant"))
+
+    def checkNoQconfig(self, module):
+        r"""Checks the module does not contain qconfig"""
+        self.assertFalse(hasattr(module, "qconfig"))
+
+        for child in module.children():
+            self.checkNoQconfig(child)
+
+    def checkHasPrepModules(self, module):
+        r"""Checks the module contains child
+        modules for quantization preparation, e.g.
+        quant, dequant and observer
+        """
+        self.assertTrue(hasattr(module, "module"))
+        self.assertTrue(hasattr(module, "quant"))
+        self.assertTrue(hasattr(module, "dequant"))
+
+    def checkObservers(
+        self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None
+    ):
+        r"""Checks the module or module's leaf descendants
+        have observers in preparation for quantization
+        """
+        if propagate_qconfig_list is None:
+            propagate_qconfig_list = get_default_qconfig_propagation_list()
+        if prepare_custom_config_dict is None:
+            prepare_custom_config_dict = {}
+        float_to_observed_module_class_mapping = prepare_custom_config_dict.get(
+            "float_to_observed_custom_module_class", {}
+        )
+
+        # check if a module is a leaf module, ignoring activation_post_process attribute
+        def is_leaf_module(module):
+            submodule_name_count = 0
+            for name, _ in module.named_children():
+                if name != "activation_post_process":
+                    submodule_name_count += 1
+            return submodule_name_count == 0
+
+        if (
+            hasattr(module, "qconfig")
+            and module.qconfig is not None
+            and (
+                (
+                    is_leaf_module(module)
+                    and not isinstance(module, torch.nn.Sequential)
+                    and type(module) in propagate_qconfig_list
+                )
+                or type(module) in float_to_observed_module_class_mapping.keys()
+            )
+            and not isinstance(module, torch.ao.quantization.DeQuantStub)
+        ):
+            self.assertTrue(
+                hasattr(module, "activation_post_process"),
+                "module: " + str(type(module)) + " do not have observer",
+            )
+        # we don't need to check observers for child modules of the
+        # qat modules
+        if (
+            type(module) not in get_default_qat_module_mappings().values()
+            and type(module) not in float_to_observed_module_class_mapping.values()
+            and not isinstance(module, _FusedModule)
+        ):
+            for child in module.children():
+                if type(child) in [nn.Dropout]:
+                    continue
+                self.checkObservers(
+                    child, propagate_qconfig_list, prepare_custom_config_dict
+                )
+
+    def checkQuantDequant(self, mod):
+        r"""Checks that mod has nn.Quantize and
+        nn.DeQuantize submodules inserted
+        """
+        self.assertEqual(type(mod.quant), nnq.Quantize)
+        self.assertEqual(type(mod.dequant), nnq.DeQuantize)
+
+    def checkWrappedQuantizedLinear(self, mod):
+        r"""Checks that mod has been swapped for an nnq.Linear
+        module, the bias is qint32, and that the module
+        has Quantize and DeQuantize submodules
+        """
+        self.assertEqual(type(mod.module), nnq.Linear)
+        self.checkQuantDequant(mod)
+
+    def checkQuantizedLinear(self, mod):
+        self.assertEqual(type(mod), nnq.Linear)
+
+    def checkDynamicQuantizedLinear(self, mod, dtype):
+        r"""Checks that mod has been swapped for an nnqd.Linear
+        module, the bias is float.
+        """
+        self.assertEqual(type(mod), nnqd.Linear)
+        self.assertEqual(mod._packed_params.dtype, dtype)
+
+    def checkDynamicQuantizedLinearRelu(self, mod, dtype):
+        r"""Checks that mod has been swapped for an nnqd.Linear
+        module, the bias is float.
+        """
+        self.assertEqual(type(mod), nniqd.LinearReLU)
+        self.assertEqual(mod._packed_params.dtype, dtype)
+
+    def check_eager_serialization(self, ref_model, loaded_model, x):
+        # Check state dict serialization and torch.save APIs
+        model_dict = ref_model.state_dict()
+        b = io.BytesIO()
+        torch.save(model_dict, b)
+        b.seek(0)
+        # weights_only=False as we sometimes get a ScriptObect here (weird)
+        loaded_dict = torch.load(b, weights_only=False)
+        loaded_model.load_state_dict(loaded_dict)
+        ref_out = ref_model(*x)
+        load_out = loaded_model(*x)
+
+        def check_outputs(ref_out, load_out):
+            self.assertEqual(ref_out[0], load_out[0])
+            if isinstance(ref_out[1], tuple):
+                self.assertEqual(ref_out[1][0], load_out[1][0])
+                self.assertEqual(ref_out[1][1], load_out[1][1])
+            else:
+                self.assertEqual(ref_out[1], load_out[1])
+
+        check_outputs(ref_out, load_out)
+        b = io.BytesIO()
+        torch.save(ref_model, b)
+        b.seek(0)
+        # weights_only=False as this is legacy code that saves the model
+        loaded = torch.load(b, weights_only=False)
+        load_out = loaded(*x)
+        check_outputs(ref_out, load_out)
+
+    def check_weight_bias_api(self, ref_model, weight_keys, bias_keys):
+        weight = ref_model.get_weight()
+        bias = ref_model.get_bias()
+        self.assertEqual(weight_keys ^ weight.keys(), set())
+        self.assertEqual(bias_keys ^ bias.keys(), set())
+
+    def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype):
+        r"""Checks that mod has been swapped for an nnqd.LSTM type
+        module, the bias is float.
+        """
+        wt_dtype_map = {
+            torch.qint8: "quantized_dynamic",
+            torch.float16: "quantized_fp16",
+        }
+        self.assertEqual(type(mod), reference_module_type)
+        for packed_params in mod._all_weight_values:
+            self.assertEqual(
+                packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]
+            )
+
+    def checkLinear(self, mod):
+        self.assertEqual(type(mod), torch.nn.Linear)
+
+    def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype):
+        r"""Checks that mod has been swapped for an nnqd.Linear
+        module, the bias is float.
+        """
+        wt_dtype_map = {
+            torch.qint8: "quantized_dynamic",
+            torch.float16: "quantized_fp16",
+        }
+        self.assertEqual(type(mod), reference_module_type)
+        if hasattr(mod, "_all_weight_values"):
+            for packed_params in mod._all_weight_values:
+                self.assertEqual(
+                    packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]
+                )
+
+    def checkScriptable(self, orig_mod, calib_data, check_save_load=False):
+        scripted = torch.jit.script(orig_mod)
+        self._checkScriptable(orig_mod, scripted, calib_data, check_save_load)
+
+        # Use first calib_data entry as trace input
+        traced = torch.jit.trace(orig_mod, calib_data[0])
+        self._checkScriptable(orig_mod, traced, calib_data, check_save_load)
+
+    # Call this twice: once for a scripted module and once for a traced module
+    def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load):
+        self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data)
+
+        # Test save/load
+        buffer = io.BytesIO()
+        torch.jit.save(script_mod, buffer)
+
+        buffer.seek(0)
+        loaded_mod = torch.jit.load(buffer)
+        # Pending __get_state_ and __set_state__ support
+        # See tracking task https://github.com/pytorch/pytorch/issues/23984
+        if check_save_load:
+            self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data)
+
+    def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data):
+        for inp in calib_data:
+            ref_output = orig_mod(*inp)
+            scripted_output = test_mod(*inp)
+            self.assertEqual(scripted_output, ref_output)
+
+    def checkGraphModeOp(
+        self,
+        module,
+        inputs,
+        quantized_op,
+        tracing=False,
+        debug=False,
+        check=True,
+        eval_mode=True,
+        dynamic=False,
+        qconfig=None,
+    ):
+        if debug:
+            print("Testing:", str(module))
+        qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)}
+
+        if eval_mode:
+            module = module.eval()
+        if dynamic:
+            qconfig_dict = {"": default_dynamic_qconfig if qconfig is None else qconfig}
+        model = get_script_module(module, tracing, inputs[0]).eval()
+        if debug:
+            print("input graph:", model.graph)
+        models = {}
+        outputs = {}
+        for debug in [True, False]:
+            if dynamic:
+                models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug)
+                # make sure it runs
+                outputs[debug] = models[debug](inputs)
+            else:
+                # module under test can contain in-place ops, and we depend on
+                # input data staying constant for comparisons
+                inputs_copy = copy.deepcopy(inputs)
+                models[debug] = quantize_jit(
+                    model,
+                    qconfig_dict,
+                    test_only_eval_fn,
+                    [inputs_copy],
+                    inplace=False,
+                    debug=debug,
+                )
+                # make sure it runs
+                outputs[debug] = models[debug](*inputs[0])
+
+        if debug:
+            print("debug graph:", models[True].graph)
+            print("non debug graph:", models[False].graph)
+
+        if check:
+            # debug and non-debug option should have the same numerics
+            self.assertEqual(outputs[True], outputs[False])
+
+            # non debug graph should produce quantized op
+            FileCheck().check(quantized_op).run(models[False].graph)
+
+        return models[False]
+
+    def checkGraphModuleNodes(
+        self,
+        graph_module,
+        expected_node=None,
+        expected_node_occurrence=None,
+        expected_node_list=None,
+    ):
+        """Check if GraphModule contains the target node
+        Args:
+            graph_module: the GraphModule instance we want to check
+            expected_node, expected_node_occurrence, expected_node_list:
+               see docs for checkGraphModeFxOp
+        """
+        nodes_in_graph = {}
+        node_list = []
+        modules = dict(graph_module.named_modules(remove_duplicate=False))
+        for node in graph_module.graph.nodes:
+            n = None
+            if node.op == "call_function" or node.op == "call_method":
+                n = NodeSpec(node.op, node.target)
+            elif node.op == "call_module":
+                n = NodeSpec(node.op, type(modules[node.target]))
+
+            if n is not None:
+                node_list.append(n)
+                if n in nodes_in_graph:
+                    nodes_in_graph[n] += 1
+                else:
+                    nodes_in_graph[n] = 1
+
+        if expected_node is not None:
+            self.assertTrue(
+                expected_node in nodes_in_graph,
+                "node:" + str(expected_node) + " not found in the graph module",
+            )
+
+        if expected_node_occurrence is not None:
+            for expected_node, occurrence in expected_node_occurrence.items():
+                if occurrence != 0:
+                    self.assertTrue(
+                        expected_node in nodes_in_graph,
+                        "Check failed for node:" + str(expected_node) + " not found",
+                    )
+                    self.assertTrue(
+                        nodes_in_graph[expected_node] == occurrence,
+                        "Check failed for node:"
+                        + str(expected_node)
+                        + " Expected occurrence:"
+                        + str(occurrence)
+                        + " Found occurrence:"
+                        + str(nodes_in_graph[expected_node]),
+                    )
+                else:
+                    self.assertTrue(
+                        expected_node not in nodes_in_graph,
+                        "Check failed for node:"
+                        + str(expected_node)
+                        + " expected no occurrence but found",
+                    )
+
+        if expected_node_list is not None:
+            cur_index = 0
+            for n in node_list:
+                if cur_index == len(expected_node_list):
+                    return
+                if n == expected_node_list[cur_index]:
+                    cur_index += 1
+            self.assertTrue(
+                cur_index == len(expected_node_list),
+                "Check failed for graph:"
+                + self.printGraphModule(graph_module, print_str=False)
+                + "Expected ordered list:"
+                + str(expected_node_list),
+            )
+
+    def printGraphModule(self, graph_module, print_str=True):
+        modules = dict(graph_module.named_modules(remove_duplicate=False))
+        node_infos = []
+        for n in graph_module.graph.nodes:
+            node_info = " ".join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs]))
+            if n.op == "call_module":
+                node_info += " module type: " + repr(type(modules[n.target]))
+            node_infos.append(node_info)
+        str_to_print = "\n".join(node_infos)
+        if print_str:
+            print(str_to_print)
+        return str_to_print
+
+    if HAS_FX:
+
+        def assert_types_for_matched_subgraph_pairs(
+            self,
+            matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]],
+            expected_types: dict[
+                str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]
+            ],
+            gm_a: GraphModule,
+            gm_b: GraphModule,
+        ) -> None:
+            """
+            Verifies that the types specified in expected_types match
+            the underlying objects pointed to by the nodes in matched_subgraph_pairs.
+
+            An example successful test case:
+
+              matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)}
+              expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)}
+
+            The function tests for key equivalence, and verifies types with
+            instance checks.
+            """
+
+            def _get_underlying_op_type(
+                node: Node, gm: GraphModule
+            ) -> Union[Callable, str]:
+                if node.op == "call_module":
+                    mod = getattr(gm, node.target)
+                    return type(mod)
+                else:
+                    assert node.op in ("call_function", "call_method")
+                    return node.target
+
+            self.assertTrue(
+                len(matched_subgraph_pairs) == len(expected_types),
+                f"Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}",
+            )
+            for k, v in expected_types.items():
+                expected_types_a, expected_types_b = v
+                exp_type_start_a, exp_type_end_a = expected_types_a
+                exp_type_start_b, exp_type_end_b = expected_types_b
+                subgraph_a, subgraph_b = matched_subgraph_pairs[k]
+
+                act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a)
+                act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b)
+                act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a)
+                act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b)
+                types_match = (
+                    (exp_type_start_a is act_type_start_a)
+                    and (exp_type_end_a is act_type_end_a)
+                    and (exp_type_start_b is act_type_start_b)
+                    and (exp_type_end_b is act_type_end_b)
+                )
+                self.assertTrue(
+                    types_match,
+                    f"Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, "
+                    f"got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}",
+                )
+
+        def assert_ns_compare_dict_valid(
+            self,
+            act_compare_dict: dict[str, dict[str, dict[str, Any]]],
+        ) -> None:
+            """
+            Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid:
+            1. for each layer, results are recorded for two models
+            2. number of seen tensors match
+            3. shapes of each pair of seen tensors match
+            """
+            for layer_name, result_type_to_data in act_compare_dict.items():
+                for result_type, layer_data in result_type_to_data.items():
+                    self.assertTrue(
+                        len(layer_data) == 2,
+                        f"Layer {layer_name} does not have exactly two model results.",
+                    )
+                    model_name_0, model_name_1 = layer_data.keys()
+                    for res_idx in range(len(layer_data[model_name_0])):
+                        layer_data_0 = layer_data[model_name_0][res_idx]
+                        layer_data_1 = layer_data[model_name_1][res_idx]
+                        self.assertTrue(
+                            layer_data_0["type"] == layer_data_0["type"],
+                            f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.",
+                        )
+
+                        self.assertTrue(
+                            len(layer_data_0["values"]) == len(layer_data_1["values"]),
+                            f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.",
+                        )
+
+                        # F.conv1d weight has rank 3, and toq.conv1d unpacked weight
+                        # has rank 4. For now, skip the length check for conv1d only.
+                        is_weight_functional_conv1d = (
+                            result_type == NSSingleResultValuesType.WEIGHT.value
+                            and (
+                                "conv1d" in layer_data_0["prev_node_target_type"]
+                                or "conv1d" in layer_data_1["prev_node_target_type"]
+                            )
+                        )
+                        if not is_weight_functional_conv1d:
+                            for idx in range(len(layer_data_0["values"])):
+                                values_0 = layer_data_0["values"][idx]
+                                values_1 = layer_data_1["values"][idx]
+                                if isinstance(values_0, torch.Tensor):
+                                    self.assertTrue(
+                                        values_0.shape == values_1.shape,
+                                        f"Layer {layer_name}, {model_name_0} and {model_name_1} "
+                                        + f"have a shape mismatch at idx {idx}.",
+                                    )
+                                elif isinstance(values_0, list):
+                                    values_0 = values_0[0]
+                                    values_1 = values_1[0]
+                                    self.assertTrue(
+                                        values_0.shape == values_1.shape,
+                                        f"Layer {layer_name}, {model_name_0} and {model_name_1} "
+                                        + f"have a shape mismatch at idx {idx}.",
+                                    )
+                                else:
+                                    assert isinstance(
+                                        values_0, tuple
+                                    ), f"unhandled type {type(values_0)}"
+                                    assert len(values_0) == 2
+                                    assert len(values_0[1]) == 2
+                                    assert values_0[0].shape == values_1[0].shape
+                                    assert values_0[1][0].shape == values_1[1][0].shape
+                                    assert values_0[1][1].shape == values_1[1][1].shape
+
+                        # verify that ref_node_name is valid
+                        ref_node_name_0 = layer_data_0["ref_node_name"]
+                        ref_node_name_1 = layer_data_1["ref_node_name"]
+                        prev_node_name_0 = layer_data_0["prev_node_name"]
+                        prev_node_name_1 = layer_data_1["prev_node_name"]
+                        if (
+                            layer_data_0["type"]
+                            == NSSingleResultValuesType.NODE_OUTPUT.value
+                        ):
+                            self.assertTrue(ref_node_name_0 == prev_node_name_0)
+                            self.assertTrue(ref_node_name_1 == prev_node_name_1)
+                        elif (
+                            layer_data_0["type"]
+                            == NSSingleResultValuesType.NODE_INPUT.value
+                        ):
+                            self.assertTrue(ref_node_name_0 != prev_node_name_0)
+                            self.assertTrue(ref_node_name_1 != prev_node_name_1)
+
+        def checkGraphModeFxOp(
+            self,
+            model,
+            inputs,
+            quant_type,
+            expected_node=None,
+            expected_node_occurrence=None,
+            expected_node_list=None,
+            is_reference=False,
+            print_debug_info=False,
+            custom_qconfig_dict=None,
+            prepare_expected_node=None,
+            prepare_expected_node_occurrence=None,
+            prepare_expected_node_list=None,
+            prepare_custom_config=None,
+            backend_config=None,
+        ):
+            """Quantizes model with graph mode quantization on fx and check if the
+            quantized model contains the quantized_node
+
+            Args:
+                model: floating point torch.nn.Module
+                inputs: one positional sample input arguments for model
+                expected_node: NodeSpec
+                    e.g. NodeSpec.call_function(torch.quantize_per_tensor)
+                expected_node_occurrence: a dict from NodeSpec to
+                    expected number of occurrences (int)
+                    e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
+                            NodeSpec.call_method('dequantize'): 1}
+                expected_node_list: a list of NodeSpec, used to check the order
+                    of the occurrence of Node
+                    e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
+                            NodeSpec.call_module(nnq.Conv2d),
+                            NodeSpec.call_function(F.hardtanh_),
+                            NodeSpec.call_method('dequantize')]
+                is_reference: if True, enables reference mode
+                print_debug_info: if True, prints debug info
+                custom_qconfig_dict: overrides default qconfig_dict
+                prepare_expected_node: same as expected_node, but for prepare
+                prepare_expected_node_occurrence: same as
+                    expected_node_occurrence, but for prepare
+                prepare_expected_node_list: same as expected_node_list, but
+                    for prepare
+
+            Returns:
+                A dictionary with the following structure:
+               {
+                   "prepared": ...,  # the prepared model
+                   "quantized": ...,  # the quantized non-reference model
+                   "quantized_reference": ...,  # the quantized reference model
+                   "result": ...,  # the result for either quantized or
+                                   # quantized_reference model depending on the
+                                   # is_reference argument
+               }
+            """
+            # TODO: make img_data a single example instead of a list
+            if type(inputs) == list:
+                inputs = inputs[0]
+
+            if quant_type == QuantType.QAT:
+                qconfig_mapping = get_default_qat_qconfig_mapping(
+                    torch.backends.quantized.engine
+                )
+                model.train()
+            elif quant_type == QuantType.STATIC:
+                qconfig_mapping = get_default_qconfig_mapping(
+                    torch.backends.quantized.engine
+                )
+                model.eval()
+            else:
+                qconfig = default_dynamic_qconfig
+                qconfig_mapping = QConfigMapping().set_global(qconfig)
+                model.eval()
+
+            if quant_type == QuantType.QAT:
+                prepare = prepare_qat_fx
+            else:
+                prepare = prepare_fx
+
+            # overwrite qconfig_dict with custom_qconfig_dict
+            if custom_qconfig_dict is not None:
+                assert type(custom_qconfig_dict) in (
+                    QConfigMapping,
+                    dict,
+                ), "custom_qconfig_dict should be a QConfigMapping or a dict"
+                if isinstance(custom_qconfig_dict, QConfigMapping):
+                    qconfig_mapping = custom_qconfig_dict
+                else:
+                    qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict)
+            prepared = prepare(
+                model,
+                qconfig_mapping,
+                example_inputs=inputs,
+                prepare_custom_config=prepare_custom_config,
+                backend_config=backend_config,
+            )
+            if not quant_type == QuantType.DYNAMIC:
+                prepared(*inputs)
+
+            if print_debug_info:
+                print()
+                print("quant type:\n", quant_type)
+                print("original model:\n", model)
+                print()
+                print("prepared model:\n", prepared)
+
+            self.checkGraphModuleNodes(
+                prepared,
+                prepare_expected_node,
+                prepare_expected_node_occurrence,
+                prepare_expected_node_list,
+            )
+
+            prepared_copy = copy.deepcopy(prepared)
+            qgraph = convert_fx(copy.deepcopy(prepared))
+            qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared))
+            result = qgraph(*inputs)
+            result_reference = qgraph_reference(*inputs)
+            qgraph_copy = copy.deepcopy(qgraph)
+            qgraph_reference_copy = copy.deepcopy(qgraph_reference)
+
+            qgraph_to_check = qgraph_reference if is_reference else qgraph
+            if print_debug_info:
+                print()
+                print("quantized model:\n", qgraph_to_check)
+                self.printGraphModule(qgraph_to_check)
+                print()
+            self.checkGraphModuleNodes(
+                qgraph_to_check,
+                expected_node,
+                expected_node_occurrence,
+                expected_node_list,
+            )
+            return {
+                "prepared": prepared_copy,
+                "quantized": qgraph_copy,
+                "quantized_reference": qgraph_reference_copy,
+                "quantized_output": result,
+                "quantized_reference_output": result_reference,
+            }
+
+    def checkEmbeddingSerialization(
+        self,
+        qemb,
+        num_embeddings,
+        embedding_dim,
+        indices,
+        offsets,
+        set_qconfig,
+        is_emb_bag,
+        dtype=torch.quint8,
+    ):
+        # Test serialization of dynamic EmbeddingBag module using state_dict
+        if is_emb_bag:
+            inputs = [indices, offsets]
+        else:
+            inputs = [indices]
+        emb_dict = qemb.state_dict()
+        b = io.BytesIO()
+        torch.save(emb_dict, b)
+        b.seek(0)
+        loaded_dict = torch.load(b)
+        embedding_unpack = torch.ops.quantized.embedding_bag_unpack
+        # Check unpacked weight values explicitly
+        for key in emb_dict:
+            if isinstance(emb_dict[key], torch._C.ScriptObject):
+                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
+                emb_weight = embedding_unpack(emb_dict[key])
+                loaded_weight = embedding_unpack(loaded_dict[key])
+                self.assertEqual(emb_weight, loaded_weight)
+
+        # Check state dict serialization and torch.save APIs
+        if is_emb_bag:
+            loaded_qemb = nnq.EmbeddingBag(
+                num_embeddings=num_embeddings,
+                embedding_dim=embedding_dim,
+                include_last_offset=True,
+                mode="sum",
+                dtype=dtype,
+            )
+        else:
+            loaded_qemb = nnq.Embedding(
+                num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype
+            )
+        self.check_eager_serialization(qemb, loaded_qemb, inputs)
+
+        loaded_qemb.load_state_dict(loaded_dict)
+        self.assertEqual(
+            embedding_unpack(qemb._packed_params._packed_weight),
+            embedding_unpack(loaded_qemb._packed_params._packed_weight),
+        )
+
+        # Test JIT serialization
+        self.checkScriptable(qemb, [inputs], check_save_load=True)
+
+        # Test from_float call
+        if is_emb_bag:
+            float_embedding = torch.nn.EmbeddingBag(
+                num_embeddings=num_embeddings,
+                embedding_dim=embedding_dim,
+                include_last_offset=True,
+                scale_grad_by_freq=False,
+                mode="sum",
+            )
+        else:
+            float_embedding = torch.nn.Embedding(
+                num_embeddings=num_embeddings, embedding_dim=embedding_dim
+            )
+
+        if set_qconfig:
+            float_qparams_observer = PerChannelMinMaxObserver.with_args(
+                dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
+            )
+            float_embedding.qconfig = QConfig(
+                activation=default_dynamic_quant_observer, weight=float_qparams_observer
+            )
+
+        prepare_dynamic(float_embedding)
+
+        float_embedding(*inputs)
+        if is_emb_bag:
+            q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding)
+            expected_name = "QuantizedEmbeddingBag"
+        else:
+            q_embeddingbag = nnq.Embedding.from_float(float_embedding)
+            expected_name = "QuantizedEmbedding"
+
+        q_embeddingbag(*inputs)
+
+        self.assertTrue(expected_name in str(q_embeddingbag))
+
+
+class QuantizationLiteTestCase(QuantizationTestCase):
+    def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs):
+        # Creates quantized model for testing mobile script modules
+        qengine = "qnnpack"
+        with override_quantized_engine(qengine):
+            # FIXME(rec): shouldn't qconfig be passed to quantize?
+            qconfig = torch.ao.quantization.get_default_qconfig(qengine)  # noqa: F841
+            model = model_class(**kwargs)
+            model = quantize(model, test_only_eval_fn, [self.calib_data])
+
+        return model
+
+    def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor):
+        # Compares the numerical outputs for script and lite modules
+        qengine = "qnnpack"
+        with override_quantized_engine(qengine):
+            script_module = torch.jit.script(model)
+            script_module_result = script_module(input)
+
+            max_retry = 5
+            for retry in range(1, max_retry + 1):
+                # retries `max_retry` times; breaks iff succeeds else throws exception
+                try:
+                    buffer = io.BytesIO(
+                        script_module._save_to_buffer_for_lite_interpreter()
+                    )
+                    buffer.seek(0)
+                    mobile_module = _load_for_lite_interpreter(buffer)
+
+                    mobile_module_result = mobile_module(input)
+
+                    torch.testing.assert_close(
+                        script_module_result, mobile_module_result
+                    )
+                    mobile_module_forward_result = mobile_module.forward(input)
+                    torch.testing.assert_close(
+                        script_module_result, mobile_module_forward_result
+                    )
+
+                    mobile_module_run_method_result = mobile_module.run_method(
+                        "forward", input
+                    )
+                    torch.testing.assert_close(
+                        script_module_result, mobile_module_run_method_result
+                    )
+                except AssertionError as e:
+                    if retry == max_retry:
+                        raise e
+                    else:
+                        continue
+                break
+
+
+class PT2EQuantizationTestCase(QuantizationTestCase):
+    """
+    Base QuantizationTestCase for PT2 with some helper methods.
+    """
+
+    _MAP_TO_FX_TRACED_OPS = {
+        torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default,
+        torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default,
+        torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default,
+        torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default,
+        torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
+        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
+    }
+
+    def _test_quantizer(
+        self,
+        model,
+        example_inputs,
+        quantizer,
+        expected_node_occurrence,
+        expected_node_list=None,
+        check_against_fx_quant=False,
+        fx_qconfig_mapping=None,
+        export_with_dynamic_shape=False,
+        is_qat=False,
+        is_debug_mode=False,
+        training_ir_node_occurrence=None,
+    ):
+        # resetting dynamo cache
+        torch._dynamo.reset()
+        m_eager = model.eval()
+
+        # program capture
+        m = copy.deepcopy(m_eager)
+        dynamic_shapes = tuple(
+            {0: torch.export.Dim("dim")} if i == 0 else None
+            for i in range(len(example_inputs))
+        )
+        m = export_for_training(
+            m,
+            example_inputs,
+            dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
+            strict=True,
+        ).module()
+
+        if is_qat:
+            m = prepare_qat_pt2e(m, quantizer)
+        else:
+            m = prepare_pt2e(m, quantizer)
+        if is_debug_mode:
+            print("prepared model:", m)
+        # Calibrate
+        m(*example_inputs)
+        m = convert_pt2e(m)
+        if is_debug_mode:
+            print("quantized model", m)
+
+        pt2_quant_output = m(*example_inputs)
+        ns = NodeSpec
+        node_occurrence = {
+            ns.call_function(k): v for k, v in expected_node_occurrence.items()
+        }
+        if expected_node_list is None:
+            expected_node_list = []
+        node_list = [ns.call_function(n) for n in expected_node_list]
+        self.checkGraphModuleNodes(
+            m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
+        )
+        if check_against_fx_quant:
+            qconfig_mapping = fx_qconfig_mapping
+            backend_config = get_executorch_backend_config()
+            m_copy = copy.deepcopy(m_eager)
+            m_fx = prepare_fx(
+                m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
+            )
+            m_fx(*example_inputs)
+            m_fx = _convert_to_reference_decomposed_fx(
+                m_fx, backend_config=backend_config
+            )
+            m_fx = export_for_training(
+                m_fx,
+                example_inputs,
+                dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
+                strict=True,
+            ).module()
+            node_occurrence = {}
+            for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
+                if k in expected_node_occurrence:
+                    node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
+            if training_ir_node_occurrence is not None:
+                node_occurrence = {
+                    ns.call_function(k): v
+                    for k, v in training_ir_node_occurrence.items()
+                }
+            self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
+            fx_quant_output = m_fx(*example_inputs)
+            self.assertEqual(fx_quant_output, pt2_quant_output)
+        return m
+
+    def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
+        # resetting dynamo cache
+        torch._dynamo.reset()
+
+        m = export_for_training(m, example_inputs, strict=True).module()
+        if is_qat:
+            m = prepare_qat_pt2e(m, quantizer)
+        else:
+            m = prepare_pt2e(m, quantizer)
+        m(*example_inputs)
+        m = convert_pt2e(m)
+        return m
+
+    def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
+        class M(torch.nn.Module):
+            def __init__(self) -> None:
+                super().__init__()
+                self.linear = torch.nn.Linear(2, 2)
+
+            def forward(self, x):
+                return self.linear(x)
+
+        quantizer = XNNPACKQuantizer()
+        operator_config = get_symmetric_quantization_config(
+            is_per_channel=is_per_channel
+        )
+        quantizer.set_global(operator_config)
+        example_inputs = (torch.randn(2, 2),)
+        m = M().eval()
+        return self._quantize(m, quantizer, example_inputs)
+
+
+# Below are a series of toy models to use in testing quantization
+
+
+class SingleLayerLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class AnnotatedSingleLayerLinearModel(torch.nn.Module):
+    def __init__(self, qengine="fbgemm"):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
+
+    def forward(self, x):
+        x = self.fc1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class SingleLayerLinearDynamicModel(torch.nn.Module):
+    def __init__(self, qengine="fbgemm"):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearAddModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = torch.add(x, 5)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class RNNDynamicModel(torch.nn.Module):
+    def __init__(self, mod_type):
+        super().__init__()
+        self.qconfig = default_dynamic_qconfig
+        if mod_type == "GRU":
+            self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
+        if mod_type == "LSTM":
+            self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.mod(x)
+        return x
+
+
+class RNNCellDynamicModel(torch.nn.Module):
+    def __init__(self, mod_type):
+        super().__init__()
+        self.qconfig = default_dynamic_qconfig
+        if mod_type == "GRUCell":
+            self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float)
+        if mod_type == "LSTMCell":
+            self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float)
+        if mod_type == "RNNReLU":
+            self.mod = torch.nn.RNNCell(2, 2, nonlinearity="relu").to(dtype=torch.float)
+        if mod_type == "RNNTanh":
+            self.mod = torch.nn.RNNCell(2, 2, nonlinearity="tanh").to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.mod(x)
+        return x
+
+
+class LSTMwithHiddenDynamicModel(torch.nn.Module):
+    def __init__(self, qengine="fbgemm"):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)
+
+    def forward(self, x, hid):
+        x, hid = self.lstm(x, hid)
+        return x, hid
+
+
+class ConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class ConvTransposeModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class AnnotatedConvModel(torch.nn.Module):
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = self.dequant(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class AnnotatedConvTransposeModel(torch.nn.Module):
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = self.dequant(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class ConvBnModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class AnnotatedConvBnModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.qconfig = default_qconfig
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.dequant(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class ConvBnReLUModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class AnnotatedConvBnReLUModel(torch.nn.Module):
+    def __init__(self, qengine="fbgemm"):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+        self.relu = nn.ReLU(inplace=True)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        x = self.dequant(x)
+        return x
+
+    def fuse_model(self):
+        # TODO: remove this check and define two fuse_modules function on this module
+        if self.training:
+            torch.ao.quantization.fuse_modules_qat(
+                self, [["conv", "bn", "relu"]], inplace=True
+            )
+        else:
+            torch.ao.quantization.fuse_modules(
+                self, [["conv", "bn", "relu"]], inplace=True
+            )
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class TwoLayerConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class TwoLayerLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearModelWithSubmodule(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.subm = TwoLayerLinearModel()
+        self.fc = nn.Linear(5, 5)
+
+    def forward(self, x):
+        x = self.subm(x)
+        x = self.fc(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.subm.get_example_inputs()
+
+
+class AnnotatedTwoLayerLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float))
+        self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class ActivationsTestModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
+        self.quant = torch.ao.quantization.QuantStub()
+        self.hardswish = torch.nn.Hardswish().to(dtype=torch.float)
+        self.elu = torch.nn.ELU().to(dtype=torch.float)
+        self.dequant = torch.ao.quantization.DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.hardswish(x)
+        x = self.elu(x)
+        x = self.dequant(x)
+        return x
+
+
+class LinearReluModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+
+    def forward(self, x):
+        x = self.relu(self.fc(x))
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearReluLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearReluAddModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+        self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = torch.add(x, 5)
+        x = self.fc2(x)
+        self.relu = torch.nn.ReLU()
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearBnLeakyReluModel(torch.nn.Module):
+    def __init__(self, with_bn=True):
+        super().__init__()
+        self.linear = nn.Linear(5, 5)
+        self.bn1d = nn.BatchNorm1d(5)
+        self.leaky_relu = nn.LeakyReLU(0.01)
+        self.with_bn = with_bn
+
+    def forward(self, x):
+        x = self.linear(x)
+        if self.with_bn:
+            x = self.bn1d(x)
+        x = self.leaky_relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class LinearTanhModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear = nn.Linear(5, 5)
+        self.tanh = nn.Tanh()
+
+    def forward(self, x):
+        x = self.linear(x)
+        x = self.tanh(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class ConvBnAddReluModel(torch.nn.Module):
+    def __init__(
+        self,
+        with_bn=True,
+        with_relu=True,
+        left_conv=True,
+        two_conv=True,
+        use_torch_add=True,
+    ):
+        super().__init__()
+        self.conv = nn.Conv2d(5, 5, (2, 2))
+        self.conv2 = nn.Conv2d(5, 5, (2, 2))
+        self.bn = nn.BatchNorm2d(5)
+        self.relu = nn.ReLU()
+        self.with_bn = with_bn
+        self.with_relu = with_relu
+        self.two_conv = two_conv
+        self.left_conv = left_conv
+        self.use_torch_add = use_torch_add
+
+    def forward(self, x1, x2):
+        if self.two_conv:
+            if self.use_torch_add:
+                if self.with_bn:
+                    x = torch.add(self.bn(self.conv(x1)), self.conv2(x1))
+                else:
+                    x = torch.add(self.conv(x1), self.conv2(x1))
+            else:
+                if self.with_bn:
+                    x = self.bn(self.conv(x1)) + self.conv2(x1)
+                else:
+                    x = self.conv(x1) + self.conv2(x1)
+        else:
+            if self.use_torch_add:
+                if self.left_conv:
+                    if self.with_bn:
+                        x = torch.add(self.bn(self.conv(x1)), x2)
+                    else:
+                        x = torch.add(self.conv(x1), x2)
+                else:
+                    if self.with_bn:
+                        x = torch.add(x2, self.bn(self.conv(x1)))
+                    else:
+                        x = torch.add(x2, self.conv(x1))
+            else:
+                if self.left_conv:
+                    if self.with_bn:
+                        x = self.bn(self.conv(x1)) + x2
+                    else:
+                        x = self.conv(x1) + x2
+                else:
+                    if self.with_bn:
+                        x = x2 + self.bn(self.conv(x1))
+                    else:
+                        x = x2 + self.conv(x1)
+        if self.with_relu:
+            x = self.relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2))
+
+
+# TODO: self.fc should be self.conv
+class ConvReluModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+
+    def forward(self, x):
+        x = self.relu(self.fc(x))
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+# TODO: self.fc should be self.conv
+class ConvReluConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+        self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = self.fc2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+# TODO: self.fc should be self.conv
+class ConvReluAddModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
+        self.relu = torch.nn.ReLU()
+        self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = torch.add(x, 5)
+        x = self.fc2(x)
+        self.relu = torch.nn.ReLU()
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class NormalizationTestModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.quant = torch.ao.quantization.QuantStub()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.layer_norm = torch.nn.LayerNorm(8)
+        self.group_norm = torch.nn.GroupNorm(2, 8)
+        self.instance_norm1d = torch.nn.InstanceNorm1d(8)
+        self.instance_norm2d = torch.nn.InstanceNorm2d(8)
+        self.instance_norm3d = torch.nn.InstanceNorm3d(8)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.fc1(x)
+        x = self.layer_norm(x)
+        x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3))
+        x = self.instance_norm1d(x)
+        x = self.instance_norm2d(x.unsqueeze(-1))
+        x = self.instance_norm3d(x.unsqueeze(-1))
+        return x
+
+
+class NestedModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = TwoLayerLinearModel()
+        self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class AnnotatedNestedModel(torch.nn.Module):
+    def __init__(self, qengine):
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = TwoLayerLinearModel()
+        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
+        self.fc3.qconfig = default_qconfig
+        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
+        if qengine == "fbgemm":
+            self.sub2.fc1.qconfig = default_per_channel_qconfig
+        else:
+            self.sub2.fc1.qconfig = default_qconfig
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class AnnotatedSubNestedModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = QuantWrapper(TwoLayerLinearModel())
+        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
+        self.fc3.qconfig = default_qconfig
+        self.sub2.qconfig = default_qconfig
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class AnnotatedCustomConfigNestedModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = TwoLayerLinearModel()
+        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
+        self.fc3.qconfig = default_qconfig
+        self.sub2.qconfig = default_qconfig
+
+        custom_options = {"dtype": torch.quint8, "qscheme": torch.per_tensor_affine}
+        custom_qconfig = QConfig(
+            activation=default_observer.with_args(**custom_options),
+            weight=default_weight_observer,
+        )
+        self.sub2.fc1.qconfig = custom_qconfig
+
+        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
+        self.sub2.fc2 = QuantWrapper(self.sub2.fc2)
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class QuantSubModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub1 = LinearReluModel()
+        self.sub2 = QuantWrapper(TwoLayerLinearModel())
+        self.sub2.qconfig = default_qconfig
+        self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
+        self.fc3.qconfig = default_qconfig
+
+    def forward(self, x):
+        x = self.sub1(x)
+        x = self.sub2(x)
+        x = self.fc3(x)
+        return x
+
+
+class InnerModule(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
+        self.relu1 = torch.nn.ReLU()
+        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
+        self.relu2 = torch.nn.ReLU()
+
+    def forward(self, x):
+        return self.relu2(self.fc2(self.relu1(self.fc1(x))))
+
+    def fuse_modules(self):
+        fusable_layers = []
+        named_children = list(self.named_children())
+        for idx, (current_name, layer) in enumerate(named_children):
+            if isinstance(layer, torch.nn.Linear):
+                if idx >= len(named_children) - 1:
+                    break
+                if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
+                    fusable_layers.append([current_name, named_children[idx + 1][0]])
+        # TODO: remove this check and define two fuse_modules function on this module
+        if self.training:
+            torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
+        else:
+            torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
+
+
+class FunctionalLinear(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.weight = torch.rand((5, 5))
+        self.bias = torch.zeros(5)
+
+    def forward(self, x):
+        return F.linear(x, self.weight, self.bias)
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 5),)
+
+
+class SingleLayerFunctionalLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear1 = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear1.get_example_inputs()
+
+
+class TwoLayerFunctionalLinearModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear1 = FunctionalLinear()
+        self.linear2 = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear1(x)
+        x = self.linear2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear1.get_example_inputs()
+
+
+class FunctionalLinearAddModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear1 = FunctionalLinear()
+        self.linear2 = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear1(x)
+        x = torch.add(x, 5)
+        x = self.linear2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear1.get_example_inputs()
+
+
+class FunctionalLinearReluModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear(x)
+        x = F.relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear.get_example_inputs()
+
+
+class FunctionalLinearReluLinearModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear1 = FunctionalLinear()
+        self.relu = nn.ReLU()
+        self.linear2 = FunctionalLinear()
+
+    def forward(self, x):
+        x = self.linear1(x)
+        x = self.relu(x)
+        x = self.linear2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.linear1.get_example_inputs()
+
+
+class FunctionalConv2d(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.weight = torch.rand(3, 3, 3, 3)
+        self.bias = torch.rand(3)
+        self.stride = (1, 1)
+        self.padding = (0, 0)
+        self.dilation = (1, 1)
+        self.groups = 1
+
+    def forward(self, x):
+        return F.conv2d(
+            x,
+            self.weight,
+            self.bias,
+            self.stride,
+            self.padding,
+            self.dilation,
+            self.groups,
+        )
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return (torch.rand(1, 3, 5, 5),)
+
+
+class SingleLayerFunctionalConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = FunctionalConv2d()
+
+    def forward(self, x):
+        x = self.conv1(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.conv1.get_example_inputs()
+
+
+class TwoLayerFunctionalConvModel(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = FunctionalConv2d()
+        self.conv2 = FunctionalConv2d()
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.conv1.get_example_inputs()
+
+
+class FunctionalConvReluModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = FunctionalConv2d()
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = F.relu(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.conv.get_example_inputs()
+
+
+class FunctionalConvReluConvModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = FunctionalConv2d()
+        self.relu = nn.ReLU()
+        self.conv2 = FunctionalConv2d()
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        return x
+
+    def get_example_inputs(self) -> tuple[Any, ...]:
+        return self.conv1.get_example_inputs()
+
+
+class SkipQuantModel(torch.nn.Module):
+    r"""We can skip quantization by explicitly
+    setting qconfig of a submodule to None
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.sub = InnerModule()
+        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        return self.fc(self.sub(x))
+
+    def fuse_modules(self):
+        self.sub.fuse_modules()
+
+
+class AnnotatedSkipQuantModel(torch.nn.Module):
+    r"""We can skip quantization by explicitly
+    setting qconfig of a submodule to None
+    """
+
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
+        self.sub = QuantWrapper(InnerModule())
+        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+        # don't quantize this fc
+        self.fc.qconfig = None
+
+    def forward(self, x):
+        return self.fc(self.sub(x))
+
+    def fuse_modules(self):
+        self.sub.module.fuse_modules()
+
+
+class QuantStubModel(torch.nn.Module):
+    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`"""
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.fc(x)
+        return self.dequant(x)
+
+
+class ManualLinearQATModel(torch.nn.Module):
+    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`"""
+
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return self.dequant(x)
+
+
+class ManualDropoutQATModel(torch.nn.Module):
+    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`"""
+
+    def __init__(self, qengine):
+        super().__init__()
+        self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
+        self.dropout = torch.nn.Dropout(0.5)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.fc1(x)
+        x = self.dropout(x)
+        return self.dequant(x)
+
+
+class ManualLinearDynamicQATModel(torch.nn.Module):
+    r"""A Module that uses a dynamic QAT by default."""
+
+    def __init__(self, qconfig=None):
+        super().__init__()
+        self.qconfig = qconfig or default_dynamic_qat_qconfig
+        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+
+class ManualConvLinearQATModel(torch.nn.Module):
+    r"""A module with manually inserted `QuantStub` and `DeQuantStub`
+    and contains both linear and conv modules
+    """
+
+    def __init__(self, qconfig=None):
+        super().__init__()
+        self.qconfig = (
+            qconfig
+            if qconfig
+            else torch.ao.quantization.get_default_qat_qconfig("qnnpack")
+        )
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float)
+        self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float)
+        self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv(x)
+        x = x.view(-1, 64).contiguous()
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return self.dequant(x)
+
+
+class ManualConvLinearSymmQATModel(ManualConvLinearQATModel):
+    r"""Same as ManualConvLinearQATModule but with Symmetric Quantization.
+    Supported only with qnnpack.
+    """
+
+    def __init__(self) -> None:
+        super().__init__(default_symmetric_qnnpack_qat_qconfig)
+
+
+class ManualEmbeddingBagLinear(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode="sum")
+        self.emb.qconfig = default_embedding_qat_qconfig
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.linear = nn.Linear(12, 1).to(dtype=torch.float)
+        self.qconfig = get_default_qat_qconfig("qnnpack")
+
+    def forward(
+        self,
+        input: torch.Tensor,
+        offsets: Optional[torch.Tensor] = None,
+        per_sample_weights: Optional[torch.Tensor] = None,
+    ):
+        x = self.emb(input, offsets, per_sample_weights)
+        x = self.quant(x)
+        x = self.linear(x)
+        return self.dequant(x)
+
+
+class DeFusedEmbeddingBagLinear(nn.Module):
+    r"""A module to simulate QAT embedding bag with a linear layer,
+    this module uses a separate embedding and bagging op, similar
+    to that which is described in the EmbeddingBag documentation.
+
+    https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12)
+        self.emb.qconfig = default_embedding_qat_qconfig
+        self.bagging_op = torch.sum
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.linear = nn.Linear(12, 1).to(dtype=torch.float)
+        self.qconfig = get_default_qat_qconfig("qnnpack")
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        x = self.bagging_op(self.emb(input), dim=1)
+        x = self.quant(x)
+        x = self.linear(x)
+        return self.dequant(x)
+
+
+class SubModelForFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
+        self.bn = nn.BatchNorm2d(2).to(dtype=torch.float)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        return x
+
+
+class SubModelWithoutFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
+        self.relu = nn.ReLU(inplace=False).to(dtype=torch.float)
+
+    def forward(self, x):
+        return self.relu(self.conv(x))
+
+
+class ModelForFusion(nn.Module):
+    def __init__(self, qconfig):
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float)
+        self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
+        self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
+        self.sub1 = SubModelForFusion()
+        self.sub2 = SubModelWithoutFusion()
+        self.fc = nn.Linear(36, 10).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+        self.qconfig = qconfig
+        self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float)
+        self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
+        self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
+        self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
+        self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
+        self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
+        self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
+        # don't quantize sub2
+        self.sub2.qconfig = None
+        self.fc.qconfig = None
+
+    def forward(self, x):
+        x = x.squeeze(2)
+        x = self.quant(x)
+        x = self.conv3(x)
+        x = self.bn3(x)
+        x = self.relu4(x)
+        x = x.unsqueeze(2)
+        y = x.unsqueeze(2)
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu1(x)
+        x = self.sub1(x)
+        x = self.dequant(x)
+        x = self.sub2(x)
+        x = x.reshape(-1, 36).contiguous()
+        x = self.fc(x)
+        y = self.conv2(y)
+        y = self.relu2(y)
+        y = self.bn2(y)
+        y = self.relu3(y)
+        y = self.dequant(y)
+        return x
+
+
+class ConvBNReLU(nn.Sequential):
+    def __init__(self) -> None:
+        super().__init__(
+            nn.Conv2d(3, 3, 1, 1, bias=False), nn.BatchNorm2d(3), nn.ReLU(inplace=False)
+        )
+
+
+class ModelWithSequentialFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 3, 1)
+        self.relu1 = nn.ReLU(inplace=False)
+        layers = [ConvBNReLU() for _ in range(3)]
+        self.features = nn.Sequential(*layers)
+        head = [nn.Linear(300, 10), nn.ReLU(inplace=False)]
+        self.classifier = nn.Sequential(*head)
+        self.seq = nn.Sequential()
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv1(x)
+        x = self.relu1(x)
+        x = self.features(x)
+        x = torch.reshape(x, (-1, 3 * 10 * 10))
+        x = self.classifier(x)
+        x = self.seq(x)
+        x = self.dequant(x)
+        return x
+
+
+class ModelForFusionWithBias(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float)
+        self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
+        self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
+        self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float)
+        self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float)
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, x):
+        x = self.quant(x)
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu1(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.dequant(x)
+        return x
+
+
+class ModelForLinearBNFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = nn.Linear(20, 10)
+        self.bn = nn.BatchNorm1d(10)
+        nn.init.uniform_(self.bn.weight)
+        nn.init.uniform_(self.bn.bias)
+
+    def forward(self, x):
+        return self.bn(self.fc(x))
+
+
+class DummyObserver(torch.nn.Module):
+    def calculate_qparams(self):
+        return 1.0, 0
+
+    def forward(self, x):
+        return x
+
+
+class ModelForConvTransposeBNFusion(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.conv1 = nn.ConvTranspose1d(3, 3, 1)
+        self.bn1 = nn.BatchNorm1d(3)
+        self.conv2 = nn.ConvTranspose2d(3, 3, 1)
+        self.bn2 = nn.BatchNorm2d(3)
+        self.conv3 = nn.ConvTranspose3d(3, 3, 1)
+        self.bn3 = nn.BatchNorm3d(3)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = x.unsqueeze(2)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = x.unsqueeze(2)
+        x = self.conv3(x)
+        x = self.bn3(x)
+        return x
+
+
+class ModelWithFunctionals(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.mycat = nnq.FloatFunctional()
+        self.myadd = nnq.FloatFunctional()
+        self.myadd_relu = nnq.FloatFunctional()
+        self.mymatmul = nnq.FloatFunctional()
+        # Tracing doesn't work yet for c10 ops with scalar inputs
+        # https://github.com/pytorch/pytorch/issues/27097
+        # self.my_scalar_add = nnq.FloatFunctional()
+        # self.my_scalar_mul = nnq.FloatFunctional()
+
+    def forward(self, x):
+        y = self.mycat.cat([x, x, x])
+        z = self.myadd.add(y, y)
+        w = self.myadd_relu.add_relu(z, z)
+        u = self.mymatmul.matmul(w, w.T)
+        # Tracing doesn't work yet for c10 ops with scalar inputs
+        # https://github.com/pytorch/pytorch/issues/27097
+        # w = self.my_scalar_add.add_scalar(w, -0.5)
+        # w = self.my_scalar_mul.mul_scalar(w, 0.5)
+        return u
+
+
+class ResNetBase(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        norm_layer = nn.BatchNorm2d
+        inplanes = 3
+        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.bn1 = norm_layer(inplanes)
+        self.relu1 = nn.ReLU()
+        self.relu2 = nn.ReLU()
+        self.downsample = torch.nn.Identity()
+        self.myop = nn.quantized.FloatFunctional()
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = torch.nn.Linear(inplanes, 1)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu1(out)
+        identity = self.downsample(x)
+        out = self.myop.add(out, identity)
+        out = self.relu2(out)
+        out = self.avgpool(out)
+        out = torch.flatten(out, 1)
+        out = self.fc(out)
+        return out
+
+    def fuse_model(self):
+        # TODO: remove this check and define two fuse_model function on this module
+        if self.training:
+            torch.ao.quantization.fuse_modules_qat(
+                self, [["conv1", "bn1", "relu1"]], inplace=True
+            )
+        else:
+            torch.ao.quantization.fuse_modules(
+                self, [["conv1", "bn1", "relu1"]], inplace=True
+            )
+
+
+class ModelMultipleOps(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        norm_layer = nn.BatchNorm2d
+        inplanes = 3
+        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.bn1 = norm_layer(inplanes)
+        self.relu1 = nn.ReLU()
+        self.relu2 = nn.ReLU()
+        self.downsample = torch.nn.Identity()
+        self.skip_add = nn.quantized.FloatFunctional()
+        self.cat = nn.quantized.FloatFunctional()
+        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
+        self.fc = nn.Linear(12, 6)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu1(out)
+        identity = self.downsample(x)
+        out = self.skip_add.add(out, identity)
+        out = self.relu2(out)
+        out = self.avgpool(out)
+        out = self.conv2(out)
+        out = torch.nn.functional.max_pool2d(out, 2, 2)
+        out = self.cat.cat([out, out])
+        out = out.reshape(-1, 3 * 2 * 2)
+        out = self.fc(out)
+        return out
+
+
+# Model to ensure consistency of fake quant with true quant
+# Average pooling and mean operations are not modelled
+# accurately with fake-quant so this model does not
+# contain those operations
+class ModelMultipleOpsNoAvgPool(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        norm_layer = nn.BatchNorm2d
+        inplanes = 3
+        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
+        self.bn1 = norm_layer(inplanes)
+        self.relu1 = nn.ReLU()
+        self.relu2 = nn.ReLU()
+        self.skip_add = nn.quantized.FloatFunctional()
+        self.cat = nn.quantized.FloatFunctional()
+        self.maxpool = nn.MaxPool2d((4, 4))
+        self.fc = nn.Linear(12, 6)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu1(out)
+        skip = self.conv2(x)
+        out = self.skip_add.add(out, skip)
+        out = self.relu2(out)
+        out = self.maxpool(out)
+        out = self.conv2(out)
+        out = torch.nn.functional.max_pool2d(out, 2, 2)
+        out = self.cat.cat([out, out])
+        out = out.reshape(-1, 3 * 2 * 2)
+        out = self.fc(out)
+        return out
+
+
+class EmbeddingBagModule(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = torch.nn.EmbeddingBag(
+            num_embeddings=10,
+            embedding_dim=12,
+            include_last_offset=True,
+            scale_grad_by_freq=False,
+            mode="sum",
+        )
+
+    def forward(self, indices, offsets, per_sample_weights):
+        return self.emb(indices, offsets, per_sample_weights)
+
+
+class EmbeddingModule(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
+
+    def forward(self, indices):
+        return self.emb(indices)
+
+
+class EmbeddingWithStaticLinear(torch.nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12)
+        self.fc = torch.nn.Linear(4, 2)
+        self.emb.qconfig = float_qparams_weight_only_qconfig
+        self.qconfig = default_qconfig
+        self.quant = QuantStub()
+        self.dequant = DeQuantStub()
+
+    def forward(self, indices, offsets, linear_in):
+        emb = self.emb(indices, offsets)
+        q_x = self.quant(linear_in)
+        fc = self.fc(q_x)
+        fc = self.dequant(fc)
+        features = torch.cat([fc] + [emb], dim=1)
+        return features
+
+
+class DenseTopMLP(nn.Module):
+    def __init__(
+        self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out
+    ) -> None:
+        super().__init__()
+
+        self.dense_mlp = nn.Sequential(
+            nn.Linear(dense_dim, dense_out),
+        )
+        self.top_mlp = nn.Sequential(
+            nn.Linear(dense_out + embedding_dim, top_out_in),
+            nn.Linear(top_out_in, top_out_out),
+        )
+
+    def forward(
+        self,
+        sparse_feature: torch.Tensor,
+        dense: torch.Tensor,
+    ) -> torch.Tensor:
+        dense_feature = self.dense_mlp(dense)
+        features = torch.cat([dense_feature] + [sparse_feature], dim=1)
+
+        out = self.top_mlp(features)
+        return out
+
+
+# thin wrapper around embedding bag, because tracing inside nn.Embedding
+# bag is not supported at the moment and this is top level
+class EmbBagWrapper(nn.Module):
+    def __init__(self, num_embeddings, embedding_dim):
+        super().__init__()
+        self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
+
+    def forward(self, indices, offsets):
+        return self.emb_bag(indices, offsets)
+
+
+class SparseNNModel(nn.Module):
+    _NUM_EMBEDDINGS = 10
+    _EMBEDDING_DIM = 5
+    _DENSE_DIM = 4
+    _DENSE_OUTPUT = 2
+    _TOP_OUT_IN = 2
+    _TOP_OUT_OUT = 2
+    _TOP_MLP_DIM = 1
+
+    def __init__(self) -> None:
+        super().__init__()
+
+        self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM)
+        self.dense_top = DenseTopMLP(
+            self._DENSE_DIM,
+            self._DENSE_OUTPUT,
+            self._EMBEDDING_DIM,
+            self._TOP_OUT_IN,
+            self._TOP_OUT_OUT,
+        )
+
+    def forward(
+        self,
+        sparse_indices: torch.Tensor,
+        sparse_offsets: torch.Tensor,
+        dense: torch.Tensor,
+    ) -> torch.Tensor:
+        sparse_feature = self.model_sparse(sparse_indices, sparse_offsets)
+        out = self.dense_top(sparse_feature, dense)
+
+        return out
+
+
+class TestHelperModules:
+    class ControlFlow(torch.nn.Module):
+        def forward(
+            self,
+            xs: torch.Tensor,
+            pred1: torch.Tensor,
+            pred2: torch.Tensor,
+            y: torch.Tensor,
+        ) -> torch.Tensor:
+            def true_nested(y: torch.Tensor) -> torch.Tensor:
+                y = y + y
+                y = torch.mm(y, y)
+                return y
+
+            def false_nested(y: torch.Tensor) -> torch.Tensor:
+                return torch.mm(y, y)
+
+            def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
+                z = control_flow.cond(pred2, true_nested, false_nested, [x])
+                return x + z
+
+            def false_fn(x: torch.Tensor, _) -> torch.Tensor:
+                return x.cos()
+
+            def map_fn(
+                x: torch.Tensor,
+                pred1: torch.Tensor,
+                pred2: torch.Tensor,
+                y: torch.Tensor,
+            ) -> torch.Tensor:
+                x = x.cos()
+                y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
+                x = x + y
+                return x.sin()
+
+            y = torch.mm(y, y)
+            return control_flow.map(map_fn, xs, pred1, pred2, y)
+
+        def example_inputs(self):
+            return (
+                torch.ones(2, 2),
+                torch.tensor([False]),
+                torch.tensor([False]),
+                torch.ones(2, 2),
+            )
+
+    class Conv2dPropAnnotaton(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 3, 3)
+            self.linear = torch.nn.Linear(3, 3)
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = x.view(-1, 3)
+            x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
+            x = self.linear(x)
+            return x
+
+    class Conv2dWithObsSharingOps(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 3, 3)
+            self.hardtanh = torch.nn.Hardtanh()
+            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = self.adaptive_avg_pool2d(x)
+            x = self.hardtanh(x)
+            x = torch.mean(x)
+            return x
+
+    class Conv2dWithTwoLinearPermute(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 16, 3)
+            self.linear1 = torch.nn.Linear(16, 8, bias=False)
+            self.linear2 = torch.nn.Linear(8, 8)
+
+        def forward(self, x):
+            conv_out = self.conv(x)
+            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
+            return self.linear2(self.linear1(permute_out))
+
+    class Conv2dWithTwoLinear(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 16, 3)
+            self.linear1 = torch.nn.Linear(64, 8, bias=False)
+            self.linear2 = torch.nn.Linear(8, 8)
+
+        def forward(self, x):
+            conv_out = self.conv(x)
+            reshape_out = torch.reshape(conv_out, (2, 64))
+            return self.linear2(self.linear1(reshape_out))
+
+    class ConvLinearWPermute(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 8, 3)
+            self.linear1 = torch.nn.Linear(8, 8)
+
+        def forward(self, x):
+            conv_out = self.conv(x)
+            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
+            return self.linear1(permute_out)
+
+    class TwoLinearModule(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.linear1 = torch.nn.Linear(8, 16, bias=False)
+            self.linear2 = torch.nn.Linear(16, 8)
+
+        def forward(self, x):
+            return self.linear2(self.linear1(x))
+
+        def example_inputs(self):
+            return (torch.randn(2, 8),)
+
+    class ConvMaxPool2d(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(2, 2, 1)
+            self.pool = torch.nn.MaxPool2d(1, 1)
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = self.pool(x)
+            return x
+
+    class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 3, 3)
+            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = self.adaptive_avg_pool2d(x)
+            return x
+
+
+    class ConvWithBNRelu(torch.nn.Module):
+        def __init__(self, relu, dim=2, bn=True, bias=True, padding=0):
+            super().__init__()
+            convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
+            bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
+            self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding)
+
+            if bn:
+                self.bn = bns[dim](3)
+            else:
+                self.bn = torch.nn.Identity()
+            if relu:
+                self.relu = torch.nn.ReLU()
+            else:
+                self.relu = torch.nn.Identity()
+
+        def forward(self, x):
+            x = self.conv(x)
+            x = self.bn(x)
+            return self.relu(x)
+
+    class ConvTWithBNRelu(torch.nn.Module):
+        def __init__(self, relu, dim=2, bn=True, bias=True):
+            super().__init__()
+            convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d}
+            bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
+            self.convt = convts[dim](3, 3, 3, bias=bias)
+
+            if bn:
+                self.bn = bns[dim](3)
+            else:
+                self.bn = torch.nn.Identity()
+            if relu:
+                self.relu = torch.nn.ReLU()
+            else:
+                self.relu = torch.nn.Identity()
+
+        def forward(self, x):
+            x = self.convt(x)
+            x = self.bn(x)
+            return self.relu(x)
+
+    class Conv2dThenConv1d(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv1d = torch.nn.Conv1d(3, 3, 3)
+            self.conv2d = torch.nn.Conv2d(3, 3, 3)
+
+        def forward(self, x):
+            x = self.conv2d(x)
+            x = x.squeeze(0)
+            x = self.conv1d(x)
+            return x
+
+        def example_inputs(self):
+            return (torch.randn(1, 3, 5, 5),)
+
+    class Conv2dWithCat(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv1 = torch.nn.Conv2d(3, 3, 3)
+            self.conv2 = torch.nn.Conv2d(3, 3, 3)
+
+        def forward(self, x, y):
+            x = self.conv1(x)
+            y = self.conv2(y)
+            z = torch.cat([x, y], dim=1)
+            return z
+
+    class Conv2dWithTwoCat(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv1 = torch.nn.Conv2d(3, 3, 3)
+            self.conv2 = torch.nn.Conv2d(3, 3, 3)
+
+        def forward(self, x1, x2, x3, x4):
+            x1 = self.conv1(x1)
+            x2 = self.conv2(x2)
+            y = torch.cat([x1, x2], dim=1)
+            z = x3 + x4
+            w = torch.cat([z, y])
+            return w
+
+    class Conv2dWithSplit(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv1 = torch.nn.Conv2d(3, 3, 3)
+            self.conv2 = torch.nn.Conv2d(3, 3, 3)
+
+        def forward(self, x):
+            x = self.conv1(x)
+            # use split so we get a list of Tensors
+            x1, x2 = torch.split(x, 2, dim=1)
+            y = torch.cat([x1, x2], dim=1)
+            return y
+
+        def example_inputs(self):
+            return (torch.randn(1, 3, 16, 16),)
+
+    class ThreeAdd(torch.nn.Module):
+        def forward(self, x1, x2, x3, x4):
+            y = x1 + x2
+            z = x3 + x4
+            w = y + z
+            return w
+
+    class EmbeddingModule(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
+
+        def forward(self, indices):
+            return self.emb(indices)
+
+    class EmbeddingConvLinearModule(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
+            self.conv = torch.nn.Conv2d(8, 16, (1, 3))
+            self.linear = torch.nn.Linear(16, 8)
+
+        def forward(self, indices):
+            embeddings = self.emb(indices)
+            embeddings = torch.unsqueeze(embeddings, dim=0)
+            embeddings = torch.permute(embeddings, (0, 3, 1, 2))
+            conv_out = self.conv(embeddings)
+            conv_out = torch.permute(conv_out, (0, 2, 3, 1))
+            conv_out = torch.squeeze(conv_out, dim=0)
+            return self.linear(conv_out)
+
+    class AddInplaceAdd(torch.nn.Module):
+        def forward(self, x, y):
+            x = x + y
+            x += y
+            return x
+
+    class MulInplaceMul(torch.nn.Module):
+        def forward(self, x, y):
+            x = x * y
+            x *= y
+            return x
+
+    class AddMulScalar(torch.nn.Module):
+        def forward(self, x):
+            x = x + 3
+            x = x * 3
+            x += 3
+            x *= 3
+            return x
+
+    class ConvBnReLU2dAndLinearReLU(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True)
+            self.linear = torch.nn.Linear(3, 8, bias=False)
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, x):
+            x = self.conv_bn_relu(x)
+            permute_out = torch.permute(x, (0, 2, 3, 1))
+            linear_out = self.linear(permute_out)
+            return linear_out
+
+    class GroupwiseConv2d(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.conv = torch.nn.Conv2d(4, 4, 3, groups=2)
+
+        def forward(self, x):
+            return self.conv(x)
+
+        def example_inputs(self):
+            return (torch.randn(2, 4, 10, 10),)
+
+    class LinearReluModel(torch.nn.Module):
+        def __init__(self) -> None:
+            super().__init__()
+            self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, x):
+            x = self.relu(self.fc(x))
+            return x
+
+
+def _generate_qdq_quantized_model(
+    mod, inputs, is_qat=False, is_dynamic=False, quantizer=None
+):
+    def get_default_quantizer(is_qat, is_dynamic, inputs):
+        has_xpu = any(
+            isinstance(input, torch.Tensor) and input.device.type == "xpu"
+            for input in inputs
+        )
+        if has_xpu:
+            quantizer = XPUInductorQuantizer()
+            assert (not is_qat) and (
+                not is_dynamic
+            ), "QAT and dynamic quantization is not supported at XPU backend currently"
+            quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config())
+        else:
+            quantizer = X86InductorQuantizer()
+            quantizer.set_global(
+                xiq.get_default_x86_inductor_quantization_config(
+                    is_qat=is_qat, is_dynamic=is_dynamic
+                )
+            )
+        return quantizer
+
+    maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
+    with maybe_no_grad:
+        export_model = export_for_training(mod, inputs, strict=True).module()
+        quantizer = (
+            quantizer
+            if quantizer
+            else get_default_quantizer(is_qat, is_dynamic, inputs)
+        )
+        prepare_model = (
+            prepare_qat_pt2e(export_model, quantizer)
+            if is_qat
+            else prepare_pt2e(export_model, quantizer)
+        )
+        prepare_model(*inputs)
+        torch.ao.quantization.move_exported_model_to_eval(prepare_model)
+        convert_model = convert_pt2e(prepare_model)
+        return convert_model
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py
new file mode 100644
index 0000000000000000000000000000000000000000..e433d907c3b99d9ddf5f1e50d2200862e246c6bc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_quantized.py
@@ -0,0 +1,482 @@
+# mypy: ignore-errors
+
+r"""Importing this file includes common utility methods for checking quantized
+tensors and modules.
+"""
+import numpy as np
+import torch
+from torch import Tensor
+from contextlib import contextmanager
+from torch.testing._internal.common_utils import TEST_WITH_TSAN, IS_PPC, IS_MACOS, IS_WINDOWS
+
+supported_qengines = torch.backends.quantized.supported_engines
+supported_qengines.remove('none')
+# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
+# QNNPACK is not supported on PPC
+if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_TSAN, IS_MACOS, IS_WINDOWS]):
+    supported_qengines.remove('qnnpack')
+
+def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
+                       output_padding=0):
+    """Computes the output shape given convolution parameters."""
+    return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
+                     * (dilation - 1)) / stride) + 2 * output_padding + 1
+
+# Quantization references
+def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
+    """Quantizes a numpy array."""
+    if qmin is None:
+        qmin = np.iinfo(dtype).min
+    if qmax is None:
+        qmax = np.iinfo(dtype).max
+    qx = np.round(x / scale + zero_point).astype(np.int64)
+    qx = np.clip(qx, qmin, qmax)
+    qx = qx.astype(dtype)
+    return qx
+
+
+def _dequantize(qx, scale, zero_point):
+    """Dequantizes a numpy array."""
+    x = (qx.astype(float) - zero_point) * scale
+    return x
+
+
+def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
+    """Requantizes a numpy array, i.e., intermediate int32 or int16 values are
+    converted back to given type"""
+    qx = (x * multiplier).round() + zero_point
+    qx = np.clip(qx, qmin, qmax).astype(qtype)
+    return qx
+
+def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
+    """Calculate the dynamic quantization parameters (scale, zero_point)
+    according to the min and max element of the tensor"""
+    assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
+    if qscheme == torch.per_tensor_symmetric:
+        assert dtype == torch.qint8
+    if isinstance(X, torch.Tensor):
+        X = X.numpy()
+    if dtype == torch.qint8:
+        if reduce_range:
+            qmin, qmax = -64, 63
+        else:
+            qmin, qmax = -128, 127
+    else:  # dtype == torch.quint8
+        if reduce_range:
+            qmin, qmax = 0, 127
+        else:
+            qmin, qmax = 0, 255
+    min_val = X.min()
+    max_val = X.max()
+    is_symmetric = (qscheme == torch.per_tensor_symmetric)
+    if min_val == max_val:
+        scale = 1.0
+        zero_point = 0
+    else:
+        if is_symmetric:
+            max_val = max(max_val, -min_val)
+            min_val = -max_val
+            scale = (max_val - min_val) / (qmax - qmin)
+            scale = max(scale, np.finfo(np.float32).eps)
+            zero_point = 0
+        else:
+            max_val = max(max_val, 0.0)
+            min_val = min(min_val, 0.0)
+            scale = (max_val - min_val) / (qmax - qmin)
+            scale = max(scale, np.finfo(np.float32).eps)
+            zero_point = qmin - round(min_val / scale)
+            zero_point = max(qmin, zero_point)
+            zero_point = min(qmax, zero_point)
+    return [float(scale), int(zero_point)]
+
+def _calculate_dynamic_per_channel_qparams(X, dtype):
+    """Calculate the dynamic quantization parameters (scale, zero_point)
+    according to the min and max element of the tensor"""
+    if isinstance(X, torch.Tensor):
+        X = X.numpy()
+    qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
+    n_levels = qmax - qmin
+    scale = np.zeros(X.shape[0], dtype=np.float64)
+    zero_point = np.zeros(X.shape[0], dtype=np.int64)
+    for i in range(zero_point.shape[0]):
+        min_val = X.min()
+        max_val = X.max()
+        if min_val == max_val:
+            scale[i] = 1.0
+            zero_point[i] = 0
+        else:
+            max_val = max(max_val, 0.0)
+            min_val = min(min_val, 0.0)
+            scale[i] = (max_val - min_val) / n_levels
+            scale[i] = max(scale[i], np.finfo(np.float32).eps)
+            zero_point[i] = qmin - round(min_val / scale[i])
+            zero_point[i] = max(qmin, zero_point[i])
+            zero_point[i] = min(qmax, zero_point[i])
+
+    return scale, zero_point
+
+def _snr(x, x_hat):
+    """Calculates the signal to noise ratio and returns the signal and noise
+    power, as well as the SNR in dB.
+    If the input is a list/tuple this function is called recursively on each
+    element. The result will have the same nested structure as the inputs.
+
+    Args:
+        x, x_hat: Either a tensor or a nested list/tuple of tensors.
+    Returns:
+        signal, noise, SNR(in dB): Either floats or a nested list of floats
+    """
+    if isinstance(x, (list, tuple)):
+        assert len(x) == len(x_hat)
+        res = [_snr(x[idx], x_hat[idx]) for idx in range(len(x))]
+        return res
+    if x_hat.is_quantized:
+        x_hat = x_hat.dequantize()
+    if x.is_quantized:
+        x = x.dequantize()
+    noise = (x - x_hat).norm()
+    if noise == 0:
+        return 0.0, float('inf'), float('inf')
+    signal = x.norm()
+    snr = signal / noise
+    snr_db = 20 * snr.log10()
+    return signal, noise, snr_db
+
+@contextmanager
+def override_quantized_engine(qengine):
+    previous = torch.backends.quantized.engine
+    torch.backends.quantized.engine = qengine
+    try:
+        yield
+    finally:
+        torch.backends.quantized.engine = previous
+
+@contextmanager
+def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
+    try:
+        if qengine_is_qnnpack:
+            torch._C._set_default_mobile_cpu_allocator()
+        yield
+    finally:
+        if qengine_is_qnnpack:
+            torch._C._unset_default_mobile_cpu_allocator()
+
+# TODO: Update all quantization tests to use this decorator.
+# Currently for some of the tests it seems to have inconsistent params
+# for fbgemm vs qnnpack.
+def override_qengines(qfunction):
+    def test_fn(*args, **kwargs):
+        for qengine in supported_qengines:
+            with override_quantized_engine(qengine):
+                # qfunction should not return anything.
+                qfunction(*args, **kwargs)
+    return test_fn
+
+def qengine_is_fbgemm():
+    return torch.backends.quantized.engine == 'fbgemm'
+def qengine_is_qnnpack():
+    return torch.backends.quantized.engine == 'qnnpack'
+def qengine_is_onednn():
+    return torch.backends.quantized.engine == 'onednn'
+def qengine_is_x86():
+    return torch.backends.quantized.engine == 'x86'
+
+# Helper function used to simulate per-channel fake-quant against any axis
+def _permute_to_axis_zero(X, axis):
+    new_axis_list = list(range(X.dim()))
+    new_axis_list[axis] = 0
+    new_axis_list[0] = axis
+    y = X.permute(tuple(new_axis_list))
+    return y, new_axis_list
+
+# Reference method for fake quantize
+# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
+def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
+    dtype = X.dtype
+    X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
+    res = torch.zeros_like(X)
+
+    for i in range(X.size()[0]):
+        res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
+                  per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
+
+    out = res.permute(tuple(permute_axis_list))
+    return out.to(dtype)
+
+# Reference method for the gradient of the fake quantize operator
+# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
+def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
+    dtype = X.dtype
+    X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
+    Xq = torch.zeros_like(X)
+    for i in range(X.size()[0]):
+        Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
+    Xq = Xq.permute(tuple(permute_axis_list))
+    mask = (Xq >= quant_min) * (Xq <= quant_max)
+    res = torch.zeros_like(dY)
+    res[mask] = dY[mask]
+    return res.to(dtype)
+
+def to_tensor(X, device):
+    if not isinstance(X, torch.Tensor):
+        X = torch.tensor(X)
+    else:
+        X = X.detach().clone()
+    return X.to(device=torch.device(device), dtype=torch.float32)
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29
+def _n_ones(n: int) -> int:
+    return (1 << n) - 1
+
+EBITS_F32, MBITS_F32 = 8, 23
+F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29
+def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert FP32 numbers to sub-byte floating point numbers with the given
+    number of exponent and mantissa bits.
+
+    Input: torch.Tensor of dtype torch.float
+    Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+
+    Note: there are no special values (NaN, inf) support in this code. Values
+    outside the representable range of Floatx after rounding are clamped to the
+    maximum Floatx magnitude (sign is preserved).
+
+    Code below is an adaptation of https://fburl.com/code/ciwofcg4
+
+    Background 1: last answer in https://stackoverflow.com/q/8981913
+    Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
+    """
+    assert x.dtype == torch.float
+    assert 1 + ebits + mbits <= 8
+
+    # calculate constants
+    exp_bias = _n_ones(ebits - 1)
+    max_int = _n_ones(ebits + mbits)
+    sign_mask = 1 << (ebits + mbits)
+
+    # TODO document this better
+    magic_adder = _n_ones(MBITS_F32 - mbits - 1)
+
+    # all E bits and M bits are 1s
+    max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits))
+
+    # E bits = 1, M bits = 0
+    min_normal = 2 ** (1 - exp_bias)
+
+    denorm_exp = (
+        # exp bias conversion between formats
+        (F32_EXP_BIAS - exp_bias)
+        # mantissa length difference between formats
+        + (MBITS_F32 - mbits)
+        # add one to encoded exponent for denormalized numbers
+        + 1
+    )
+    denorm_mask_int = denorm_exp << MBITS_F32
+
+    # reinterpret int32 as float32
+    denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(
+        torch.float32
+    )
+
+    # save the sign
+    # Note that we have torch.uint32, but some ops like cpu bit shifts
+    # do not work on it. So, we stay in int32.
+    x = x.view(torch.int32)
+    sign = x & 0x80000000
+
+    # set everything to positive, will add sign back at the end
+    x = x ^ sign
+
+    # TODO: can the branch floating point comparisons below be done without
+    # converting to float? probably but need to verify
+    x = x.view(torch.float)
+
+    # rewrite saturate/denorm/norm branches without explicit data dependent
+    # control flow, to be more compiler friendly
+    saturate_mask = x >= max_normal
+    denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
+    normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
+
+    #
+    # branch 1: saturate to max val - handled later in the code which combines
+    #   the branches
+    #
+
+    #
+    # branch 2: to conversion to denormal as well as rounding up to normal
+    #
+    denormal_x = x + denorm_mask_float
+    denormal_x = denormal_x.view(torch.int32)
+    denormal_x -= denorm_mask_int
+    denormal_x = denormal_x.to(torch.uint8)
+
+    #
+    # branch 3: stay in normal range, adjust the exponent and round
+    #
+    normal_x = x.view(torch.int32)
+    # resulting mantissa is odd
+    mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
+    # update exponent, rounding bias part 1
+    val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
+    normal_x += val_to_add
+    # rounding bias part 2
+    normal_x += mant_odd
+    # take the bits!
+    normal_x = normal_x >> (MBITS_F32 - mbits)
+    normal_x = normal_x.to(torch.uint8)
+
+    #
+    # combine the branches
+    #
+    x = torch.full_like(x, max_int, dtype=torch.uint8)
+    x = torch.where(denormal_mask, denormal_x, x)
+    x = torch.where(normal_mask, normal_x, x)
+
+    # add sign back
+    sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
+    sign_lp = sign_lp.to(torch.uint8)
+    # Right shift of a negative signed integer can fill the least significant
+    # bits with either 1s or 0s, depending on the implementation. Since PyTorch
+    # doesn't have an uint32 dtype, we mask out these bits to get just the
+    # f4 sign bit
+    sign_lp = sign_lp & sign_mask
+    x = x | sign_lp
+
+    return x.to(torch.uint8)
+
+
+# copy-pasted from
+# https://github.com/pytorch/ao/blob/29488018d99af7f7339f06353c6b5bbeae8a1493/torchao/prototype/custom_fp_utils.py#L147
+def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
+    """Convert sub-byte floating point numbers with the given number of exponent
+    and mantissa bits to FP32.
+
+    Input: torch.Tensor of dtype uint8, where the bit encoding is stored
+    in the least significant bits. e.g.
+      fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
+      fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
+    Output: torch.Tensor of dtype fp32 with the dequantized value
+    """
+    assert x.dtype == torch.uint8
+    assert 1 + ebits + mbits <= 8
+
+    sign_mask = 1 << (ebits + mbits)
+    exp_bias = _n_ones(ebits - 1)
+    mantissa_mask = _n_ones(mbits)
+
+    # save the sign
+    sign_lp = x & sign_mask
+
+    # set everything to positive, will add sign back at the end
+    x_pos = x ^ sign_lp
+
+    #
+    # 1. Calculate zero mask
+    #
+    zero_mask = x_pos == 0
+
+    #
+    # 2. Calculate the denormal path mask
+    #
+    denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))
+
+    #
+    # 3. Calculate the normal path
+    #
+
+    # calculate the new exponent and shift it to bits 2:9 of the result
+    exp_biased_lp = x_pos >> mbits
+    exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
+    exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32
+
+    # shift the mantissa to bits 10:32 of the result
+    mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
+    mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
+    result = exp_biased_f32 | mantissa_f32
+
+    #
+    # 4. Add the zero and denormal casts to the already casted normal path
+    #
+    result[zero_mask] = 0
+
+    denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS
+
+    # fast path.
+    # without this, performance for FP4_E2M1 is slower by 2x
+    if mbits == 1:
+        result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32
+
+    else:
+        # iterate over all possible values of mantissa
+        # i=0, j=1
+        # i=1, j=10,11
+        # i=2, j=100,101,110,111
+        # and so on
+        for i in range(mbits):
+            for mantissa_cmp in range(1 << i, 1 << (i + 1)):
+                # left shift mantissa until it overflows (create an implicit 1)
+                # subtract exponent by the same amount
+                left_shift = mbits - i
+                mantissa_f32 = (mantissa_cmp - (1 << i)) << (
+                    left_shift + MBITS_F32 - mbits
+                )
+                exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32
+
+                # we can update this in-place since the values won't overlap
+                # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int'
+                # thus we use + instead of | here
+                mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = (
+                    exp_biased_f32 + mantissa_f32
+                )
+
+        result = torch.where(denormal_mask, mantissa_lp_int32, result)
+
+    # add sign back
+    sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
+    result = result | sign_f32
+
+    return result.view(torch.float)
+
+# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+def to_blocked(input_matrix) -> torch.Tensor:
+    """
+    Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
+
+    See:
+        https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
+
+    Args:
+        input_matrix: Input tensor of shape (H, W)
+
+    Returns:
+        Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
+    """
+    rows, cols = input_matrix.shape
+    n_row_blocks = ceil_div(rows, 128)
+    n_col_blocks = ceil_div(cols, 4)
+
+    # Calculate the padded shape
+    padded_rows = n_row_blocks * 128
+    padded_cols = n_col_blocks * 4
+
+    padded = input_matrix
+    # Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
+    if (rows, cols) != (padded_rows, padded_cols):
+        padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
+        padded[:rows, :cols] = input_matrix
+
+    # Rearrange the blocks
+    blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
+    rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
+
+    return rearranged.flatten()
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aeb78035cb845938a087f0cd58ece1eb045edc4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_subclass.py
@@ -0,0 +1,346 @@
+# mypy: ignore-errors
+
+import torch
+from copy import deepcopy
+from torch.utils._pytree import tree_map
+import torch.utils._pytree as pytree
+
+
+# TODO: Move LoggingTensor here.
+from torch.testing._internal.logging_tensor import LoggingTensor
+
+
+# Base class for wrapper-style tensors.
+class WrapperTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, *args, **kwargs):
+        t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
+        if "size" not in kwargs:
+            size = t.size()
+        else:
+            size = kwargs["size"]
+            del kwargs["size"]
+        if "dtype" not in kwargs:
+            kwargs["dtype"] = t.dtype
+        if "layout" not in kwargs:
+            kwargs["layout"] = t.layout
+        if "device" not in kwargs:
+            kwargs["device"] = t.device
+        if "requires_grad" not in kwargs:
+            kwargs["requires_grad"] = False
+        # Ignore memory_format and pin memory for now as I don't know how to
+        # safely access them on a Tensor (if possible??)
+
+        wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
+        wrapper._validate_methods()
+        return wrapper
+
+    @classmethod
+    def get_wrapper_properties(cls, *args, **kwargs):
+        # Should return both an example Tensor and a dictionary of kwargs
+        # to override any of that example Tensor's properly.
+        # This is very similar to the `t.new_*(args)` API
+        raise NotImplementedError("You need to implement get_wrapper_properties")
+
+    def _validate_methods(self):
+        # Skip this if not in debug mode?
+        # Changing these on the python side is wrong as it would not be properly reflected
+        # on the c++ side
+        # This doesn't catch attributes set in the __init__
+        forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
+        for el in forbidden_overrides:
+            if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
+                raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
+                                   f"property {el} but this is not allowed as such change would "
+                                   "not be reflected to c++ callers.")
+
+
+class WrapperTensorWithCustomSizes(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, t, requires_grad=False):
+        return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "sizes"}
+
+    def __init__(self, t, requires_grad=False):
+        self.t = t
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        if kwargs is None:
+            kwargs = {}
+
+        def unwrap(e):
+            return e.t if isinstance(e, WrapperTensorWithCustomSizes) else e
+
+        def wrap(e):
+            return WrapperTensorWithCustomSizes(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"t={self.t}")
+
+
+class WrapperTensorWithCustomStrides(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, t, requires_grad=False):
+        return t, {"requires_grad": requires_grad, "dispatch_sizes_strides_policy": "strides"}
+
+    def __init__(self, t, requires_grad=False):
+        self.t = t
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        if kwargs is None:
+            kwargs = {}
+
+        def unwrap(e):
+            return e.t if isinstance(e, WrapperTensorWithCustomStrides) else e
+
+        def wrap(e):
+            return WrapperTensorWithCustomStrides(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"t={self.t}")
+
+
+class DiagTensorBelow(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, diag, requires_grad=False):
+        assert diag.ndim == 1
+        return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
+
+    def __init__(self, diag, requires_grad=False):
+        self.diag = diag
+
+    handled_ops = {}
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        if not all(issubclass(cls, t) for t in types):
+            return NotImplemented
+
+        # For everything else, call the handler:
+        fn = cls.handled_ops.get(func.__name__, None)
+        if fn:
+            return fn(*args, **(kwargs or {}))
+        else:
+            # Note that here, because we don't need to provide the autograd formulas
+            # we can have a default "fallback" that creates a plain Tensor based
+            # on the diag elements and calls the func again.
+
+            def unwrap(e):
+                return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
+
+            def wrap(e):
+                if isinstance(e, torch.Tensor) and e.ndim == 1:
+                    return DiagTensorBelow(e)
+                if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
+                    return DiagTensorBelow(e.diag())
+                return e
+
+            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+            return rs
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"diag={self.diag}")
+
+
+class SparseTensor(WrapperTensor):
+    @classmethod
+    def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
+        assert values.device == indices.device
+        return values, {"size": size, "requires_grad": requires_grad}
+
+    def __init__(self, size, values, indices, requires_grad=False):
+        self.values = values
+        self.indices = indices
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
+
+    def sparse_to_dense(self):
+        res = torch.zeros(self.size(), dtype=self.values.dtype)
+        res[self.indices.unbind(1)] = self.values
+        return res
+
+    @staticmethod
+    def from_dense(t):
+        indices = t.nonzero()
+        values = t[indices.unbind(1)]
+        return SparseTensor(t.size(), values, indices)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        func_name = f"{func.__module__}.{func.__name__}"
+
+        res = cls._try_call_special_impl(func_name, args, kwargs)
+        if res is not NotImplemented:
+            return res
+
+        # Otherwise, use a default implementation that construct dense
+        # tensors and use that to compute values
+        def unwrap(e):
+            return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
+
+        # Wrap back all Tensors into our custom class
+        def wrap(e):
+            # Check for zeros and use that to get indices
+            return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
+
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
+        return rs
+
+    # To show how things happen later
+    def __rmul__(self, other):
+        return super().__rmul__(other)
+
+    _SPECIAL_IMPLS = {}
+
+    @classmethod
+    def _try_call_special_impl(cls, func, args, kwargs):
+        if func not in cls._SPECIAL_IMPLS:
+            return NotImplemented
+        return cls._SPECIAL_IMPLS[func](args, kwargs)
+
+
+# Example non-wrapper subclass that stores extra state.
+class NonWrapperTensor(torch.Tensor):
+    def __new__(cls, data):
+        t = torch.Tensor._make_subclass(cls, data)
+        t.extra_state = {
+            'last_func_called': None
+        }
+        return t
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        result = super().__torch_function__(func, types, args, kwargs)
+
+        if isinstance(result, cls):
+            # Do something with the extra state. For the example here, just store the name of the
+            # last function called (skip for deepcopy so the copy has the same extra state).
+            if func is torch.Tensor.__deepcopy__:
+                result.extra_state = deepcopy(args[0].extra_state)
+            else:
+                result.extra_state = {
+                    'last_func_called': func.__name__,
+                }
+
+        return result
+
+    # new_empty() must be defined for deepcopy to work
+    def new_empty(self, shape):
+        return type(self)(torch.empty(shape))
+
+
+# Class used to store info about subclass tensors used in testing.
+class SubclassInfo:
+
+    __slots__ = ['name', 'create_fn', 'closed_under_ops']
+
+    def __init__(self, name, create_fn, closed_under_ops=True):
+        self.name = name
+        self.create_fn = create_fn  # create_fn(shape) -> tensor instance
+        self.closed_under_ops = closed_under_ops
+
+
+# Helper function to create a subclass of the given class and possibly cache sizes / strides.
+def _create_and_access_shape(cls, shape):
+    sub = cls(torch.randn(shape))
+    # NB: Wrapper subclasses with custom dispatched sizes / strides cache this info
+    # on the first call via non-serializable PyCapsules. We purposefully trigger cache
+    # population here for serialization / deepcopy tests to verify that the presence of this
+    # cache info doesn't cause problems.
+    sub.size()
+    sub.stride()
+    return sub
+
+
+subclass_db = {
+    torch.Tensor: SubclassInfo(
+        'base_tensor', create_fn=torch.randn
+    ),
+    NonWrapperTensor: SubclassInfo(
+        'non_wrapper_tensor',
+        create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
+    ),
+    LoggingTensor: SubclassInfo(
+        'logging_tensor',
+        create_fn=lambda shape: LoggingTensor(torch.randn(shape))
+    ),
+    SparseTensor: SubclassInfo(
+        'sparse_tensor',
+        create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
+    ),
+    DiagTensorBelow: SubclassInfo(
+        'diag_tensor_below',
+        create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
+        closed_under_ops=False  # sparse semantics
+    ),
+    WrapperTensorWithCustomSizes: SubclassInfo(
+        'wrapper_with_custom_sizes',
+        create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomSizes, shape),
+        closed_under_ops=False,
+    ),
+    WrapperTensorWithCustomStrides: SubclassInfo(
+        'wrapper_with_custom_strides',
+        create_fn=lambda shape: _create_and_access_shape(WrapperTensorWithCustomStrides, shape),
+        closed_under_ops=False,
+    ),
+}
+
+class SubclassWithTensorFactory(torch.Tensor):
+    @staticmethod
+    def __new__(cls, src):
+        shape = src.shape
+        kwargs = {}
+        kwargs["strides"] = src.stride()
+        kwargs["storage_offset"] = src.storage_offset()
+        kwargs["device"] = src.device
+        kwargs["layout"] = src.layout
+        kwargs["requires_grad"] = src.requires_grad
+        kwargs["dtype"] = src.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
+        return out
+
+    def __init__(self, src):
+        self.src = src
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}"
+
+    def __tensor_flatten__(self):
+        return ["src"], None
+
+    @classmethod
+    def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
+        src = inner_tensors["src"]
+        return cls(src)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+
+        def _fn(x):
+            return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
+
+        _args = pytree.tree_map_only(cls, _fn, args)
+        _kwargs = pytree.tree_map_only(cls, _fn, kwargs)
+
+        _out = func(*_args, **_kwargs)
+
+        _out_flat, _out_spec = pytree.tree_flatten(_out)
+
+        out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
+        return pytree.tree_unflatten(out_flat, _out_spec)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b7378f88cc8545a540009064725bde63ec50da
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/common_utils.py
@@ -0,0 +1,5712 @@
+# mypy: allow-untyped-defs
+
+r"""Importing this file must **not** initialize CUDA context. test_distributed
+relies on this assumption to properly run. This means that when this is imported
+no CUDA calls shall be made, including torch.cuda.device_count(), etc.
+
+torch.testing._internal.common_cuda.py can freely initialize CUDA context when imported.
+"""
+
+import argparse
+import contextlib
+import copy
+import ctypes
+import errno
+import functools
+import gc
+import hashlib
+import inspect
+import io
+import json
+import logging
+import math
+import operator
+import os
+import pathlib
+import platform
+import random
+import re
+import shutil
+import signal
+import socket
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import types
+import unittest
+import warnings
+from collections.abc import Mapping, Sequence
+from contextlib import closing, contextmanager
+from copy import deepcopy
+from dataclasses import dataclass
+from enum import Enum
+from functools import partial, wraps
+from itertools import product, chain
+from pathlib import Path
+from statistics import mean
+from typing import (
+    Any,
+    Callable,
+    Optional,
+    TypeVar,
+    Union,
+)
+from collections.abc import Iterable, Iterator
+from unittest.mock import MagicMock
+
+import expecttest
+import numpy as np
+
+import __main__  # type: ignore[import]
+import torch
+import torch.backends.cudnn
+import torch.backends.mkl
+import torch.backends.mps
+import torch.backends.xnnpack
+import torch.cuda
+from torch import Tensor
+from torch._C import ScriptDict, ScriptList  # type: ignore[attr-defined]
+from torch._dynamo.trace_rules import _as_posix_path
+from torch._utils_internal import get_writable_path
+from torch._logging.scribe import open_source_signpost
+from torch.nn import (
+    ModuleDict,
+    ModuleList,
+    ParameterDict,
+    ParameterList,
+    Sequential,
+)
+from torch.onnx import (
+    register_custom_op_symbolic,
+    unregister_custom_op_symbolic,
+)
+from torch.testing import make_tensor
+from torch.testing._comparison import (
+    BooleanPair,
+    NonePair,
+    NumberPair,
+    Pair,
+    TensorLikePair,
+)
+from torch.testing._comparison import not_close_error_metas
+from torch.testing._internal.common_dtype import get_all_dtypes
+from torch.utils._import_utils import _check_module_exists
+import torch.utils._pytree as pytree
+from torch.utils import cpp_extension
+try:
+    import pytest
+    has_pytest = True
+except ImportError:
+    has_pytest = False
+
+
+MI300_ARCH = ("gfx942",)
+
+
+def freeze_rng_state(*args, **kwargs):
+    return torch.testing._utils.freeze_rng_state(*args, **kwargs)
+
+
+# Class to keep track of test flags configurable by environment variables.
+# Flags set here are intended to be read-only and should not be modified after
+# definition.
+# TODO: Expand this class to handle arbitrary settings in addition to boolean flags?
+class TestEnvironment:
+    # Set of env vars to set for the repro command that is output on test failure.
+    # Specifically, this includes env vars that are set to non-default values and
+    # are not implied. Maps from env var name -> value (int)
+    repro_env_vars: dict = {}
+
+    # Defines a flag usable throughout the test suite, determining its value by querying
+    # the specified environment variable.
+    #
+    # Args:
+    #     name (str): The name of the flag. A global variable with this name will be set
+    #         for convenient access throughout the test suite.
+    #     env_var (str): The name of the primary environment variable from which to
+    #         determine the value of this flag. If this is None or the environment variable
+    #         is unset, the default value will be used unless otherwise implied (see
+    #         implied_by_fn). Default: None
+    #     default (bool): The default value to use for the flag if unset by the environment
+    #         variable and unimplied. Default: False
+    #     include_in_repro (bool): Indicates whether this flag should be included in the
+    #         repro command that is output on test failure (i.e. whether it is possibly
+    #         relevant to reproducing the test failure). Default: True
+    #     enabled_fn (Callable): Callable returning whether the flag should be enabled
+    #         given the environment variable value and the default value. Default: Lambda
+    #         requiring "0" to disable if on by default OR "1" to enable if off by default.
+    #     implied_by_fn (Callable): Thunk returning a bool to imply this flag as enabled
+    #         by something outside of its primary environment variable setting. For example,
+    #         this can be useful if the value of another environment variable implies the flag
+    #         as enabled. Default: Lambda returning False to indicate no implications.
+    @staticmethod
+    def def_flag(
+        name,
+        env_var=None,
+        default=False,
+        include_in_repro=True,
+        enabled_fn=lambda env_var_val, default: (
+            (env_var_val != "0") if default else (env_var_val == "1")),
+        implied_by_fn=lambda: False,
+    ):
+        enabled = default
+        env_var_val = None
+        if env_var is not None:
+            env_var_val = os.getenv(env_var)
+            enabled = enabled_fn(env_var_val, default)
+        implied = implied_by_fn()
+        enabled = enabled or implied
+        if include_in_repro and (env_var is not None) and (enabled != default) and not implied:
+            TestEnvironment.repro_env_vars[env_var] = env_var_val
+
+        # export flag globally for convenience
+        assert name not in globals(), f"duplicate definition of flag '{name}'"
+        globals()[name] = enabled
+        return enabled
+
+    # Defines a setting usable throughout the test suite, determining its value by querying
+    # the specified environment variable. This differs from a flag in that it's not restricted
+    # to a boolean value.
+    #
+    # Args:
+    #     name (str): The name of the setting. A global variable with this name will be set
+    #         for convenient access throughout the test suite.
+    #     env_var (str): The name of the primary environment variable from which to
+    #         determine the value of this setting. If this is None or the environment variable
+    #         is unset, the default value will be used. Default: None
+    #     default (Any): The default value to use for the setting if unset by the environment
+    #         variable. Default: None
+    #     include_in_repro (bool): Indicates whether this setting should be included in the
+    #         repro command that is output on test failure (i.e. whether it is possibly
+    #         relevant to reproducing the test failure). Default: True
+    #     parse_fn (Callable): Callable parsing the env var string. Default value just uses
+    #         the string itself.
+    @staticmethod
+    def def_setting(
+        name,
+        env_var=None,
+        default=None,
+        include_in_repro=True,
+        parse_fn=lambda maybe_val_str: maybe_val_str,
+    ):
+        value = default if env_var is None else os.getenv(env_var)
+        value = parse_fn(value)
+        if include_in_repro and (value != default):
+            TestEnvironment.repro_env_vars[env_var] = value
+
+        # export setting globally for convenience
+        assert name not in globals(), f"duplicate definition of setting '{name}'"
+        globals()[name] = value
+        return value
+
+    # Returns a string prefix usable to set environment variables for any test
+    # settings that should be explicitly set to match this instantiation of the
+    # test suite.
+    # Example: "PYTORCH_TEST_WITH_ASAN=1 PYTORCH_TEST_WITH_ROCM=1"
+    @staticmethod
+    def repro_env_var_prefix() -> str:
+        return " ".join([f"{env_var}={value}"
+                         for env_var, value in TestEnvironment.repro_env_vars.items()])
+
+
+log = logging.getLogger(__name__)
+torch.backends.disable_global_flags()
+
+FILE_SCHEMA = "file://"
+if sys.platform == 'win32':
+    FILE_SCHEMA = "file:///"
+
+# NB: This flag differs semantically from others in that setting the env var to any
+# non-empty value will cause it to be true:
+#   CI=1, CI="true", CI=0, etc. all set the flag to be true.
+#   CI= and an unset CI set the flag to be false.
+# GitHub sets the value to CI="true" to enable it.
+IS_CI: bool = TestEnvironment.def_flag(
+    "IS_CI",
+    env_var="CI",
+    include_in_repro=False,
+    enabled_fn=lambda env_var_value, _: bool(env_var_value),
+)
+IS_SANDCASTLE: bool = TestEnvironment.def_flag(
+    "IS_SANDCASTLE",
+    env_var="SANDCASTLE",
+    implied_by_fn=lambda: os.getenv("TW_JOB_USER") == "sandcastle",
+    include_in_repro=False,
+)
+IN_RE_WORKER: bool = os.environ.get("INSIDE_RE_WORKER") is not None
+
+_is_fbcode_default = (
+    hasattr(torch._utils_internal, "IS_FBSOURCE") and
+    torch._utils_internal.IS_FBSOURCE
+)
+
+IS_FBCODE: bool = TestEnvironment.def_flag(
+    "IS_FBCODE",
+    env_var="PYTORCH_TEST_FBCODE",
+    default=_is_fbcode_default,
+    include_in_repro=False,
+)
+IS_REMOTE_GPU: bool = TestEnvironment.def_flag(
+    "IS_REMOTE_GPU",
+    env_var="PYTORCH_TEST_REMOTE_GPU",
+    include_in_repro=False,
+)
+
+DISABLE_RUNNING_SCRIPT_CHK: bool = TestEnvironment.def_flag(
+    "DISABLE_RUNNING_SCRIPT_CHK",
+    env_var="PYTORCH_DISABLE_RUNNING_SCRIPT_CHK",
+    include_in_repro=False,
+)
+# NB: enabled by default unless in an fbcode context.
+PRINT_REPRO_ON_FAILURE: bool = TestEnvironment.def_flag(
+    "PRINT_REPRO_ON_FAILURE",
+    env_var="PYTORCH_PRINT_REPRO_ON_FAILURE",
+    default=(not IS_FBCODE),
+    include_in_repro=False,
+)
+
+# possibly restrict OpInfo tests to a single sample input
+OPINFO_SAMPLE_INPUT_INDEX: Optional[int] = TestEnvironment.def_setting(
+    "OPINFO_SAMPLE_INPUT_INDEX",
+    env_var="PYTORCH_OPINFO_SAMPLE_INPUT_INDEX",
+    default=None,
+    # Don't include the env var value in the repro command because the info will
+    # be queried from the tracked sample input instead
+    include_in_repro=False,
+    parse_fn=lambda val: None if val is None else int(val),
+)
+
+DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
+DEFAULT_SLOW_TESTS_FILE = 'slow_tests.json'
+
+disabled_tests_dict = {}
+slow_tests_dict = {}
+
+def maybe_load_json(filename):
+    if os.path.isfile(filename):
+        with open(filename) as fp:
+            return json.load(fp)
+    log.warning("Attempted to load json file '%s' but it does not exist.", filename)
+    return {}
+
+# set them here in case the tests are running in a subprocess that doesn't call run_tests
+if os.getenv("SLOW_TESTS_FILE", ""):
+    slow_tests_dict = maybe_load_json(os.getenv("SLOW_TESTS_FILE", ""))
+if os.getenv("DISABLED_TESTS_FILE", ""):
+    disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", ""))
+
+NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name())
+
+# used for managing devices testing for torch profiler UTs
+# for now cpu, cuda and xpu are added for testing torch profiler UTs
+DEVICE_LIST_SUPPORT_PROFILING_TEST = ('cpu', 'cuda', 'xpu')
+ALLOW_XPU_PROFILING_TEST = True
+
+check_names = ['orin', 'concord', 'galen', 'xavier', 'nano', 'jetson', 'tegra', 'thor']
+IS_JETSON = any(name in platform.platform() for name in check_names)
+
+def gcIfJetson(fn):
+    # Irregular Jetson host/device memory setup requires cleanup to avoid tests being killed
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        if IS_JETSON:
+            gc.collect()
+            torch.cuda.empty_cache()
+        fn(*args, **kwargs)
+    return wrapper
+
+# Tries to extract the current test function by crawling the stack.
+# If unsuccessful, return None.
+def extract_test_fn() -> Optional[Callable]:
+    try:
+        stack = inspect.stack()
+        for frame_info in stack:
+            frame = frame_info.frame
+            if "self" not in frame.f_locals:
+                continue
+            self_val = frame.f_locals["self"]
+            if isinstance(self_val, unittest.TestCase):
+                test_id = self_val.id()
+                test_name = test_id.split('.')[2]
+                test_fn = getattr(self_val, test_name).__func__
+                return test_fn
+    except Exception:
+        pass
+    return None
+
+# Contains tracked input data useful for debugging purposes
+@dataclass
+class TrackedInput:
+    index: int
+    val: Any
+    type_desc: str
+
+# Attempt to pull out tracked input information from the test function.
+# A TrackedInputIter is used to insert this information.
+def get_tracked_input() -> Optional[TrackedInput]:
+    test_fn = extract_test_fn()
+    if test_fn is None:
+        return None
+    return getattr(test_fn, "tracked_input", None)
+
+def clear_tracked_input() -> None:
+    test_fn = extract_test_fn()
+    if test_fn is None:
+        return
+    if not hasattr(test_fn, "tracked_input"):
+        return
+    test_fn.tracked_input = None  # type: ignore[attr-defined]
+
+# Wraps an iterator and tracks the most recent value the iterator produces
+# for debugging purposes. Tracked values are stored on the test function.
+class TrackedInputIter:
+    def __init__(
+        self,
+        child_iter,
+        input_type_desc,
+        item_callback=None,
+        track_callback=None,
+        set_seed=True,
+        restrict_to_index=None
+    ):
+        self.child_iter = enumerate(child_iter)
+        # Input type describes the things we're tracking (e.g. "sample input", "error input").
+        self.input_type_desc = input_type_desc
+        # NB: The two types of callbacks below exist because the thing we want to track isn't
+        # always the same as the thing we want returned from the iterator. An example of this
+        # is ErrorInput, which we want returned from the iterator, but which contains a
+        # SampleInput that we want to track.
+        # Item callback is run on each (iterated thing, index) to get the thing to return.
+        self.item_callback = item_callback
+        if self.item_callback is None:
+            self.item_callback = lambda x, i: x
+        # Track callback is run on each iterated thing to get the thing to track.
+        self.track_callback = track_callback
+        if self.track_callback is None:
+            self.track_callback = lambda x: x
+        self.test_fn = extract_test_fn()
+        # Indicates whether the random seed should be set before each call to the iterator
+        self.set_seed = set_seed
+        # Indicates that iteration should be restricted to only the provided index.
+        # If None, no restriction is done
+        self.restrict_to_index = restrict_to_index
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        while True:
+            if self.set_seed:
+                # use a test-name-specific hash for the seed if possible
+                seed = (
+                    int.from_bytes(hashlib.sha256(
+                        self.test_fn.__qualname__.encode("utf-8")).digest()[:4], 'little')
+                    if self.test_fn is not None else SEED
+                )
+                set_rng_seed(seed)
+
+            # allow StopIteration to bubble up
+            input_idx, input_val = next(self.child_iter)
+            if (self.restrict_to_index is None) or (input_idx == self.restrict_to_index):
+                break
+
+        self._set_tracked_input(
+            TrackedInput(
+                index=input_idx, val=self.track_callback(input_val), type_desc=self.input_type_desc
+            )
+        )
+        return self.item_callback(input_val, input_idx)
+
+    def _set_tracked_input(self, tracked_input: TrackedInput):
+        if self.test_fn is None:
+            return
+        if not hasattr(self.test_fn, "tracked_input"):
+            return
+        self.test_fn.tracked_input = tracked_input  # type: ignore[attr-defined]
+
+class _TestParametrizer:
+    """
+    Decorator class for parametrizing a test function, yielding a set of new tests spawned
+    from the original generic test, each specialized for a specific set of test inputs. For
+    example, parametrizing a test across the set of ops will result in a test function per op.
+
+    The decision of how to parametrize / what to parametrize over is intended to be implemented
+    by each derived class.
+
+    In the details, the decorator adds a 'parametrize_fn' property to the test function. This function
+    is intended to be called later by one of:
+      * Device-specific test instantiation via instantiate_device_type_tests(). Note that for this
+        case there is no need to explicitly parametrize over device type, as that is handled separately.
+      * Device-agnostic parametrized test instantiation via instantiate_parametrized_tests().
+
+    If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new
+    composite 'parametrize_fn' will be created that generates tests with the product of the parameters
+    generated by the old and new parametrize_fns. This allows for convenient composability of decorators.
+    """
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        """
+        Parametrizes the given test function across whatever dimension is specified by the derived class.
+        Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all
+        ops, all modules, or all ops + their associated dtypes.
+
+        Args:
+            test (fn): Test function to parametrize over
+            generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+            device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None
+                if the tests are not part of a device-specific set
+
+        Returns:
+            Generator object returning 4-tuples of:
+                test (fn): Parametrized test function; must support a device arg and args for any params
+                test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to
+                    the base name of the test
+                param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64})
+                decorator_fn (callable): Callable[[Dict], List] for list of decorators to apply given param_kwargs
+        """
+        raise NotImplementedError
+
+    def __call__(self, fn):
+        if hasattr(fn, 'parametrize_fn'):
+            # Do composition with the product of args.
+            old_parametrize_fn = fn.parametrize_fn
+            new_parametrize_fn = self._parametrize_test
+            fn.parametrize_fn = compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn)
+        else:
+            fn.parametrize_fn = self._parametrize_test
+        return fn
+
+
+def compose_parametrize_fns(old_parametrize_fn, new_parametrize_fn):
+    """
+    Returns a parametrize_fn that parametrizes over the product of the parameters handled
+    by the given parametrize_fns. Each given parametrize_fn should each have the signature
+    f(test, generic_cls, device_cls).
+
+    The test names will be a combination of the names produced by the parametrize_fns in
+    "_" order. This order is done to match intuition for constructed names
+    when composing multiple decorators; the names will be built in top to bottom order when stacking
+    parametrization decorators.
+
+    Args:
+        old_parametrize_fn (callable) - First parametrize_fn to compose.
+        new_parametrize_fn (callable) - Second parametrize_fn to compose.
+    """
+
+    def composite_fn(test, generic_cls, device_cls,
+                     old_parametrize_fn=old_parametrize_fn,
+                     new_parametrize_fn=new_parametrize_fn):
+        old_tests = list(old_parametrize_fn(test, generic_cls, device_cls))
+        for (old_test, old_test_name, old_param_kwargs, old_dec_fn) in old_tests:
+            for (new_test, new_test_name, new_param_kwargs, new_dec_fn) in \
+                    new_parametrize_fn(old_test, generic_cls, device_cls):
+                redundant_params = set(old_param_kwargs.keys()).intersection(new_param_kwargs.keys())
+                if redundant_params:
+                    raise RuntimeError('Parametrization over the same parameter by multiple parametrization '
+                                       f'decorators is not supported. For test "{test.__name__}", the following parameters '
+                                       f'are handled multiple times: {redundant_params}')
+                full_param_kwargs = {**old_param_kwargs, **new_param_kwargs}
+                merged_test_name = '{}{}{}'.format(new_test_name,
+                                                   '_' if old_test_name != '' and new_test_name != '' else '',
+                                                   old_test_name)
+
+                def merged_decorator_fn(param_kwargs, old_dec_fn=old_dec_fn, new_dec_fn=new_dec_fn):
+                    return list(old_dec_fn(param_kwargs)) + list(new_dec_fn(param_kwargs))
+
+                yield (new_test, merged_test_name, full_param_kwargs, merged_decorator_fn)
+
+    return composite_fn
+
+
+def instantiate_parametrized_tests(generic_cls):
+    """
+    Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a
+    decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by
+    parametrized tests with specialized names. This should be used instead of
+    instantiate_device_type_tests() if the test class contains device-agnostic tests.
+
+    You can also use it as a class decorator. E.g.
+
+    ```
+    @instantiate_parametrized_tests
+    class TestFoo(TestCase):
+        ...
+    ```
+
+    Args:
+        generic_cls (class): Generic test class object containing tests (e.g. TestFoo)
+    """
+    for attr_name in tuple(dir(generic_cls)):
+        class_attr = getattr(generic_cls, attr_name)
+        if not hasattr(class_attr, 'parametrize_fn'):
+            continue
+
+        # Remove the generic test from the test class.
+        delattr(generic_cls, attr_name)
+
+        # Add parametrized tests to the test class.
+        def instantiate_test_helper(cls, name, test, param_kwargs):
+            @wraps(test)
+            def instantiated_test(self, param_kwargs=param_kwargs):
+                test(self, **param_kwargs)
+
+            assert not hasattr(generic_cls, name), f"Redefinition of test {name}"
+            setattr(generic_cls, name, instantiated_test)
+
+        for (test, test_suffix, param_kwargs, decorator_fn) in class_attr.parametrize_fn(
+                class_attr, generic_cls=generic_cls, device_cls=None):
+            full_name = f'{test.__name__}_{test_suffix}'
+
+            # Apply decorators based on full param kwargs.
+            for decorator in decorator_fn(param_kwargs):
+                test = decorator(test)
+
+            instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs)
+    return generic_cls
+
+
+class subtest:
+    """
+    Explicit subtest case for use with test parametrization.
+    Allows for explicit naming of individual subtest cases as well as applying
+    decorators to the parametrized test.
+
+    Args:
+        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+            tuples of arg values (e.g. [(1, 2), (3, 4)]).
+        name (str): Optional name to use for the test.
+        decorators (iterable): Iterable of decorators to apply to the generated test.
+    """
+    __slots__ = ['arg_values', 'name', 'decorators']
+
+    def __init__(self, arg_values, name=None, decorators=None):
+        self.arg_values = arg_values
+        self.name = name
+        self.decorators = decorators if decorators else []
+
+
+class parametrize(_TestParametrizer):
+    """
+    Decorator for applying generic test parametrizations.
+
+    The interface for this decorator is modeled after `@pytest.mark.parametrize`.
+    Basic usage between this decorator and pytest's is identical. The first argument
+    should be a string containing comma-separated names of parameters for the test, and
+    the second argument should be an iterable returning values or tuples of values for
+    the case of multiple parameters.
+
+    Beyond this basic usage, the decorator provides some additional functionality that
+    pytest does not.
+
+    1. Parametrized tests end up as generated test functions on unittest test classes.
+    Since this differs from how pytest works, this decorator takes on the additional
+    responsibility of naming these test functions. The default test names consists of
+    the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"),
+    but custom names can be defined using `name_fn` or the `subtest` structure (see below).
+
+    2. The decorator specially handles parameter values of type `subtest`, which allows for
+    more fine-grained control over both test naming and test execution. In particular, it can
+    be used to tag subtests with explicit test names or apply arbitrary decorators (see examples
+    below).
+
+    Examples::
+
+        @parametrize("x", range(5))
+        def test_foo(self, x):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
+        def test_bar(self, x, y):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')],
+                     name_fn=lambda x, y: '{}_{}'.format(x, y))
+        def test_bar_custom_names(self, x, y):
+            ...
+
+        @parametrize("x, y", [subtest((1, 2), name='double'),
+                              subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]),
+                              subtest((1, 4), name='quadruple')])
+        def test_baz(self, x, y):
+            ...
+
+    To actually instantiate the parametrized tests, one of instantiate_parametrized_tests() or
+    instantiate_device_type_tests() should be called. The former is intended for test classes
+    that contain device-agnostic tests, while the latter should be used for test classes that
+    contain device-specific tests. Both support arbitrary parametrizations using the decorator.
+
+    Args:
+        arg_str (str): String of arg names separate by commas (e.g. "x,y").
+        arg_values (iterable): Iterable of arg values (e.g. range(10)) or
+            tuples of arg values (e.g. [(1, 2), (3, 4)]).
+        name_fn (Callable): Optional function that takes in parameters and returns subtest name.
+    """
+    def __init__(self, arg_str, arg_values, name_fn=None):
+        self.arg_names: list[str] = [s.strip() for s in arg_str.split(',') if s != '']
+        self.arg_values = arg_values
+        self.name_fn = name_fn
+
+    def _formatted_str_repr(self, idx, name, value):
+        """ Returns a string representation for the given arg that is suitable for use in test function names. """
+        if isinstance(value, torch.dtype):
+            return dtype_name(value)
+        elif isinstance(value, torch.device):
+            return str(value)
+        # Can't use isinstance as it would cause a circular import
+        elif type(value).__name__ in {'OpInfo', 'ModuleInfo'}:
+            return value.formatted_name
+        elif isinstance(value, (int, float, str)):
+            return f"{name}_{str(value).replace('.', '_')}"
+        else:
+            return f"{name}{idx}"
+
+    def _default_subtest_name(self, idx, values):
+        return '_'.join([self._formatted_str_repr(idx, a, v) for a, v in zip(self.arg_names, values)])
+
+    def _get_subtest_name(self, idx, values, explicit_name=None):
+        if explicit_name:
+            subtest_name = explicit_name
+        elif self.name_fn:
+            subtest_name = self.name_fn(*values)
+        else:
+            subtest_name = self._default_subtest_name(idx, values)
+        return subtest_name
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        if len(self.arg_names) == 0:
+            # No additional parameters needed for the test.
+            test_name = ''
+            yield (test, test_name, {}, lambda _: [])
+        else:
+            # Each "values" item is expected to be either:
+            # * A tuple of values with one for each arg. For a single arg, a single item is expected.
+            # * A subtest instance with arg_values matching the previous.
+            values = check_exhausted_iterator = object()
+            for idx, values in enumerate(self.arg_values):
+                maybe_name = None
+
+                decorators: list[Any] = []
+                if isinstance(values, subtest):
+                    sub = values
+                    values = sub.arg_values
+                    maybe_name = sub.name
+
+                    @wraps(test)
+                    def test_wrapper(*args, **kwargs):
+                        return test(*args, **kwargs)
+
+                    decorators = sub.decorators
+                    gen_test = test_wrapper
+                else:
+                    gen_test = test
+
+                values = list(values) if len(self.arg_names) > 1 else [values]  # type: ignore[call-overload]
+                if len(values) != len(self.arg_names):
+                    raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} '
+                                       f'values and {len(self.arg_names)} names for test "{test.__name__}"')
+
+                param_kwargs = dict(zip(self.arg_names, values))
+
+                test_name = self._get_subtest_name(idx, values, explicit_name=maybe_name)
+
+                def decorator_fn(_, decorators=decorators):
+                    return decorators
+
+                yield (gen_test, test_name, param_kwargs, decorator_fn)
+
+            if values is check_exhausted_iterator:
+                raise ValueError(f'{test}: An empty arg_values was passed to @parametrize. '
+                                 'Note that this may result from reuse of a generator.')
+
+
+class reparametrize(_TestParametrizer):
+    """
+    Decorator for adjusting the way an existing parametrizer operates. This class runs
+    the given adapter_fn on each parametrization produced by the given parametrizer,
+    allowing for on-the-fly parametrization more flexible than the default,
+    product-based composition that occurs when stacking parametrization decorators.
+
+    If the adapter_fn returns None for a given test parametrization, that parametrization
+    will be excluded. Otherwise, it's expected that the adapter_fn returns an iterable of
+    modified parametrizations, with tweaked test names and parameter kwargs.
+
+    Examples::
+
+        def include_is_even_arg(test_name, param_kwargs):
+            x = param_kwargs["x"]
+            is_even = x % 2 == 0
+            new_param_kwargs = dict(param_kwargs)
+            new_param_kwargs["is_even"] = is_even
+            is_even_suffix = "_even" if is_even else "_odd"
+            new_test_name = f"{test_name}{is_even_suffix}"
+            yield (new_test_name, new_param_kwargs)
+
+        ...
+
+        @reparametrize(parametrize("x", range(5)), include_is_even_arg)
+        def test_foo(self, x, is_even):
+            ...
+
+        def exclude_odds(test_name, param_kwargs):
+            x = param_kwargs["x"]
+            is_even = x % 2 == 0
+            yield None if not is_even else (test_name, param_kwargs)
+
+        ...
+
+        @reparametrize(parametrize("x", range(5)), exclude_odds)
+        def test_bar(self, x):
+            ...
+
+    """
+    def __init__(self, parametrizer, adapter_fn):
+        self.parametrizer = parametrizer
+        self.adapter_fn = adapter_fn
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        for (gen_test, test_name, param_kwargs, decorator_fn) in \
+                self.parametrizer._parametrize_test(test, generic_cls, device_cls):
+            adapted = self.adapter_fn(test_name, param_kwargs)
+            if adapted is not None:
+                for adapted_item in adapted:
+                    if adapted_item is not None:
+                        new_test_name, new_param_kwargs = adapted_item
+                        yield (gen_test, new_test_name, new_param_kwargs, decorator_fn)
+
+
+class decorateIf(_TestParametrizer):
+    """
+    Decorator for applying parameter-specific conditional decoration.
+    Composes with other test parametrizers (e.g. @modules, @ops, @parametrize, etc.).
+
+    Examples::
+
+        @decorateIf(unittest.skip, lambda params: params["x"] == 2)
+        @parametrize("x", range(5))
+        def test_foo(self, x):
+            ...
+
+        @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')])
+        @decorateIf(
+            unittest.expectedFailure,
+            lambda params: params["x"] == 3 and params["y"] == "baz"
+        )
+        def test_bar(self, x, y):
+            ...
+
+        @decorateIf(
+            unittest.expectedFailure,
+            lambda params: params["op"].name == "add" and params["dtype"] == torch.float16
+        )
+        @ops(op_db)
+        def test_op_foo(self, device, dtype, op):
+            ...
+
+        @decorateIf(
+            unittest.skip,
+            lambda params: params["module_info"].module_cls is torch.nn.Linear and \
+                params["device"] == "cpu"
+        )
+        @modules(module_db)
+        def test_module_foo(self, device, dtype, module_info):
+            ...
+
+    Args:
+        decorator: Test decorator to apply if the predicate is satisfied.
+        predicate_fn (Callable): Function taking in a dict of params and returning a boolean
+            indicating whether the decorator should be applied or not.
+    """
+    def __init__(self, decorator, predicate_fn):
+        self.decorator = decorator
+        self.predicate_fn = predicate_fn
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+
+        # Leave test as-is and return the appropriate decorator_fn.
+        def decorator_fn(params, decorator=self.decorator, predicate_fn=self.predicate_fn):
+            if predicate_fn(params):
+                return [decorator]
+            else:
+                return []
+
+        @wraps(test)
+        def test_wrapper(*args, **kwargs):
+            return test(*args, **kwargs)
+
+        test_name = ''
+        yield (test_wrapper, test_name, {}, decorator_fn)
+
+
+class ProfilingMode(Enum):
+    LEGACY = 1
+    SIMPLE = 2
+    PROFILING = 3
+
+def cppProfilingFlagsToProfilingMode():
+    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+    old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    torch._C._jit_set_profiling_executor(old_prof_exec_state)
+    torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+    if old_prof_exec_state:
+        if old_prof_mode_state:
+            return ProfilingMode.PROFILING
+        else:
+            return ProfilingMode.SIMPLE
+    else:
+        return ProfilingMode.LEGACY
+
+@contextmanager
+def enable_profiling_mode_for_profiling_tests():
+    old_prof_exec_state = False
+    old_prof_mode_state = False
+    if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+        old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+        old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    try:
+        yield
+    finally:
+        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+            torch._C._jit_set_profiling_executor(old_prof_exec_state)
+            torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+@contextmanager
+def enable_profiling_mode():
+    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
+    old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_profiling_executor(old_prof_exec_state)
+        torch._C._get_graph_executor_optimize(old_prof_mode_state)
+
+@contextmanager
+def num_profiled_runs(num_runs):
+    old_num_runs = torch._C._jit_set_num_profiled_runs(num_runs)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_num_profiled_runs(old_num_runs)
+
+func_call = torch._C.ScriptFunction.__call__
+meth_call = torch._C.ScriptMethod.__call__
+
+def prof_callable(callable, *args, **kwargs):
+    if 'profile_and_replay' in kwargs:
+        del kwargs['profile_and_replay']
+        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
+            with enable_profiling_mode_for_profiling_tests():
+                callable(*args, **kwargs)
+                return callable(*args, **kwargs)
+
+    return callable(*args, **kwargs)
+
+def raise_on_run_directly(file_to_call):
+    raise RuntimeError("This test file is not meant to be run directly, "
+                       f"use:\n\n\tpython {file_to_call} TESTNAME\n\n"
+                       "instead.")
+
+def prof_func_call(*args, **kwargs):
+    return prof_callable(func_call, *args, **kwargs)
+
+def prof_meth_call(*args, **kwargs):
+    return prof_callable(meth_call, *args, **kwargs)
+
+torch._C.ScriptFunction.__call__ = prof_func_call  # type: ignore[method-assign]
+torch._C.ScriptMethod.__call__ = prof_meth_call  # type: ignore[method-assign]
+
+def _get_test_report_path():
+    # allow users to override the test file location. We need this
+    # because the distributed tests run the same test file multiple
+    # times with different configurations.
+    override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
+    test_source = override if override is not None else 'python-unittest'
+    return os.path.join('test-reports', test_source)
+
+is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "")
+parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False)
+parser.add_argument('--subprocess', action='store_true',
+                    help='whether to run each test in a subprocess')
+parser.add_argument('--seed', type=int, default=1234)
+parser.add_argument('--accept', action='store_true')
+parser.add_argument('--jit-executor', '--jit_executor', type=str)
+parser.add_argument('--repeat', type=int, default=1)
+parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true')
+parser.add_argument('--use-pytest', action='store_true')
+parser.add_argument('--save-xml', nargs='?', type=str,
+                    const=_get_test_report_path(),
+                    default=_get_test_report_path() if IS_CI else None)
+parser.add_argument('--discover-tests', action='store_true')
+parser.add_argument('--log-suffix', type=str, default="")
+parser.add_argument('--run-parallel', type=int, default=1)
+parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE)
+parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
+parser.add_argument('--rerun-disabled-tests', action='store_true')
+parser.add_argument('--pytest-single-test', type=str, nargs=1)
+parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
+
+# Only run when -h or --help flag is active to display both unittest and parser help messages.
+def run_unittest_help(argv):
+    unittest.main(argv=argv)
+
+if '-h' in sys.argv or '--help' in sys.argv:
+    help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,))
+    help_thread.start()
+    help_thread.join()
+
+args, remaining = parser.parse_known_args()
+if args.jit_executor == 'legacy':
+    GRAPH_EXECUTOR = ProfilingMode.LEGACY
+elif args.jit_executor == 'profiling':
+    GRAPH_EXECUTOR = ProfilingMode.PROFILING
+elif args.jit_executor == 'simple':
+    GRAPH_EXECUTOR = ProfilingMode.SIMPLE
+else:
+    # infer flags based on the default settings
+    GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()
+
+RERUN_DISABLED_TESTS = args.rerun_disabled_tests
+
+SLOW_TESTS_FILE = args.import_slow_tests
+DISABLED_TESTS_FILE = args.import_disabled_tests
+LOG_SUFFIX = args.log_suffix
+RUN_PARALLEL = args.run_parallel
+TEST_BAILOUTS = args.test_bailouts
+USE_PYTEST = args.use_pytest
+PYTEST_SINGLE_TEST = args.pytest_single_test
+TEST_DISCOVER = args.discover_tests
+TEST_IN_SUBPROCESS = args.subprocess
+TEST_SAVE_XML = args.save_xml
+REPEAT_COUNT = args.repeat
+SEED = args.seed
+SHOWLOCALS = args.showlocals
+if not getattr(expecttest, "ACCEPT", False):
+    expecttest.ACCEPT = args.accept
+UNITTEST_ARGS = [sys.argv[0]] + remaining
+torch.manual_seed(SEED)
+
+# CI Prefix path used only on CI environment
+CI_TEST_PREFIX = str(Path(os.getcwd()))
+CI_PT_ROOT = str(Path(os.getcwd()).parent)
+CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
+
+def wait_for_process(p, timeout=None):
+    try:
+        return p.wait(timeout=timeout)
+    except KeyboardInterrupt:
+        # Give `p` a chance to handle KeyboardInterrupt. Without this,
+        # `pytest` can't print errors it collected so far upon KeyboardInterrupt.
+        exit_status = p.wait(timeout=5)
+        if exit_status is not None:
+            return exit_status
+        else:
+            p.kill()
+            raise
+    except subprocess.TimeoutExpired:
+        # send SIGINT to give pytest a chance to make xml
+        p.send_signal(signal.SIGINT)
+        exit_status = None
+        try:
+            exit_status = p.wait(timeout=5)
+        # try to handle the case where p.wait(timeout=5) times out as well as
+        # otherwise the wait() call in the finally block can potentially hang
+        except subprocess.TimeoutExpired:
+            pass
+        if exit_status is not None:
+            return exit_status
+        else:
+            p.kill()
+        raise
+    except:  # noqa: B001,E722, copied from python core library
+        p.kill()
+        raise
+    finally:
+        # Always call p.wait() to ensure exit
+        p.wait()
+
+def shell(command, cwd=None, env=None, stdout=None, stderr=None, timeout=None):
+    sys.stdout.flush()
+    sys.stderr.flush()
+    # The following cool snippet is copied from Py3 core library subprocess.call
+    # only the with
+    #   1. `except KeyboardInterrupt` block added for SIGINT handling.
+    #   2. In Py2, subprocess.Popen doesn't return a context manager, so we do
+    #      `p.wait()` in a `final` block for the code to be portable.
+    #
+    # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
+    assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens"
+    p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
+    return wait_for_process(p, timeout=timeout)
+
+
+def retry_shell(
+    command,
+    cwd=None,
+    env=None,
+    stdout=None,
+    stderr=None,
+    timeout=None,
+    retries=1,
+    was_rerun=False,
+) -> tuple[int, bool]:
+    # Returns exicode + whether it was rerun
+    assert (
+        retries >= 0
+    ), f"Expecting non negative number for number of retries, got {retries}"
+    try:
+        exit_code = shell(
+            command, cwd=cwd, env=env, stdout=stdout, stderr=stderr, timeout=timeout
+        )
+        if exit_code == 0 or retries == 0:
+            return exit_code, was_rerun
+        print(
+            f"Got exit code {exit_code}, retrying (retries left={retries})",
+            file=stdout,
+            flush=True,
+        )
+    except subprocess.TimeoutExpired:
+        if retries == 0:
+            print(
+                f"Command took >{timeout // 60}min, returning 124",
+                file=stdout,
+                flush=True,
+            )
+            return 124, was_rerun
+        print(
+            f"Command took >{timeout // 60}min, retrying (retries left={retries})",
+            file=stdout,
+            flush=True,
+        )
+    return retry_shell(
+        command,
+        cwd=cwd,
+        env=env,
+        stdout=stdout,
+        stderr=stderr,
+        timeout=timeout,
+        retries=retries - 1,
+        was_rerun=True,
+    )
+
+
+def discover_test_cases_recursively(suite_or_case):
+    if isinstance(suite_or_case, unittest.TestCase):
+        return [suite_or_case]
+    rc = []
+    for element in suite_or_case:
+        print(element)
+        rc.extend(discover_test_cases_recursively(element))
+    return rc
+
+def get_test_names(test_cases):
+    return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]
+
+def _print_test_names():
+    suite = unittest.TestLoader().loadTestsFromModule(__main__)
+    test_cases = discover_test_cases_recursively(suite)
+    for name in get_test_names(test_cases):
+        print(name)
+
+def chunk_list(lst, nchunks):
+    return [lst[i::nchunks] for i in range(nchunks)]
+
+# sanitize filename e.g., distributed/pipeline/sync/skip/test_api.py -> distributed.pipeline.sync.skip.test_api
+def sanitize_test_filename(filename):
+    # inspect.getfile returns absolute path in some CI jobs, converting it to relative path if needed
+    if filename.startswith(CI_TEST_PREFIX):
+        filename = filename[len(CI_TEST_PREFIX) + 1:]
+    strip_py = re.sub(r'.py$', '', filename)
+    return re.sub('/', r'.', strip_py)
+
+def lint_test_case_extension(suite):
+    succeed = True
+    for test_case_or_suite in suite:
+        test_case = test_case_or_suite
+        if isinstance(test_case_or_suite, unittest.TestSuite):
+            first_test = test_case_or_suite._tests[0] if len(test_case_or_suite._tests) > 0 else None
+            if first_test is not None and isinstance(first_test, unittest.TestSuite):
+                return succeed and lint_test_case_extension(test_case_or_suite)
+            test_case = first_test
+
+        if test_case is not None:
+            if not isinstance(test_case, TestCase):
+                test_class = test_case.id().split('.', 1)[1].split('.')[0]
+                err = "This test class should extend from torch.testing._internal.common_utils.TestCase but it doesn't."
+                print(f"{test_class} - failed. {err}")
+                succeed = False
+    return succeed
+
+
+def get_report_path(argv=UNITTEST_ARGS, pytest=False):
+    test_filename = sanitize_test_filename(argv[0])
+    test_report_path = TEST_SAVE_XML + LOG_SUFFIX
+    test_report_path = os.path.join(test_report_path, test_filename)
+    if pytest:
+        test_report_path = test_report_path.replace('python-unittest', 'python-pytest')
+        os.makedirs(test_report_path, exist_ok=True)
+        test_report_path = os.path.join(test_report_path, f"{test_filename}-{os.urandom(8).hex()}.xml")
+        return test_report_path
+    os.makedirs(test_report_path, exist_ok=True)
+    return test_report_path
+
+
+def sanitize_pytest_xml(xml_file: str):
+    # pytext xml is different from unittext xml, this function makes pytest xml more similar to unittest xml
+    # consider somehow modifying the XML logger in conftest to do this instead
+    import xml.etree.ElementTree as ET
+    tree = ET.parse(xml_file)
+    for testcase in tree.iter('testcase'):
+        full_classname = testcase.attrib.get("classname")
+        if full_classname is None:
+            continue
+        # The test prefix is optional
+        regex_result = re.search(r"^(test\.)?(?P.*)\.(?P[^\.]*)$", full_classname)
+        if regex_result is None:
+            continue
+        classname = regex_result.group("classname")
+        file = regex_result.group("file").replace(".", "/")
+        testcase.set("classname", classname)
+        testcase.set("file", f"{file}.py")
+    tree.write(xml_file)
+
+
+def get_pytest_test_cases(argv: list[str]) -> list[str]:
+    class TestCollectorPlugin:
+        def __init__(self) -> None:
+            self.tests: list[Any] = []
+
+        def pytest_collection_finish(self, session):
+            for item in session.items:
+                self.tests.append(session.config.cwd_relative_nodeid(item.nodeid))
+
+    test_collector_plugin = TestCollectorPlugin()
+    import pytest
+    pytest.main(
+        [arg for arg in argv if arg != '-vv'] + ['--collect-only', '-qq', '--use-main-module'],
+        plugins=[test_collector_plugin]
+    )
+    return test_collector_plugin.tests
+
+
+def run_tests(argv=UNITTEST_ARGS):
+    # import test files.
+    if SLOW_TESTS_FILE:
+        if os.path.exists(SLOW_TESTS_FILE):
+            with open(SLOW_TESTS_FILE) as fp:
+                global slow_tests_dict
+                slow_tests_dict = json.load(fp)
+                # use env vars so pytest-xdist subprocesses can still access them
+                os.environ['SLOW_TESTS_FILE'] = SLOW_TESTS_FILE
+        else:
+            warnings.warn(f'slow test file provided but not found: {SLOW_TESTS_FILE}')
+    if DISABLED_TESTS_FILE:
+        if os.path.exists(DISABLED_TESTS_FILE):
+            with open(DISABLED_TESTS_FILE) as fp:
+                global disabled_tests_dict
+                disabled_tests_dict = json.load(fp)
+                os.environ['DISABLED_TESTS_FILE'] = DISABLED_TESTS_FILE
+        else:
+            warnings.warn(f'disabled test file provided but not found: {DISABLED_TESTS_FILE}')
+    # Determine the test launch mechanism
+    if TEST_DISCOVER:
+        _print_test_names()
+        return
+
+    # Before running the tests, lint to check that every test class extends from TestCase
+    suite = unittest.TestLoader().loadTestsFromModule(__main__)
+    if not lint_test_case_extension(suite):
+        sys.exit(1)
+
+    if SHOWLOCALS:
+        argv = [
+            argv[0],
+            *(["--showlocals", "--tb=long", "--color=yes"] if USE_PYTEST else ["--locals"]),
+            *argv[1:],
+        ]
+
+    if TEST_IN_SUBPROCESS:
+        other_args = []
+        if DISABLED_TESTS_FILE:
+            other_args.append("--import-disabled-tests")
+        if SLOW_TESTS_FILE:
+            other_args.append("--import-slow-tests")
+        if USE_PYTEST:
+            other_args.append("--use-pytest")
+        if RERUN_DISABLED_TESTS:
+            other_args.append("--rerun-disabled-tests")
+        if TEST_SAVE_XML:
+            other_args += ['--save-xml', TEST_SAVE_XML]
+
+        test_cases = (
+            get_pytest_test_cases(argv) if USE_PYTEST else
+            [case.id().split('.', 1)[1] for case in discover_test_cases_recursively(suite)]
+        )
+
+        failed_tests = []
+
+        for test_case_full_name in test_cases:
+
+            cmd = (
+                [sys.executable] + [argv[0]] + other_args + argv[1:] +
+                (["--pytest-single-test"] if USE_PYTEST else []) +
+                [test_case_full_name]
+            )
+            string_cmd = " ".join(cmd)
+
+            timeout = None if RERUN_DISABLED_TESTS else 15 * 60
+
+            exitcode, _ = retry_shell(cmd, timeout=timeout, retries=0 if RERUN_DISABLED_TESTS else 1)
+
+            if exitcode != 0:
+                # This is sort of hacky, but add on relevant env variables for distributed tests.
+                if 'TestDistBackendWithSpawn' in test_case_full_name:
+                    backend = os.environ.get("BACKEND", "")
+                    world_size = os.environ.get("WORLD_SIZE", "")
+                    env_prefix = f"BACKEND={backend} WORLD_SIZE={world_size}"
+                    string_cmd = env_prefix + " " + string_cmd
+                # Log the command to reproduce the failure.
+                print(f"Test exited with non-zero exitcode {exitcode}. Command to reproduce: {string_cmd}")
+                failed_tests.append(test_case_full_name)
+
+            assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
+                len(failed_tests), '\n\t'.join(failed_tests))
+
+    elif RUN_PARALLEL > 1:
+        test_cases = discover_test_cases_recursively(suite)
+        test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL)
+        processes = []
+        for i in range(RUN_PARALLEL):
+            command = [sys.executable] + argv + [f'--log-suffix=-shard-{i + 1}'] + test_batches[i]
+            processes.append(subprocess.Popen(command, universal_newlines=True))
+        failed = False
+        for p in processes:
+            failed |= wait_for_process(p) != 0
+        assert not failed, "Some test shards have failed"
+    elif USE_PYTEST:
+        pytest_args = argv + ["--use-main-module"]
+        test_report_path = ""
+        if TEST_SAVE_XML:
+            test_report_path = get_report_path(pytest=True)
+            print(f'Test results will be stored in {test_report_path}')
+            pytest_args.append(f'--junit-xml-reruns={test_report_path}')
+        if PYTEST_SINGLE_TEST:
+            pytest_args = PYTEST_SINGLE_TEST + pytest_args[1:]
+
+        import pytest
+        os.environ["NO_COLOR"] = "1"
+        exit_code = pytest.main(args=pytest_args)
+        if TEST_SAVE_XML:
+            sanitize_pytest_xml(test_report_path)
+
+        # exitcode of 5 means no tests were found, which happens since some test configs don't
+        # run tests from certain files
+        sys.exit(0 if exit_code == 5 else exit_code)
+    elif TEST_SAVE_XML:
+        # import here so that non-CI doesn't need xmlrunner installed
+        import xmlrunner  # type: ignore[import]
+        from xmlrunner.result import _XMLTestResult  # type: ignore[import]
+
+        class XMLTestResultVerbose(_XMLTestResult):
+            """
+            Adding verbosity to test outputs:
+            by default test summary prints 'skip',
+            but we want to also print the skip reason.
+            GH issue: https://github.com/pytorch/pytorch/issues/69014
+
+            This works with unittest_xml_reporting<=3.2.0,>=2.0.0
+            (3.2.0 is latest at the moment)
+            """
+            def __init__(self, *args, **kwargs):
+                super().__init__(*args, **kwargs)
+
+            def addSkip(self, test, reason):
+                super().addSkip(test, reason)
+                for c in self.callback.__closure__:
+                    if isinstance(c.cell_contents, str) and c.cell_contents == 'skip':
+                        # this message is printed in test summary;
+                        # it stands for `verbose_str` captured in the closure
+                        c.cell_contents = f"skip: {reason}"
+
+            def printErrors(self) -> None:
+                super().printErrors()
+                self.printErrorList("XPASS", self.unexpectedSuccesses)
+        test_report_path = get_report_path()
+        verbose = '--verbose' in argv or '-v' in argv
+        if verbose:
+            print(f'Test results will be stored in {test_report_path}')
+        unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(
+            output=test_report_path,
+            verbosity=2 if verbose else 1,
+            resultclass=XMLTestResultVerbose))
+    elif REPEAT_COUNT > 1:
+        for _ in range(REPEAT_COUNT):
+            if not unittest.main(exit=False, argv=argv).result.wasSuccessful():
+                sys.exit(-1)
+    else:
+        unittest.main(argv=argv)
+
+IS_LINUX = sys.platform == "linux"
+IS_WINDOWS = sys.platform == "win32"
+IS_MACOS = sys.platform == "darwin"
+IS_PPC = platform.machine() == "ppc64le"
+IS_X86 = platform.machine() in ('x86_64', 'i386')
+IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
+IS_S390X = platform.machine() == "s390x"
+
+def is_avx512_vnni_supported():
+    if sys.platform != 'linux':
+        return False
+    with open("/proc/cpuinfo", encoding="ascii") as f:
+        lines = f.read()
+    return "vnni" in lines
+
+IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported()
+
+if IS_WINDOWS:
+    @contextmanager
+    def TemporaryFileName(*args, **kwargs):
+        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+        # close the file after creation and try to remove it manually
+        if 'delete' in kwargs:
+            if kwargs['delete'] is not False:
+                raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.")
+        else:
+            kwargs['delete'] = False
+        f = tempfile.NamedTemporaryFile(*args, **kwargs)
+        try:
+            f.close()
+            yield f.name
+        finally:
+            os.unlink(f.name)
+else:
+    @contextmanager  # noqa: T484
+    def TemporaryFileName(*args, **kwargs):
+        with tempfile.NamedTemporaryFile(*args, **kwargs) as f:
+            yield f.name
+
+if IS_WINDOWS:
+    @contextmanager
+    def TemporaryDirectoryName(suffix=None):
+        # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely,
+        # so we first create the directory using mkdtemp and then remove it manually
+        try:
+            dir_name = tempfile.mkdtemp(suffix=suffix)
+            yield dir_name
+        finally:
+            shutil.rmtree(dir_name)
+else:
+    @contextmanager  # noqa: T484
+    def TemporaryDirectoryName(suffix=None):
+        with tempfile.TemporaryDirectory(suffix=suffix) as d:
+            yield d
+
+
+def is_privateuse1_backend_available():
+    privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
+    privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None)
+    return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available()
+
+
+IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
+
+TEST_NUMPY = _check_module_exists('numpy')
+TEST_FAIRSEQ = _check_module_exists('fairseq')
+TEST_SCIPY = _check_module_exists('scipy')
+TEST_MKL = torch.backends.mkl.is_available()
+TEST_ACL = torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_acl_supported()
+TEST_MPS = torch.backends.mps.is_available()
+MACOS_VERSION = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
+TEST_XPU = torch.xpu.is_available()
+TEST_HPU = True if (hasattr(torch, "hpu") and torch.hpu.is_available()) else False
+TEST_CUDA = torch.cuda.is_available()
+custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None)
+TEST_PRIVATEUSE1 = is_privateuse1_backend_available()
+TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name()
+TEST_NUMBA = _check_module_exists('numba')
+TEST_TRANSFORMERS = _check_module_exists('transformers')
+TEST_DILL = _check_module_exists('dill')
+
+TEST_LIBROSA = _check_module_exists('librosa') and not IS_ARM64
+
+TEST_OPT_EINSUM = _check_module_exists('opt_einsum')
+
+TEST_Z3 = _check_module_exists('z3')
+
+def split_if_not_empty(x: str):
+    return x.split(",") if len(x) != 0 else []
+
+NOTEST_CPU = "cpu" in split_if_not_empty(os.getenv('PYTORCH_TESTING_DEVICE_EXCEPT_FOR', ''))
+
+skipIfNoDill = unittest.skipIf(not TEST_DILL, "no dill")
+
+
+NO_MULTIPROCESSING_SPAWN: bool = False
+TEST_WITH_ASAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_ASAN",
+    env_var="PYTORCH_TEST_WITH_ASAN",
+)
+TEST_WITH_DEV_DBG_ASAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_DEV_DBG_ASAN",
+    env_var="PYTORCH_TEST_WITH_DEV_DBG_ASAN",
+)
+TEST_WITH_TSAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TSAN",
+    env_var="PYTORCH_TEST_WITH_TSAN",
+)
+TEST_WITH_UBSAN: bool = TestEnvironment.def_flag(
+    "TEST_WITH_UBSAN",
+    env_var="PYTORCH_TEST_WITH_UBSAN",
+)
+TEST_WITH_ROCM: bool = TestEnvironment.def_flag(
+    "TEST_WITH_ROCM",
+    env_var="PYTORCH_TEST_WITH_ROCM",
+)
+
+# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
+# See #64427
+TEST_WITH_MIOPEN_SUGGEST_NHWC = os.getenv('PYTORCH_MIOPEN_SUGGEST_NHWC', '0') == '1'
+# Enables tests that are slow to run (disabled by default)
+TEST_WITH_SLOW: bool = TestEnvironment.def_flag(
+    "TEST_WITH_SLOW",
+    env_var="PYTORCH_TEST_WITH_SLOW",
+)
+
+# Disables non-slow tests (these tests enabled by default)
+# This is usually used in conjunction with TEST_WITH_SLOW to
+# run *only* slow tests.  (I could have done an enum, but
+# it felt a little awkward.
+TEST_SKIP_FAST: bool = TestEnvironment.def_flag(
+    "TEST_SKIP_FAST",
+    env_var="PYTORCH_TEST_SKIP_FAST",
+)
+
+# Enables crossref tests, in addition to standard tests which
+# are being run.  crossref tests work by installing a torch
+# function mode that runs extra compute alongside the regular
+# computation that happens with the test.  After both computations
+# are done, we cross-reference them (thus the name) to check for
+# correction, before throwing out the extra compute and proceeding
+# as we had before.  By default, we don't run these tests.
+TEST_WITH_CROSSREF: bool = TestEnvironment.def_flag(
+    "TEST_WITH_CROSSREF",
+    env_var="PYTORCH_TEST_WITH_CROSSREF",
+)
+
+TEST_SKIP_CUDAGRAPH: bool = TestEnvironment.def_flag(
+    "TEST_SKIP_CUDAGRAPH",
+    env_var="PYTORCH_TEST_SKIP_CUDAGRAPH",
+)
+TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
+    torch.version.cuda or
+    (torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
+)
+
+TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
+
+TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and (
+    torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12
+)
+
+if TEST_CUDA_PYTHON_BINDINGS:
+    def cuda_python_error_check(function_call_output):
+        """Makes calls to cuda-python's cuda runtime functions more
+        pythonic by throwing an exception if they return a status
+        which is not cudaSuccess
+        """
+        import cuda.bindings  # type: ignore[import]
+
+        error, *others = function_call_output
+        if error != cuda.bindings.runtime.cudaError_t.cudaSuccess:
+            raise ValueError(f"CUDA failure! {error}")
+        else:
+            return tuple(others)
+else:
+    cuda_python_error_check = None  # type: ignore[assignment]
+
+def allocator_option_enabled_fn(allocator_config, _, option):
+    if allocator_config is None:
+        return False
+    allocator_config = allocator_config.split(',') if ',' in allocator_config else [allocator_config]
+    mapping = dict([var.split(':') for var in allocator_config])
+
+    if option in mapping and mapping[option] == 'True':
+        return True
+    else:
+        return False
+
+EXPANDABLE_SEGMENTS: bool = TestEnvironment.def_flag(
+    "EXPANDABLE_SEGMENTS",
+    env_var="PYTORCH_CUDA_ALLOC_CONF",
+    enabled_fn=functools.partial(allocator_option_enabled_fn, option='expandable_segments'),
+)
+
+if TEST_CUDA and 'NUM_PARALLEL_PROCS' in os.environ:
+    num_procs = int(os.getenv("NUM_PARALLEL_PROCS", "2"))
+    gb_available = torch.cuda.mem_get_info()[1] / 2 ** 30
+    # other libraries take up about a little under 1 GB of space per process
+    torch.cuda.set_per_process_memory_fraction(round((gb_available - num_procs * .85) / gb_available / num_procs, 2))
+
+requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "Requires CUDA")
+
+def skipIfCrossRef(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_WITH_CROSSREF:
+            raise unittest.SkipTest("test doesn't currently with crossref")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+class CrossRefMode(torch.overrides.TorchFunctionMode):
+    def __torch_function__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+        r = func(*args, **kwargs)
+        return r
+
+# Run PyTorch tests with TorchDynamo
+TEST_WITH_TORCHINDUCTOR: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TORCHINDUCTOR",
+    env_var="PYTORCH_TEST_WITH_INDUCTOR",
+)
+# AOT_EAGER not tested in ci, useful for debugging
+TEST_WITH_AOT_EAGER: bool = TestEnvironment.def_flag(
+    "TEST_WITH_AOT_EAGER",
+    env_var="PYTORCH_TEST_WITH_AOT_EAGER",
+)
+TEST_WITH_TORCHDYNAMO: bool = TestEnvironment.def_flag(
+    "TEST_WITH_TORCHDYNAMO",
+    env_var="PYTORCH_TEST_WITH_DYNAMO",
+    implied_by_fn=lambda: TEST_WITH_TORCHINDUCTOR or TEST_WITH_AOT_EAGER,
+)
+TEST_WITHOUT_COMPILED_AUTOGRAD: bool = TestEnvironment.def_flag(
+    "TEST_WITHOUT_COMPILED_AUTOGRAD",
+    env_var="PYTORCH_TEST_WITHOUT_COMPILED_AUTOGRAD",
+)
+
+if TEST_WITH_TORCHDYNAMO:
+    import torch._dynamo
+    # Do not spend time on helper functions that are called with different inputs
+    torch._dynamo.config.accumulated_recompile_limit = 64
+    # Do not log compilation metrics from unit tests
+    torch._dynamo.config.log_compilation_metrics = False
+    # Silence 3.13.0 guard performance warnings
+    torch._dynamo.config.issue_3_13_0_warning = False
+    if TEST_WITH_TORCHINDUCTOR:
+        import torch._inductor.config
+        torch._inductor.config.fallback_random = True
+    else:
+        # only dynamo for now
+        torch._dynamo.config.compiled_autograd = not TEST_WITHOUT_COMPILED_AUTOGRAD
+
+
+# seems like this is only used in test/torch_np
+def xpassIfTorchDynamo_np(func):
+    # numpy 2.0+ is causing issues
+    if TEST_WITH_TORCHDYNAMO and np.__version__[0] == '2':
+        return unittest.skip("skipping numpy 2.0+ dynamo-wrapped test")(func)
+    return func if TEST_WITH_TORCHDYNAMO else unittest.expectedFailure(func)
+
+
+def xfailIfACL(func):
+    return unittest.expectedFailure(func) if TEST_ACL else func
+
+
+def xfailIfTorchDynamo(func):
+    return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func
+
+
+def xfailIfPy312Plus(func):
+    return unittest.expectedFailure(func) if sys.version_info >= (3, 12) else func
+
+
+def xfailIfLinux(func):
+    return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM and not IS_FBCODE else func
+
+
+def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
+    """
+    Usage:
+    @skipIfTorchDynamo(msg)
+    def test_blah(self):
+        ...
+    """
+    assert isinstance(msg, str), "Are you using skipIfTorchDynamo correctly?"
+
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if TEST_WITH_TORCHDYNAMO:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if TEST_WITH_TORCHDYNAMO:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
+                        condition=TEST_WITH_TORCHINDUCTOR):
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if condition:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if condition:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def runWithoutCompiledAutograd(msg="test doesn't currently work with compiled autograd"):
+    """
+    Usage:
+    @runWithoutCompiledAutograd(msg)
+    def test_blah(self):
+        ...
+    """
+    assert isinstance(msg, str)
+
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            with torch._dynamo.compiled_autograd._disable():
+                func(*args, **kwargs)
+        return wrapper
+
+    return decorator
+
+def serialTest(condition=True):
+    """
+    Decorator for running tests serially.  Requires pytest
+    """
+    def decorator(fn):
+        if has_pytest and condition:
+            return pytest.mark.serial(fn)
+        return fn
+    return decorator
+
+def unMarkDynamoStrictTest(cls=None):
+    def decorator(cls):
+        cls.dynamo_strict = False
+        return cls
+
+    if cls is None:
+        return decorator
+    else:
+        return decorator(cls)
+
+
+def markDynamoStrictTest(cls_or_func=None, nopython=False):
+    """
+    Marks the test as 'strict'. In strict mode, we reset before and after the
+    test, and run without suppress errors.
+
+    Args:
+    - nopython: if we should run torch._dynamo.optimize with nopython={True/False}.
+    """
+    def decorator(cls_or_func):
+        if inspect.isclass(cls_or_func):
+            cls_or_func.dynamo_strict = True
+            cls_or_func.dynamo_strict_nopython = nopython
+            return cls_or_func
+
+        fn = cls_or_func
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            torch._dynamo.reset()
+            with unittest.mock.patch("torch._dynamo.config.suppress_errors", False):
+                fn(*args, **kwargs)
+            torch._dynamo.reset()
+        return wrapper
+
+    if cls_or_func is None:
+        return decorator
+    else:
+        return decorator(cls_or_func)
+
+
+def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
+    return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
+
+def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
+    def decorator(fn):
+        if not isinstance(fn, type):
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+            return wrapper
+
+        assert isinstance(fn, type)
+        if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+
+    return decorator
+
+
+def make_dynamo_test(
+    fn: Optional[Callable[..., Any]] = None
+) -> Callable[..., Any]:
+    """
+    Decorator function to create a dynamo test case. A function annotate with
+    this decorator takes as input a unittest object.
+    """
+    from torch._dynamo.testing import CompileCounter, reset, optimize_assert
+    if fn is None:
+        return lambda fn: make_dynamo_test(fn)
+
+    def standard_test(
+        self: Any,
+        fn: Callable[..., Any],
+        kwargs,
+    ) -> None:
+        def dummy() -> None:
+            fn(self, **kwargs)
+
+        actual = CompileCounter()
+
+        dummy()
+        reset()
+        opt_fn = optimize_assert(actual)(dummy)
+        opt_fn()
+        reset()
+
+    @functools.wraps(fn)
+    def test_fn(self: Any, **kwargs) -> None:
+        return standard_test(
+            self,
+            fn=fn,
+            kwargs=kwargs,
+        )
+
+    return test_fn
+
+
+# Run PyTorch tests with translation validation on.
+TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'
+
+if TEST_WITH_TV:
+    torch.fx.experimental._config.translation_validation = True
+
+# Determine whether to enable cuda memory leak check.
+# CUDA mem leak check is expensive and thus we don't want to execute it on every
+# test case / configuration.
+# If this is True then CUDA memory leak checks are skipped. If this is false
+#   then CUDA memory leak checks are performed.
+# See: https://github.com/pytorch/pytorch/pull/59402#issuecomment-858811135
+TEST_CUDA_MEM_LEAK_CHECK: bool = TestEnvironment.def_flag(
+    "TEST_CUDA_MEM_LEAK_CHECK",
+    env_var="PYTORCH_TEST_CUDA_MEM_LEAK_CHECK",
+)
+
+
+# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
+numpy_to_torch_dtype_dict = {
+    np.bool_      : torch.bool,
+    np.uint8      : torch.uint8,
+    np.uint16     : torch.uint16,
+    np.uint32     : torch.uint32,
+    np.uint64     : torch.uint64,
+    np.int8       : torch.int8,
+    np.int16      : torch.int16,
+    np.int32      : torch.int32,
+    np.int64      : torch.int64,
+    np.float16    : torch.float16,
+    np.float32    : torch.float32,
+    np.float64    : torch.float64,
+    np.complex64  : torch.complex64,
+    np.complex128 : torch.complex128
+}
+
+
+# numpy dtypes like np.float64 are not instances, but rather classes. This leads to rather absurd cases like
+# np.float64 != np.dtype("float64") but np.float64 == np.dtype("float64").type.
+# Especially when checking against a reference we can't be sure which variant we get, so we simply try both.
+def numpy_to_torch_dtype(np_dtype):
+    try:
+        return numpy_to_torch_dtype_dict[np_dtype]
+    except KeyError:
+        return numpy_to_torch_dtype_dict[np_dtype.type]
+
+
+def has_corresponding_torch_dtype(np_dtype):
+    try:
+        numpy_to_torch_dtype(np_dtype)
+        return True
+    except KeyError:
+        return False
+
+
+if IS_WINDOWS:
+    # Size of `np.intc` is platform defined.
+    # It is returned by functions like `bitwise_not`.
+    # On Windows `int` is 32-bit
+    # https://docs.microsoft.com/en-us/cpp/cpp/data-type-ranges?view=msvc-160
+    numpy_to_torch_dtype_dict[np.intc] = torch.int
+
+# Dict of torch dtype -> NumPy dtype
+torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
+torch_to_numpy_dtype_dict.update({
+    torch.bfloat16: np.float32,
+    torch.complex32: np.complex64
+})
+
+def skipIfNNModuleInlined(
+    msg="test doesn't currently work with nn module inlining",
+    condition=torch._dynamo.config.inline_inbuilt_nn_modules,
+):
+    def decorator(fn):
+        if not isinstance(fn, type):
+
+            @wraps(fn)
+            def wrapper(*args, **kwargs):
+                if condition:
+                    raise unittest.SkipTest(msg)
+                else:
+                    fn(*args, **kwargs)
+
+            return wrapper
+
+        assert isinstance(fn, type)
+        if condition:
+            fn.__unittest_skip__ = True  # type: ignore[attr-defined]
+            fn.__unittest_skip_why__ = msg  # type: ignore[attr-defined]
+
+        return fn
+
+    return decorator
+
+def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"):
+    def dec_fn(fn):
+        reason = f"skipIfRocm: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if TEST_WITH_ROCM:
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def skipIfRocmArch(arch: tuple[str, ...]):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM:
+                prop = torch.cuda.get_device_properties(0)
+                if prop.gcnArchName.split(":")[0] in arch:
+                    reason = f"skipIfRocm: test skipped on {arch}"
+                    raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def runOnRocm(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_WITH_ROCM:
+            fn(*args, **kwargs)
+        else:
+            raise unittest.SkipTest("test currently only works on the ROCm stack")
+    return wrapper
+
+def runOnRocmArch(arch: tuple[str, ...]):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM:
+                prop = torch.cuda.get_device_properties(0)
+                if prop.gcnArchName.split(":")[0] not in arch:
+                    reason = f"skipIfRocm: test only runs on {arch}"
+                    raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def xfailIfS390X(func):
+    return unittest.expectedFailure(func) if IS_S390X else func
+
+def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
+    def dec_fn(fn):
+        reason = f"skipIfXpu: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if TEST_XPU:
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+def skipIfMPS(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_MPS:
+            raise unittest.SkipTest("test doesn't currently work with MPS")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+
+def skipIfMPSOnMacOS13(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_MPS and int(MACOS_VERSION) == 13:
+            raise unittest.SkipTest("Test crashes MPSGraph on MacOS13")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+
+def skipIfHpu(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if TEST_HPU:
+            raise unittest.SkipTest("test doesn't currently work with HPU")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+# Skips a test on CUDA if ROCm is available and its version is lower than requested.
+def skipIfRocmVersionLessThan(version=None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrap_fn(self, *args, **kwargs):
+            if TEST_WITH_ROCM:
+                rocm_version = str(torch.version.hip)
+                rocm_version = rocm_version.split("-")[0]    # ignore git sha
+                rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
+                if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
+                    reason = f"ROCm {rocm_version_tuple} is available but {version} required"
+                    raise unittest.SkipTest(reason)
+            return fn(self, *args, **kwargs)
+        return wrap_fn
+    return dec_fn
+
+def skipIfNotMiopenSuggestNHWC(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_WITH_MIOPEN_SUGGEST_NHWC:
+            raise unittest.SkipTest("test doesn't currently work without MIOpen NHWC activation")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfWindows(func=None, *, msg="test doesn't currently work on the Windows stack"):
+    def dec_fn(fn):
+        reason = f"skipIfWindows: {msg}"
+
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if IS_WINDOWS:  # noqa: F821
+                raise unittest.SkipTest(reason)
+            else:
+                return fn(*args, **kwargs)
+        return wrapper
+    if func:
+        return dec_fn(func)
+    return dec_fn
+
+# Reverts the linalg backend back to default to make sure potential failures in one
+# test do not affect other tests
+def setLinalgBackendsToDefaultFinally(fn):
+    @wraps(fn)
+    def _fn(*args, **kwargs):
+        _preferred_backend = torch.backends.cuda.preferred_linalg_library()
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch.backends.cuda.preferred_linalg_library(_preferred_backend)
+    return _fn
+
+
+# Reverts the blas backend back to default to make sure potential failures in one
+# test do not affect other tests
+def setBlasBackendsToDefaultFinally(fn):
+    @wraps(fn)
+    def _fn(*args, **kwargs):
+        _preferred_backend = torch.backends.cuda.preferred_blas_library()
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch.backends.cuda.preferred_blas_library(_preferred_backend)
+    return _fn
+
+
+# Context manager for setting deterministic flag and automatically
+# resetting it to its original value
+class DeterministicGuard:
+    def __init__(self, deterministic, *, warn_only=False, fill_uninitialized_memory=True):
+        self.deterministic = deterministic
+        self.warn_only = warn_only
+        self.fill_uninitialized_memory = fill_uninitialized_memory
+
+    @classmethod
+    def _current_state(cls):
+        return cls(
+            torch.are_deterministic_algorithms_enabled(),
+            warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
+            fill_uninitialized_memory=torch.utils.deterministic.fill_uninitialized_memory,  # type: ignore[attr-defined]
+        )
+
+    def _update(self):
+        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
+        torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory  # type: ignore[attr-defined]
+
+    def __enter__(self):
+        self._restore = self._current_state()
+        self._update()
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        self._restore._update()
+
+class AlwaysWarnTypedStorageRemoval:
+    def __init__(self, always_warn):
+        assert isinstance(always_warn, bool)
+        self.always_warn = always_warn
+
+    def __enter__(self):
+        self.always_warn_restore = torch.storage._get_always_warn_typed_storage_removal()
+        torch.storage._set_always_warn_typed_storage_removal(self.always_warn)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.storage._set_always_warn_typed_storage_removal(self.always_warn_restore)
+
+# Context manager for setting cuda sync debug mode and reset it
+# to original value
+# we are not exposing it to the core because sync debug mode is
+# global and thus not thread safe
+class CudaSyncGuard:
+    def __init__(self, sync_debug_mode):
+        self.mode = sync_debug_mode
+
+    def __enter__(self):
+        self.debug_mode_restore = torch.cuda.get_sync_debug_mode()
+        torch.cuda.set_sync_debug_mode(self.mode)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.cuda.set_sync_debug_mode(self.debug_mode_restore)
+
+# Context manager for setting torch.__future__.set_swap_module_params_on_conversion
+# and automatically resetting it to its original value
+class SwapTensorsGuard:
+    def __init__(self, use_swap_tensors):
+        self.use_swap_tensors = use_swap_tensors
+
+    def __enter__(self):
+        self.swap_tensors_restore = torch.__future__.get_swap_module_params_on_conversion()
+        if self.use_swap_tensors is not None:
+            torch.__future__.set_swap_module_params_on_conversion(self.use_swap_tensors)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.__future__.set_swap_module_params_on_conversion(self.swap_tensors_restore)
+
+# This decorator can be used for API tests that call
+# torch.use_deterministic_algorithms().  When the test is finished, it will
+# restore the previous deterministic flag setting.
+#
+# If CUDA >= 10.2, this will set the environment variable
+# CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that
+# setting is not thrown during the test unless the test changes that variable
+# on purpose. The previous CUBLAS_WORKSPACE_CONFIG setting will also be
+# restored once the test is finished.
+#
+# Note that if a test requires CUDA to actually register the changed
+# CUBLAS_WORKSPACE_CONFIG variable, a new subprocess must be created, because
+# CUDA only checks the variable when the runtime initializes. Tests can be
+# run inside a subprocess like so:
+#
+#   import subprocess, sys, os
+#   script = '''
+#   # Test code should go here
+#   '''
+#   try:
+#       subprocess.check_output(
+#           [sys.executable, '-c', script],
+#           stderr=subprocess.STDOUT,
+#           cwd=os.path.dirname(os.path.realpath(__file__)),
+#           env=os.environ.copy())
+#   except subprocess.CalledProcessError as e:
+#       error_message = e.output.decode('utf-8')
+#       # Handle exceptions raised by the subprocess here
+#
+def wrapDeterministicFlagAPITest(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        with DeterministicGuard(
+                torch.are_deterministic_algorithms_enabled(),
+                warn_only=torch.is_deterministic_algorithms_warn_only_enabled()):
+            class CuBLASConfigGuard:
+                cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG'
+
+                def __enter__(self):
+                    self.cublas_config_restore = os.environ.get(self.cublas_var_name)
+                    os.environ[self.cublas_var_name] = ':4096:8'
+
+                def __exit__(self, exception_type, exception_value, traceback):
+                    cur_cublas_config = os.environ.get(self.cublas_var_name)
+                    if self.cublas_config_restore is None:
+                        if cur_cublas_config is not None:
+                            del os.environ[self.cublas_var_name]
+                    else:
+                        os.environ[self.cublas_var_name] = self.cublas_config_restore
+            with CuBLASConfigGuard():
+                fn(*args, **kwargs)
+    return wrapper
+
+# This decorator can be used for API tests that want to safely call
+# torch.__future__.set_swap_module_params_on_conversion.  `swap` can be set to
+# True, False or None where None indicates that the context manager does not
+# set the flag. When the test is finished, it will restore the previous swap
+# flag setting.
+def wrapSwapTensorsTest(swap=None):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            with SwapTensorsGuard(swap):
+                fn(*args, **kwargs)
+        return wrapper
+    return dec_fn
+
+# test parametrizer for swapping
+class swap(_TestParametrizer):
+    def __init__(self, swap_values):
+        super().__init__()
+        self.swap_values = swap_values
+
+    def _parametrize_test(self, test, generic_cls, device_cls):
+        for swap in self.swap_values:
+            yield wrapSwapTensorsTest(swap)(test), f'swap_{swap}', {}, lambda _: []
+
+def skipIfCompiledWithoutNumpy(fn):
+    # Even if the numpy module is present, if `USE_NUMPY=0` is used during the
+    # build, numpy tests will fail
+    numpy_support = TEST_NUMPY
+    if numpy_support:
+        try:
+            # The numpy module is present, verify that PyTorch is compiled with
+            # numpy support
+            torch.from_numpy(np.array([2, 2]))
+        except RuntimeError:
+            numpy_support = False
+
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not numpy_support:
+            raise unittest.SkipTest("PyTorch was compiled without numpy support")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def _test_function(fn, device):
+    def run_test_function(self):
+        return fn(self, device)
+    return run_test_function
+
+def skipIfNoXNNPACK(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torch.backends.xnnpack.enabled:  # type: ignore[attr-defined]
+            raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfNoLapack(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not torch._C.has_lapack:
+            raise unittest.SkipTest('PyTorch compiled without Lapack')
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skipIfNotRegistered(op_name, message):
+    """Wraps the decorator to hide the import of the `core`.
+
+    Args:
+        op_name: Check if this op is registered in `core._REGISTERED_OPERATORS`.
+        message: message to fail with.
+
+    Usage:
+        @skipIfNotRegistered('MyOp', 'MyOp is not linked!')
+            This will check if 'MyOp' is in the caffe2.python.core
+    """
+    return unittest.skip("Pytorch is compiled without Caffe2")
+
+def skipIfNoSciPy(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_SCIPY:
+            raise unittest.SkipTest("test require SciPy, but SciPy not found")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def skip_if_pytest(fn):
+    @wraps(fn)
+    def wrapped(*args, **kwargs):
+        if "PYTEST_CURRENT_TEST" in os.environ:
+            raise unittest.SkipTest("does not work under pytest")
+        return fn(*args, **kwargs)
+
+    return wrapped
+
+def skipIfNoXPU(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_XPU:
+            raise unittest.SkipTest("test required PyTorched compiled with XPU")
+        else:
+            fn(*args, **kwargs)
+    return wrapper
+
+def slowTest(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if not TEST_WITH_SLOW:
+            raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
+        else:
+            fn(*args, **kwargs)
+    wrapper.__dict__['slow_test'] = True
+    return wrapper
+
+
+def slowTestIf(condition):
+    return slowTest if condition else lambda fn: fn
+
+
+def skipCUDAMemoryLeakCheckIf(condition):
+    def dec(fn):
+        if getattr(fn, '_do_cuda_memory_leak_check', True):  # if current True
+            fn._do_cuda_memory_leak_check = not condition
+        return fn
+    return dec
+
+def skipCUDANonDefaultStreamIf(condition):
+    def dec(fn):
+        if getattr(fn, '_do_cuda_non_default_stream', True):  # if current True
+            fn._do_cuda_non_default_stream = not condition
+        return fn
+    return dec
+
+def suppress_warnings(fn):
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            fn(*args, **kwargs)
+    return wrapper
+
+
+def to_gpu(obj, type_map=None):
+    if type_map is None:
+        type_map = {}
+    if isinstance(obj, torch.Tensor):
+        assert obj.is_leaf
+        t = type_map.get(obj.dtype, obj.dtype)
+        with torch.no_grad():
+            res = obj.to(dtype=t, device="cuda", copy=True)
+            res.requires_grad = obj.requires_grad
+        return res
+    elif torch.is_storage(obj):
+        return obj.new().resize_(obj.size()).copy_(obj)  # type: ignore[attr-defined, union-attr]
+    elif isinstance(obj, list):
+        return [to_gpu(o, type_map) for o in obj]
+    elif isinstance(obj, tuple):
+        return tuple(to_gpu(o, type_map) for o in obj)
+    else:
+        return deepcopy(obj)
+
+
+def get_function_arglist(func):
+    return inspect.getfullargspec(func).args
+
+
+def set_rng_seed(seed):
+    torch.manual_seed(seed)
+    random.seed(seed)
+    if TEST_NUMPY:
+        np.random.seed(seed)
+
+
+@contextlib.contextmanager
+def set_default_dtype(dtype):
+    saved_dtype = torch.get_default_dtype()
+    torch.set_default_dtype(dtype)
+    try:
+        yield
+    finally:
+        torch.set_default_dtype(saved_dtype)
+
+@contextlib.contextmanager
+def set_default_tensor_type(tensor_type):
+    saved_tensor_type = torch.tensor([]).type()
+    torch.set_default_tensor_type(tensor_type)
+    try:
+        yield
+    finally:
+        torch.set_default_tensor_type(saved_tensor_type)
+
+def iter_indices(tensor):
+    if tensor.dim() == 0:
+        return range(0)
+    if tensor.dim() == 1:
+        return range(tensor.size(0))
+    return product(*(range(s) for s in tensor.size()))
+
+
+def is_iterable(obj):
+    try:
+        iter(obj)
+        return True
+    except TypeError:
+        return False
+
+
+def is_iterable_of_tensors(iterable, include_empty=False):
+    """ Returns True if iterable is an iterable of tensors and False o.w.
+
+        If the iterable is empty, the return value is :attr:`include_empty`
+    """
+    # Tensor itself is iterable so we check this first
+    if isinstance(iterable, torch.Tensor):
+        return False
+
+    try:
+        if len(iterable) == 0:
+            return include_empty
+
+        for t in iter(iterable):
+            if not isinstance(t, torch.Tensor):
+                return False
+
+    except TypeError:
+        return False
+
+    return True
+
+
+class CudaNonDefaultStream:
+    def __enter__(self):
+        # Before starting CUDA test save currently active streams on all
+        # CUDA devices and set new non default streams to all CUDA devices
+        # to ensure CUDA tests do not use default stream by mistake.
+        beforeDevice = torch.cuda.current_device()
+        self.beforeStreams = []
+        for d in range(torch.cuda.device_count()):
+            self.beforeStreams.append(torch.cuda.current_stream(d))
+            deviceStream = torch.cuda.Stream(device=d)
+            self.beforeStreams[-1].synchronize()
+            torch._C._cuda_setStream(stream_id=deviceStream.stream_id,
+                                     device_index=deviceStream.device_index,
+                                     device_type=deviceStream.device_type)
+        torch._C._cuda_setDevice(beforeDevice)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        # After completing CUDA test load previously active streams on all
+        # CUDA devices.
+        beforeDevice = torch.cuda.current_device()
+        for d in range(torch.cuda.device_count()):
+            torch._C._cuda_setStream(stream_id=self.beforeStreams[d].stream_id,
+                                     device_index=self.beforeStreams[d].device_index,
+                                     device_type=self.beforeStreams[d].device_type)
+        torch._C._cuda_setDevice(beforeDevice)
+
+class CudaMemoryLeakCheck:
+    def __init__(self, testcase, name=None):
+        self.name = testcase.id() if name is None else name
+        self.testcase = testcase
+
+        # initialize context & RNG to prevent false positive detections
+        # when the test is the first to initialize those
+        from torch.testing._internal.common_cuda import initialize_cuda_context_rng
+        initialize_cuda_context_rng()
+
+    # Stores CUDA memory data provided by PyTorch's caching allocator and
+    #   the CUDA driver.
+    #
+    # NOTE: The undocumented torch.cuda.mem_get_info() returns
+    #   (#free bytes, #total bytes available) on the GPU
+    def __enter__(self):
+        self.caching_allocator_befores = []
+        self.driver_befores = []
+
+        # Performs a gc if required (required if any CUDA memory is held)
+        num_devices = torch.cuda.device_count()
+        for i in range(num_devices):
+            caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+            # NOTE: gc is based exclusively on caching allocator memory
+            #   because the driver will always have some bytes in use (context size?)
+            if caching_allocator_mem_allocated > 0:
+                gc.collect()
+                torch._C._cuda_clearCublasWorkspaces()
+                torch.cuda.empty_cache()
+                break
+
+        # Acquires caching allocator and driver statistics before the test is run
+        for i in range(num_devices):
+            self.caching_allocator_befores.append(torch.cuda.memory_allocated(i))
+            bytes_free, bytes_total = torch.cuda.mem_get_info(i)
+            driver_mem_allocated = bytes_total - bytes_free
+            self.driver_befores.append(driver_mem_allocated)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        # Don't check for leaks if an exception was thrown
+        if exc_type is not None:
+            return
+
+        # Compares caching allocator before/after statistics
+        # An increase in allocated memory is a discrepancy indicating a possible
+        #   memory leak
+        discrepancy_detected = False
+        num_devices = torch.cuda.device_count()
+        for i in range(num_devices):
+            # avoid counting cublasWorkspace allocations
+            torch._C._cuda_clearCublasWorkspaces()
+            caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+
+            if caching_allocator_mem_allocated > self.caching_allocator_befores[i]:
+                discrepancy_detected = True
+                break
+
+        # Short-circuits if no discrepancy detected
+        if not discrepancy_detected:
+            return
+
+        # Validates the discrepancy persists after garbage collection and
+        #   is confirmed by the driver API
+
+        # NOTE: driver API iscrepancies alone are ignored because with the jiterator
+        #   some tests may permanently increase the CUDA context size and
+        #   that will appear as a driver memory leak but is the expected behavior.
+
+        # GCs and clears the cache
+        gc.collect()
+        torch.cuda.empty_cache()
+
+        for i in range(num_devices):
+
+            discrepancy_detected = True
+
+            # Query memory multiple items to ensure leak was not transient
+            for _ in range(3):
+                caching_allocator_mem_allocated = torch.cuda.memory_allocated(i)
+                bytes_free, bytes_total = torch.cuda.mem_get_info(i)
+                driver_mem_allocated = bytes_total - bytes_free
+
+                caching_allocator_discrepancy = False
+                driver_discrepancy = False
+
+                if caching_allocator_mem_allocated > self.caching_allocator_befores[i]:
+                    caching_allocator_discrepancy = True
+
+                if driver_mem_allocated > self.driver_befores[i]:
+                    driver_discrepancy = True
+
+                if not (caching_allocator_discrepancy or driver_discrepancy):
+                    # Leak was false positive, exit loop
+                    discrepancy_detected = False
+                    break
+
+            if not discrepancy_detected:
+                continue
+
+            if caching_allocator_discrepancy and not driver_discrepancy:  # type: ignore[possibly-undefined]
+                # Just raises a warning if the leak is not validated by the
+                #   driver API
+                # NOTE: this may be a problem with how the caching allocator collects its
+                #   statistics or a leak too small to trigger the allocation of an
+                #   additional block of memory by the CUDA driver
+                msg = ("CUDA caching allocator reports a memory leak not "  # type: ignore[possibly-undefined]
+                       f"verified by the driver API in {self.name}! "
+                       f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
+                       f"and is now reported as {caching_allocator_mem_allocated} "
+                       f"on device {i}. "
+                       f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")
+                warnings.warn(msg)
+            elif caching_allocator_discrepancy and driver_discrepancy:
+                # A caching allocator discrepancy validated by the driver API is a
+                #   failure (except on ROCm, see below)
+                msg = (f"CUDA driver API confirmed a leak in {self.name}! "  # type: ignore[possibly-undefined]
+                       f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
+                       f"and is now reported as {caching_allocator_mem_allocated} "
+                       f"on device {i}. "
+                       f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.")
+
+                raise RuntimeError(msg)
+
+@contextmanager
+def skip_exception_type(exc_type):
+    try:
+        yield
+    except exc_type as e:
+        raise unittest.SkipTest(f"not implemented: {e}") from e
+
+@contextmanager
+def print_repro_on_failure(repro_parts):
+    try:
+        yield
+    except unittest.SkipTest:
+        raise
+    except Exception as e:
+        # Get the index of the sample input that failed the test if possible.
+        sample_isolation_prefix = ""
+        tracked_input = getattr(e, "_tracked_input", None)
+        if tracked_input is not None:
+            sample_isolation_prefix = f"PYTORCH_OPINFO_SAMPLE_INPUT_INDEX={tracked_input.index}"
+
+        repro_str = " ".join(filter(None, (sample_isolation_prefix, *repro_parts)))
+
+        open_source_signpost(
+            subsystem="test_repros",
+            name="test_failure",
+            parameters=json.dumps(
+                {
+                    "repro": " ".join(filter(None, (sample_isolation_prefix, *repro_parts))),
+                }
+            ),
+        )
+
+        repro_msg = f"""
+To execute this test, run the following from the base repo dir:
+    {repro_str}
+
+This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
+
+        # NB: Hacking the exception args is the cleanest way I've found to append
+        # failure reproduction info without poisoning the stack trace.
+        if len(e.args) >= 1:
+            e.args = (f"{e.args[0]}\n{repro_msg}", *e.args[1:])
+        raise
+
+#  "min_satisfying_examples" setting has been deprecated in hypothesis
+#  3.56.0 and removed in hypothesis 4.x
+try:
+    import hypothesis
+
+    def settings(*args, **kwargs):
+        if 'min_satisfying_examples' in kwargs and hypothesis.version.__version_info__ >= (3, 56, 0):
+            kwargs.pop('min_satisfying_examples')
+        return hypothesis.settings(*args, **kwargs)
+
+
+    hypothesis.settings.register_profile(
+        "pytorch_ci",
+        settings(
+            derandomize=True,
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=50,
+            verbosity=hypothesis.Verbosity.normal))
+    hypothesis.settings.register_profile(
+        "dev",
+        settings(
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=10,
+            verbosity=hypothesis.Verbosity.normal))
+    hypothesis.settings.register_profile(
+        "debug",
+        settings(
+            suppress_health_check=[hypothesis.HealthCheck.too_slow],
+            database=None,
+            max_examples=1000,
+            verbosity=hypothesis.Verbosity.verbose))
+
+    hypothesis.settings.load_profile(
+        "pytorch_ci" if IS_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', 'dev')
+    )
+except ImportError:
+    warnings.warn('Fail to import hypothesis in common_utils, tests are not derandomized', ImportWarning)
+
+# Used in check_if_enable to see if a test method should be disabled by an issue,
+# sanitizes a test method name from appended suffixes by @dtypes parametrization.
+# e.g., an issue with title "DISABLED test_bitwise_ops (__main__.TestBinaryUfuncs)" should
+# disabled ALL parametrized test_bitwise_ops tests, such test_bitwise_ops_cuda_int32
+def remove_device_and_dtype_suffixes(test_name: str) -> str:
+    # import statement is localized to avoid circular dependency issues with common_device_type.py
+    from torch.testing._internal.common_device_type import get_device_type_test_bases
+    device_suffixes = [x.device_type for x in get_device_type_test_bases()]
+    dtype_suffixes = [str(dt)[len("torch."):] for dt in get_all_dtypes()]
+
+    test_name_chunks = test_name.split("_")
+    if len(test_name_chunks) > 0 and test_name_chunks[-1] in dtype_suffixes:
+        if len(test_name_chunks) > 1 and test_name_chunks[-2] in device_suffixes:
+            return "_".join(test_name_chunks[0:-2])
+        return "_".join(test_name_chunks[0:-1])
+    return test_name
+
+
+def check_if_enable(test: unittest.TestCase):
+    classname = str(test.__class__).split("'")[1].split(".")[-1]
+    sanitized_testname = remove_device_and_dtype_suffixes(test._testMethodName)
+
+    def matches_test(target: str):
+        target_test_parts = target.split()
+        if len(target_test_parts) < 2:
+            # poorly formed target test name
+            return False
+        target_testname = target_test_parts[0]
+        target_classname = target_test_parts[1][1:-1].split(".")[-1]
+        # if test method name or its sanitized version exactly matches the disabled
+        # test method name AND allow non-parametrized suite names to disable
+        # parametrized ones (TestSuite disables TestSuiteCPU)
+        return classname.startswith(target_classname) and (target_testname in (test._testMethodName, sanitized_testname))
+
+    if any(matches_test(x) for x in slow_tests_dict.keys()):
+        getattr(test, test._testMethodName).__dict__['slow_test'] = True
+        if not TEST_WITH_SLOW:
+            raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test")
+
+    if not IS_SANDCASTLE:
+        should_skip = False
+        skip_msg = ""
+
+        for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
+            if matches_test(disabled_test):
+                platform_to_conditional: dict = {
+                    "mac": IS_MACOS,
+                    "macos": IS_MACOS,
+                    "win": IS_WINDOWS,
+                    "windows": IS_WINDOWS,
+                    "linux": IS_LINUX,
+                    "rocm": TEST_WITH_ROCM,
+                    "xpu": TEST_XPU,
+                    "asan": TEST_WITH_ASAN,
+                    "dynamo": TEST_WITH_TORCHDYNAMO,
+                    "dynamo_wrapped": TEST_WITH_TORCHDYNAMO,
+                    "inductor": TEST_WITH_TORCHINDUCTOR,
+                    "slow": TEST_WITH_SLOW,
+                }
+
+                invalid_platforms = list(filter(lambda p: p not in platform_to_conditional, platforms))
+                if len(invalid_platforms) > 0:
+                    invalid_plats_str = ", ".join(invalid_platforms)
+                    valid_plats = ", ".join(platform_to_conditional.keys())
+
+                    print(f"Test {disabled_test} is disabled for some unrecognized ",
+                          f"platforms: [{invalid_plats_str}]. Please edit issue {issue_url} to fix the platforms ",
+                          'assigned to this flaky test, changing "Platforms: ..." to a comma separated ',
+                          f"subset of the following (or leave it blank to match all platforms): {valid_plats}")
+
+                    # Sanitize the platforms list so that we continue to disable the test for any valid platforms given
+                    platforms = list(filter(lambda p: p in platform_to_conditional, platforms))
+
+                if platforms == [] or any(platform_to_conditional[platform] for platform in platforms):
+                    should_skip = True
+                    skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \
+                        f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \
+                        "If you're seeing this on your local machine and would like to enable this test, " \
+                        "please make sure CI is not set and you are not using the flag --import-disabled-tests."
+                    break
+
+        if should_skip and not RERUN_DISABLED_TESTS:
+            # Skip the disabled test when not running under --rerun-disabled-tests verification mode
+            raise unittest.SkipTest(skip_msg)
+
+        if not should_skip and RERUN_DISABLED_TESTS:
+            # Probably test has disable issue but not for this platform
+            skip_msg = "Test is enabled but --rerun-disabled-tests verification mode is set, so only" \
+                " disabled tests are run"
+            raise unittest.SkipTest(skip_msg)
+
+    if TEST_SKIP_FAST:
+        if hasattr(test, test._testMethodName) and not getattr(test, test._testMethodName).__dict__.get('slow_test', False):
+            raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST")
+
+
+# `TestCase.assertEqual` is very permissive and coerced the inputs into a format that could be compared. This is very
+# convenient when writing tests, but not so much while reviewing them. By default, the comparison `Pair` framework of
+# `torch.testing._comparison.are_equal`, used for example by the public testing function
+# `torch.testing.assert_close`, is more strict. In order to use the same framework and thus reduce the divergence
+# between internal and external comparison logic as much as possible, we define some "relaxed" pairs here. They only
+# change the supported inputs, but the comparison logic is the same.
+# TODO: Revisit the relaxed pairs and check how much work it is to fix the tests that would fail without the relaxation.
+
+class RelaxedBooleanPair(BooleanPair):
+    """Pair for boolean-like inputs.
+
+    In contrast to the builtin :class:`BooleanPair`, this class also supports one input being a number or a single
+    element tensor-like.
+    """
+    _supported_number_types = NumberPair(0, 0)._supported_types
+
+    def _process_inputs(self, actual, expected, *, id):
+        # We require only one of the inputs of the inputs to be a boolean and the other can also be a boolean, a
+        # number, or a single element tensor or array, whereas in default BooleanPair both inputs have to be booleans.
+        tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
+        other_supported_types = (*self._supported_types, *self._supported_number_types, *tensor_or_array_types)
+        if not (
+            (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
+            or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types))
+        ):
+            self._inputs_not_supported()
+
+        return [self._to_bool(input, id=id) for input in (actual, expected)]
+
+    def _to_bool(self, bool_like, *, id):
+        if isinstance(bool_like, np.number):
+            return bool(bool_like.item())
+        elif type(bool_like) in self._supported_number_types:
+            return bool(bool_like)
+        elif isinstance(bool_like, (torch.Tensor, np.ndarray)):
+            numel = bool_like.numel() if isinstance(bool_like, torch.Tensor) else bool_like.size
+            if numel > 1:
+                self._fail(
+                    ValueError,
+                    f"Only single element tensor-likes can be compared against a boolean. "
+                    f"Got {numel} elements instead.",
+                    id=id
+                )
+
+            return bool(bool_like.item())
+        else:
+            return super()._to_bool(bool_like, id=id)
+
+
+class RelaxedNumberPair(NumberPair):
+    """Pair for number-like inputs.
+
+    In contrast to the builtin :class:`NumberPair`, this class also supports one input being a single element
+    tensor-like or a :class:`enum.Enum`. (D)Type checks are disabled, meaning comparing 1 to 1.0 succeeds even when
+    ``check_dtype=True`` is passed.
+
+    In addition, this class uses looser default tolerances for :class:`float` and :class:`complex` inputs. Also
+    supports overriding the absolute and relative tolerance through the ``@precisionOverride`` and
+    ``@toleranceOverride`` decorators.
+    """
+    _TYPE_TO_DTYPE = {
+        int: torch.int64,
+        float: torch.float32,
+        complex: torch.complex64,
+    }
+
+    def __init__(
+            self, actual, expected, *, rtol_override=0.0, atol_override=0.0, check_dtype=None, **other_parameters
+    ) -> None:
+        super().__init__(actual, expected, check_dtype=False, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _process_inputs(self, actual, expected, *, id):
+        # We require only one of the inputs of the inputs to be a number and the other can also be a number or a single
+        # element tensor or array, whereas in default NumberPair both inputs have to be numbers.
+        tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
+        other_supported_types = (*self._supported_types, *tensor_or_array_types)
+        if not (
+                (isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
+                or (isinstance(expected, self._supported_types) and isinstance(actual, other_supported_types))
+        ):
+            self._inputs_not_supported()
+
+        return [self._to_number(input, id=id) for input in (actual, expected)]
+
+    def _to_number(self, number_like, *, id):
+        if isinstance(number_like, (torch.Tensor, np.ndarray)):
+            numel = number_like.numel() if isinstance(number_like, torch.Tensor) else number_like.size
+            if numel > 1:
+                self._fail(
+                    ValueError,
+                    f"Only single element tensor-likes can be compared against a number. "
+                    f"Got {numel} elements instead.",
+                    id=id
+                )
+            number = number_like.item()
+            if isinstance(number, bool):
+                number = int(number)
+
+            return number
+        elif isinstance(number_like, Enum):
+            return int(number_like)  # type: ignore[call-overload]
+        else:
+            number = super()._to_number(number_like, id=id)
+            if type(number) not in self._TYPE_TO_DTYPE.keys():
+                self._inputs_not_supported()
+            return number
+
+
+class TensorOrArrayPair(TensorLikePair):
+    """Pair for tensor-like inputs.
+
+    On the one hand this class is stricter than the builtin :class:`TensorLikePair` since it only allows instances of
+    :class:`torch.Tensor` and :class:`numpy.ndarray` rather than allowing any tensor-like than can be converted into a
+    tensor. On the other hand this class is looser since it converts all inputs into tensors with no regard of their
+    relationship, e.g. comparing a :class:`torch.Tensor` to :class:`numpy.ndarray` is fine.
+
+    In addition, this class supports overriding the absolute and relative tolerance through the ``@precisionOverride``
+    and ``@toleranceOverride`` decorators.
+    """
+    def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
+        super().__init__(actual, expected, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _process_inputs(self, actual, expected, *, id, allow_subclasses):
+        self._check_inputs_isinstance(actual, expected, cls=(torch.Tensor, np.ndarray))
+
+        actual, expected = (self._to_tensor(input) for input in (actual, expected))
+        for tensor in (actual, expected):
+            self._check_supported(tensor, id=id)
+        return actual, expected
+
+
+class TypedStoragePair(TensorLikePair):
+    """Pair for :class:`torch.storage.TypedStorage` inputs."""
+    def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, **other_parameters):
+        self._check_inputs_isinstance(actual, expected, cls=torch.storage.TypedStorage)
+        super().__init__(actual, expected, **other_parameters)
+        self.rtol = max(self.rtol, rtol_override)
+        self.atol = max(self.atol, atol_override)
+
+    def _to_tensor(self, typed_storage):
+        return torch.tensor(
+            typed_storage._untyped_storage,
+            dtype={
+                torch.quint8: torch.uint8,
+                torch.quint4x2: torch.uint8,
+                torch.quint2x4: torch.uint8,
+                torch.qint32: torch.int32,
+                torch.qint8: torch.int8
+            }.get(typed_storage.dtype, typed_storage.dtype),
+            device=typed_storage.device,
+        )
+
+
+class UnittestPair(Pair):
+    """Fallback ABC pair that handles non-numeric inputs.
+
+    To avoid recreating the mismatch messages of :meth:`unittest.TestCase.assertEqual`, this pair simply wraps it in
+    order to use it with the :class:`Pair` "framework" from :func:`are_equal`.
+
+    Define the :attr:`UnittestPair.CLS` in a subclass to indicate which class(es) of the inputs the pair should support.
+    """
+    CLS: Union[type, tuple[type, ...]]
+    TYPE_NAME: Optional[str] = None
+
+    def __init__(self, actual, expected, **other_parameters):
+        self._check_inputs_isinstance(actual, expected, cls=self.CLS)
+        super().__init__(actual, expected, **other_parameters)
+
+    def compare(self):
+        test_case = unittest.TestCase()
+
+        try:
+            return test_case.assertEqual(self.actual, self.expected)
+        except test_case.failureException as error:
+            msg = str(error)
+
+        type_name = self.TYPE_NAME or (self.CLS if isinstance(self.CLS, type) else self.CLS[0]).__name__
+        self._fail(AssertionError, f"{type_name.title()} comparison failed: {msg}")
+
+
+class StringPair(UnittestPair):
+    CLS = (str, bytes)
+    TYPE_NAME = "string"
+
+
+class SetPair(UnittestPair):
+    CLS = set
+
+
+class TypePair(UnittestPair):
+    CLS = type
+
+
+class ObjectPair(UnittestPair):
+    CLS = object
+
+
+# This implements a variant of assertRaises/assertRaisesRegex where we first test
+# if the exception is NotImplementedError, and if so just skip the test instead
+# of failing it.
+#
+# This is implemented by inheriting from the (private) implementation of
+# assertRaises from unittest.case, and slightly tweaking it for this new
+# behavior.  The year is 2021: this private class hierarchy hasn't changed since
+# 2010, seems low risk to inherit from.
+class AssertRaisesContextIgnoreNotImplementedError(unittest.case._AssertRaisesContext):
+    def __exit__(self, exc_type, exc_value, tb):
+        if exc_type is not None and issubclass(exc_type, NotImplementedError):
+            self.test_case.skipTest(f"not_implemented: {exc_value}")  # type: ignore[attr-defined]
+        return super().__exit__(exc_type, exc_value, tb)
+
+
+@contextmanager
+def set_warn_always_context(new_val: bool):
+    old_val = torch.is_warn_always_enabled()
+    torch.set_warn_always(new_val)
+    try:
+        yield
+    finally:
+        torch.set_warn_always(old_val)
+
+
+class NoTest:
+    # causes pytest to not recognize this class as a test
+    __test__ = False
+
+
+class TestCase(expecttest.TestCase):
+    # NOTE: "precision" lets classes and generated tests set minimum
+    # atol values when comparing tensors. Used by @precisionOverride and @toleranceOverride, for
+    # example.
+    # NOTE: "rel_tol" lets classes and generated tests set minimum
+    # rtol values when comparing tensors. Used by @toleranceOverride, for example.
+    _precision: float = 0
+    _rel_tol: float = 0
+
+    # Toggles whether to assert that `torch.get_default_dtype()` returns
+    # `torch.float` when `setUp` and `tearDown` are called.
+    _default_dtype_check_enabled: bool = False
+
+    # Always use difflib to print diffs on multi line equality.
+    # Undocumented feature in unittest
+    _diffThreshold = sys.maxsize
+    maxDiff = None
+
+    # checker to early terminate test suite if unrecoverable failure occurs.
+    def _should_stop_test_suite(self):
+        if torch.cuda.is_initialized():
+            # CUDA device side error will cause subsequence test cases to fail.
+            # stop entire test suite if catches RuntimeError during torch.cuda.synchronize().
+            try:
+                torch.cuda.synchronize()
+            except RuntimeError as rte:
+                print("TEST SUITE EARLY TERMINATION due to torch.cuda.synchronize() failure", file=sys.stderr)
+                print(str(rte), file=sys.stderr)
+                return True
+            return False
+        else:
+            return False
+
+    @property
+    def precision(self) -> float:
+        return self._precision
+
+    @precision.setter
+    def precision(self, prec: float) -> None:
+        self._precision = prec
+
+    @property
+    def rel_tol(self) -> float:
+        return self._rel_tol
+
+    @rel_tol.setter
+    def rel_tol(self, prec: float) -> None:
+        self._rel_tol = prec
+
+    _do_cuda_memory_leak_check = False
+    _do_cuda_non_default_stream = False
+
+    # When True, if a test case raises a NotImplementedError, instead of failing
+    # the test, skip it instead.
+    _ignore_not_implemented_error = False
+
+    def __init__(self, method_name='runTest', methodName='runTest'):
+        # methodName is the correct naming in unittest and testslide uses keyword arguments.
+        # So we need to use both to 1) not break BC and, 2) support testslide.
+        if methodName != "runTest":
+            method_name = methodName
+        super().__init__(method_name)
+
+        test_method = getattr(self, method_name, None)
+        if test_method is not None:
+            # Wraps the tested method if we should do CUDA memory check.
+            if TEST_CUDA_MEM_LEAK_CHECK:
+                self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True)
+                # FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044
+                if self._do_cuda_memory_leak_check and not IS_WINDOWS:
+                    self.wrap_with_cuda_policy(method_name, self.assertLeaksNoCudaTensors)
+
+            # Wraps the tested method if we should enforce non default CUDA stream.
+            self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True)
+            if self._do_cuda_non_default_stream and not IS_WINDOWS:
+                self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream)
+
+            if self._ignore_not_implemented_error:
+                self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError))
+
+            if PRINT_REPRO_ON_FAILURE:
+                try:
+                    def _get_rel_test_path(abs_test_path):
+                        # Attempt to get relative path based on the "test" dir.
+                        # In CI, the working dir is not guaranteed to be the base repo dir so
+                        # we can't just compute relative path from that.
+                        parts = Path(abs_test_path).parts
+                        for i, part in enumerate(parts):
+                            if part == "test":
+                                base_dir = os.path.join(*parts[:i]) if i > 0 else ''
+                                return os.path.relpath(abs_test_path, start=base_dir)
+
+                        # Can't determine containing dir; just return the test filename.
+                        # The path isn't strictly correct but it's arguably better than nothing.
+                        return os.path.split(abs_test_path)[1]
+
+                    abs_test_path = inspect.getfile(type(self))
+                    test_filename = _get_rel_test_path(abs_test_path)
+                    class_name = type(self).__name__
+                    test_run_cmd = f"python {test_filename} {class_name}.{method_name}"
+                    env_var_prefix = TestEnvironment.repro_env_var_prefix()
+                    repro_parts = [env_var_prefix, test_run_cmd]
+                    self.wrap_with_policy(
+                        method_name,
+                        lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts))
+                except Exception as e:
+                    # Don't fail entirely if we can't get the test filename
+                    log.info("could not print repro string", extra=str(e))  # type: ignore[arg-type]
+
+    def assertLeaksNoCudaTensors(self, name=None):
+        name = self.id() if name is None else name
+        return CudaMemoryLeakCheck(self, name)
+
+    def enforceNonDefaultStream(self):
+        return CudaNonDefaultStream()
+
+    def _remove_ansi_escape(self, input):
+        # 7-bit C1 ANSI sequences
+        ansi_escape = re.compile(r'''
+            \x1B  # ESC
+            (?:   # 7-bit C1 Fe (except CSI)
+                [@-Z\\-_]
+            |     # or [ for CSI, followed by a control sequence
+                \[
+                [0-?]*  # Parameter bytes
+                [ -/]*  # Intermediate bytes
+                [@-~]   # Final byte
+            )
+        ''', re.VERBOSE)
+        return ansi_escape.sub('', input)
+
+    def remove_comment_lines(self, input_string):
+        lines = input_string.split('\n')
+        filtered_lines = [line for line in lines if not line.strip().startswith('#')]
+        return '\n'.join(filtered_lines)
+
+    def remove_empty_lines(self, input_string):
+        lines = input_string.split('\n')
+        filtered_lines = [line for line in lines if not line.strip() == '']
+        return '\n'.join(filtered_lines)
+
+    # ignore comments will ignore lines that starts with # after being stripped
+    def assertExpectedInline(self, actual, expect, skip=0, ignore_comments=False, ignore_empty_lines=False):
+        actual = actual if isinstance(actual, str) else str(actual)
+        actual = self._remove_ansi_escape(actual)
+        expect = self._remove_ansi_escape(expect)
+        if ignore_comments:
+            actual = self.remove_comment_lines(actual)
+            expect = self.remove_comment_lines(expect)
+
+        if ignore_empty_lines:
+            actual = self.remove_empty_lines(actual)
+            expect = self.remove_empty_lines(expect)
+
+        return super().assertExpectedInline(actual if isinstance(actual, str) else str(actual), expect, skip + 1)
+
+    # Munges exceptions that internally contain stack traces, using munge_exc
+    def assertExpectedInlineMunged(
+        self, exc_type, callable, expect, *, skip=0, suppress_suffix=True, post_munge=None,
+    ):
+        try:
+            callable()
+        except exc_type as e:
+            munged = munge_exc(e, suppress_suffix=suppress_suffix, skip=skip + 1)
+            if post_munge:
+                munged = post_munge(munged)
+            self.assertExpectedInline(
+                munged, expect, skip=skip + 1
+            )
+            return
+        self.fail(msg="Did not raise when expected to")
+
+    def assertLogs(self, logger=None, level=None):
+        if logger is None:
+            logger = logging.getLogger("torch")
+        return super().assertLogs(logger, level)
+
+    def assertNoLogs(self, logger=None, level=None):
+        if logger is None:
+            logger = logging.getLogger("torch")
+        return super().assertNoLogs(logger, level)
+
+    def wrap_with_cuda_policy(self, method_name, policy):
+        test_method = getattr(self, method_name)
+        # the import below may initialize CUDA context, so we do it only if
+        # self._do_cuda_memory_leak_check or self._do_cuda_non_default_stream
+        # is True.
+        # TODO: sure looks like we unconditionally initialize the context here
+        # -- ezyang
+        from torch.testing._internal.common_cuda import TEST_CUDA
+        fullname = self.id().lower()  # class_name.method_name
+        if TEST_CUDA and ('gpu' in fullname or 'cuda' in fullname):
+            setattr(self, method_name, self.wrap_method_with_policy(test_method, policy))
+
+    def wrap_with_policy(self, method_name, policy):
+        test_method = getattr(self, method_name)
+        setattr(self, method_name, self.wrap_method_with_policy(test_method, policy))
+
+    # A policy is a zero-argument function that returns a context manager.
+    # We don't take the context manager directly as it may be necessary to
+    # construct it once per test method
+    def wrap_method_with_policy(self, method, policy):
+        # Assumes that `method` is the tested function in `self`.
+        # NOTE: Python Exceptions (e.g., unittest.Skip) keeps objects in scope
+        #       alive, so this cannot be done in setUp and tearDown because
+        #       tearDown is run unconditionally no matter whether the test
+        #       passes or not. For the same reason, we can't wrap the `method`
+        #       call in try-finally and always do the check.
+        @wraps(method)
+        def wrapper(self, *args, **kwargs):
+            with policy():
+                method(*args, **kwargs)
+        return types.MethodType(wrapper, self)
+
+    def wrap_with_cuda_memory_check(self, method):
+        return self.wrap_method_with_policy(method, self.assertLeaksNoCudaTensors)
+
+    def _dynamo_test_key(self):
+        return f"{self.__class__.__name__}.{self._testMethodName}"
+
+    def compile_fn(self, fn, backend, nopython):
+        # Allows subclasses to control compilation
+        return torch._dynamo.optimize(backend, nopython=nopython)(fn)
+
+    def _run_custom(self, result=None):
+        using_unittest = isinstance(result, unittest.TestResult)
+
+        super_run = super().run
+        test_cls = super_run.__self__  # type: ignore[attr-defined]
+
+        # Are we compiling?
+        compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR
+        # Is the class strict and compiling?
+        strict_default = False
+        should_reset_dynamo = False
+
+        # We disable size_asserts for test_ops since some tests fail
+        # due to mismatch of strides returned from eager v.s. meta kernels
+        # Only some of the ops has this problem, but since tests in
+        # test_op.py are parametrized, it's hard to do this specifically
+        # for the affected ops.
+        # It's not a big deal since these problems are captured by
+        # test_torchinductor_opinfo.py as well.
+        should_disable_size_asserts = False
+        if compiled:
+            try:
+                path = inspect.getfile(type(test_cls))
+                full_path = os.path.abspath(path)
+                match = re.match(r".*/test/(.*).py", full_path)
+                if match is not None:
+                    filename = match.group(1)
+                    if TEST_WITH_TORCHINDUCTOR:
+                        from .dynamo_test_failures import FIXME_inductor_non_strict
+                        strict_default = filename not in FIXME_inductor_non_strict
+                        should_reset_dynamo = True
+
+                        if filename == "test_ops":
+                            should_disable_size_asserts = True
+                    else:
+                        strict_default = True
+            # inspect.getfile can fail with these
+            except (OSError, TypeError):
+                pass
+            if "STRICT_DEFAULT" in os.environ:
+                if os.environ["STRICT_DEFAULT"] == "1":
+                    strict_default = True
+
+        strict_mode = False
+        if compiled:
+            test_method = getattr(self, self._testMethodName)
+            if hasattr(test_method, "dynamo_strict"):
+                strict_mode = test_method.dynamo_strict
+            elif hasattr(test_cls, "dynamo_strict"):
+                strict_mode = test_cls.dynamo_strict
+            else:
+                strict_mode = strict_default
+        nopython = getattr(test_cls, "dynamo_strict_nopython", False) and compiled
+
+        if strict_mode or should_reset_dynamo:
+            torch._dynamo.reset()
+
+        torch.compiler.set_stance("default")
+
+        # TODO: Remove this; this is grandfathered in because we suppressed errors
+        # on test suite previously
+        # When strict mode is False, suppress_errors is True
+        if compiled:
+            suppress_errors = not strict_mode
+        else:
+            suppress_errors = torch._dynamo.config.suppress_errors
+
+        maybe_disable_size_asserts = (
+            torch._inductor.config.patch(size_asserts=False)
+            if should_disable_size_asserts
+            else contextlib.nullcontext()
+        )
+
+        with unittest.mock.patch("torch._dynamo.config.suppress_errors", suppress_errors), maybe_disable_size_asserts:
+            if TEST_WITH_AOT_EAGER:
+                super_run = self.compile_fn(super_run, "aot_eager_decomp_partition", nopython)
+            elif TEST_WITH_TORCHDYNAMO or TEST_WITH_TORCHINDUCTOR:
+                if TEST_WITH_TORCHINDUCTOR:
+                    super_run = self.compile_fn(super_run, "inductor", nopython)
+                else:
+                    # Assume eager-generated GraphModules will not error out.
+                    # If we do, this is probably a Dynamo bug!
+                    super_run = self.compile_fn(super_run, "eager_noexcept", nopython)
+
+                key = self._dynamo_test_key()
+
+                def expect_failure(f, file_name):
+                    @wraps(f)
+                    def wrapper(*args, **kwargs):
+                        try:
+                            f(*args, **kwargs)
+                        except BaseException as e:
+                            self.skipTest(e)
+                        raise RuntimeError(f"Unexpected success, please remove `{file_name}`")
+                    return wrapper
+
+                if TEST_WITH_TORCHINDUCTOR:
+                    subdir = "test/inductor_expected_failures"
+                    from .dynamo_test_failures import inductor_expected_failures as expected_failures
+                else:
+                    subdir = "test/dynamo_expected_failures"
+                    from .dynamo_test_failures import dynamo_expected_failures as expected_failures
+
+                if key in expected_failures:
+                    method = getattr(self, self._testMethodName)
+                    file_name = os.path.join(subdir, key)
+                    setattr(self, self._testMethodName, expect_failure(method, file_name))
+
+                def ignore_failure(f, file_name):
+                    @wraps(f)
+                    def wrapper(*args, **kwargs):
+                        try:
+                            f(*args, **kwargs)
+                        except BaseException as e:
+                            self.skipTest(e)
+                        method = getattr(self, self._testMethodName)
+                        if getattr(method, "__unittest_expecting_failure__", False):
+                            self.skipTest("unexpected success")
+                        else:
+                            self.skipTest(f"This test passed, maybe we can remove `{file_name}`")
+                    return wrapper
+
+                if TEST_WITH_TORCHINDUCTOR:
+                    subdir = "test/inductor_skips"
+                    from .dynamo_test_failures import inductor_skips as skips
+                else:
+                    subdir = "test/dynamo_skips"
+                    from .dynamo_test_failures import dynamo_skips as skips
+
+                if key in skips:
+                    method = getattr(self, self._testMethodName)
+                    file_name = os.path.join(subdir, key)
+                    setattr(self, self._testMethodName, ignore_failure(method, file_name))
+
+                from .dynamo_test_failures import compiled_autograd_skips
+                if torch._dynamo.config.compiled_autograd and key in compiled_autograd_skips:
+                    # Still run the test, but with compiled autograd disabled
+                    super_run = runWithoutCompiledAutograd()(super_run)
+
+            super_run(result=result)
+
+        if strict_mode or should_reset_dynamo:
+            torch._dynamo.reset()
+
+        # Early terminate test if necessary.  If using pytest, use the -x flag instead
+        if using_unittest and self._should_stop_test_suite():
+            if result.wasSuccessful():
+                case = TestCase()
+                if TEST_SAVE_XML is not None:
+                    # This is a big hacky, XMLRunner modifies expected type from TestCase to TestInfo
+                    # Create dummy TestInfo to record results correctly
+                    from xmlrunner.result import _TestInfo  # type: ignore[import]
+                    case = _TestInfo(result, case)
+                    case.output = _TestInfo.ERROR  # type: ignore[attr-defined]
+                    case.elapsed_time = 0.0  # type: ignore[attr-defined]
+                    case.test_description = "TestSuiteEarlyFailure"  # type: ignore[attr-defined]
+                # This shouldn't really happen, but if does add fake failure
+                # For more details see https://github.com/pytorch/pytorch/issues/71973
+                result.failures.append((case, "TestSuite execution was aborted early"))
+                assert result.wasSuccessful() is False
+            result.stop()
+
+
+    def run(self, result=None):
+        with contextlib.ExitStack() as stack:
+            if TEST_WITH_CROSSREF:
+                stack.enter_context(CrossRefMode())
+            self._run_custom(
+                result=result,
+            )
+
+    def setUp(self):
+        check_if_enable(self)
+        set_rng_seed(SEED)
+
+        # Save global check sparse tensor invariants state that can be
+        # restored from tearDown:
+        self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled()
+
+        # Enable invariant checks for all sparse tensors constructions
+        # including the unsafe ones. If this is not desired for some
+        # test case, use check_invariants=False optional argument to
+        # sparse tensor constructors or
+        # @torch.sparse.check_sparse_tensor_invariants(False)
+        # decorator to disable the invariant checks.
+        torch.sparse.check_sparse_tensor_invariants.enable()
+
+        if self._default_dtype_check_enabled:
+            assert torch.get_default_dtype() == torch.float
+
+        # attempt to reset some global state at the end of the test
+        self._prev_grad_state = torch.is_grad_enabled()
+
+    def tearDown(self):
+        # There exists test cases that override TestCase.setUp
+        # definition, so we cannot assume that _check_invariants
+        # attribute is defined in general.
+        if hasattr(self, '_check_invariants'):
+            # Restore the global check sparse tensor invariants state
+            if self._check_invariants:
+                torch.sparse.check_sparse_tensor_invariants.enable()
+            else:
+                torch.sparse.check_sparse_tensor_invariants.disable()
+
+        if self._default_dtype_check_enabled:
+            assert torch.get_default_dtype() == torch.float
+
+        # attribute may not be defined, per above
+        if hasattr(self, '_prev_grad_state'):
+            torch.set_grad_enabled(self._prev_grad_state)
+
+    @staticmethod
+    def _make_crow_indices(n_rows, n_cols, nnz,
+                           *, device, dtype, random=True):
+        """Return crow_indices of a CSR tensor with size (n_rows, n_cols) and
+        the number of specified elements nnz.
+
+        If random is True, the column counts of rows are in random
+        order. Otherwise, the column counts of rows are defined by the
+        used sampling method.
+
+        Sampling method
+        ---------------
+
+        The used sampling method was introduced in
+        https://pearu.github.io/csr_sampling.html, and here we give
+        only an overall description of the method.
+
+        Notice that crow_indices can be defined as cumsum(counts)
+        where counts is a sequence of non-negative integers satisfying
+        the following conditions:
+
+          len(counts) == n_rows + 1
+          counts.max() <= n_cols
+
+        while counts[i + 1] is interpreted as the number of specified
+        elements in the i-th row.
+
+        The used sampling method aims at increasing the diversity of
+        CSR samples, that is, a CSR sample should contain (i) rows
+        that are all filled, (ii) rows with no elements at all, and
+        (iii) rows that are partially filled. At the same time and for
+        the given total number of specified elements (nnz), there
+        should be minimal preference to rows with a given number of
+        elements.  To achieve this, the sampling method is built-up on
+        using a sawteeth model for counts. In the simplest case, we
+        would have
+
+          counts = arange(n_rows + 1) % (n_cols + 1)
+
+        that has equal number of all possible column counts per row.
+        This formula can be used only for specific input values of
+        n_rows, n_cols, and nnz. To generalize this model to any
+        combinations of inputs, the counts model above is extended
+        with an incomplete sawtooth, and the right and lower
+        rectangular parts that will guarantee that
+
+          counts.sum() == nnz
+
+        for any combination of n_rows, n_cols, and nnz. Basically,
+        we'll find a maximal window in (n_rows + 1, n_cols + 1)-grid
+        that is able to hold a sequence of sawteeth and so-called
+        final correction, while the external part of the window is
+        filled with counts to meet the nnz constraint exactly.
+        """
+        assert 0 <= nnz <= n_rows * n_cols, (nnz, n_rows, n_cols)
+
+        def sawteeth(n, m):
+            # return the total number of counts in the sequence of
+            # sawteeth where n and m define a window in (n_rows+1,
+            # n_cols+1) rectangle where the sequence of sawteeth
+            # perfectly fit.
+            M = (n_cols - m) * (n_cols - m + 1) // 2
+            K = (n_rows - n) % (n_cols - m + 1)
+            return M * ((n_rows - n) // (n_cols - m + 1)) + K * (K - 1) // 2
+
+        # Different from the original method description, here counts
+        # has leading 0 required by crow_indices:
+        counts = torch.zeros(n_rows + 1, dtype=dtype, device=torch.device('cpu'))
+
+        n = m = 0
+        N = sawteeth(n, m)
+        if N and nnz >= max(N, n_cols):
+            # determine the width of the sawteeth window. We use bisection to solve
+            #   N(n, 0) == 0 or nnz - n * n_cols < max(N(n, 0), n_cols)
+            # for n
+            n_left = n
+            n_right = n_rows - 1
+            N_right = sawteeth(n_right, m)
+            while n_right - n_left > 1:
+                n_middle = (n_left + n_right) // 2
+                N_middle = sawteeth(n_middle, m)
+                if N_middle == 0 or nnz - n_middle * n_cols < max(N_middle, n_cols):
+                    n_right, N_right = n_middle, N_middle
+                else:
+                    n_left = n_middle
+            n, N = n_right, N_right
+            # fill the right rectangle with counts:
+            assert n
+            counts[-n:].fill_(n_cols)
+
+        if N and nnz - n * n_cols >= max(N, n_rows - n):
+            # determine the height of the sawteeth window. We use bisection to solve
+            #   N(n, m) == 0 or nnz - n * n_cols - m * (n_rows - n) < max(N(n, m), n_rows - n)
+            # for m.
+            m_left = m
+            m_right = n_cols - 1
+            N_right = sawteeth(n, m_right)
+            while m_right - m_left > 1:
+                m_middle = (m_left + m_right) // 2
+                N_middle = sawteeth(n, m_middle)
+                if N_middle == 0 or nnz - n * n_cols - m_middle * (n_rows - n) < max(N_middle, n_rows - n):
+                    m_right, N_right = m_middle, N_middle
+                else:
+                    m_left = m_middle
+            m, N = m_right, N_right
+            # fill the bottom rectangle with counts:
+            assert m
+            counts[1:n_rows - n + 1].fill_(m)
+
+        if N:
+            # fill the sawteeth window with counts
+            q, r = divmod(nnz - n * n_cols - m * (n_rows - n),
+                          (n_cols - m) * (n_cols - m + 1) // 2)
+            p = 1 + q * (n_cols - m + 1)
+            k = math.isqrt(2 * r)
+            if k * (k + 1) > 2 * r:
+                k -= 1
+            corr = r - k * (k + 1) // 2
+            assert not ((p > 1) and (m > 0))  # full sawteeth are never on top of a bottom rectangle
+            # sequence of full sawteeth:
+            counts[1:p] = torch.arange(p - 1, dtype=dtype, device=counts.device) % (n_cols - m + 1)
+            # incomplete sawtooth:
+            counts[p:p + k + 1] += torch.arange(k + 1, dtype=dtype, device=counts.device)
+        else:
+            # given input does not support sawteeth
+            p = 1
+            corr = nnz - n * n_cols - m * (n_rows - n)
+
+        # correction that will guarantee counts.sum() == nnz:
+        counts[p] += corr
+
+        if random:
+            # randomize crow_indices by shuffling the sawteeth
+            # sequence:
+            perm = torch.randperm(n_rows, device=counts.device)
+            counts[1:] = counts[1:][perm]
+
+        # compute crow_indices:
+        crow_indices = counts
+        crow_indices.cumsum_(dim=0)
+        return crow_indices.to(device=device)
+
+    def genSparseCompressedTensor(self, size, nnz, *, layout, device, dtype, index_dtype, blocksize=(), dense_dims=0):
+        from operator import mul
+        from functools import reduce
+        sparse_dim = 2
+        assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments'
+        assert len(size) >= sparse_dim
+        if blocksize:
+            assert len(blocksize) == 2, (size, blocksize)
+            assert size[-2 - dense_dims] % blocksize[0] == 0, (size, blocksize)
+            assert size[-1 - dense_dims] % blocksize[1] == 0, (size, blocksize)
+            blocksize0, blocksize1 = blocksize
+        else:
+            blocksize0 = blocksize1 = 1
+
+        size = tuple(size)
+        dense_size = size[(len(size) - dense_dims):]
+
+        def random_sparse_compressed(n_compressed_dims, n_plain_dims, nnz):
+            compressed_indices = self._make_crow_indices(n_compressed_dims, n_plain_dims, nnz, device=device, dtype=index_dtype)
+            plain_indices = torch.zeros(nnz, dtype=index_dtype, device=device)
+            for i in range(n_compressed_dims):
+                count = compressed_indices[i + 1] - compressed_indices[i]
+                plain_indices[compressed_indices[i]:compressed_indices[i + 1]], _ = torch.sort(
+                    torch.randperm(n_plain_dims, dtype=index_dtype, device=device)[:count])
+            low = -1 if dtype != torch.uint8 else 0
+            high = 1 if dtype != torch.uint8 else 2
+            values = make_tensor((nnz,) + blocksize + dense_size, device=device, dtype=dtype, low=low, high=high)
+            return values, compressed_indices, plain_indices
+
+        batch_shape = size[:-2 - dense_dims]
+        n_batch = reduce(mul, batch_shape, 1)
+
+        if layout in {torch.sparse_csr, torch.sparse_bsr}:
+            n_compressed_dims, n_plain_dims = size[-2 - dense_dims] // blocksize0, size[-1 - dense_dims] // blocksize1
+        else:
+            n_compressed_dims, n_plain_dims = size[-1 - dense_dims] // blocksize1, size[-2 - dense_dims] // blocksize0
+        blocknnz = nnz // (blocksize0 * blocksize1)
+        sparse_tensors = [random_sparse_compressed(n_compressed_dims, n_plain_dims, blocknnz) for _ in range(n_batch)]
+        sparse_tensors_it = map(list, zip(*sparse_tensors))
+
+        values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, blocknnz, *blocksize, *dense_size)
+        compressed_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
+        plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
+        return torch.sparse_compressed_tensor(compressed_indices, plain_indices,
+                                              values, size=size, dtype=dtype, layout=layout, device=device)
+
+    def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csr, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=dense_dims)
+
+    def genSparseCSCTensor(self, size, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_csc, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=(), dense_dims=0)
+
+    def genSparseBSRTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        assert len(blocksize) == 2
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsr, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims)
+
+    def genSparseBSCTensor(self, size, blocksize, nnz, *, device, dtype, index_dtype, dense_dims=0):
+        assert len(blocksize) == 2
+        return self.genSparseCompressedTensor(size, nnz, layout=torch.sparse_bsc, device=device,
+                                              dtype=dtype, index_dtype=index_dtype, blocksize=blocksize, dense_dims=dense_dims)
+
+    def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device, dtype):
+        # Assert not given impossible combination, where the sparse dims have
+        # empty numel, but nnz > 0 makes the indices containing values.
+        assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
+
+        v_size = [nnz] + list(size[sparse_dim:])
+        v = make_tensor(v_size, device=device, dtype=dtype, low=-1, high=1)
+        i = torch.rand(sparse_dim, nnz, device=device)
+        i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
+        i = i.to(torch.long)
+        if is_uncoalesced:
+            i1 = i[:, :(nnz // 2), ...]
+            i2 = i[:, :((nnz + 1) // 2), ...]
+            i = torch.cat([i1, i2], 1)
+        x = torch.sparse_coo_tensor(i, v, torch.Size(size), dtype=dtype, device=device)
+
+        if not is_uncoalesced:
+            x = x.coalesce()
+        else:
+            # FIXME: `x` is a sparse view of `v`. Currently rebase_history for
+            #        sparse views is not implemented, so this workaround is
+            #        needed for inplace operations done on `x`, e.g., copy_().
+            #        Remove after implementing something equivalent to CopySlice
+            #        for sparse views.
+            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards
+            x = x.detach().clone()._coalesced_(False)
+        return x, x._indices().clone(), x._values().clone()
+
+    def generate_simple_inputs(self, layout,
+                               device=None,
+                               dtype=None,
+                               index_dtype=None,
+                               pin_memory=None,
+                               members_pin_memory=None,
+                               enable_batch=True,
+                               enable_hybrid=True,
+                               enable_zero_sized=True,
+                               enable_non_contiguous_indices=True,
+                               enable_non_contiguous_values=True,
+                               enable_batch_variable_nse=False,
+                               output_tensor=True,
+                               patterns=None):
+        """Generator of simple inputs for tensor constructors of the given layout.
+
+        The generated tensor inputs have the following properties:
+
+        - tensor shapes are minimal but not trivial
+        - tensor values are sorted sequences for COO and CSR formats, e.g. [1, 2, 3, 4]
+        - the generated tensors represent the same mathematical tensor for all layouts
+        - the generated tensors include regular, zero-sized, and optionally, batched or/and hybrid tensors.
+        - the generated tensors include contiguous or non-contiguous tensors both in indices and values
+
+        If output_tensor is True, yield tensors with the given
+        layout. Otherwise, yield inputs to the corresponding tensor
+        constructors:
+
+          - sparse compressed input is defined as
+            (compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype,
+                                                              pin_memory=pin_memory)
+
+          - sparse COO input is defined as
+            (indices, values), dict(size=expected_size_from_shape_inference, device=device, dtype=dtype, pin_memory=pin_memory)
+
+          - strided input is defined as
+            (values,), dict(device=device, dtype=dtype)
+        """
+        if index_dtype is None:
+            index_dtype = torch.int64
+
+        is_compressed_sparse_layout = layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
+
+        if output_tensor:
+            for args, kwargs in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype,
+                                                            pin_memory=pin_memory,
+                                                            enable_batch=enable_batch, enable_hybrid=enable_hybrid,
+                                                            enable_zero_sized=enable_zero_sized,
+                                                            enable_non_contiguous_indices=enable_non_contiguous_indices,
+                                                            enable_non_contiguous_values=enable_non_contiguous_values,
+                                                            enable_batch_variable_nse=enable_batch_variable_nse,
+                                                            output_tensor=False):
+                if members_pin_memory:
+                    args = tuple(a.pin_memory() for a in args)
+                if layout is torch.strided:
+                    assert len(args) == 1
+                    size = kwargs.pop('size', None)  # to ensure that a zero-sized tensor has the desired shape
+                    assert size is not None
+                    if pin_memory:
+                        yield args[0].reshape(size).pin_memory()
+                    else:
+                        yield args[0].reshape(size)
+                elif layout is torch.sparse_coo:
+                    yield torch.sparse_coo_tensor(*args, **kwargs)
+                elif is_compressed_sparse_layout:
+                    kwargs.update(layout=layout)
+                    yield torch.sparse_compressed_tensor(*args, **kwargs)
+                else:
+                    assert 0  # unreachable
+            return
+
+        def get_blockpattern(pattern, blocksize):
+            basesize = pattern.shape
+            assert basesize[0] % blocksize[0] == 0, (basesize, blocksize)
+            assert basesize[1] % blocksize[1] == 0, (basesize, blocksize)
+            blockpattern = pattern.reshape(-1,
+                                           blocksize[0],
+                                           basesize[1] // blocksize[1],
+                                           blocksize[1]).transpose(-3, -2).any(-1).any(-1)
+            block_ids = torch.arange(1, blockpattern.numel() + 1).reshape(blockpattern.shape)
+            return (blockpattern != 0) * block_ids
+
+        def get_sparse_data(pattern):
+            basesize = pattern.shape
+            assert len(basesize) == 2, basesize  # pattern is expected to be a matrix
+
+            # We cannot use `torch.sparse_xyz_tensor(pattern)` to
+            # compute the sparse layout indices and values because
+            # generate_simple_inputs is used to generate the inputs to
+            # test `torch.sparse_xyz_tensor` factory functions, so
+            # we'll compute the indices and values independently of
+            # the factory functions.
+
+            indices = torch.where(pattern != 0)
+            coo_indices = torch.stack(indices)
+            crow_indices = torch.zeros(basesize[0] + 1, dtype=torch.int64)
+            crow_indices[1:] = torch.cumsum(coo_indices[0].bincount(minlength=basesize[0]), 0)
+            col_indices = coo_indices[1]
+            strided_values = torch.zeros(basesize, dtype=torch.int64)
+
+            # the property of `values == range(1, 1+nnz)` is used in
+            # get_sparse_data_with_block to relate BSR and BSC values,
+            # so, don't change the following line:
+            values = torch.arange(1, 1 + len(indices[0]), dtype=torch.int64)
+            strided_values[indices] = values
+
+            indices_T = torch.where(pattern.transpose(0, 1) != 0)
+            coo_indices_T = torch.stack(indices_T)
+            ccol_indices = torch.zeros(basesize[1] + 1, dtype=torch.int64)
+            ccol_indices[1:] = torch.cumsum(coo_indices_T[0].bincount(minlength=basesize[1]), 0)
+            row_indices = coo_indices_T[1]
+            csc_values = strided_values.transpose(0, 1)[indices_T]
+
+            return {torch.sparse_coo: (coo_indices, values),
+                    torch.sparse_csr: (crow_indices, col_indices, values),
+                    torch.sparse_csc: (ccol_indices, row_indices, csc_values),
+                    torch.strided: (strided_values,)}
+
+        def get_sparse_data_with_block(pattern, blocksize):
+            nonblock_data = get_sparse_data(pattern)
+            blockpattern = get_blockpattern(pattern, blocksize)
+            block_data = get_sparse_data(blockpattern)
+
+            strided_values = nonblock_data[torch.strided][0]
+            block_indices = block_data[torch.sparse_coo][0]
+            bsr_values = torch.stack([strided_values[bi * blocksize[0]:(bi + 1) * blocksize[0],
+                                                     bj * blocksize[1]:(bj + 1) * blocksize[1]]
+                                      for bi, bj in block_indices.transpose(0, 1)])
+
+            # here we use the property `values == range(1, 1+nnz)` and
+            # `values` relation to `csc_values` (see get_sparse_data)
+            # to get BSC blocks via reordering the BSR blocks:
+            bsc_values = bsr_values[block_data[torch.sparse_csc][2] - 1]
+
+            return {torch.sparse_bsr: (*block_data[torch.sparse_csr][:2], bsr_values),
+                    torch.sparse_bsc: (*block_data[torch.sparse_csc][:2], bsc_values),
+                    **nonblock_data}
+
+        def get_batch_sparse_data(pattern, blocksize):
+            size = pattern.shape
+            if len(size) <= 2:  # non-batch
+                return get_sparse_data_with_block(pattern, blocksize)
+
+            # batch data is created recursively:
+            batch_data = {}  # type: ignore[var-annotated]
+            for i, item in enumerate(pattern):
+                for layout, d in get_batch_sparse_data(item, blocksize).items():
+                    target = batch_data.get(layout)
+                    if layout is torch.sparse_coo:
+                        # a "batch COO" means a COO with the leading
+                        # sparse dimensions interpreted as batch
+                        # dimensions
+                        ext_coo_indices1 = torch.cat((torch.full((1, len(d[1])), i, dtype=torch.int64), d[0]))
+                        if target is None:
+                            target = batch_data[layout] = (ext_coo_indices1, d[1])
+                        else:
+                            target[0].set_(torch.cat((target[0], ext_coo_indices1), 1))  # type: ignore[call-overload]
+                            target[1].set_(torch.cat((target[1], d[1])))
+                    else:
+                        if target is None:
+                            target = batch_data[layout] = tuple(d[j].unsqueeze(0) for j in range(len(d)))
+                        else:
+                            for j in range(len(d)):
+                                target[j].set_(torch.cat((target[j], d[j].unsqueeze(0))))  # type: ignore[call-overload]
+            return batch_data
+
+        def generate_values(base, densesize):
+            """Generates a tensor of shape densesize with values equal to
+
+              base + i_1 * 10^0 + ... + i_d * 10^{d - 1}
+
+            at indices i_1, ..., i_d (with 0 <= i_j < densesize[j] for any 1 <= j <=
+            len(densesize))
+
+            This mapping produces unique values as long as
+            densesize[i] < 10 for all i in range(len(densesize)).
+            """
+
+            if not densesize:
+                return base
+            if not isinstance(base, int) and base.ndim > 0:
+                return torch.stack([generate_values(b, densesize) for b in base])
+            if base == 0:
+                return torch.zeros(densesize, dtype=torch.int64)
+            r = torch.arange(densesize[0], dtype=torch.int64)
+            for i, d in enumerate(densesize[1:]):
+                y = torch.arange(d, dtype=torch.int64) * (10 ** (i + 1))
+                r = r[..., None] + y[None, ...]
+            r.add_(base)
+            return r
+
+        if patterns is None:
+            # A pattern is a 3-tuple with the following items:
+            #
+            # - a list of integers with the depth of two or more. The
+            #   integers define the sparsity patterns of the generated
+            #   inputs: zero values correspond to unspecified
+            #   elements/blocks, and non-zero values to the specified
+            #   elements.
+            #
+            #   For debugging convenience, the elements with the same
+            #   value typically belong to the same block. However, it
+            #   is not a hard requirement: as long as the shape of a
+            #   pattern divides with block sizes, the pattern will be
+            #   a valid one.
+            #
+            #   If the depth of the list is larger than two, inputs
+            #   with batch dimensions will be generated.
+            #
+            # - a list of 2-tuples of block sizes, used to generate
+            #   BSR/BSC tensors with various block size parameters
+            #
+            # - a list of tuples of dense dimensions, used to generate
+            #   hybrid tensors with various dense dimensions
+            #
+            patterns = [
+                # a simple 3 x 2 tensor: non-hybrid, hybrid with 1 and 2 dense dimensions
+                ([[1, 2, 0],
+                  [1, 0, 3]], [(2, 1), (1, 3)], [(), (2,), (4, 5)]),
+                # 2 x 3 batch of 3 x 2 tensors: non-hybrid and hybrid with 2 dense dimensions
+                ([[[[1, 2, 0],
+                    [1, 0, 3]],
+                   [[1, 2, 3],
+                    [1, 0, 0]],
+                   [[1, 0, 0],
+                    [1, 2, 3]]],
+                  [[[0, 2, 0],
+                    [1, 2, 3]],
+                   [[1, 0, 3],
+                    [1, 2, 0]],
+                   [[1, 2, 3],
+                    [0, 2, 0]]]], [(2, 1), (2, 3)], [(), (2,)]),
+                # tensor with non-trivial blocksize
+                ([[0, 1, 0, 2, 0, 2],
+                  [0, 1, 0, 0, 2, 0],
+                  [3, 3, 3, 0, 0, 0],
+                  [0, 0, 0, 0, 0, 0],
+                  [0, 5, 0, 6, 6, 6],
+                  [5, 0, 5, 6, 6, 6],
+                  [0, 0, 0, 0, 8, 8],
+                  [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 5)]),
+                # batch tensor with variable NSE
+                # Requires https://github.com/pytorch/pytorch/pull/84843 or similar.
+                ([[[1, 2],
+                   [3, 4]],
+                  [[1, 0],
+                   [0, 0]]], [(1, 1)], ([()] if enable_batch_variable_nse else []))]
+
+        def non_contiguous_copy(t, dim=-1, offset=0):
+            # return a copy of t that is non-contiguous along the
+            # given dimension and with the given storage offset
+            self.assertTrue(t.is_contiguous())
+            if dim < 0:
+                dim = dim + t.ndim
+            assert dim >= 0 and dim < t.ndim
+            step = max(2, offset + 1)
+            tmp = torch.zeros((*t.shape[:dim], t.shape[dim] * step, *t.shape[dim + 1:]), dtype=t.dtype, device=t.device)
+            dim_slices = (*((slice(None),) * dim), slice(offset, None, step))
+            r = tmp[dim_slices].copy_(t)
+            self.assertFalse(r.is_contiguous())
+            self.assertEqual(t, r)
+            return r
+
+        # the main loop of the method:
+        for pattern, blocksizes, densesizes in patterns:
+            if not enable_hybrid:
+                densesizes = [s for s in densesizes if not s]
+            if not (densesizes and blocksizes):
+                continue
+            pattern = torch.tensor(pattern, dtype=torch.int64)
+            if not enable_batch and pattern.ndim > 2:
+                continue
+            for blocksize in blocksizes:
+                data = get_batch_sparse_data(pattern, blocksize)[layout]
+                for densesize in densesizes:
+                    indices = [a.to(device=device, dtype=index_dtype) for a in data[:-1]]
+                    values = generate_values(data[-1], densesize).to(device=device, dtype=dtype)
+                    kwargs = dict(device=device, dtype=dtype, size=pattern.shape + densesize)
+                    if pin_memory is not None:
+                        kwargs.update(pin_memory=pin_memory)
+
+                    yield (*indices, values), kwargs.copy()
+                    if enable_non_contiguous_indices and pattern.ndim > 2:
+                        # sparse compressed indices can be sliced only along batch dimensions
+                        for (dim, offset) in {(0, 1), (-2, 0)}:
+                            indices_copy = [non_contiguous_copy(a, dim=dim, offset=offset) for a in indices]
+                            yield (*indices_copy, values), kwargs.copy()
+
+                            if enable_non_contiguous_values:
+                                values_copy = non_contiguous_copy(values, dim=-1, offset=1)
+                                yield (*indices_copy, values_copy), kwargs.copy()
+
+                    if enable_non_contiguous_values:
+                        values_copy = non_contiguous_copy(values, dim=-1, offset=1)
+                        yield (*indices, values_copy), kwargs.copy()
+
+        # zero-sized tensor inputs, non-batch, non-hybrid/hybrid
+        if enable_zero_sized:
+            for basesize, blocksizes, densesizes in [
+                    ((2, 0), [(1, 2)], [(), (2,), (2, 3)] if enable_hybrid else [()]),
+                    ((0, 2), [(1, 2), (2, 1), (3, 2)], [()]),
+                    ((0, 0), [(1, 2)], [()]),
+            ]:
+                for blocksize in blocksizes:
+                    for densesize in densesizes:  # type: ignore[attr-defined]
+                        if layout == torch.strided:
+                            indices = ()  # type: ignore[assignment]
+                            values = torch.empty((basesize + densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_coo:
+                            indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_csr:
+                            crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype)
+                            col_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (crow_indices, col_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_csc:
+                            ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype)
+                            row_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (ccol_indices, row_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_bsr:
+                            crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype)
+                            col_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (crow_indices, col_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
+                        elif layout == torch.sparse_bsc:
+                            ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype)
+                            row_indices = torch.empty(0, device=device, dtype=index_dtype)
+                            indices = (ccol_indices, row_indices)  # type: ignore[assignment]
+                            values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
+                        else:
+                            assert 0  # unreachable
+                        kwargs = dict(device=device, dtype=dtype, size=basesize + densesize)
+                        if pin_memory is not None:
+                            kwargs.update(pin_memory=pin_memory)
+                        yield (*indices, values), kwargs
+
+    def safeToDense(self, t):
+        # coalesce is only implemented for COO
+        if t.layout == torch.sparse_coo:
+            t = t.coalesce()
+        return t.to_dense()
+
+    # Compares a torch function with a reference function for a given sample input (object of SampleInput)
+    # Note: only values are compared, type comparison is not done here
+    def compare_with_reference(self, torch_fn, ref_fn, sample_input, **kwargs):
+        numpy_sample = sample_input.numpy()
+        n_inp, n_args, n_kwargs = numpy_sample.input, numpy_sample.args, numpy_sample.kwargs
+        t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
+
+        actual = torch_fn(t_inp, *t_args, **t_kwargs)
+        expected = ref_fn(n_inp, *n_args, **n_kwargs)
+
+        self.assertEqual(actual, expected, exact_device=False, **kwargs)
+
+    # Compares the given Torch and NumPy functions on the given tensor-like object.
+    # NOTE: both torch_fn and np_fn should be functions that take a single
+    #   tensor (array). If the torch and/or NumPy function require additional
+    #   arguments then wrap the function in a lambda or pass a partial function.
+    # TODO: add args/kwargs for passing to assertEqual (e.g. rtol, atol)
+    def compare_with_numpy(self, torch_fn, np_fn, tensor_like,
+                           device=None, dtype=None, **kwargs):
+        assert TEST_NUMPY
+
+        if isinstance(tensor_like, torch.Tensor):
+            assert device is None
+            assert dtype is None
+            t_cpu = tensor_like.detach().cpu()
+            if t_cpu.dtype is torch.bfloat16:
+                t_cpu = t_cpu.float()
+            a = t_cpu.numpy()
+            t = tensor_like
+        else:
+            d = copy.copy(torch_to_numpy_dtype_dict)
+            d[torch.bfloat16] = np.float32
+            a = np.array(tensor_like, dtype=d[dtype])
+            t = torch.tensor(tensor_like, device=device, dtype=dtype)
+
+        np_result = np_fn(a)
+        torch_result = torch_fn(t).cpu()
+
+        # Converts arrays to tensors
+        if isinstance(np_result, np.ndarray):
+            try:
+                np_result = torch.from_numpy(np_result)
+            except Exception:
+                # NOTE: copying an array before conversion is necessary when,
+                #   for example, the array has negative strides.
+                np_result = torch.from_numpy(np_result.copy())
+            if t.dtype is torch.bfloat16 and torch_result.dtype is torch.bfloat16 and np_result.dtype is torch.float:
+                torch_result = torch_result.to(torch.float)
+
+        self.assertEqual(np_result, torch_result, **kwargs)
+
+    def assertEqualIgnoreType(self, *args, **kwargs) -> None:
+        # If you are seeing this function used, that means test is written wrongly
+        # and deserves detailed investigation
+        return self.assertEqual(*args, exact_dtype=False, **kwargs)
+
+    def assertEqualBroadcasting(self, x, y, *args, **kwargs) -> None:
+        r"""Tests if tensor x equals to y, if y to be broadcast to x.shape.
+        """
+        if not isinstance(y, Iterable):
+            # int, float, etc. or different shape tensors
+            y = torch.ones_like(x) * y
+        if not isinstance(y, torch.Tensor):
+            # iterable, but not a tensor
+            y = torch.ones_like(x) * torch.tensor(y)
+        return self.assertEqual(x, y, *args, **kwargs)
+
+    def assertEqual(
+            self,
+            x,
+            y,
+            msg: Optional[Union[str, Callable[[str], str]]] = None,
+            *,
+            atol: Optional[float] = None,
+            rtol: Optional[float] = None,
+            equal_nan=True,
+            exact_dtype=True,
+            # TODO: default this to True
+            exact_device=False,
+            exact_layout=False,
+            exact_stride=False,
+            exact_is_coalesced=False
+    ):
+        # Hide this function from `pytest`'s traceback
+        __tracebackhide__ = True
+
+        # numpy's dtypes are a superset of what PyTorch supports. In case we encounter an unsupported dtype, we fall
+        # back to an elementwise comparison. Note that this has to happen here and not for example in
+        # `TensorOrArrayPair`, since at that stage we can no longer split the array into its elements and perform
+        # multiple comparisons.
+        if any(
+            isinstance(input, np.ndarray) and not has_corresponding_torch_dtype(input.dtype) for input in (x, y)
+        ):
+            def to_list(input):
+                return input.tolist() if isinstance(input, (torch.Tensor, np.ndarray)) else list(input)
+
+            x = to_list(x)
+            y = to_list(y)
+        # When comparing a sequence of numbers to a tensor, we need to convert the sequence to a tensor here.
+        # Otherwise, the pair origination of `are_equal` will fail, because the sequence is recognized as container
+        # that should be checked elementwise while the tensor is not.
+        elif isinstance(x, torch.Tensor) and isinstance(y, Sequence):
+            y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
+        elif isinstance(x, Sequence) and isinstance(y, torch.Tensor):
+            x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
+
+        # unbind NSTs to compare them; don't do this for NJTs
+        if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.strided:
+            x = x.unbind()
+        if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided:
+            y = y.unbind()
+
+        error_metas = not_close_error_metas(
+            x,
+            y,
+            pair_types=(
+                NonePair,
+                RelaxedBooleanPair,
+                RelaxedNumberPair,
+                TensorOrArrayPair,
+                TypedStoragePair,
+                StringPair,
+                SetPair,
+                TypePair,
+                ObjectPair,
+            ),
+            sequence_types=(
+                Sequence,
+                Sequential,
+                ModuleList,
+                ParameterList,
+                ScriptList,
+                torch.utils.data.dataset.Subset,
+            ),
+            mapping_types=(Mapping, ModuleDict, ParameterDict, ScriptDict),
+            rtol=rtol,
+            rtol_override=self.rel_tol,
+            atol=atol,
+            atol_override=self.precision,
+            equal_nan=equal_nan,
+            check_device=exact_device,
+            check_dtype=exact_dtype,
+            check_layout=exact_layout,
+            check_stride=exact_stride,
+            check_is_coalesced=exact_is_coalesced,
+        )
+
+        if error_metas:
+            # See [ErrorMeta Cycles]
+            error_metas = [error_metas]  # type: ignore[list-item]
+            # TODO: compose all metas into one AssertionError
+            raise error_metas.pop()[0].to_error(  # type: ignore[index]
+                # This emulates unittest.TestCase's behavior if a custom message passed and
+                # TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage)
+                # is True (default)
+                (lambda generated_msg: f"{generated_msg}\n{msg}") if isinstance(msg, str) and self.longMessage else msg
+            )
+
+    def assertNotEqual(self, x, y, msg: Optional[str] = None, *,                                       # type: ignore[override]
+                       atol: Optional[float] = None, rtol: Optional[float] = None, **kwargs) -> None:
+        with self.assertRaises(AssertionError, msg=msg):
+            self.assertEqual(x, y, msg, atol=atol, rtol=rtol, **kwargs)
+
+    def assertEqualTypeString(self, x, y) -> None:
+        # This API is used simulate deprecated x.type() == y.type()
+        self.assertEqual(x.device, y.device)
+        self.assertEqual(x.dtype, y.dtype)
+        self.assertEqual(x.is_sparse, y.is_sparse)
+
+    def assertObjectIn(self, obj: Any, iterable: Iterable[Any]) -> None:
+        for elem in iterable:
+            if id(obj) == id(elem):
+                return
+        raise AssertionError("object not found in iterable")
+
+    # Reimplemented to provide special behavior when
+    # _ignore_not_implemented_error is True
+    def assertRaises(self, expected_exception, *args, **kwargs):
+        if self._ignore_not_implemented_error:
+            context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \
+                AssertRaisesContextIgnoreNotImplementedError(expected_exception, self)  # type: ignore[call-arg]
+            try:
+                return context.handle('assertRaises', args, kwargs)  # type: ignore[union-attr, arg-type]
+            finally:
+                # see https://bugs.python.org/issue23890
+                context = None
+        else:
+            return super().assertRaises(expected_exception, *args, **kwargs)
+
+    # Reimplemented to provide special behavior when
+    # _ignore_not_implemented_error is True
+    def assertRaisesRegex(self, expected_exception, expected_regex, *args, **kwargs):
+        # Verifies that an exception with the type expected_exception and message
+        # matching the regular expression defined by expected_regex is thrown.
+        # If the test is instantiated for a non-native device type (like XLA)
+        # then the message is not validated.
+
+        # Checks whether the test is instantiated for a device type by testing
+        # if the test class has defined the device_type attribute and,
+        # if so, tests whether the instantiated device type is native or not
+        if hasattr(self, 'device_type') and self.device_type not in NATIVE_DEVICES and self.device_type != "mps":  # type: ignore[attr-defined]
+            # empty string matches any string
+            expected_regex = ''
+
+        if self._ignore_not_implemented_error:
+            context = AssertRaisesContextIgnoreNotImplementedError(  # type: ignore[call-arg]
+                expected_exception, self, expected_regex)
+            return context.handle('assertRaisesRegex', args, kwargs)  # type: ignore[attr-defined, arg-type]
+        else:
+            return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
+
+    # Verifies that no unraisable exceptions are raised by callable.  Unlike regular
+    # exceptions, these do not actually propagate to the caller and are
+    # suppressed.  We must test for them specially.
+    def assertNoUnraisable(self, callable, *args, **kwargs):
+        raised = None
+
+        def record_unraisable(unraisable):
+            nonlocal raised
+            raised = unraisable
+
+        # Disable GC when running the callable to prevent spurious flakiness
+        # from unlucky GCs inside the callable
+        prev = gc.isenabled()
+        gc.disable()
+        try:
+            with unittest.mock.patch("sys.unraisablehook", record_unraisable):
+                callable(*args, **kwargs)
+        finally:
+            if prev:
+                gc.enable()
+
+        self.assertIsNone(raised)
+
+    # TODO: Support context manager interface
+    # NB: The kwargs forwarding to callable robs the 'subname' parameter.
+    # If you need it, manually apply your callable in a lambda instead.
+    def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
+        subname = None
+        if 'subname' in kwargs:
+            subname = kwargs['subname']
+            del kwargs['subname']
+        try:
+            callable(*args, **kwargs)
+        except exc_type as e:
+            self.assertExpected(str(e), subname)
+            return
+        # Don't put this in the try block; the AssertionError will catch it
+        self.fail(msg="Did not raise when expected to")
+
+    def assertNotWarn(self, callable, msg=''):
+        r"""
+        Test if :attr:`callable` does not raise a warning.
+        """
+        with warnings.catch_warnings(record=True) as ws:
+            warnings.simplefilter("always")  # allow any warning to be raised
+            with set_warn_always_context(True):
+                callable()
+            self.assertTrue(len(ws) == 0, msg)
+
+    @contextmanager
+    def assertWarnsOnceRegex(self, category, regex=''):
+        """Context manager for code that *must always* warn
+
+        This filters expected warnings from the test and fails if
+        the expected warning is not caught. It uses set_warn_always() to force
+        TORCH_WARN_ONCE to behave like TORCH_WARN
+        """
+        pattern = re.compile(regex)
+        with warnings.catch_warnings(record=True) as ws:
+            warnings.simplefilter("always")  # allow any warning to be raised
+            with set_warn_always_context(True):
+                yield
+            if len(ws) == 0:
+                self.fail('no warning caught')
+            self.assertTrue(any(type(w.message) is category for w in ws))
+            self.assertTrue(
+                any(re.match(pattern, str(w.message)) for w in ws),
+                f'{pattern}, {[w.message for w in ws if type(w.message) is category]}')
+
+    def assertExpected(self, s, subname=None):
+        r"""
+        Test that a string matches the recorded contents of a file
+        derived from the name of this test and subname.  This file
+        is placed in the 'expect' directory in the same directory
+        as the test script. You can automatically update the recorded test
+        output using --accept.
+
+        If you call this multiple times in a single function, you must
+        give a unique subname each time.
+        """
+        if not isinstance(s, str):
+            raise TypeError("assertExpected is strings only")
+
+        def remove_prefix(text, prefix):
+            if text.startswith(prefix):
+                return text[len(prefix):]
+            return text
+        # NB: we take __file__ from the module that defined the test
+        # class, so we place the expect directory where the test script
+        # lives, NOT where test/common_utils.py lives.  This doesn't matter in
+        # PyTorch where all test scripts are in the same directory as
+        # test/common_utils.py, but it matters in onnx-pytorch
+        module_id = self.__class__.__module__
+        munged_id = remove_prefix(self.id(), module_id + ".")
+        test_file = os.path.realpath(sys.modules[module_id].__file__)  # type: ignore[type-var]
+        expected_file = os.path.join(os.path.dirname(test_file),  # type: ignore[type-var, arg-type]
+                                     "expect",
+                                     munged_id)
+
+        subname_output = ""
+        if subname:
+            expected_file += "-" + subname
+            subname_output = f" ({subname})"
+        expected_file += ".expect"
+        expected = None
+
+        def accept_output(update_type):
+            print(f"Accepting {update_type} for {munged_id}{subname_output}:\n\n{s}")
+            with open(expected_file, 'w') as f:
+                # Adjust for producer_version, leave s unmodified
+                s_tag = re.sub(r'(producer_version): "[0-9.]*"',
+                               r'\1: "CURRENT_VERSION"', s)
+                f.write(s_tag)
+
+        try:
+            with open(expected_file) as f:
+                expected = f.read()
+        except OSError as e:
+            if e.errno != errno.ENOENT:
+                raise
+            elif expecttest.ACCEPT:
+                return accept_output("output")
+            else:
+                raise RuntimeError(
+                      f"I got this output for {munged_id}{subname_output}:\n\n{s}\n\n"
+                      "No expect file exists; to accept the current output, run:\n"
+                      f"python {__main__.__file__} {munged_id} --accept") from None
+
+        # a hack for JIT tests
+        if IS_WINDOWS:
+            expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected)
+            s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s)
+
+        # Adjust for producer_version
+        expected = expected.replace(
+            'producer_version: "CURRENT_VERSION"',
+            f'producer_version: "{torch.onnx.producer_version}"'
+        )
+        if expecttest.ACCEPT:
+            if expected != s:
+                return accept_output("updated output")
+        else:
+            if hasattr(self, "assertMultiLineEqual"):
+                # Python 2.7 only
+                # NB: Python considers lhs "old" and rhs "new".
+                self.assertMultiLineEqual(expected, s)
+            else:
+                self.assertEqual(s, expected)
+
+    def assertExpectedStripMangled(self, s, subname=None):
+        s = re.sub(r'__torch__[^ ]+', '', s)
+        self.assertExpected(s, subname)
+
+    def assertGreaterAlmostEqual(self, first, second, places=None, msg=None, delta=None):
+        """Assert that ``first`` is greater than or almost equal to ``second``.
+
+        The equality of ``first`` and ``second`` is determined in a similar way to
+        the ``assertAlmostEqual`` function of the standard library.
+        """
+        if delta is not None and places is not None:
+            raise TypeError("specify delta or places not both")
+
+        if first >= second:
+            return
+
+        diff = second - first
+        if delta is not None:
+            if diff <= delta:
+                return
+
+            standardMsg = f"{first} not greater than or equal to {second} within {delta} delta"
+        else:
+            if places is None:
+                places = 7
+
+            if round(diff, places) == 0:
+                return
+
+            standardMsg = f"{first} not greater than or equal to {second} within {places} places"
+
+        msg = self._formatMessage(msg, standardMsg)
+        raise self.failureException(msg)
+
+    def assertAtenOp(self, onnx_model, operator, overload_name=""):
+        all_aten_nodes = [p for p in onnx_model.graph.node
+                          if p.op_type == "ATen" and p.domain == "org.pytorch.aten"]
+        self.assertTrue(all_aten_nodes)
+
+        for op in all_aten_nodes:
+            attrs = {attr.name: attr.s.decode() for attr in op.attribute}
+            if attrs.get("operator") == operator:
+                break
+
+        self.assertEqual(attrs["operator"], operator)  # type: ignore[possibly-undefined]
+        self.assertEqual(attrs.get("overload_name", ""), overload_name)
+
+    def check_nondeterministic_alert(self, fn, caller_name, should_alert=True):
+        '''Checks that an operation produces a nondeterministic alert when
+        expected while `torch.use_deterministic_algorithms(True)` is set.
+
+        Args:
+          fn (callable): Function to check for a nondeterministic alert
+
+          caller_name (str): Name of the operation that produces the
+              nondeterministic alert. This name is expected to appear at the
+              beginning of the error/warning message.
+
+          should_alert (bool, optional): If True, then the check will only pass
+              if calling `fn` produces a nondeterministic error/warning with the
+              expected message. If False, then the check will only pass if
+              calling `fn` does not produce an error. Default: `True`.
+        '''
+
+        alert_message = '^' + caller_name + ' does not have a deterministic implementation, but you set'
+
+        # Check that errors are thrown correctly
+        with DeterministicGuard(True):
+            if should_alert:
+                with self.assertRaisesRegex(
+                        RuntimeError,
+                        alert_message,
+                        msg='expected a non-deterministic error, but it was not raised'):
+                    fn()
+
+            else:
+                # If a nondeterministic error is not expected, make sure
+                # that it is not raised
+                try:
+                    fn()
+                except RuntimeError as e:
+                    if 'does not have a deterministic implementation' in str(e):
+                        self.fail(
+                            'did not expect non-deterministic error message, '
+                            + 'but got one anyway: "' + str(e) + '"')
+                    # Reraise exceptions unrelated to nondeterminism
+                    raise
+
+        # Check that warnings are thrown correctly
+        with DeterministicGuard(True, warn_only=True):
+            if should_alert:
+                with self.assertWarnsRegex(
+                        UserWarning,
+                        alert_message):
+                    fn()
+            else:
+                with warnings.catch_warnings(record=True) as w:
+                    warnings.simplefilter("always")
+                    fn()
+                    for warning in w:
+                        if isinstance(warning, UserWarning):
+                            self.assertTrue(re.search(alert_message, str(warning)) is None)
+
+    # run code in subprocess and capture exceptions.
+    @staticmethod
+    def run_process_no_exception(code, env=None):
+        import subprocess
+
+        popen = subprocess.Popen(
+            [sys.executable, '-c', code],
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            env=env)
+        (stdout, stderr) = popen.communicate()
+        return (stdout, stderr)
+
+    # returns captured stderr
+    @staticmethod
+    def runWithPytorchAPIUsageStderr(code):
+        env = os.environ.copy()
+        env["PYTORCH_API_USAGE_STDERR"] = "1"
+        # remove CI flag since this is a wrapped test process.
+        # CI flag should be set in the parent process only.
+        env.pop("CI", None)
+        env.pop("TEST_SHOWLOCALS", None)
+        _stdout, stderr = TestCase.run_process_no_exception(code, env=env)
+        return stderr.decode('ascii')
+
+    def _attempt_load_from_subprocess(
+        self,
+        file: pathlib.Path,
+        import_string: str,
+        expected_failure_message: Optional[str] = None
+    ) -> None:
+        """
+        Attempts weights_only `torch.load` in a subprocess. This is used to test that
+        weights_only `torch.load` works as expected without global imports.
+
+        Args:
+            file (pathlib.Path): The path to the checkpoint to load.
+            import_string (str): import string to add to the script
+            exected_failure_message (str, optional): The expected failure message if the
+                checkpoint fails to load. If None, the test will pass
+        """
+        script = f"import torch;{import_string}torch.load(r'{file}', weights_only=True)"
+        cm = (
+            self.assertRaisesRegex(RuntimeError, re.escape(expected_failure_message))
+            if expected_failure_message else contextlib.nullcontext()
+        )
+        with cm:
+            try:
+                subprocess.check_output(
+                    [sys.executable, "-c", script],
+                    # On Windows, opening the subprocess with the default CWD makes `import torch`
+                    # fail, so just set CWD to this script's directory
+                    cwd=os.path.dirname(os.path.realpath(__file__)),
+                    stderr=subprocess.STDOUT,
+                )
+            except subprocess.CalledProcessError as e:
+                raise RuntimeError(e.output.decode("utf-8")) from None
+
+
+class TestCaseBase(TestCase):
+    # Calls to super() in dynamically created classes are a bit odd.
+    # See https://github.com/pytorch/pytorch/pull/118586 for more info
+    # Subclassing this class and then calling super(TestCaseBase) will run
+    # TestCase's setUp, tearDown etc functions
+    pass
+
+
+def download_file(url, binary=True):
+    from urllib.parse import urlsplit
+    from urllib import request, error
+
+    filename = os.path.basename(urlsplit(url)[2])
+    data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 'data'))
+    path = os.path.join(data_dir, filename)
+
+    if os.path.exists(path):
+        return path
+    try:
+        data = request.urlopen(url, timeout=15).read()
+        with open(path, 'wb' if binary else 'w') as f:
+            f.write(data)
+        return path
+    except error.URLError as e:
+        msg = f"could not download test file '{url}'"
+        warnings.warn(msg, RuntimeWarning)
+        raise unittest.SkipTest(msg) from e
+
+def find_free_port():
+    """
+    Finds an available port and returns that port number.
+
+    NOTE: If this function is being used to allocate a port to Store (or
+    indirectly via init_process_group or init_rpc), it should be used
+    in conjunction with the `retry_on_connect_failures` decorator as there is a potential
+    race condition where the allocated port may become unavailable before it can be used
+    """
+    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        sock.bind(('localhost', 0))
+        _, port = sock.getsockname()
+        return port
+
+# Errors that we can get in c10d initialization for which we should retry tests for.
+ADDRESS_IN_USE = "Address already in use"
+CONNECT_TIMEOUT = "connect() timed out."
+
+def retry_on_connect_failures(func=None, connect_errors=(ADDRESS_IN_USE)):
+    """Reruns a test if the test returns a RuntimeError and the exception
+    contains one of the strings in connect_errors."""
+    # This if block is executed when using this function as a decorator with arguments.
+    if func is None:
+        return partial(retry_on_connect_failures, connect_errors=connect_errors)
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        n_retries = 10
+        tries_remaining = n_retries
+        while True:
+            try:
+                return func(*args, **kwargs)
+            except RuntimeError as error:
+                if any(connect_error in str(error) for connect_error in connect_errors):
+                    tries_remaining -= 1
+                    if tries_remaining == 0:
+                        raise RuntimeError(f"Failing after {n_retries} retries with error: {str(error)}") from error
+                    time.sleep(random.random())
+                    continue
+                raise
+    return wrapper
+
+
+# Decorator to retry upon certain Exceptions.
+def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
+    def deco_retry(f):
+        @wraps(f)
+        def f_retry(*args, **kwargs):
+            mtries, mdelay = tries, delay
+            while mtries > 1:
+                try:
+                    return f(*args, **kwargs)
+                except ExceptionToCheck as e:
+                    msg = f"{e}, Retrying in {mdelay:d} seconds..."
+                    print(msg)
+                    time.sleep(mdelay)
+                    mtries -= 1
+            try:
+                return f(*args, **kwargs)
+            except ExceptionToCheck as e:
+                raise unittest.SkipTest(f"Skipping after {tries} consecutive {str(e)}") from e if skip_after_retries else e
+        return f_retry  # true decorator
+    return deco_retry
+
+
+# FIXME: modernize these to be consistent with make_tensor
+#   and review including them in torch.testing
+# Methods for matrix generation
+
+def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
+    assert rank <= l
+    A = torch.randn(l, l, dtype=dtype, device=device)
+    u, s, vh = torch.linalg.svd(A, full_matrices=False)
+    for i in range(l):
+        if i >= rank:
+            s[i] = 0
+        elif s[i] == 0:
+            s[i] = 1
+    return (u * s.to(dtype).unsqueeze(-2)) @ vh
+
+def random_well_conditioned_matrix(*shape, dtype, device, mean=1.0, sigma=0.001):
+    """
+    Returns a random rectangular matrix (batch of matrices)
+    with singular values sampled from a Gaussian with
+    mean `mean` and standard deviation `sigma`.
+    The smaller the `sigma`, the better conditioned
+    the output matrix is.
+    """
+    primitive_dtype = {
+        torch.float: torch.float,
+        torch.double: torch.double,
+        torch.cfloat: torch.float,
+        torch.cdouble: torch.double
+    }
+    x = torch.rand(shape, dtype=dtype, device=device)
+    m = x.size(-2)
+    n = x.size(-1)
+    u, _, vh = torch.linalg.svd(x, full_matrices=False)
+    s = (torch.randn(*(shape[:-2] + (min(m, n),)), dtype=primitive_dtype[dtype], device=device) * sigma + mean) \
+        .sort(-1, descending=True).values.to(dtype)
+    return (u * s.unsqueeze(-2)) @ vh
+
+# Returns a noncontiguous (tensor with the same shape and values as t
+# The noncontiguous tensor is constructed such that elements in the innermost
+#   dimension are separated by zeros or (whenever possible) nans
+# TODO: consider more complicated noncontiguity schemes
+def noncontiguous_like(t):
+    # Short-circuits if t is already noncontiguous
+    if not t.is_contiguous():
+        return t
+
+    # Choose a "weird" value that won't be accessed
+    if t.dtype.is_floating_point or t.dtype.is_complex:
+        value = math.nan
+    elif t.dtype == torch.bool:
+        value = True
+    else:
+        value = 12
+
+    result = t.new_empty(t.shape + (2,))
+    result[..., 0] = value
+    result[..., 1] = t.detach()
+    result = result[..., 1]
+    result.requires_grad_(t.requires_grad)
+    return result
+
+# TODO: remove this (prefer make_symmetric_matrices below)
+def random_symmetric_matrix(l, *batches, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    A = (A + A.mT).div_(2)
+    return A
+
+# Creates a symmetric matrix or batch of symmetric matrices
+# Shape must be a square matrix or batch of square matrices
+def make_symmetric_matrices(*shape, device, dtype):
+    assert shape[-1] == shape[-2]
+    t = make_tensor(shape, device=device, dtype=dtype)
+    t = (t + t.mT).div_(2)
+    return t
+
+def random_hermitian_matrix(l, *batches, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    A = (A + A.mH).div_(2)
+    return A
+
+
+def random_symmetric_psd_matrix(l, *batches, **kwargs):
+    """
+    Returns a batch of random symmetric positive-semi-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_symmetric_psd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device)
+    return A @ A.mT
+
+
+def random_hermitian_psd_matrix(matrix_size, *batch_dims, dtype=torch.double, device='cpu'):
+    """
+    Returns a batch of random Hermitian positive-semi-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_hermitian_psd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device)
+    return A @ A.mH
+
+
+# TODO: remove this (prefer make_symmetric_pd_matrices below)
+def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs):
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
+                    dtype=dtype, device=device)
+    return torch.matmul(A, A.mT) \
+        + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5
+
+
+# Creates a symmetric positive-definite matrix or batch of
+#   such matrices
+def make_symmetric_pd_matrices(*shape, device, dtype):
+    assert shape[-1] == shape[-2]
+    t = make_tensor(shape, device=device, dtype=dtype)
+    i = torch.eye(shape[-1], device=device, dtype=dtype) * 1e-5
+    return t @ t.mT + i
+
+def random_hermitian_pd_matrix(matrix_size, *batch_dims, dtype, device):
+    """
+    Returns a batch of random Hermitian positive-definite matrices.
+    The shape of the result is batch_dims + (matrix_size, matrix_size)
+    The following example creates a tensor of size 2 x 4 x 3 x 3
+    >>> # xdoctest: +SKIP("undefined variables")
+    >>> matrices = random_hermitian_pd_matrix(3, 2, 4, dtype=dtype, device=device)
+    """
+    A = torch.randn(*(batch_dims + (matrix_size, matrix_size)),
+                    dtype=dtype, device=device)
+    return A @ A.mH + torch.eye(matrix_size, dtype=dtype, device=device)
+
+# Creates a full rank matrix with distinct singular values or
+#   a batch of such matrices
+def make_fullrank_matrices_with_distinct_singular_values(*shape, device, dtype, requires_grad=False):
+    with torch.no_grad():
+        t = make_tensor(shape, device=device, dtype=dtype)
+        u, _, vh = torch.linalg.svd(t, full_matrices=False)
+        real_dtype = t.real.dtype if t.dtype.is_complex else t.dtype
+        k = min(shape[-1], shape[-2])
+        # We choose the singular values to be "around one"
+        # This is to make the matrix well conditioned
+        # s = [2, 3, ..., k+1]
+        s = torch.arange(2, k + 2, dtype=real_dtype, device=device)
+        # s = [2, -3, 4, ..., (-1)^k k+1]
+        s[1::2] *= -1.
+        # 1 + 1/s so that the singular values are in the range [2/3, 3/2]
+        # This gives a condition number of 9/4, which should be good enough
+        s.reciprocal_().add_(1.)
+        # Note that the singular values need not be ordered in an SVD so
+        # we don't need need to sort S
+        x = (u * s.to(u.dtype)) @ vh
+    x.requires_grad_(requires_grad)
+    return x
+
+def random_matrix(rows, columns, *batch_dims, **kwargs):
+    """Return rectangular matrix or batches of rectangular matrices.
+
+    Parameters:
+      dtype - the data type
+      device - the device kind
+      singular - when True, the output will be singular
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    silent = kwargs.get("silent", False)
+    singular = kwargs.get("singular", False)
+    if silent and not torch._C.has_lapack:
+        return torch.ones(rows, columns, dtype=dtype, device=device)
+
+    A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device)
+    if A.numel() == 0:
+        return A
+    u, _, vh = torch.linalg.svd(A, full_matrices=False)
+    k = min(rows, columns)
+    s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device)
+    if singular:
+        # make matrix singular
+        s[k - 1] = 0
+        if k > 2:
+            # increase the order of singularity so that the pivoting
+            # in LU factorization will be non-trivial
+            s[0] = 0
+    return (u * s.unsqueeze(-2)) @ vh
+
+
+def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):
+    """Return rectangular matrix or batches of rectangular matrices with
+    given rank.
+    """
+    B = random_matrix(rows, rank, *batch_dims, **kwargs)
+    C = random_matrix(rank, columns, *batch_dims, **kwargs)
+    return B.matmul(C)
+
+
+def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor:
+    """Generate indices for a row x cols matrix, preferring at least one index per row if possible."""
+    indices = []  # type: ignore[var-annotated]
+    n_per_row = math.ceil(num_indices / rows)
+    col_indices = list(range(cols))
+
+    for r in range(rows):
+        # Note that this can yield overlapping indices
+        indices.extend((r, c) for c in random.choices(col_indices, k=n_per_row))
+
+    return torch.tensor(indices[:num_indices])
+
+
+def random_sparse_matrix(rows, columns, density=0.01, **kwargs):
+    """Return rectangular random sparse matrix within given density.
+
+    The density of the result approaches to given density as the size
+    of the matrix is increased and a relatively small value of density
+    is specified but higher than min(rows, columns)/(rows * columns)
+    for non-singular matrices.
+    """
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+
+    nonzero_elements = max(min(rows, columns), int(rows * columns * density))
+    indices = _generate_indices_prefer_all_rows(rows, columns, nonzero_elements)
+    values = torch.randn(nonzero_elements, dtype=dtype, device=device)
+
+    # ensure that the diagonal dominates
+    values *= torch.tensor([-float(i - j)**2 for i, j in indices], dtype=dtype, device=device).exp()
+    A = torch.sparse_coo_tensor(indices.t(), values, (rows, columns), device=device)
+    return A.coalesce()
+
+
+def random_sparse_pd_matrix(matrix_size, density=0.01, **kwargs):
+    """Return random sparse positive-definite matrix with given density.
+
+    The eigenvalues of the matrix are defined as::
+      arange(1, matrix_size+1)/matrix_size
+
+    Algorithm:
+      A = diag(arange(1, matrix_size+1)/matrix_size)
+      while :
+          
+          R = 
+          A = R^T A R
+    """
+    import math
+    torch = kwargs.get('torch', globals()['torch'])
+    dtype = kwargs.get('dtype', torch.double)
+    device = kwargs.get('device', 'cpu')
+    data = {(i, i): float(i + 1) / matrix_size
+            for i in range(matrix_size)}
+
+
+    def multiply(data, N, i, j, cs, sn, left=True):
+        for k in range(N):
+            if left:
+                ik, jk = (k, i), (k, j)
+            else:
+                ik, jk = (i, k), (j, k)
+            aik, ajk = data.get(ik, 0), data.get(jk, 0)
+            aik, ajk = cs * aik + sn * ajk, -sn * aik + cs * ajk
+            if aik:
+                data[ik] = aik
+            else:
+                data.pop(ik, None)
+            if ajk:
+                data[jk] = ajk
+            else:
+                data.pop(jk, None)
+
+    target_nnz = density * matrix_size * matrix_size
+    while len(data) < target_nnz:
+        i = random.randint(0, matrix_size - 1)
+        j = random.randint(0, matrix_size - 1)
+        if i != j:
+            theta = random.uniform(0, 2 * math.pi)
+            cs = math.cos(theta)
+            sn = math.sin(theta)
+            multiply(data, matrix_size, i, j, cs, sn, left=True)
+            multiply(data, matrix_size, i, j, cs, sn, left=False)
+    icoords, jcoords, values = [], [], []
+    for (i, j), v in sorted(data.items()):
+        icoords.append(i)
+        jcoords.append(j)
+        values.append(v)
+    indices_tensor = torch.tensor([icoords, jcoords])
+    return torch.sparse_coo_tensor(indices_tensor, values, (matrix_size, matrix_size), dtype=dtype, device=device)
+
+# FIXME: remove this by updating test suites using it
+def do_test_dtypes(self, dtypes, layout, device):
+    for dtype in dtypes:
+        if dtype != torch.float16:
+            out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
+            self.assertIs(dtype, out.dtype)
+            self.assertIs(layout, out.layout)
+            self.assertEqual(device, out.device)
+
+# FIXME: remove this by updating test suites using it
+def do_test_empty_full(self, dtypes, layout, device):
+    shape = torch.Size([2, 3])
+
+    def check_value(tensor, dtype, layout, device, value, requires_grad):
+        self.assertEqual(shape, tensor.shape)
+        self.assertIs(dtype, tensor.dtype)
+        self.assertIs(layout, tensor.layout)
+        self.assertEqual(tensor.requires_grad, requires_grad)
+        if tensor.is_cuda and device is not None:
+            self.assertEqual(device, tensor.device)
+        if value is not None:
+            fill = tensor.new(shape).fill_(value)
+            self.assertEqual(tensor, fill)
+
+    def get_int64_dtype(dtype):
+        module = '.'.join(str(dtype).split('.')[1:-1])
+        if not module:
+            return torch.int64
+        return operator.attrgetter(module)(torch).int64
+
+    default_dtype = torch.get_default_dtype()
+    check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False)
+    check_value(torch.full(shape, -5.), default_dtype, torch.strided, -1, None, False)
+    for dtype in dtypes:
+        for rg in {dtype.is_floating_point, False}:
+            int64_dtype = get_int64_dtype(dtype)
+            v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
+            check_value(v, dtype, layout, device, None, rg)
+            out = v.new()
+            check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg),
+                        dtype, layout, device, None, rg)
+            check_value(v.new_empty(shape), dtype, layout, device, None, False)
+            check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False),
+                        int64_dtype, layout, device, None, False)
+            check_value(torch.empty_like(v), dtype, layout, device, None, False)
+            check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
+                        int64_dtype, layout, device, None, False)
+
+            if dtype is not torch.float16 and layout != torch.sparse_coo:
+                fv = 3
+                v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg)
+                check_value(v, dtype, layout, device, fv, rg)
+                check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1, False)
+                out = v.new()
+                check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg),
+                            dtype, layout, device, fv + 2, rg)
+                check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=False),
+                            int64_dtype, layout, device, fv + 3, False)
+                check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False)
+                check_value(torch.full_like(v, fv + 5,
+                                            dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
+                            int64_dtype, layout, device, fv + 5, False)
+
+# FIXME: improve load_tests() documentation here
+running_script_path = None  # type: ignore[var-annotated]
+def set_running_script_path():
+    global running_script_path
+    try:
+        running_file = os.path.abspath(os.path.realpath(sys.argv[0]))
+        if running_file.endswith('.py'):  # skip if the running file is not a script
+            running_script_path = running_file
+    except Exception:
+        pass
+
+def check_test_defined_in_running_script(test_case):
+    if running_script_path is None:
+        return
+    test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__)))
+    assert test_case_class_file == running_script_path, f'Class of loaded TestCase "{test_case.id()}" ' \
+        f'is not defined in the running script "{running_script_path}", but in "{test_case_class_file}". Did you ' \
+        "accidentally import a unittest.TestCase from another file?"
+
+def load_tests(loader, tests, pattern):
+    set_running_script_path()
+    test_suite = unittest.TestSuite()
+    for test_group in tests:
+        if not DISABLE_RUNNING_SCRIPT_CHK:
+            for test in test_group:
+                check_test_defined_in_running_script(test)
+        if test_group._tests:
+            test_suite.addTest(test_group)
+    return test_suite
+
+# FIXME: document this and move it to test_serialization
+class BytesIOContext(io.BytesIO):
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args):
+        pass
+
+# Tentative value for nondet_tol for gradcheck when backward implementation
+# relies on nondeterministic operations, i.e., those listed here:
+# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
+#
+# For more information see https://github.com/pytorch/pytorch/issues/56202
+GRADCHECK_NONDET_TOL = 1e-12
+
+TEST_WITH_SLOW_GRADCHECK: bool = TestEnvironment.def_flag(
+    "TEST_WITH_SLOW_GRADCHECK",
+    env_var="PYTORCH_TEST_WITH_SLOW_GRADCHECK",
+)
+
+skipIfSlowGradcheckEnv = unittest.skipIf(
+    TEST_WITH_SLOW_GRADCHECK,
+    "Tests that don't use gradcheck don't need to run on slow_gradcheck CI",
+)
+
+
+def gradcheck(fn, inputs, **kwargs):
+    # Wrapper around gradcheck that enables certain keys by default.
+    # Use this testing-internal gradcheck instead of autograd.gradcheck so that new features like vmap and
+    # forward-mode AD are tested by default. We create this wrapper because we'd like to keep new checks
+    # to be disabled to default for the public-facing api to avoid breaking user code.
+    #
+    # All PyTorch devs doing testing should use this wrapper instead of autograd.gradcheck.
+    default_values = {
+        "check_batched_grad": True,
+        "fast_mode": True,
+    }
+
+    if TEST_WITH_SLOW_GRADCHECK:
+        default_values["fast_mode"] = False
+
+    for key, value in default_values.items():
+        # default value override values explicitly set to None
+        k = kwargs.get(key, None)
+        kwargs[key] = k if k is not None else value
+
+    return torch.autograd.gradcheck(fn, inputs, **kwargs)
+
+def gradgradcheck(fn, inputs, grad_outputs=None, **kwargs):
+    # Wrapper around gradgradcheck that enables certain keys by default
+    # See gradcheck above for an explanation of why we need something like this.
+    #
+    # All PyTorch devs doing testing should use this wrapper instead of autograd.gradgradcheck
+    default_values = {
+        "check_batched_grad": True,
+        "fast_mode": True,
+    }
+
+    if TEST_WITH_SLOW_GRADCHECK:
+        default_values["fast_mode"] = False
+
+    for key, value in default_values.items():
+        # default value override values explicitly set to None
+        k = kwargs.get(key, None)
+        kwargs[key] = k if k is not None else value
+
+    return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)
+
+
+def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, **kwargs):
+    # call assert function rather than returning a bool since it's nicer
+    # if we get whether this failed on the gradcheck or the gradgradcheck.
+    test_case.assertTrue(gradcheck(apply_fn, inputs, **kwargs))
+    test_case.assertTrue(gradgradcheck(apply_fn, inputs, **kwargs))
+
+
+@contextmanager
+def set_cwd(path: str) -> Iterator[None]:
+    old_cwd = os.getcwd()
+    try:
+        os.chdir(path)
+        yield
+    finally:
+        os.chdir(old_cwd)
+
+
+# FIXME: delete this
+# Using @toleranceOverride specific to your test is the recommended way
+# of doing this. These are just some values that worked for test_nn.
+dtype2prec_DONTUSE = {torch.float: 1e-5,
+                      torch.double: 1e-5,
+                      torch.half: 1e-2,
+                      torch.bfloat16: 1e-1}
+
+# FIXME: move to test_sparse or sparse utils
+# This is a wrapper that wraps a test to run this test twice, one with
+# coalesced=True, another with coalesced=False for coalesced/uncoalesced sparse tensors.
+def coalescedonoff(f):
+    @wraps(f)
+    def wrapped(self, *args, **kwargs):
+        f(self, *args, **kwargs, coalesced=True)
+        f(self, *args, **kwargs, coalesced=False)
+    return wrapped
+
+
+def is_coalesced_indices(s):
+    indices = s._indices()
+    hash_coeffs = (1,) + s.shape[s.sparse_dim() - 1:0:-1]
+    hash_indices = torch.tensor(hash_coeffs, device=s.device).cumprod(-1).flip(-1)
+    if s.sparse_dim() > 1:
+        hash_indices.unsqueeze_(-1)
+        hash_indices = (indices * hash_indices).sum(0)
+    else:
+        hash_indices = indices * hash_indices
+
+    # check if indices are sorted
+    res = torch.allclose(hash_indices, hash_indices.sort()[0])
+
+    # check if there are no repeated indices
+    res = res and torch.allclose(hash_indices, hash_indices.unique())
+
+    return res
+
+
+@contextlib.contextmanager
+def disable_gc():
+    if gc.isenabled():
+        try:
+            gc.disable()
+            yield
+        finally:
+            gc.enable()
+    else:
+        yield
+
+
+def find_library_location(lib_name: str) -> Path:
+    # return the shared library file in the installed folder if exist,
+    # else the file in the build folder
+    torch_root = Path(torch.__file__).resolve().parent
+    path = torch_root / 'lib' / lib_name
+    if os.path.exists(path):
+        return path
+    torch_root = Path(__file__).resolve().parents[2]
+    return torch_root / 'build' / 'lib' / lib_name
+
+def skip_but_pass_in_sandcastle(reason):
+    """
+    Similar to unittest.skip, however in the sandcastle environment it just
+    "passes" the test instead to avoid creating tasks complaining about tests
+    skipping continuously.
+    """
+    def decorator(func):
+        if not IS_SANDCASTLE:
+            func.__unittest_skip__ = True
+            func.__unittest_skip_why__ = reason
+            return func
+
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
+            return
+        return wrapper
+
+    return decorator
+
+def mock_wrapper(method):
+    """
+    Returns a function that calls the real implementation of a method
+    in addition to passing args to a mock object.
+    """
+    mock = MagicMock()
+
+    @wraps(method)
+    def wrapper(self, *args, **kwargs):
+        mock(*args, **kwargs)
+        return method(self, *args, **kwargs)
+    wrapper.mock = mock  # type: ignore[attr-defined]
+    return wrapper
+
+def get_tensors_from(args, kwargs):
+    """ Returns a set of all Tensor objects in the given args and kwargs. """
+    return set([arg for arg in args if isinstance(arg, Tensor)] +
+               [v for v in kwargs.values() if isinstance(v, Tensor)])
+
+
+# Returns scalar tensor representation of a list of integer byte values
+def bytes_to_scalar(byte_list: list[int], dtype: torch.dtype, device: torch.device):
+    dtype_to_ctype: dict[torch.dtype, Any] = {
+        torch.int8: ctypes.c_int8,
+        torch.uint8: ctypes.c_uint8,
+        torch.uint16: ctypes.c_uint16,
+        torch.uint32: ctypes.c_uint32,
+        torch.uint64: ctypes.c_uint64,
+        torch.int16: ctypes.c_int16,
+        torch.int32: ctypes.c_int32,
+        torch.int64: ctypes.c_int64,
+        torch.bool: ctypes.c_bool,
+        torch.float32: ctypes.c_float,
+        torch.complex64: ctypes.c_float,
+        torch.float64: ctypes.c_double,
+        torch.complex128: ctypes.c_double,
+    }
+    ctype = dtype_to_ctype[dtype]
+    num_bytes = ctypes.sizeof(ctype)
+
+    def check_bytes(byte_list):
+        for byte in byte_list:
+            assert 0 <= byte <= 255
+
+    if dtype.is_complex:
+        assert len(byte_list) == (num_bytes * 2)
+        check_bytes(byte_list)
+        real = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list[:num_bytes])).value
+        imag = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list[num_bytes:])).value
+        res = real + 1j * imag
+    else:
+        assert len(byte_list) == num_bytes
+        check_bytes(byte_list)
+        res = ctype.from_buffer((ctypes.c_byte * num_bytes)(
+            *byte_list)).value
+
+    return torch.tensor(res, device=device, dtype=dtype)
+
+
+def copy_func(f):
+    """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)"""
+    g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__,
+                           argdefs=f.__defaults__,
+                           closure=f.__closure__)
+    g = functools.update_wrapper(g, f)
+    g.__kwdefaults__ = f.__kwdefaults__  # type: ignore[attr-defined]
+    return g
+
+
+def xfail_inherited_tests(tests):
+    """
+    Given a list of test names which are defined by a superclass of the
+    class this decorates, mark them as expected failure.  This is useful
+    if you are doing poor man's parameterized tests by subclassing a generic
+    test class.
+    """
+    def deco(cls):
+        for t in tests:
+            # NB: expectedFailure operates by mutating the method in question,
+            # which is why you have to copy the function first
+            setattr(cls, t, unittest.expectedFailure(copy_func(getattr(cls, t))))
+        return cls
+    return deco
+
+
+def skip_but_pass_in_sandcastle_if(condition, reason):
+    """
+    Similar to unittest.skipIf, however in the sandcastle environment it just
+    "passes" the test instead to avoid creating tasks complaining about tests
+    skipping continuously.
+    """
+    def decorator(func):
+        if condition:
+            if IS_SANDCASTLE:
+                @wraps(func)
+                def wrapper(*args, **kwargs):
+                    print(f'Skipping {func.__name__} on sandcastle for following reason: {reason}', file=sys.stderr)
+                return wrapper
+            else:
+                func.__unittest_skip__ = True
+                func.__unittest_skip_why__ = reason
+
+        return func
+
+    return decorator
+
+def dtype_name(dtype):
+    """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """
+    return str(dtype).split('.')[1]
+
+
+@functools.lru_cache
+def get_cycles_per_ms() -> float:
+    """Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
+    """
+
+    def measure() -> float:
+        start = torch.cuda.Event(enable_timing=True)
+        end = torch.cuda.Event(enable_timing=True)
+        start.record()
+        torch.cuda._sleep(1000000)
+        end.record()
+        end.synchronize()
+        cycles_per_ms = 1000000 / start.elapsed_time(end)
+        return cycles_per_ms
+
+    # Get 10 values and remove the 2 max and 2 min and return the avg.
+    # This is to avoid system disturbance that skew the results, e.g.
+    # the very first cuda call likely does a bunch of init, which takes
+    # much longer than subsequent calls.
+    #
+    # Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
+    # and seems to return stable values. Therefore, we enable caching
+    # using lru_cache decorator above.
+    num = 10
+    vals = [measure() for _ in range(num)]
+    vals = sorted(vals)
+    return mean(vals[2 : num - 2])
+
+
+# OpInfo utils
+
+T = TypeVar('T')
+def first_sample(self: unittest.TestCase, samples: Iterable[T]) -> T:
+    """
+    Returns the first sample from an iterable of samples, like those returned by OpInfo.
+    The test will be skipped if no samples are available.
+    """
+    try:
+        return next(iter(samples))
+    except StopIteration as e:
+        raise unittest.SkipTest('Skipped! Need at least 1 sample input') from e
+
+# this helper method is to recursively
+# clone the tensor-type input of operators tested by OpInfo
+def clone_input_helper(input):
+    if isinstance(input, torch.Tensor):
+        return torch.clone(input)
+
+    if isinstance(input, Sequence):
+        return tuple(map(clone_input_helper, input))
+
+    return input
+
+@contextmanager
+def custom_op(opname, symbolic_fn, opset_version):
+    """Context manager/decorator to test ONNX export with custom operator"""
+    try:
+        register_custom_op_symbolic(opname, symbolic_fn, opset_version)
+        yield
+    finally:
+        unregister_custom_op_symbolic(opname, opset_version)
+
+
+def outs_and_grads(fn, graph_inps, inps):
+    outs = fn(*graph_inps)
+    for out in pytree.tree_leaves(outs):
+        if isinstance(out, torch.Tensor) and out.requires_grad:
+            out.sum().backward(retain_graph=True)
+    grads = [inp.grad for inp in pytree.tree_leaves(inps) if isinstance(inp, torch.Tensor)]
+    for inp in pytree.tree_leaves(inps):
+        if isinstance(inp, torch.Tensor):
+            inp.grad = None
+    return outs, grads
+
+def compare_equal_outs_and_grads(test, m1, m2, inps):
+    r1, g1 = outs_and_grads(m1, inps, inps)
+    r2, g2 = outs_and_grads(m2, inps, inps)
+    test.assertEqual(r1, r2)
+    test.assertEqual(g1, g2)
+
+class TestGradients(TestCase):
+    exact_dtype = True
+
+    # Copies inputs to inplace operations to avoid inplace modifications
+    #   to leaves requiring gradient
+    def _get_safe_inplace(self, inplace_variant):
+        @wraps(inplace_variant)
+        def _fn(t, *args, **kwargs):
+            return inplace_variant(t.clone(), *args, **kwargs)
+
+        return _fn
+
+    def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
+                      check_batched_grad=None, check_batched_forward_grad=False):
+        assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
+        # NB: check_backward_ad does not affect gradgradcheck (always True)
+        if variant is None:
+            self.skipTest("Skipped! Variant not implemented.")
+        if not op.supports_dtype(dtype, torch.device(device).type):
+            self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
+
+        def is_inplace(variant):
+            if hasattr(variant, "__wrapped__"):
+                return variant.__wrapped__ is op.get_inplace()
+            return variant is op.get_inplace()
+
+        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
+
+        samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs,
+                                   small_inputs_only=TEST_WITH_SLOW_GRADCHECK)
+
+        for sample in samples:
+            if sample.broadcasts_input and is_inplace(variant):
+                continue
+
+            # Gradcheck expects tensors as its input, but autograd actually supports tensorlists
+            #   and tensors passed as kwargs. The following creates a function that accepts just
+            #   the tensors that require grad as varargs, and then recomposes them back into the
+            #   original input.
+
+            # Creates gradcheck inputs by identifying tensors requiring grad
+            all_args = None
+            if is_iterable_of_tensors(sample.input):
+                all_args = chain(sample.input, sample.args, sample.kwargs.values())
+            else:
+                all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))  # type: ignore[assignment]
+            gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))  # type: ignore[union-attr]
+
+            # Verifies sample input tensors should have no grad
+            # This may happen if the same tensor is used in two different SampleInputs
+            for t in gradcheck_args:
+                self.assertIsNone(t.grad,
+                                  "A sampled input has a gradient before running autograd. "
+                                  "This usually means that (at least) one input tensor is reused "
+                                  "across different SampleInputs. "
+                                  "Please create a new tensor for each SampleInput.")
+
+            def _input_recomposition_helper(inputs, inp, input_idx):
+                if is_iterable_of_tensors(inp):
+                    tensor_list = []
+                    for x in inp:
+                        if isinstance(x, torch.Tensor) and x.requires_grad:
+                            tensor_list.append(inputs[input_idx])
+                            input_idx = input_idx + 1
+                        else:
+                            tensor_list.append(x)
+                    return tensor_list, input_idx
+                elif isinstance(inp, torch.Tensor) and inp.requires_grad:
+                    return inputs[input_idx], input_idx + 1
+                else:
+                    return inp, input_idx
+
+            def fn(*inputs):
+                # Puts inputs back into sample properly
+                positional_args = []
+                input_idx = 0
+                inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
+                positional_args.append(inp)
+
+                for x in sample.args:
+                    inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
+                    positional_args.append(inp)
+
+                # Recreates kwargs
+                kwargs = {}
+                for k, v in sample.kwargs.items():
+                    inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
+                    kwargs[k] = inp
+
+                output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
+                if sample.output_process_fn_grad is not None:
+                    return sample.output_process_fn_grad(output)
+                return output
+
+            if check == 'gradcheck':
+                if check_batched_grad is None:
+                    check_batched_grad = op.check_batched_grad
+                self.assertTrue(gradcheck(fn, gradcheck_args,
+                                          check_batched_grad=check_batched_grad,
+                                          check_grad_dtypes=True,
+                                          nondet_tol=op.gradcheck_nondet_tol,
+                                          fast_mode=op.gradcheck_fast_mode,
+                                          check_forward_ad=check_forward_ad,
+                                          check_backward_ad=check_backward_ad,
+                                          check_undefined_grad=True,
+                                          check_batched_forward_grad=check_batched_forward_grad))
+            elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'):  # gradgrad check
+                self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
+                for gen_non_contig_grad_outputs in (False, True):
+                    kwargs = {
+                        "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
+                        "check_batched_grad": op.check_batched_gradgrad,
+                        "check_grad_dtypes": True,
+                        "nondet_tol": op.gradcheck_nondet_tol,
+                        "fast_mode": op.gradcheck_fast_mode
+                    }
+                    if check == "fwgrad_bwgrad":
+                        kwargs["check_fwd_over_rev"] = True
+                        kwargs["check_rev_over_rev"] = False
+                        kwargs["check_batched_grad"] = False
+                        kwargs["check_undefined_grad"] = False
+
+                    self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
+            else:
+                self.assertTrue(False, msg="Unknown check requested!")
+
+    def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
+                          check_batched_grad=None, check_batched_forward_grad=False):
+        return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
+                                  check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
+                                  check_batched_forward_grad=check_batched_forward_grad)
+
+    def _skip_helper(self, op, device, dtype):
+        if dtype not in op.supported_backward_dtypes(torch.device(device).type):
+            self.skipTest("Skipped! Op doesn't support autograd for this dtype.")
+        if not op.supports_autograd and not op.supports_forward_ad:
+            self.skipTest("Skipped! autograd not supported.")
+
+def make_lazy_class(cls):
+
+    def lazy_init(self, cb):
+        self._cb = cb
+        self._value = None
+
+    cls.__init__ = lazy_init
+
+    for basename in [
+        "add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow",
+        "lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert",
+        "eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index",
+    ]:
+        name = f"__{basename}__"
+
+        def inner_wrapper(name):
+            use_operator = basename not in ("bool", "int")
+
+            def wrapped(self, *args, **kwargs):
+                if self._cb is not None:
+                    self._value = self._cb()
+                    self._cb = None
+                if not use_operator:
+                    return getattr(self._value, name)(*args, **kwargs)
+                else:
+                    return getattr(operator, name)(self._value, *args, **kwargs)
+            return wrapped
+
+        setattr(cls, name, inner_wrapper(name))
+
+    return cls
+
+
+# Base TestCase for NT tests; used to define common helpers, etc.
+class NestedTensorTestCase(TestCase):
+    def assertEqualIgnoringNestedInts(self, a, b):
+        # unbinding NJTs allows us to compare them as essentially equal without
+        # caring about exact nested int comparison
+        def _unbind_njts(x):
+            if isinstance(x, torch.Tensor) and x.is_nested and x.layout == torch.jagged:
+                return x.unbind()
+            else:
+                return x
+
+        self.assertEqual(pytree.tree_map(_unbind_njts, a), pytree.tree_map(_unbind_njts, b))
+
+    def assertEqualNoncontigAware(self, a, b):
+        # assertEqual() doesn't take into account lengths, so hack around this
+        # by comparing unbound components and shapes
+        self.assertEqualIgnoringNestedInts(a, b)
+
+        def _get_njt_shapes(x):
+            return (
+                x.shape
+                if isinstance(x, torch.Tensor) and x.is_nested
+                else None
+            )
+
+        a_shapes = pytree.tree_map(_get_njt_shapes, a)
+        b_shapes = pytree.tree_map(_get_njt_shapes, b)
+        self.assertEqual(a_shapes, b_shapes)
+
+    @contextlib.contextmanager
+    def branch_nested_state(self):
+        """Context manager to branch and restore the nested tensor state."""
+        nested_tensor_module = torch.nested._internal.nested_tensor
+        original_tensor_symint_registry = nested_tensor_module._tensor_symint_registry.copy()
+        original_tensor_id_counter = nested_tensor_module._tensor_id_counter
+        try:
+            yield
+        finally:
+            nested_tensor_module._tensor_id_counter = original_tensor_id_counter
+            nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry
+
+
+@make_lazy_class
+class LazyVal:
+    pass
+
+
+def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0):
+    if file is None:
+        file = inspect.stack()[1 + skip].filename  # skip one frame
+
+    file = _as_posix_path(file)
+    s = _as_posix_path(str(e))
+
+    # Remove everything that looks like stack frames in NOT this file
+    def repl_frame(m):
+        if m.group(1) != file:
+            return ""
+        # Don't accept top-level, even for this script, these will wobble
+        # depending on how the testing script was invoked
+        if m.group(2) == "":
+            return ""
+
+        return m.group(0)
+
+    s = re.sub(r'  File "([^"]+)", line \d+, in (.+)\n(    .+\n( +[~^]+ *\n)?)+', repl_frame, s)
+    s = re.sub(r"line \d+", "line N", s)
+    s = re.sub(r".py:\d+", ".py:N", s)
+    s = re.sub(r'https:/([a-zA-Z0-9_.-]+)', r'https://\1', s)
+    s = re.sub(file, _as_posix_path(os.path.basename(file)), s)
+    s = re.sub(_as_posix_path(os.path.join(os.path.dirname(torch.__file__), "")), "", s)
+    if suppress_suffix:
+        s = re.sub(r"\n*Set TORCH_LOGS.+", "", s, flags=re.DOTALL)
+        s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
+        s = re.sub(r"\n*Set TORCHDYNAMO_VERBOSE=1.+", "", s, flags=re.DOTALL)
+    if suppress_prefix:
+        s = re.sub(r"Cannot export model.+\n\n", "", s)
+    s = re.sub(r" +$", "", s, flags=re.MULTILINE)
+    return s
+
+
+@contextmanager
+def check_leaked_tensors(limit=1, matched_type=torch.Tensor):
+    """Wrap around operations you want to ensure are not leaking tensor memory.
+
+    This code intentionally ignores other reference cycles, which can be benign and which we have plenty
+    of in pytorch code.  It focuses on any reference cycles that directly or indirectly result holding a Tensor alive,
+    since this is likely a more serious leak than typical python refcycles.
+
+    limit specifies how many tensors to dump debug graphs for (default=1)
+    """
+    def match_obj(obj):
+        return isinstance(obj, matched_type)
+
+    try:
+        gc.collect()
+        gc.set_debug(gc.DEBUG_SAVEALL)
+        garbage_objs = []  # type: ignore[var-annotated]
+
+        # run the user code, after cleaning any existing refcycles, and then check for new ones
+        # also allow usercode to check the garbage objs (e.g. for assertion) after exiting ctxmgr
+        yield garbage_objs
+
+        gc.collect()
+        garbage_objs.extend(filter(match_obj, gc.garbage))
+        num_garbage_objs = len(garbage_objs)
+        if num_garbage_objs > 0:
+            warnings.warn(
+                f"{num_garbage_objs} tensors were found in the garbage. Did you introduce a reference cycle?"
+            )
+            try:
+                import objgraph  # type: ignore[import-not-found,import-untyped]
+                warnings.warn(
+                    f"Dumping first {limit} objgraphs of leaked {matched_type}s rendered to png"
+                )
+                for g in garbage_objs[:limit]:
+                    objgraph.show_backrefs([g], max_depth=10)
+            except ImportError:
+                warnings.warn("`pip install objgraph` to enable memory leak debugging")
+
+    finally:
+        gc.set_debug(0)
+
+
+def remove_cpp_extensions_build_root():
+    """
+    Removes the default root folder under which extensions are built.
+    """
+    default_build_root = cpp_extension.get_default_build_root()
+    if os.path.exists(default_build_root):
+        if IS_WINDOWS:
+            # rmtree returns permission error: [WinError 5] Access is denied
+            # on Windows, this is a workaround
+            subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE)
+        else:
+            shutil.rmtree(default_build_root, ignore_errors=True)
+
+
+def install_cpp_extension(extension_root):
+    # Wipe the build / install dirs if they exist
+    build_dir = os.path.join(extension_root, "build")
+    install_dir = os.path.join(extension_root, "install")
+    for d in (build_dir, install_dir):
+        if os.path.exists(d):
+            shutil.rmtree(d)
+
+    # Build the extension
+    setup_py_path = os.path.join(extension_root, "setup.py")
+    cmd = [sys.executable, setup_py_path, "install", "--root", install_dir]
+    return_code = shell(cmd, cwd=extension_root, env=os.environ)
+    if return_code != 0:
+        raise RuntimeError(f"build failed for cpp extension at {extension_root}")
+
+    mod_install_dir = None
+    # install directory is the one that is named site-packages
+    for root, directories, _ in os.walk(install_dir):
+        for directory in directories:
+            if "-packages" in directory:
+                mod_install_dir = os.path.join(root, directory)
+
+    if mod_install_dir is None:
+        raise RuntimeError(f"installation failed for cpp extension at {extension_root}")
+
+    if mod_install_dir not in sys.path:
+        sys.path.insert(0, mod_install_dir)
+
+
+# Decorator to provide a helper to load inline extensions to a temp directory
+def scoped_load_inline(func):
+
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        def load_inline(*args, **kwargs):
+            if IS_WINDOWS:
+                # TODO(xmfan): even using TemporaryDirectoryName will result in permission error
+                return cpp_extension.load_inline(*args, **kwargs)
+
+            assert "build_directory" not in kwargs
+            with TemporaryDirectoryName() as temp_dir_name:
+                if kwargs.get("verbose", False):
+                    print(f'Using temporary extension directory {temp_dir_name}...', file=sys.stderr)
+                kwargs["build_directory"] = temp_dir_name
+                return cpp_extension.load_inline(*args, **kwargs)
+
+        return func(*args, load_inline=load_inline, **kwargs)
+
+    return wrapper
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py
new file mode 100644
index 0000000000000000000000000000000000000000..8007d3563091649bb5f84199da0c38f35d32f6a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/composite_compliance.py
@@ -0,0 +1,608 @@
+# mypy: ignore-errors
+
+import torch
+from torch import Tensor
+import itertools
+
+from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
+from torch.utils import _pytree as pytree
+from functools import partial
+from torch.utils._mode_utils import no_dispatch, all_same_mode
+import torch.autograd.forward_ad as fwAD
+from typing import Callable
+import re
+
+
+def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor):
+    elem = wrapper_tensor.elem
+    metadata_wrapper_tensor = metadata_accessor(wrapper_tensor)
+    metadata_elem = metadata_accessor(elem)
+    if metadata_wrapper_tensor == metadata_elem:
+        return
+    raise RuntimeError(
+        f"This operator is not Composite Compliant: the "
+        f"{metadata_name} of the tensor was modified directly without "
+        f"going through the PyTorch dispatcher.")
+
+def check_metadata_consistency(wrapper_tensor, CCT):
+    # CCT: CompositeCompliantTensor class which is generated using generate_cct
+    if not isinstance(wrapper_tensor, CCT):
+        return
+    things_to_check = {
+        'shape': Tensor.size,
+        'dtype': lambda x: x.dtype,
+        'device': lambda x: x.device,
+        'numel': Tensor.numel,
+        'stride': Tensor.stride,
+        'storage_offset': Tensor.storage_offset,
+    }
+    for metadata_name, metadata_accessor in things_to_check.items():
+        check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
+
+def is_view_fn(func):
+    return func.overloadpacket.__name__ in {
+        'as_strided',
+        'detach',
+        'diagonal',
+        'expand',
+        'expand_as',
+        'movedim',
+        'narrow',
+        'permute',
+        'select',
+        'squeeze',
+        'transpose',
+        't',
+        'real',
+        'imag',
+        'view_as_real',
+        'view_as_complex',
+        'unflatten',
+        'unfold',
+        'unsqueeze',
+        'view',
+        'view_as',
+        'unbind',
+        'split',
+        'split_with_sizes',
+        'vsplit',
+        'hsplit',
+        'tensor_split',
+        'chunk',
+        'swapaxes',
+        'slice',
+        '_reshape_alias',
+        '_unsafe_view',
+        '_conj',
+        'alias',
+    }
+
+# manually populated from native_functions that have inplace_view: True.
+# In the future we will probably be able to grab that list directly
+def is_inplace_view_fn(func):
+    return func.overloadpacket.__name__ in {
+        'as_strided_',
+        'detach_',
+        'squeeze_',
+        'swapaxes_',
+        'swapdims_',
+        't_',
+        'transpose_',
+        'unsqueeze_',
+    }
+
+
+# Introspection please save us
+def is_inplace(func):
+    name = func.overloadpacket.__name__
+    if re.match('__i.+__', name):
+        return True
+    if re.match('__.+__', name):
+        return False
+    return name[-1] == '_'
+
+
+def generate_cct_and_mode(autograd_view_consistency=True):
+    # This function returns a new class CompositeCompliantTensor
+    # The two arguments control the behaviour described below.
+
+    # autograd_view_consistency:
+    #   If True, alias result using `set_` if func returns a view
+    #   (See Note [Alias Result]).
+    #   Since Forward AD doesn't work with `set_`
+    #   we disable it by setting alias to False.
+
+    class CompositeCompliantTensor(torch.Tensor):
+        elem: torch.Tensor
+
+        __slots__ = ['elem']
+
+        @staticmethod
+        def __new__(cls, elem, mode, *args, **kwargs):
+            assert type(elem) is not cls, \
+                "Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported"
+
+            # The storage of CompositeCompliantTensor should never be used directly
+            # by a Composite operation; if the Composite
+            # operator attempts to read from the storage without dispatching then it'll
+            # raise a RuntimeError due to it being a meta storage.
+            r = torch.Tensor._make_wrapper_subclass(
+                cls, elem.size(),
+                dtype=elem.dtype, layout=elem.layout,
+                device=elem.device, requires_grad=elem.requires_grad,
+                strides=elem.stride(), storage_offset=elem.storage_offset())
+
+            if elem.requires_grad:
+                # CompositeCompliantTensor steals the "requires_grad"-ness.
+                # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests...
+                tmp = torch.empty(
+                    (),
+                    dtype=elem.dtype,
+                    device=elem.device,
+                    layout=elem.layout,
+                    requires_grad=False,
+                )
+                # Use set_ rather than empty_strided() + copy_ so that we can preserve
+                # things like storage_offset.
+                tmp.set_(
+                    source=elem.untyped_storage().clone(),
+                    storage_offset=elem.storage_offset(),
+                    size=elem.size(),
+                    stride=elem.stride(),
+                )
+                r.elem = tmp
+            else:
+                r.elem = elem
+
+            assert r.stride() == r.elem.stride()
+
+            # Propagate conjugate bits to the wrapper tensor
+            # Ref: https://github.com/albanD/subclass_zoo/issues/24
+            # Ref: https://github.com/albanD/subclass_zoo/issues/21
+            torch._C._set_conj(r, r.elem.is_conj())
+            torch._C._set_neg(r, r.elem.is_neg())
+
+            r.mode = mode
+            return r
+
+        def __repr__(self):
+            return f"CompositeCompliantTensor({self.elem})"
+
+        @classmethod
+        def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+            all_args = pytree.arg_tree_leaves(*args, **(kwargs or {}))
+            modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
+            if not all_same_mode(modes):
+                raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
+            with modes[0]:
+                return func(*args, **kwargs)
+
+    class CompositeCompliantTensorMode(TorchDispatchMode):
+        def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+            def unwrap(e):
+                return e.elem if isinstance(e, CompositeCompliantTensor) else e
+
+            def wrap(e):
+                return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e
+
+            if func == torch.ops.aten._local_scalar_dense.default:
+                raise RuntimeError(
+                    ".item() is not allowed to be called inside of composite "
+                    "functions in the PyTorch library because not all backends "
+                    "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.")
+
+            if func.overloadpacket.__name__ in ('set_', 'resize_'):
+                raise RuntimeError(
+                    f"{func.__name__} is not allowed to be called inside of "
+                    f"Composite operators.")
+
+            if is_inplace(func):
+                # NB: We are making an assumption that if the function is in-place,
+                # then the first argument is being written to. Introspection please save us!
+                mutated_argument = args[0]
+                if not isinstance(mutated_argument, CompositeCompliantTensor) and \
+                        any(isinstance(a, CompositeCompliantTensor) for a in args[1:]):
+                    raise RuntimeError(
+                        'Not composite compliant: performing in-place operation '
+                        f'{func.__name__} where the Tensor being written to is '
+                        'regular Tensor but the other tensors are Tensor Subclasses. '
+                        'Please try to avoid this in-place operation.')
+
+            unwrapped_args = tree_map(unwrap, args)
+            unwrapped_kwargs = tree_map(unwrap, kwargs)
+            unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs)
+            rs = tree_map(wrap, unwrapped_rs)
+
+            if is_view_fn(func) and autograd_view_consistency:
+                # Note [Alias Result]
+                # Autograd asserts that for B = A.view_fn(...), B and A's storages
+                # are the same. Here we try to make B alias A to avoid those asserts.
+                # See https://github.com/pytorch/pytorch/issues/65339 for more information
+                # about the issue.
+                with no_dispatch():
+                    # Idea: this is a weird way of getting a storage that aliases the input.
+                    # This is a workaround for #65339.
+                    # 1. under no_dispatch, all of the wrapper tensors look like regular
+                    #    tensors with special storage (the storage is nullptr and
+                    #    advertises CPU/CUDA device.
+                    # 2. we run func, which ends up running the view operation
+                    # 3. All view operations reuse the input's storage and return
+                    #    result Tensor(s) with new sizes/strides/offset that alias
+                    #    the input.
+                    # 4. we set the storage (and sizes/strides/offset) of the wrapper
+                    #    tensor results to be that of the tensors that alias the input
+                    result = func(*args, **kwargs)
+                    if isinstance(result, (tuple, list)):
+                        for a, b in zip(rs, result):
+                            a.set_(b)
+                    else:
+                        rs.set_(result)
+
+            # Some operations are allowed to in-place modify the metadata of the
+            # inputs. The only ones are the "inplace view functions"; when we
+            # run into these, we manually modify the metadata of the input.
+            with no_dispatch():
+                if is_inplace_view_fn(func):
+                    func(*args, **kwargs)
+
+            # For each CompositeCompliantTensor t, we check that t and t.elem
+            # have consistent metadata. If they don't have consistent metadata,
+            # that means the operator did something fishy.
+            check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
+            pytree.tree_map_(check, args)
+            pytree.tree_map_(check, kwargs)
+            pytree.tree_map_(check, rs)
+            return rs
+
+    return CompositeCompliantTensor, CompositeCompliantTensorMode()
+
+def is_tensorlist(lst):
+    if not isinstance(lst, list) and not isinstance(lst, tuple):
+        return False
+    if len(lst) == 0:
+        return False
+    all_tensors = all(isinstance(elt, torch.Tensor) for elt in lst)
+    if all_tensors:
+        return True
+    exists_one_tensor = all(isinstance(elt, torch.Tensor) for elt in lst)
+    if exists_one_tensor:
+        raise RuntimeError('This test assumes that PyTorch APIs cannot take '
+                           'mixed lists of Tensor and other things')
+    return False
+
+
+def maybe_map(fn, should_map, arg):
+    return fn(arg) if should_map else arg
+
+
+def wrap(arg, CCT, cct_mode):
+    # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
+    if isinstance(arg, torch.Tensor):
+        return CCT(arg, cct_mode)
+    if is_tensorlist(arg):
+        return [CCT(a, cct_mode) for a in arg]
+    raise RuntimeError("wrap assumes that the input can be wrapped")
+
+
+# Given a list of flat arguments, some of which may be Tensors, return all
+# possible ways some of the arguments could be CompositeCompliantTensors (CCT).
+# For example, given Tensors A, B, C and flat_args = [A, 1, B],
+# We would return the following 4 options:
+# [CCT(A), 1, CCT(B)]
+# [CCT(A), 1, B]
+# [A, 1, CCT(B)]
+# [A, 1, B]
+# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops
+# don't accept that many input Tensors.
+def generate_subclass_choices(flat_args, CCT, cct_mode):
+    # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
+    is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args]
+    subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes]
+
+    for which_args_are_wrapped in itertools.product(*subclass_options):
+
+        result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg)
+                  for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)]
+        yield result, which_args_are_wrapped
+
+
+# For an operation f(*args, **kwargs), each Tensor argument may either be
+# a regular Tensor or a Tensor Subclass. This iterator iterates through
+# all of those options.
+def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
+    # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
+    flat_kwargs, spec = tree_flatten(kwargs)
+    flat_args_kwargs = list(args) + list(flat_kwargs)
+    for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode):
+        new_args = choice[:len(args)]
+        new_kwargs = tree_unflatten(choice[len(args):], spec)
+        which_args_are_wrapped = debug_metadata[:len(args)]
+        which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec)
+        yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
+
+
+def raise_composite_compliance_error(err, additional_info=''):
+    raise RuntimeError(
+        "Composite compliance check failed with "
+        "the above error.\n"
+        f"{additional_info}"
+        "If you are adding an OpInfo of an "
+        "existing operator, please feel free to skip this test "
+        "because the problem was pre-existing and file an issue. "
+        "Otherwise, if you added a new operator, please read "
+        "through the Composite Compliance section in "
+        "aten/src/ATen/native/README.md for how to resolve this. "
+    ) from err
+
+
+# This test checks ALL possible permutations of calling `op` with arguments
+# that are individually either a regular Tensor or a Tensor subclass.
+#
+# The general strategy is to wrap some Tensor args and kwargs in
+# CompositeCompliantTensor wrappers and call the operation.
+
+# If some composite operation does any non-compliant behavior,
+# CompositeCompliantTensor will raise an error.
+def check_all_permutations(op, args, kwargs, assert_equal_fn):
+    CCT, cct_mode = generate_cct_and_mode()
+    expected = op(*args, **kwargs)
+    for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
+        new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
+
+        try:
+            actual = op(*new_args, **new_kwargs)
+        # NOTE: [What errors are Composite Compliance trying to catch?]
+        #
+        # There's two things we want to catch:
+        # - errors that would raise within the torch_dispatch impl
+        # - data_ptr accesses
+        # The first is easy to filter for (we could make the error a different
+        # error class), the second is always going to be a RuntimeError due to
+        # how it is implemented (if you try to access the data_ptr of the
+        # wrapper Tensor, it raises you some internal RuntimeError).
+        #
+        # So the most general thing to catch here was RuntimeError. If you
+        # are here and debugging why your test failed, it's plausible that
+        # the operator itself is broken and that there are other tests failing.
+        except RuntimeError as err:
+            raise_composite_compliance_error(
+                err,
+                f"- wrapped_args: {which_args_are_wrapped}\n"
+                f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
+            )
+
+        def unwrap(e):
+            return e.elem if isinstance(e, CCT) else e
+
+        assert_equal_fn(tree_map(unwrap, actual), expected)
+
+# Checks via the usage of torch dispatch mode certain anti-patterns that
+# are not composite compliant.
+#
+# In particular, the anti-pattern we are trying to prevent is a user
+# creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps
+# here because all factory functions will create tensors that are
+# CompositeCompliantTensor.
+#
+# The general strategy is to wrap all Tensor args and kwargs in
+# CompositeCompliantTensor wrappers. If an operator that is
+# Composite does any non-compliant behavior,
+# CompositeCompliantTensor will raise an error.
+def check_with_mode(op, args, kwargs, assert_equal_fn):
+    CCT, cct_mode = generate_cct_and_mode()
+
+    def wrap(e):
+        return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e
+
+    expected = op(*args, **kwargs)
+
+    args = tree_map(wrap, args)
+    kwargs = tree_map(wrap, kwargs)
+    try:
+        with cct_mode:
+            actual = op(*args, **kwargs)
+    # see NOTE: [What errors are Composite Compliance trying to catch?]
+    except RuntimeError as err:
+        raise_composite_compliance_error(err)
+
+    def unwrap(e):
+        return e.elem if isinstance(e, CCT) else e
+
+    assert_equal_fn(tree_map(unwrap, actual), expected)
+
+def gather_leaf_tensors(args, kwargs):
+    leaf_tensors = []
+    args, _args_spec = tree_flatten(args)
+    kwargs, _kwargs_spec = tree_flatten(kwargs)
+    args = args + kwargs
+    for arg in args:
+        if not isinstance(arg, torch.Tensor):
+            continue
+        if arg.requires_grad:
+            leaf_tensors.append(arg)
+    return leaf_tensors
+
+
+def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None):
+    if gradcheck_wrapper is None:
+        results = op(*args, **kwargs)
+    else:
+        results = gradcheck_wrapper(op, *args, **kwargs)
+
+    if output_process_fn_grad is not None:
+        results = output_process_fn_grad(results)
+
+    flat_results = pytree.tree_leaves(results)
+    flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
+    flat_diff_results = [r for r in flat_results if r.requires_grad]
+    assert len(flat_diff_results) > 0
+
+    grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results]
+    leaf_tensors = gather_leaf_tensors(args, kwargs)
+    assert len(leaf_tensors) > 0
+    return torch.autograd.grad(flat_diff_results, leaf_tensors,
+                               grads, allow_unused=True, retain_graph=True)
+
+
+# Checks if the backward formula is composite compliant by testing
+# all possible permutations of {inputs, grad_outputs} being
+# CompositeCompliantTensor or regular Tensors.
+#
+# NB: it is important that op is accepted as a Callable and not an OpInfo,
+# this means we can apply check_backward_formula to things that aren't OpInfos
+# while debugging.
+def check_backward_formula(op: Callable, args, kwargs,
+                           output_process_fn_grad=None,
+                           gradcheck_wrapper=None, assert_equal_fn=None):
+    CCT, cct_mode = generate_cct_and_mode()
+
+    expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
+
+    for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
+        new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
+        leaf_tensors = gather_leaf_tensors(new_args, new_kwargs)
+        assert len(leaf_tensors) > 0
+
+        try:
+            if gradcheck_wrapper is None:
+                results = op(*new_args, **new_kwargs)
+            else:
+                results = gradcheck_wrapper(op, *new_args, **new_kwargs)
+            if output_process_fn_grad is not None:
+                results = output_process_fn_grad(results)
+        # see NOTE: [What errors are Composite Compliance trying to catch?]
+        except RuntimeError as err:
+            raise_composite_compliance_error(
+                err,
+                f"- wrapped_args: {which_args_are_wrapped}\n"
+                f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
+            )
+
+        flat_results = pytree.tree_leaves(results)
+        flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
+        flat_diff_results = [r for r in flat_results if r.requires_grad]
+        assert len(flat_diff_results) > 0
+
+        # NB: ones, not ones_like, so we get a regular Tensor here
+        grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
+                 for r in flat_diff_results]
+        for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode):
+            try:
+                actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads,
+                                             allow_unused=True, retain_graph=True)
+            # see NOTE: [What errors are Composite Compliance trying to catch?]
+            except RuntimeError as err:
+                raise_composite_compliance_error(
+                    err,
+                    f"- wrapped_args: {which_args_are_wrapped}\n"
+                    f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
+                    f"- wrapped_grads: {which_grad_is_batched}\n"
+                )
+
+            def unwrap(e):
+                return e.elem if isinstance(e, CCT) else e
+
+            assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True)
+
+# Checks if the forward AD formula is composite compliant by testing
+# all possible permutations of {primals, tangents} being
+# CompositeCompliantTensor or regular Tensors.
+#
+# NB: it is important that op is accepted as a Callable and not an OpInfo,
+# this means we can apply check_forward_ad_formula to things that aren't OpInfos
+# while debugging.
+def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None):
+    CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False)
+
+    def maybe_tangent(t):
+        assert type(t) is not CCT
+        # Generate `tangent` tensor
+        # if given object is a Tensor and requires grad is set.
+        if isinstance(t, torch.Tensor) and t.requires_grad:
+            return torch.randn_like(t)
+        elif is_tensorlist(t):
+            return [torch.randn_like(e) if e.requires_grad else None for e in t]
+        return None
+
+    tangent_args = tuple(maybe_tangent(arg) for arg in args)
+    flat_kwargs, spec = tree_flatten(kwargs)
+    flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs)
+    tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec)
+
+    with fwAD.dual_level():
+        def maybe_make_dual(dual):
+            # Returns dual tensor if primal is a tensor/tensor subclass
+            # with requires_grad set.
+            primal, tangent = dual
+            if isinstance(primal, torch.Tensor) and primal.requires_grad:
+                return fwAD.make_dual(primal.detach(), tangent)
+            elif is_tensorlist(primal):
+                return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri
+                             for pri, tang in zip(primal, tangent))
+            return primal
+
+        def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs):
+            op_args = tuple(map(maybe_make_dual, zip(args, tangent_args)))
+            op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()}
+
+            if gradcheck_wrapper is None:
+                return op(*op_args, **op_kwargs)
+            return gradcheck_wrapper(op, *op_args, **op_kwargs)
+
+        expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
+        expected = tree_map(fwAD.unpack_dual, expected)
+        expected_primals = tree_map(
+            lambda x: x.primal,
+            expected,
+            is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
+        )
+        expected_tangents = tree_map(
+            lambda x: x.tangent,
+            expected,
+            is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
+        )
+
+        # Permutations of arg and kwargs in CCT.
+        for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
+            new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
+
+            # Permutations tangent arg and tangent kwargs in CCT.
+            for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode):
+                new_tang_args, new_tang_kwargs, \
+                    which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice
+
+                op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args)))
+                op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()}
+
+                try:
+                    if gradcheck_wrapper is None:
+                        actual = op(*op_args, **op_kwargs)
+                    else:
+                        actual = gradcheck_wrapper(op, *op_args, **op_kwargs)
+                # see NOTE: [What errors are Composite Compliance trying to catch?]
+                except RuntimeError as err:
+                    raise_composite_compliance_error(
+                        err,
+                        f"- wrapped_args: {which_args_are_wrapped}\n"
+                        f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
+                        f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n"
+                        f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n"
+                    )
+
+                def unwrap(e):
+                    return e.elem if isinstance(e, CCT) else e
+
+                actual = tree_map(fwAD.unpack_dual, actual)
+                actual_primals = tree_map(
+                    lambda x: unwrap(x.primal),
+                    actual,
+                    is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
+                )
+                actual_tangents = tree_map(
+                    lambda x: unwrap(x.tangent),
+                    actual,
+                    is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor,
+                )
+                assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
+                assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..32982d0a3e2a358a2530abd234b37a24c6efe77d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_op_db.py
@@ -0,0 +1,585 @@
+# mypy: allow-untyped-defs
+import torch
+import functools
+from torch.testing import make_tensor
+from torch.testing._internal.opinfo.core import (
+    OpInfo,
+    SampleInput,
+)
+from torch.testing._internal.common_dtype import all_types_and
+import numpy as np
+from torch.testing._internal.autograd_function_db import (
+    sample_inputs_numpy_cube,
+    sample_inputs_numpy_mul,
+    sample_inputs_numpy_mul_scalar,
+    sample_inputs_numpy_sort,
+    sample_inputs_numpy_take,
+)
+from torch import Tensor
+from torch.types import Number
+from typing import *  # noqa: F403
+
+# Note: [custom op db]
+#
+# This is a collection of custom operator test cases written as OpInfos
+# so they can easily be consumed by OpInfo-based tests to check if subsystems
+# support them correctly.
+
+def to_numpy(tensor):
+    return tensor.cpu().numpy()
+
+@torch.library.custom_op("_torch_testing::numpy_cube", mutates_args=())
+def numpy_cube(x: Tensor) -> tuple[Tensor, Tensor]:
+    x_np = to_numpy(x)
+    dx = torch.tensor(3 * x_np ** 2, device=x.device)
+    return torch.tensor(x_np ** 3, device=x.device), dx
+
+@numpy_cube.register_fake
+def _(x):
+    return x.clone(), x.clone()
+
+def numpy_cube_setup_context(ctx, inputs, output):
+    x, = inputs
+    _cube, dx = output
+    ctx.save_for_backward(x, dx)
+
+def numpy_cube_backward(ctx, grad_out, grad_dx):
+    x, dx = ctx.saved_tensors
+    grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
+    return grad_x
+
+numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
+
+def numpy_cube_vmap(info, in_dims, x):
+    result = numpy_cube(x)
+    return result, (in_dims[0], in_dims[0])
+
+numpy_cube.register_vmap(numpy_cube_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
+def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
+    return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
+
+@numpy_mul.register_fake
+def _(x, y):
+    assert x.device == y.device
+    return (x * y).contiguous()
+
+def numpy_mul_setup_context(ctx, inputs, output):
+    ctx.save_for_backward(*inputs)
+
+def numpy_mul_backward(ctx, grad_out):
+    x, y = ctx.saved_tensors
+    grad_x = grad_out * y if ctx.needs_input_grad[0] else None
+    grad_y = grad_out * x if ctx.needs_input_grad[1] else None
+    return grad_x, grad_y
+
+numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
+
+def numpy_mul_vmap(info, in_dims, x, y):
+    x_bdim, y_bdim = in_dims
+    x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
+    y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
+    result = x * y
+    result = result.movedim(-1, 0)
+    return result, 0
+
+numpy_mul.register_vmap(numpy_mul_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=())
+def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor:
+    return torch.tensor(to_numpy(x) * scalar, device=x.device)
+
+@numpy_mul_scalar.register_fake
+def _(x, *, scalar):
+    return (x * scalar).contiguous()
+
+def numpy_mul_scalar_setup_context(ctx, inputs, keyword_only_inputs, output):
+    ctx.scalar = keyword_only_inputs["scalar"]
+
+def numpy_mul_scalar_backward(ctx, grad_out):
+    grad_x = grad_out * ctx.scalar
+    return grad_x
+
+numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context)
+
+def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar):
+    x_bdim, = in_dims
+    x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
+    result = x * scalar
+    result = result.movedim(-1, 0)
+    return result, 0
+
+numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
+def numpy_sort(x: Tensor, dim: int) -> tuple[Tensor, Tensor, Tensor]:
+    device = x.device
+    x = to_numpy(x)
+    ind = np.argsort(x, axis=dim)
+    ind_inv = np.argsort(ind, axis=dim)
+    result = np.take_along_axis(x, ind, axis=dim)
+    return (
+        torch.tensor(result, device=device),
+        torch.tensor(ind, device=device),
+        torch.tensor(ind_inv, device=device),
+    )
+
+@numpy_sort.register_fake
+def _(x, dim):
+    return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long)
+
+def numpy_sort_setup_context(ctx, inputs, output):
+    _out, ind, ind_inv = output
+    ctx.dim = inputs[1]
+    ctx.save_for_backward(ind, ind_inv)
+    ctx.mark_non_differentiable(ind, ind_inv)
+
+def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
+    ind, ind_inv = ctx.saved_tensors
+    return numpy_take(grad_out, ind_inv, ind, ctx.dim), None
+
+numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
+
+def numpy_sort_vmap(info, in_dims, x, dim):
+    x_bdim, _ = in_dims
+    x = x.movedim(x_bdim, 0)
+    dim = dim if dim >= 0 else dim + x.dim() - 1
+    result = numpy_sort(x, dim + 1)
+    return result, (0, 0, 0)
+
+numpy_sort.register_vmap(numpy_sort_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
+def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
+    device = x.device
+    x = to_numpy(x)
+    ind = to_numpy(ind)
+    return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
+
+@numpy_take.register_fake
+def _(x, ind, ind_inv, dim):
+    assert x.device == ind.device
+    assert x.device == ind_inv.device
+    assert ind.dtype == torch.long
+    assert ind_inv.dtype == torch.long
+    return torch.empty_like(x)
+
+def numpy_take_setup_context(ctx, inputs, output):
+    _x, ind, ind_inv, dim = inputs
+    ctx.dim = dim
+    ctx.save_for_backward(ind, ind_inv)
+
+def numpy_take_backward(ctx, grad_out):
+    ind, ind_inv = ctx.saved_tensors
+    grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim)
+    return grad_x, None, None, None
+
+numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
+
+def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim):
+    x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
+
+    # wrap dim
+    logical_dim = x.dim() if x_bdim is None else x_bdim - 1
+    dim = dim if dim >= 0 else dim + logical_dim
+
+    def expand_bdim(x, x_bdim):
+        if x_bdim is None:
+            return x.expand(info.batch_size, *x.shape)
+        return x.movedim(x_bdim, 0)
+
+    x = expand_bdim(x, x_bdim)
+    ind = expand_bdim(ind, ind_bdim)
+    ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
+
+    return numpy_take(x, ind, ind_inv, dim + 1), 0
+
+numpy_take.register_vmap(numpy_take_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
+def numpy_nonzero(x: Tensor) -> Tensor:
+    x_np = to_numpy(x)
+    res = np.stack(np.nonzero(x_np), axis=1)
+    if res.shape[0] <= 1:
+        raise RuntimeError("not supported")
+    return torch.tensor(res, device=x.device)
+
+@numpy_nonzero.register_fake
+def _(x):
+    ctx = torch._custom_op.impl.get_ctx()
+    i0 = ctx.create_unbacked_symint()
+    shape = [i0, x.dim()]
+    result = x.new_empty(shape, dtype=torch.long)
+    return result
+
+def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    shape = 10
+    result = make_arg(shape, low=0.9, high=2)
+    mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long)
+    with torch.no_grad():
+        result *= mask
+
+    yield SampleInput(result, args=())
+
+def numpy_nonzero_vmap(info, in_dims, x):
+    raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
+
+numpy_nonzero.register_vmap(numpy_nonzero_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=())
+def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor:
+    return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device)
+
+@numpy_view_copy.register_fake
+def _(x, shape) -> Tensor:
+    return x.clone().view(shape).clone()
+
+def numpy_view_copy_setup_context(ctx, inputs, output) -> None:
+    ctx.x_shape = inputs[0].shape
+
+def numpy_view_copy_backward(ctx, grad_out):
+    return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None
+
+numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
+
+def numpy_view_copy_vmap(info, in_dims, x, shape):
+    x_bdim, _ = in_dims
+    x = x.movedim(x_bdim, 0)
+    x_shape = x.shape[0]
+    batch_shape = (x_shape, *shape)
+    result = numpy_view_copy(x, batch_shape)
+    return result, 0
+
+numpy_view_copy.register_vmap(numpy_view_copy_vmap)
+
+def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    result = make_arg(2, 3, 4, low=0.9, high=2)
+    yield SampleInput(result, args=([2, 12],))
+
+@torch.library.custom_op('_torch_testing::numpy_cat', mutates_args=())
+def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor:
+    assert len(xs) > 0
+    assert all(x.device == xs[0].device for x in xs)
+    assert all(x.dtype == xs[0].dtype for x in xs)
+    np_xs = [to_numpy(x) for x in xs]
+    np_out = np.concatenate(np_xs, axis=dim)
+    return torch.tensor(np_out, device=xs[0].device)
+
+@numpy_cat.register_fake
+def _(xs, dim):
+    assert len(xs) > 0
+    assert all(x.device == xs[0].device for x in xs)
+    assert all(x.dtype == xs[0].dtype for x in xs)
+    return torch.cat(xs, dim=dim)
+
+def numpy_cat_setup_context(ctx, inputs, output):
+    xs, dim = inputs
+    ctx.dim_sizes = [x.shape[dim] for x in xs]
+    ctx.dim = dim
+
+def numpy_cat_backward(ctx, grad_out):
+    dim_sizes = ctx.dim_sizes
+    dim = ctx.dim
+
+    splits = list(np.cumsum(dim_sizes)[:-1])
+    grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
+    return grad_xs, None
+
+numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
+
+def numpy_cat_vmap(info, in_dims, x, dim):
+    x_bdim, = in_dims
+    result = numpy_cat(x, dim)
+    return result, x_bdim
+
+numpy_cat.register_vmap(numpy_cat_vmap)
+
+def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    r0 = make_arg(2, 3, 4, low=0.9, high=2)
+    r1 = make_arg(4, 3, 4, low=0.9, high=2)
+    r2 = make_arg(5, 3, 4, low=0.9, high=2)
+    yield SampleInput([r0, r1, r2], args=(0,))
+
+@torch.library.custom_op('_torch_testing::numpy_split_copy', mutates_args=())
+def numpy_split_copy(x: Tensor, splits: Sequence[int], dim: int) -> List[Tensor]:
+    x_np = to_numpy(x)
+    arrs = np.split(x_np, splits, axis=dim)
+    return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs]
+
+@numpy_split_copy.register_fake
+def _(x, splits, dim):
+    return [xi.clone() for xi in torch.tensor_split(x, splits, dim)]
+
+def numpy_split_copy_setup_context(ctx, inputs, output):
+    _, _, dim = inputs
+    ctx.dim = dim
+
+def numpy_split_copy_backward(ctx, grad_out):
+    result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim)
+    return result, None, None
+
+numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
+
+def numpy_split_copy_vmap(info, in_dims, x, splits, dim):
+    x_bdim, _ , _ = in_dims
+    x = x.movedim(x_bdim, 0)
+    result = numpy_split_copy(x, splits, dim + 1)
+    return result, 0
+
+numpy_split_copy.register_vmap(numpy_split_copy_vmap)
+
+def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    x = make_arg(2, 9, low=0.9, high=2)
+    yield SampleInput(x, args=([1, 3, 6], 1))
+
+@torch.library.custom_op('_torch_testing::numpy_split_copy_with_int', mutates_args=())
+def numpy_split_copy_with_int(x: Tensor, splits: Sequence[int], dim: int) -> tuple[List[Tensor], int]:
+    x_np = to_numpy(x)
+    arrs = np.split(x_np, splits, axis=dim)
+    return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits)
+
+@numpy_split_copy_with_int.register_fake
+def _(x, splits, dim):
+    return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits)
+
+def numpy_split_copy_with_int_setup_context(ctx, inputs, output):
+    _, _, dim = inputs
+    ctx.dim = dim
+
+def numpy_split_copy_with_int_backward(ctx, grad_out, _):
+    return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None
+
+numpy_split_copy_with_int.register_autograd(
+    numpy_split_copy_with_int_backward,
+    setup_context=numpy_split_copy_with_int_setup_context)
+
+def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim):
+    x_bdim, _ , _ = in_dims
+    x = x.movedim(x_bdim, 0)
+    result, len_split = numpy_split_copy_with_int(x, splits, dim + 1)
+    return (result, len_split), ([0 for _ in range(len(result))], None)
+
+numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap)
+
+@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
+def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
+    # Adapted from Ross Girshick's fast-rcnn implementation at
+    # https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
+    assert boxes.device == scores.device
+    device = boxes.device
+
+    boxes = to_numpy(boxes)
+    scores = to_numpy(scores)
+
+    N = boxes.shape[0]
+    assert boxes.shape == (N, 4)
+    assert scores.shape == (N,)
+
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+
+    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+        inds = np.where(ovr <= iou_threshold)[0]
+        order = order[inds + 1]
+
+    result = torch.tensor(np.stack(keep), device=device)
+    # Needed for data-dependent condition :(
+    assert result.size(0) >= 2
+    return result
+
+@numpy_nms.register_fake
+def _(boxes, scores, iou_threshold):
+    assert boxes.device == scores.device
+    N = boxes.shape[0]
+    assert boxes.shape == (N, 4)
+    assert scores.shape == (N,)
+
+    ctx = torch._custom_op.impl.get_ctx()
+    i0 = ctx.create_unbacked_symint()
+    result = boxes.new_empty([i0], dtype=torch.int64)
+    return result
+
+def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold):
+    raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
+
+numpy_nms.register_vmap(numpy_nms_vmap)
+
+def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(make_tensor, device=device, dtype=dtype)
+    N = 64
+    xs = make_arg([N], low=0, high=28)
+    dx = make_arg([N], low=0, high=4)
+    ys = make_arg([N], low=0, high=28)
+    dy = make_arg([N], low=0, high=4)
+    boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad)
+    scores = make_arg([N], low=0, high=1, requires_grad=requires_grad)
+    iou_threshold = make_arg([], low=0, high=1).item()
+
+    yield SampleInput(boxes, args=(scores, iou_threshold))
+
+custom_op_db = [
+    OpInfo(
+        'NumpyCubeCustomOp',
+        op=numpy_cube._opoverload,
+        sample_inputs_func=sample_inputs_numpy_cube,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyMulCustomOp',
+        op=numpy_mul._opoverload,
+        sample_inputs_func=sample_inputs_numpy_mul,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyMulScalarCustomOp',
+        op=numpy_mul_scalar._opoverload,
+        sample_inputs_func=sample_inputs_numpy_mul_scalar,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpySortCustomOp',
+        op=numpy_sort._opoverload,
+        sample_inputs_func=sample_inputs_numpy_sort,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyTakeCustomOp',
+        op=numpy_take._opoverload,
+        sample_inputs_func=sample_inputs_numpy_take,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyNonzeroCustomOp',
+        op=numpy_nonzero._opoverload,
+        sample_inputs_func=sample_inputs_numpy_nonzero,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=False,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyNMSCustomOp',
+        op=torch.ops._torch_testing.numpy_nms,
+        sample_inputs_func=sample_inputs_numpy_nms,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=False,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyViewCopyCustomOp',
+        op=torch.ops._torch_testing.numpy_view_copy,
+        sample_inputs_func=sample_inputs_numpy_view_copy,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpyCatCustomOp',
+        op=torch.ops._torch_testing.numpy_cat,
+        sample_inputs_func=sample_inputs_numpy_cat,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpySplitCopyCustomOp',
+        op=torch.ops._torch_testing.numpy_split_copy,
+        sample_inputs_func=sample_inputs_numpy_split_copy,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_autograd=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_out=False,
+    ),
+    OpInfo(
+        'NumpySplitCopyWithIntCustomOp',
+        op=torch.ops._torch_testing.numpy_split_copy_with_int,
+        sample_inputs_func=sample_inputs_numpy_split_copy,
+        dtypes=all_types_and(torch.bool, torch.half),
+        gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0],
+        supports_autograd=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_out=False,
+    ),
+]
+
+
+# ==============================================================
+# some mechanical test cases
+# ==============================================================
+
+lib = torch.library.Library("_torch_testing", "FRAGMENT")  # noqa: TOR901
+
+lib.define("source0(Tensor x) -> Tensor")
+
+@torch.library.register_fake("_torch_testing::source0", lib=lib)
+def _(x):
+    return x.clone()
+
+lib.define("source1(Tensor x) -> Tensor")
+
+def source1_fake(x):
+    return x.clone()
+
+torch.library.register_fake("_torch_testing::source1", source1_fake, lib=lib)
+
+lib.define("source2(Tensor x) -> Tensor")
+
+@torch.library.register_fake("_torch_testing::source2", lib=lib)
+def _(x):
+    return x.clone()
+
+lib.define("source3(Tensor x) -> Tensor")
+
+def source3_fake(x):
+    return x.clone()
+
+torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib)
+
+
+@torch.library.custom_op("_torch_testing::source4", mutates_args=())
+def source4(x: Tensor) -> Tensor:
+    return x.clone()
+
+@source4.register_fake
+def _(x):
+    return x.clone()
+
+@torch.library.custom_op("_torch_testing::source5", mutates_args=())
+def source5(x: Tensor) -> Tensor:
+    return x.clone()
+
+def source5_fake(x):
+    return x.clone()
+
+source5.register_fake(source5_fake)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fa6f79ec68a71f830acba604a5343dcbe44589f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/custom_tensor.py
@@ -0,0 +1,158 @@
+# mypy: ignore-errors
+
+
+from collections import namedtuple
+
+import torch
+import torch.utils._pytree as pytree
+from torch.utils._python_dispatch import return_and_correct_aliasing
+
+
+FancyNamedTuple = namedtuple("FancyNamedTuple", ["foo", "bar"])
+
+
+# A simple tensor subclass that holds a tensor with custom metadata and custom method
+class ConstantExtraMetadataTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, elem):
+        shape = elem.shape
+        kwargs = {}
+        kwargs["strides"] = elem.stride()
+        kwargs["storage_offset"] = elem.storage_offset()
+        kwargs["device"] = elem.device
+        kwargs["layout"] = elem.layout
+        kwargs["requires_grad"] = elem.requires_grad
+        kwargs["dtype"] = elem.dtype
+        return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
+
+    def __init__(self, elem):
+        self.elem = elem
+        self.constant_attribute = 4
+
+    def __repr__(self):
+        inner_repr = repr(self.elem)
+        return f"CustomTensor({inner_repr})"
+
+    def get_complicated_metadata(self):
+        return FancyNamedTuple(self.constant_attribute, self.constant_attribute)
+
+    def __tensor_flatten__(self):
+        return ["elem"], self.constant_attribute
+
+    def add_constant(self, a):
+        self.constant_attribute += a
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        assert meta is not None
+        elem = inner_tensors["elem"]
+        out = ConstantExtraMetadataTensor(elem)
+        out.constant_attribute = meta
+        return out
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+        args_inner = pytree.tree_map_only(
+            ConstantExtraMetadataTensor, lambda x: x.elem, args
+        )
+
+        kwargs_inner = pytree.tree_map_only(
+            ConstantExtraMetadataTensor, lambda x: x.elem, kwargs
+        )
+
+        out_inner = func(*args_inner, **kwargs_inner)
+        out_inner_flat, spec = pytree.tree_flatten(out_inner)
+        # for aten ops that return non-tensors, just assume that
+        # our cust inner tensors return the same value
+        out_flat = [
+            ConstantExtraMetadataTensor(o_inner)
+            if isinstance(o_inner, torch.Tensor)
+            else o_inner
+            for o_inner in out_inner_flat
+        ]
+        out = pytree.tree_unflatten(out_flat, spec)
+        return return_and_correct_aliasing(func, args, kwargs, out)
+
+
+# A simple tensor subclass that always returns plain tensor during __torch_dispatch__
+# It is similar to TwoTensor and is used to simulate torchao quantized tensors
+class CustomTensorPlainOut(torch.Tensor):
+    @staticmethod
+    def __new__(cls, elem1, elem2):
+        shape = elem1.shape
+        kwargs = {}
+        kwargs["strides"] = elem1.stride()
+        kwargs["storage_offset"] = elem1.storage_offset()
+        kwargs["device"] = elem1.device
+        kwargs["layout"] = elem1.layout
+        kwargs["requires_grad"] = elem1.requires_grad
+        kwargs["dtype"] = elem1.dtype
+        return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
+
+    def __init__(self, elem1, elem2):
+        self.elem1 = elem1
+        self.elem2 = elem2
+
+    def get_elem(self):
+        return self.elem1
+
+    def __repr__(self):
+        inner_repr_1 = repr(self.elem1)
+        inner_repr_2 = repr(self.elem2)
+        return f"CustomTensorPlainOut({inner_repr_1}, {inner_repr_2})"
+
+    def __tensor_flatten__(self):
+        return ["elem1", "elem2"], None
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        elem1 = inner_tensors["elem1"]
+        elem2 = inner_tensors["elem2"]
+        out = CustomTensorPlainOut(elem1, elem2)
+        return out
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        # Don't use this tensor with view ops
+        if kwargs is None:
+            kwargs = {}
+        args_inner_1 = pytree.tree_map_only(
+            CustomTensorPlainOut, lambda x: x.elem1, args
+        )
+
+        kwargs_inner_1 = pytree.tree_map_only(
+            CustomTensorPlainOut, lambda x: x.elem1, kwargs
+        )
+
+        args_inner_2 = pytree.tree_map_only(
+            CustomTensorPlainOut, lambda x: x.elem2, args
+        )
+
+        kwargs_inner_2 = pytree.tree_map_only(
+            CustomTensorPlainOut, lambda x: x.elem2, kwargs
+        )
+
+        out_inner_1 = func(*args_inner_1, **kwargs_inner_1)
+        out_inner_2 = func(*args_inner_2, **kwargs_inner_2)
+
+        out_inner_flat_1, spec = pytree.tree_flatten(out_inner_1)
+        out_inner_flat_2, spec = pytree.tree_flatten(out_inner_2)
+
+        if func.is_view:
+            new_out = pytree.tree_unflatten(
+                (
+                    CustomTensorPlainOut(tensor1, tensor2)
+                    for tensor1, tensor2 in zip(out_inner_flat_1, out_inner_flat_2)
+                ),
+                spec,
+            )
+            return return_and_correct_aliasing(func, args, kwargs, new_out)
+
+        out_new = (
+            out_inner_flat_1[ix] + out_inner_flat_2[ix]
+            for ix in range(len(out_inner_flat_1))
+        )
+
+        return pytree.tree_unflatten(out_new, spec)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e3572cfc4c6a0ddc3d8fa2e1b056415204acdfa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/__init__.py
@@ -0,0 +1 @@
+# mypy: ignore-errors
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8755643a78cca80668988df9e9db3de75778b5db
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network1.py
@@ -0,0 +1,10 @@
+# mypy: ignore-errors
+
+import torch.nn as nn
+
+
+class Net(nn.Module):
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear = nn.Linear(10, 20)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py
new file mode 100644
index 0000000000000000000000000000000000000000..19b0b8ee53d3b530aa33978c7a13da4e5fee4ebd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/data/network2.py
@@ -0,0 +1,11 @@
+# mypy: ignore-errors
+
+import torch.nn as nn
+
+
+class Net(nn.Module):
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.linear = nn.Linear(10, 20)
+        self.relu = nn.ReLU()
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..45af2552cf25cef03a517f5b136c1a2e61c3a61d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/dist_utils.py
@@ -0,0 +1,199 @@
+# mypy: ignore-errors
+
+import re
+import sys
+import time
+from functools import partial, wraps
+
+import torch.distributed as dist
+import torch.distributed.rpc as rpc
+from torch.distributed.rpc import _rref_context_get_debug_info
+from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN
+
+
+if not dist.is_available():
+    print("c10d not available, skipping tests", file=sys.stderr)
+    sys.exit(0)
+
+
+INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}"
+
+def dist_init(
+    old_test_method=None,
+    setup_rpc: bool = True,
+    clean_shutdown: bool = True,
+    faulty_messages=None,
+    messages_to_delay=None,
+):
+    """
+    We use this decorator for setting up and tearing down state since
+    MultiProcessTestCase runs each `test*` method in a separate process and
+    each process just runs the `test*` method without actually calling
+    'setUp' and 'tearDown' methods of unittest.
+
+    Note: pass the string representation of MessageTypes that should be used
+    with the faulty agent's send function. By default, all retriable messages
+    ("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE",
+    "CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is
+    set from faulty_rpc_agent_test_fixture.py).
+    """
+    # If we use dist_init without arguments (ex: @dist_init), old_test_method is
+    # appropriately set and we return the wrapper appropriately. On the other
+    # hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)),
+    # old_test_method is None and we return a functools.partial which is the real
+    # decorator that is used and as a result we recursively call dist_init with
+    # old_test_method and the rest of the arguments appropriately set.
+    if old_test_method is None:
+        return partial(
+            dist_init,
+            setup_rpc=setup_rpc,
+            clean_shutdown=clean_shutdown,
+            faulty_messages=faulty_messages,
+            messages_to_delay=messages_to_delay,
+        )
+
+    @wraps(old_test_method)
+    def new_test_method(self, *arg, **kwargs):
+        # Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
+        # in tests.
+        import torch.distributed.rpc.api as api
+
+        api._ignore_rref_leak = False
+        self.worker_id = self.rank
+        self.setup_fault_injection(faulty_messages, messages_to_delay)
+
+        rpc_backend_options = self.rpc_backend_options
+        if setup_rpc:
+            if TEST_WITH_TSAN:
+                # TSAN runs much slower.
+                rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5
+                rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60
+
+            rpc.init_rpc(
+                name=f"worker{self.rank:d}",
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+        return_value = old_test_method(self, *arg, **kwargs)
+
+        if setup_rpc:
+            rpc.shutdown(graceful=clean_shutdown)
+
+        return return_value
+
+    return new_test_method
+
+
+def noop() -> None:
+    pass
+
+
+def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
+    """
+    Loops until an RPC to the given rank fails. This is used to
+    indicate that the node has failed in unit tests.
+    Args:
+    rank (int): Rank of the node expected to fail
+    expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure
+    occurs, not just any.
+    """
+    while True:
+        try:
+            rpc.rpc_sync(f"worker{rank}", noop, args=())
+            time.sleep(0.1)
+        except Exception as e:
+            if re.search(pattern=expected_error_regex, string=str(e)):
+                return str(e)
+
+
+def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
+    """
+    The RRef protocol holds forkIds of rrefs in a map until those forks are
+    confirmed by the owner. The message confirming the fork may arrive after
+    our tests check whether this map is empty, which leads to failures and
+    flaky tests. to_here also does not guarantee that we have finished
+    processind the owner's confirmation message for the RRef. This function
+    loops until the map is empty, which means the messages have been received
+    as processed. Call this function before asserting the map returned by
+    _get_debug_info is empty.
+    """
+    start = time.time()
+    while True:
+        debug_info = _rref_context_get_debug_info()
+        num_pending_futures = int(debug_info["num_pending_futures"])
+        num_pending_users = int(debug_info["num_pending_users"])
+        if num_pending_futures == 0 and num_pending_users == 0:
+            break
+        time.sleep(0.1)
+        if time.time() - start > timeout:
+            raise ValueError(
+                f"Timed out waiting to flush pending futures and users, "
+                f"had {num_pending_futures} pending futures and {num_pending_users} pending users"
+            )
+
+
+def get_num_owners_and_forks() -> tuple[str, str]:
+    """
+    Retrieves number of OwnerRRefs and forks on this node from
+    _rref_context_get_debug_info.
+    """
+    rref_dbg_info = _rref_context_get_debug_info()
+    num_owners = rref_dbg_info["num_owner_rrefs"]
+    num_forks = rref_dbg_info["num_forks"]
+    return num_owners, num_forks
+
+
+def wait_until_owners_and_forks_on_rank(
+    num_owners: int, num_forks: int, rank: int, timeout: int = 20
+) -> None:
+    """
+    Waits until timeout for num_forks and num_owners to exist on the rank. Used
+    to ensure proper deletion of RRefs in tests.
+    """
+    start = time.time()
+    while True:
+        num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync(
+            worker_name(rank), get_num_owners_and_forks, args=(), timeout=5
+        )
+        num_owners_on_rank = int(num_owners_on_rank)
+        num_forks_on_rank = int(num_forks_on_rank)
+        if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks:
+            return
+        time.sleep(1)
+        if time.time() - start > timeout:
+            raise ValueError(
+                f"Timed out waiting {timeout} sec for {num_owners} owners and {num_forks} forks on rank,"
+                f" had {num_owners_on_rank} owners and {num_forks_on_rank} forks"
+            )
+
+
+def initialize_pg(init_method, rank: int, world_size: int) -> None:
+    # This is for tests using `dist.barrier`.
+    if not dist.is_initialized():
+        dist.init_process_group(
+            backend="gloo",
+            init_method=init_method,
+            rank=rank,
+            world_size=world_size,
+        )
+
+
+def worker_name(rank: int) -> str:
+    return f"worker{rank}"
+
+
+def get_function_event(function_events, partial_event_name):
+    """
+    Returns the first event that matches partial_event_name in the provided
+    function_events. These function_events should be the output of
+    torch.autograd.profiler.function_events().
+
+    Args:
+    function_events: function_events returned by the profiler.
+    event_name (str): partial key that the event was profiled with.
+    """
+    event = [event for event in function_events if partial_event_name in event.name][0]  # noqa: RUF015
+    return event
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1481de490d6ecbd7018f70131a83f764f39cc73
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e604206fc37e683779696febcca49d9e09dec99
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..acc7005c6b9e3d64d1ca50714839b0732d41b5a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/__init__.py
@@ -0,0 +1 @@
+# mypy: allow-untyped-defs
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..838c5fd01adfc51b59c92a315542aee88a54e4f7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py
@@ -0,0 +1,103 @@
+# mypy: allow-untyped-defs
+
+import sys
+from functools import partial, wraps
+
+import torch
+import torch.distributed as dist
+from torch.distributed import rpc
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    TEST_SKIPS,
+    tp_transports,
+)
+
+
+TEST_GPU_NUM = 4
+
+
+class ShardedTensorTestBase(MultiProcessTestCase):
+    @property
+    def world_size(self):
+        return TEST_GPU_NUM
+
+    def init_pg(self, backend="nccl"):
+        if backend not in ["nccl", "gloo", "mpi"]:
+            raise RuntimeError(f"Backend {backend} not supported!")
+
+        dist.init_process_group(
+            backend=backend,
+            world_size=self.world_size,
+            rank=self.rank,
+            init_method=f"file://{self.file_name}",
+        )
+
+        # set device for nccl pg for collectives
+        if backend == "nccl":
+            torch.cuda.set_device(self.rank)
+
+    def init_rpc(self):
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            _transports=tp_transports()
+        )
+        rpc_backend_options.init_method = f"file://{self.file_name}"
+        for rank in range(self.world_size):
+            rpc_backend_options.set_device_map(
+                f"worker{rank}", {rank: self.rank, self.rank: rank}
+            )
+
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+    def init_comms(self, init_rpc=True, backend="nccl"):
+        if init_rpc:
+            self.init_rpc()
+        self.init_pg(backend=backend)
+
+    def destroy_comms(self, destroy_rpc=True):
+        # Wait for all ranks to reach here before starting shutdown.
+        dist.barrier()
+
+        if destroy_rpc:
+            rpc.shutdown()
+        dist.destroy_process_group()
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    def assert_sharded_tensor_equal(self, st1, st2):
+        st1_local_shards = st1.local_shards()
+        st2_local_shards = st2.local_shards()
+        self.assertEqual(len(st1_local_shards), len(st2_local_shards))
+        for i, st1_local_shard in enumerate(st1_local_shards):
+            self.assertEqual(st1_local_shard.tensor, st2_local_shards[i].tensor)
+            self.assertEqual(st1_local_shard.metadata, st2_local_shards[i].metadata)
+
+        self.assertEqual(st1.metadata(), st2.metadata())
+        self.assertEqual(st1.sharding_spec(), st2.sharding_spec())
+        self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards()))
+
+
+# wrapper to initialize comms (processgroup + rpc)
+def with_comms(func=None, init_rpc=True, backend="nccl"):
+    if func is None:
+        return partial(
+            with_comms,
+            init_rpc=init_rpc,
+            backend=backend,
+        )
+
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        if backend == "nccl" and torch.cuda.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+        self.init_comms(init_rpc=init_rpc, backend=backend)
+        func(self, *args, **kwargs)
+        self.destroy_comms(destroy_rpc=init_rpc)
+
+    return wrapper
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e83bc3a35102a051d42587352c2dcb7967510903
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py
@@ -0,0 +1,137 @@
+# mypy: allow-untyped-defs
+
+import builtins
+
+import torch
+from torch.distributed._shard.sharding_spec import (
+    ChunkShardingSpec,
+    EnumerableShardingSpec,
+    ShardMetadata,
+)
+from torch.distributed._shard.sharding_spec._internals import (
+    get_chunked_dim_size,
+    get_split_size,
+)
+
+
+def generate_chunk_sharding_specs_for_test(sharding_dim):
+    return [
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+            ],
+        ),
+        # Test different ordering. (Case 1)
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+            ],
+        ),
+        # Test different ordering. (Case 2)
+        ChunkShardingSpec(
+            dim=sharding_dim,
+            placements=[
+                "rank:3/cuda:3",
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+            ],
+        ),
+    ]
+
+
+def generate_enumerable_sharding_specs_for_test():
+    return [
+        EnumerableShardingSpec(
+            [
+                ShardMetadata(
+                    shard_offsets=[0, 0],
+                    shard_sizes=[5, 5],
+                    placement="rank:0/cuda:0",
+                ),
+                ShardMetadata(
+                    shard_offsets=[5, 0],
+                    shard_sizes=[5, 5],
+                    placement="rank:1/cuda:1",
+                ),
+                ShardMetadata(
+                    shard_offsets=[0, 5],
+                    shard_sizes=[5, 5],
+                    placement="rank:2/cuda:2",
+                ),
+                ShardMetadata(
+                    shard_offsets=[5, 5],
+                    shard_sizes=[5, 5],
+                    placement="rank:3/cuda:3",
+                ),
+            ]
+        )
+    ]
+
+
+def generate_local_weight_sharding_params_for_test(
+    local_weight, sharded_dim, gpu_num, spec, rank
+):
+    """
+    Shard the local weight based the given spec, so we can compare against
+    the one from sharded tensor.
+
+    Args:
+        local_weight: weight matrix to be sharded.
+        sharded_dim: The dimension which we shard on.
+        gpu_num: number of ranks.
+        spec: sharding spec.
+        rank: # of cuda process.
+
+    Returns:
+        start_pos: start position of sharded weight on the given rank.
+        chunk_size: chunk size of sharded weight on the given rank.
+    """
+    sharding_dim_size = local_weight.size(sharded_dim)
+    split_size = get_split_size(sharding_dim_size, gpu_num)
+    current_offsets = 0
+    start_pos = current_offsets
+    for idx, placement in enumerate(spec.placements):
+        chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
+        if rank == placement.rank():
+            start_pos = current_offsets
+            break
+        current_offsets += chunk_size
+    return start_pos, chunk_size
+
+
+def clone_module_parameter(module, param_name):
+    """
+    Clone a parameter from a given existing module.
+
+    Args:
+        module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned.
+        param_name (str): Name of the parameter of ``module`` that needs to be cloned.
+
+    Returns: cloned tensor as :class:`torch.nn.Parameter`.
+    """
+    tensor = getattr(module, param_name)
+    return torch.nn.Parameter(tensor.detach().clone())
+
+
+def gen_binary_op_func(python_op, inplace=False):
+    src_lines = ["def f(lhs, rhs):"]
+    if "torch" in python_op:
+        src_lines.append(f"  return {python_op}(lhs, rhs)\n")
+    elif inplace:
+        src_lines.append(f"  lhs {python_op}= rhs\n  return lhs\n")
+    else:
+        src_lines.append(f"  return lhs {python_op} rhs\n")
+
+    code_str = "\n".join(src_lines)
+    g = {"torch": torch}
+    builtins.exec(code_str, g)
+    return g["f"]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fe82a8dc43f8f876cb4c8d0c000cda9a32d46fb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py
@@ -0,0 +1,56 @@
+# mypy: allow-untyped-defs
+
+import copy
+import random
+
+import torch
+from torch.distributed._shard import sharded_tensor
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+
+
+PLACEMENTS = [
+    "rank:0/cuda:0",
+    "rank:1/cuda:1",
+    "rank:2/cuda:2",
+    "rank:3/cuda:3",
+]
+
+DEFAULT_GPU_NUM = 4
+
+
+def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
+    spec_list = []
+    for i in range(len(sharding_dims)):
+        random.Random(seed + i).shuffle(PLACEMENTS)
+        spec_list.append(
+            ChunkShardingSpec(
+                dim=sharding_dims[i],
+                placements=copy.deepcopy(PLACEMENTS),
+            )
+        )
+    return spec_list
+
+
+class MyShardedModel2(torch.nn.Module):
+    def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
+        super().__init__()
+        if spec is not None:
+            self.sharded_tensor2 = sharded_tensor.rand(
+                spec, 10, 20, process_group=group, init_rrefs=init_rrefs
+            )
+        else:
+            self.sharded_tensor2 = None
+        self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2))
+
+
+class MyShardedModel1(torch.nn.Module):
+    def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
+        super().__init__()
+        if spec is not None:
+            self.sharded_tensor1 = sharded_tensor.rand(
+                spec, 10, 20, process_group=group, init_rrefs=init_rrefs
+            )
+        else:
+            self.sharded_tensor1 = None
+        self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2))
+        self.submodule = MyShardedModel2(spec, group, init_rrefs)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9390da489851872ec1d0715a0b3e46275e5752b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_shard/test_common.py
@@ -0,0 +1,41 @@
+# mypy: allow-untyped-defs
+
+import torch
+import torch.nn as nn
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+
+class SimpleMegatronLM(nn.Module):
+    def __init__(self, linear_size, rank=None, dtype=torch.float32):
+        super().__init__()
+        self.fc1 = nn.Linear(*linear_size[0], dtype=dtype)
+        self.gelu = nn.GELU()
+        self.fc2 = nn.Linear(*linear_size[1], dtype=dtype)
+        if rank is not None:
+            self.fc1.cuda(rank)
+            self.fc2.cuda(rank)
+
+    def forward(self, inp):
+        return self.fc2(self.gelu(self.fc1(inp)))
+
+    def get_weights(self):
+        if isinstance(self.fc1.weight, ShardedTensor):
+            weight1 = self.fc1.weight.local_tensor()
+        else:
+            weight1 = self.fc1.weight
+
+        if isinstance(self.fc2.weight, ShardedTensor):
+            weight2 = self.fc2.weight.local_tensor()
+        else:
+            weight2 = self.fc2.weight
+
+        return (weight1, weight2)
+
+    def get_biases(self):
+        return (self.fc1.bias, self.fc2.bias)
+
+    def get_weight_grads(self):
+        return (self.fc1.weight.grad, self.fc2.weight.grad)
+
+    def get_bias_grads(self):
+        return (self.fc1.bias.grad, self.fc2.bias.grad)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e9a9a55f6774ab67114abe8d7895fd0829cf88a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/_tensor/common_dtensor.py
@@ -0,0 +1,615 @@
+# mypy: allow-untyped-defs
+
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import itertools
+import sys
+from collections.abc import Iterator, Sequence
+from dataclasses import dataclass
+from functools import partial, wraps
+from typing import Any, Callable, cast, Optional, TypeVar, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch._utils import _get_device_module
+from torch.distributed.tensor import (
+    DeviceMesh,
+    distribute_tensor,
+    Placement,
+    Replicate,
+    Shard,
+)
+from torch.distributed.tensor.parallel import (
+    ColwiseParallel,
+    parallelize_module,
+    PrepareModuleInput,
+    RowwiseParallel,
+    SequenceParallel,
+)
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    MultiThreadedTestCase,
+    run_subtests,
+    skip_if_lt_x_gpu,
+    TEST_SKIPS,
+)
+from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU
+from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
+
+
+if TEST_CUDA:
+    DEVICE_TYPE = "cuda"
+    PG_BACKEND = "nccl"
+    DEVICE_COUNT = _get_device_module("cuda").device_count()
+elif TEST_HPU:
+    DEVICE_TYPE = "hpu"
+    PG_BACKEND = "hccl"
+    DEVICE_COUNT = _get_device_module("hpu").device_count()
+elif TEST_XPU:
+    DEVICE_TYPE = "xpu"
+    PG_BACKEND = "xccl"
+    DEVICE_COUNT = _get_device_module("xpu").device_count()
+else:
+    DEVICE_TYPE = "cpu"
+    PG_BACKEND = "gloo"
+
+NUM_DEVICES = 4
+
+# We use this as a proxy for "multiple GPUs exist"
+if (TEST_CUDA or TEST_XPU or TEST_HPU) and DEVICE_COUNT > 1:
+    # when we actually have multiple GPUs, relax the requirement to smaller counts.
+    NUM_DEVICES = min(NUM_DEVICES, DEVICE_COUNT)
+
+T = TypeVar("T")
+
+
+# simple RMSNorm layer for testing
+class RMSNormPython(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        super().__init__()
+        self.eps = eps
+        self.weight = torch.nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x):
+        output = self._norm(x)
+        return output * self.weight
+
+
+class MLPModule(nn.Module):
+    def __init__(self, device, bias: bool = True):
+        super().__init__()
+        torch.manual_seed(5)
+        self.net1 = nn.Linear(10, 16, bias=bias, device=device)
+        self.relu = nn.ReLU()
+        self.net2 = nn.Linear(16, 10, bias=bias, device=device)
+
+    def forward(self, x):
+        return self.net2(self.relu(self.net1(x)))
+
+    def reset_parameters(self):
+        self.net1.reset_parameters()
+        self.net2.reset_parameters()
+
+
+class MLPStacked(nn.Module):
+    def __init__(self, device, n_layers: int = 2):
+        super().__init__()
+        self.layers = nn.ModuleList([MLPModule(device) for i in range(n_layers)])
+
+    def forward(self, x):
+        for layer in self.layers:
+            x = layer(x)
+        return x
+
+
+@dataclass
+class ModelArgs:
+    n_layers: int = 2
+    vocab_size: int = 8
+    max_seq_len: int = 16
+    dim: int = 16
+    n_heads: int = 4
+    dropout_p: float = 0.1
+    use_attn_mask: bool = True
+    weight_tying: bool = True
+    checkpoint_activations: bool = False
+
+
+class Attention(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        assert args.dim % args.n_heads == 0
+        self.head_dim = args.dim // args.n_heads
+        self.n_heads = args.n_heads
+        self.dropout_p = args.dropout_p
+        self.resid_dropout = nn.Dropout(args.dropout_p)
+        self.use_attn_mask = args.use_attn_mask
+
+        self.wq = nn.Linear(args.dim, args.dim, bias=False)
+        self.wk = nn.Linear(args.dim, args.dim, bias=False)
+        self.wv = nn.Linear(args.dim, args.dim, bias=False)
+        self.wo = nn.Linear(args.dim, args.dim, bias=False)
+
+    def forward(self, x):
+        bsz, seq_len, _ = x.size()
+        queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
+        queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
+        keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
+        values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
+
+        queries = queries.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+        keys = keys.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+        values = values.transpose(1, 2)  # (bsz, n_heads, seq_len, head_dim)
+
+        output = F.scaled_dot_product_attention(
+            queries,
+            keys,
+            values,
+            None,
+            self.dropout_p if self.training else 0,
+            self.use_attn_mask,
+        )
+        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
+        return self.resid_dropout(self.wo(output))
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, hidden_dim, dropout_p):
+        super().__init__()
+        self.w1 = nn.Linear(dim, hidden_dim)
+        self.gelu = nn.GELU()
+        self.w2 = nn.Linear(hidden_dim, dim)
+        self.resid_dropout = nn.Dropout(dropout_p)
+
+    def forward(self, x):
+        return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        self.attention_norm = nn.LayerNorm(args.dim)
+        self.attention = Attention(args)
+        self.ffn_norm = nn.LayerNorm(args.dim)
+        self.feed_forward = FeedForward(
+            args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
+        )
+
+    def forward(self, x):
+        h = x + self.attention(self.attention_norm(x))
+        out = h + self.feed_forward(self.ffn_norm(h))
+        return out
+
+
+# A toy transformer model, partly inspired by the nanoGPT model:
+# https://github.com/karpathy/nanoGPT.
+class Transformer(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+        assert args.vocab_size is not None
+        assert args.max_seq_len is not None
+        self.model_args = args
+        self.max_seq_len = args.max_seq_len
+        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
+        self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
+        self.dropout = nn.Dropout(args.dropout_p)
+        self.layers = nn.ModuleList()
+        for _ in range(args.n_layers):
+            self.layers.append(TransformerBlock(args))
+        self.norm = nn.LayerNorm(args.dim)
+        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
+        if args.weight_tying:
+            self.output.weight = self.tok_embeddings.weight
+        self.checkpoint_activations = args.checkpoint_activations
+
+    def forward(self, tokens):
+        _bsz, seq_len = tokens.size()
+        assert seq_len <= self.max_seq_len
+        h = self.tok_embeddings(tokens)
+        pos = torch.arange(0, seq_len, device=tokens.device)
+        p = self.pos_embeddings(pos)  # positional embeddings of shape (seq_len, dim)
+        h = h + p
+        h = self.dropout(h)
+        for layer in self.layers:
+            if self.checkpoint_activations:
+                h = torch.utils.checkpoint.checkpoint(layer, h, use_reentrant=False)
+            else:
+                h = layer(h)
+        h = self.norm(h)
+        output = self.output(h).float()
+        return output
+
+    @staticmethod
+    def parallelize(
+        module: "Transformer",
+        device_mesh: DeviceMesh,
+        use_seq_parallel: bool,
+        local_output_for_attn: bool = False,
+    ) -> nn.Module:
+        assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
+        # Parallelize the root submodules.
+        if use_seq_parallel:
+            root_plan = {
+                "tok_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Shard(1)
+                ),
+                "pos_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Shard(0)
+                ),
+                "norm": SequenceParallel(),
+            }
+        else:
+            root_plan = {
+                "tok_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Replicate()
+                ),
+                "pos_embeddings": RowwiseParallel(
+                    input_layouts=Replicate(), output_layouts=Replicate()
+                ),
+            }
+
+        module_tp = parallelize_module(module, device_mesh, root_plan)
+        # Parallelize the attention and feed forward submodules.
+        for layer in module_tp.layers:
+            layer_parallelize_plan = {}
+            if use_seq_parallel:
+                layer_parallelize_plan["attention"] = PrepareModuleInput(
+                    input_layouts=Shard(1),
+                    desired_input_layouts=Replicate(),
+                )
+                # shard the RMSNorms
+                layer_parallelize_plan["attention_norm"] = SequenceParallel()
+                layer_parallelize_plan["ffn_norm"] = SequenceParallel()
+            layer_parallelize_plan["attention.wq"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wk"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wv"] = ColwiseParallel(
+                use_local_output=local_output_for_attn
+            )
+            layer_parallelize_plan["attention.wo"] = (
+                RowwiseParallel(output_layouts=Shard(1))
+                if use_seq_parallel
+                else RowwiseParallel()
+            )
+
+            layer_parallelize_plan["feed_forward.w1"] = (
+                ColwiseParallel(input_layouts=Shard(1))
+                if use_seq_parallel
+                else ColwiseParallel()
+            )
+            layer_parallelize_plan["feed_forward.w2"] = (
+                RowwiseParallel(output_layouts=Shard(1))
+                if use_seq_parallel
+                else RowwiseParallel()
+            )
+
+            parallelize_module(layer, device_mesh, layer_parallelize_plan)
+
+        # Parallelize the output submodule. If weight tying is enabled, we need to
+        # make sure output.weight is sharded consistently as tok_embeddings.weight,
+        # at the cost of the all_reduce operation using RowwiseParallel.
+        output_parallelize_plan = (
+            ColwiseParallel(
+                input_layouts=Shard(1),
+                output_layouts=Replicate(),
+            )
+            if use_seq_parallel
+            else ColwiseParallel(output_layouts=Replicate())
+        )
+        parallelize_module(module_tp.output, device_mesh, output_parallelize_plan)
+
+        if local_output_for_attn:
+            for layer in module_tp.layers:
+                layer.attention.n_heads = (
+                    module_tp.model_args.n_heads // device_mesh.size()
+                )
+
+        # Manually set output.weight so that parameters and gradients are shared.
+        if module_tp.model_args.weight_tying:
+            module_tp.output.weight = module_tp.tok_embeddings.weight
+
+        return module_tp
+
+
+def skip_unless_torch_gpu(method: T) -> T:
+    """
+    Test decorator which skips the test unless there's a GPU available to torch.
+
+    >>> # xdoctest: +SKIP
+    >>> @skip_unless_torch_gpu
+    >>> def test_some_method(self) -> None:
+    >>>   ...
+    """
+    # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set.
+    return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method))
+
+
+class DTensorTestBase(MultiProcessTestCase):
+    @property
+    def world_size(self) -> int:
+        return NUM_DEVICES
+
+    @property
+    def device_type(self) -> str:
+        # if enough GPU/XPU/HPU we can use those devices, otherwise we fallback to CPU
+        if not (TEST_CUDA or TEST_XPU or TEST_HPU) or DEVICE_COUNT < self.world_size:
+            return "cpu"
+        else:
+            return DEVICE_TYPE
+
+    @property
+    def backend(self) -> str:
+        backend = dist.get_default_backend_for_device(DEVICE_TYPE)
+        return backend
+
+    def build_device_mesh(self) -> DeviceMesh:
+        return DeviceMesh(self.device_type, list(range(self.world_size)))
+
+    def init_pg(self, eager_init) -> None:
+        if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+        if self.backend not in [
+            "nccl",
+            "gloo",
+            "mpi",
+            "cpu:gloo,cuda:nccl",
+            "hccl",
+            "xccl",
+        ]:
+            raise RuntimeError(f"Backend {self.backend} not supported!")
+
+        device_id = None
+        if "nccl" in self.backend or "xccl" in self.backend:
+            # set device for nccl pg for collectives
+            torch.accelerator.set_device_index(self.rank)
+            # we only need to set device_id for nccl backend with eager init
+            device_id = (
+                torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
+            )
+        # For nccl backend, bind the device to the process if device_id is not None
+        # so the nccl communicator is immediately formed and we can use `ncclCommSplit`
+        # for form subgroup to avoid unnecesssary overhead.
+        dist.init_process_group(
+            backend=self.backend,
+            world_size=self.world_size,
+            rank=self.rank,  # pyre-ignore[16]
+            init_method=f"file://{self.file_name}",  # pyre-ignore[16]
+            device_id=device_id,
+        )
+
+    def destroy_pg(self, device_id: Optional[int] = None) -> None:
+        # Wait for all ranks to reach here before starting shutdown.
+        # FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895
+        # dist.all_reduce(torch.zeros((1,), device="cuda" if TEST_CUDA else "cpu"))
+        # FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs:
+        #  test_dtensor.py  -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion
+        if device_id is None:
+            device_id = (
+                torch.cuda.current_device() if self.device_type == "cuda" else self.rank
+            )
+        dist.barrier(device_ids=[device_id])
+        dist.destroy_process_group()
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    # pyre-ignore[2]:
+    def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
+        out = op_call(*args, **kwargs)
+        dtc = DTensorConverter(mesh, args, kwargs)
+        for d_args, d_kwargs in dtc:
+            # pyre can't find assertTrue anymore?
+            self.assertEqual(dtc.successful(), True)
+            d_out = op_call(*d_args, **d_kwargs)
+            self.assertEqual(d_out.full_tensor(), out)
+
+    def run_subtests(self, *args, **kwargs):
+        return run_subtests(self, *args, **kwargs)
+
+
+TestFunc = Callable[[...], object]
+
+
+# wrapper to initialize comms (processgroup)
+def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc:
+    def decorator(func, eager_init: bool = False):
+        @wraps(func)  # pyre-ignore[6]
+        def wrapper(
+            self, *args: tuple[object], **kwargs: dict[str, Any]  # type: ignore[misc]
+        ) -> None:
+            self.init_pg(eager_init)
+
+            try:
+                func(self, *args, **kwargs)  # type: ignore[misc]
+            except Exception as e:
+                dist.destroy_process_group()
+                raise e
+
+            self.destroy_pg()
+
+        return wrapper
+
+    return (
+        decorator(func=eager_init)
+        if callable(eager_init)
+        else partial(decorator, eager_init=eager_init)
+    )
+
+
+class DTensorOpTestBase(MultiThreadedTestCase):
+    @property
+    def world_size(self) -> int:
+        return NUM_DEVICES
+
+    @property
+    def device_type(self) -> str:
+        return DEVICE_TYPE
+
+    def build_device_mesh(self):
+        return DeviceMesh(self.device_type, list(range(self.world_size)))
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_threads()
+
+
+# This is a class for converting args/kwargs of an op into distributed args/kwargs
+class DTensorConverter:
+    def __init__(
+        self,
+        mesh: DeviceMesh,
+        args: tuple[object, ...],
+        kwargs: dict[str, object],
+    ) -> None:
+        self.hit = 0
+        self.miss = 0
+        self.mesh = mesh
+        self.args = args
+        self.kwargs = kwargs
+        flatten_args, flatten_args_spec = tree_flatten(args)
+        flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
+
+        self.flatten_args: list[object] = flatten_args
+        self.flatten_args_spec: TreeSpec = flatten_args_spec
+        self.flatten_kwargs: list[object] = flatten_kwargs
+        self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
+
+        choices_for_args = [
+            self.gen_sharding_choices_for_arg(arg)
+            for arg in self.flatten_args
+            if isinstance(arg, torch.Tensor)
+        ]
+
+        choices_for_args.extend(
+            self.gen_sharding_choices_for_arg(arg)
+            for arg in self.flatten_kwargs
+            if isinstance(arg, torch.Tensor)
+        )
+
+        self.sharding_combs: Iterator[Sequence[Placement]] = iter(
+            itertools.product(*choices_for_args)
+        )
+
+    def successful(self) -> bool:
+        return self.hit > 0 and self.miss == 0
+
+    def is_supported_tensor(self, t: torch.Tensor) -> bool:
+        # TODO: dist tensor need to support quantized and sparse
+        # tensors, quantized tensor might be relatively easy, but
+        # sparse tensor have special layouts that we need to possibly
+        # deal with, until we are clear about them, we don't officially
+        # support them.
+        return not any(
+            [
+                t.is_sparse_csr,
+                t.is_sparse,
+                t.is_mkldnn,
+                t.is_quantized,
+                t.is_nested,
+                torch._is_functional_tensor(t),
+                t.is_neg(),
+                t.is_conj(),
+                t.device.type in ("lazy", "meta"),
+                # We need a way to test if a tensor is batched but there
+                # is no official APi to do it
+                # torch._C._is_batched(t),
+            ]
+        )
+
+    def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]:
+        mesh_size = self.mesh.size()
+        sharding_choices: list[Placement] = [Replicate()]
+        # c10d collective does not support bool tensor
+        # for bool tensor we treat it as replicated
+        if arg.dtype != torch.bool:
+            # only generating choices with: replicate, or sharding
+            # evenly on a dimension that could be sharded
+            sharding_choices = sharding_choices + [
+                Shard(i)
+                for i, s in enumerate(arg.shape)
+                if s > 1 and s % mesh_size == 0
+            ]
+        # TODO: add multi mesh choices
+        # all_choices = itertools.product(
+        #     *(self.mesh.ndim * [sharding_choices])
+        # )
+        return sharding_choices
+
+    def __iter__(self) -> "DTensorConverter":
+        return self
+
+    def __next__(self) -> tuple[tuple[object, ...], dict[str, object]]:
+        try:
+            next_sharding_choices = next(self.sharding_combs)
+            idx = 0
+
+            new_args: list[object] = []
+            for arg in self.flatten_args:
+                if isinstance(arg, torch.Tensor):
+                    new_args.append(
+                        self.to_dist_tensor(
+                            arg, self.mesh, [next_sharding_choices[idx]]
+                        )
+                    )
+                    idx += 1
+                else:
+                    new_args.append(arg)
+
+            new_kwargs: list[object] = []
+            for arg in self.flatten_kwargs:
+                if isinstance(arg, torch.Tensor):
+                    new_kwargs.append(
+                        self.to_dist_tensor(
+                            arg, self.mesh, [next_sharding_choices[idx]]
+                        )
+                    )
+                    idx += 1
+                else:
+                    new_kwargs.append(arg)
+
+            return (
+                tree_unflatten(new_args, self.flatten_args_spec),
+                tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
+            )
+        except StopIteration as e:
+            raise StopIteration from e
+
+    def to_dist_tensor(
+        self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
+    ) -> torch.Tensor:
+        if type(t) is torch.Tensor or type(t) is nn.Parameter:
+            if self.is_supported_tensor(t):
+                self.hit += 1
+                if t.ndim == 0:
+                    # scalar tensor by default will be replicated
+                    r = distribute_tensor(t, mesh, [Replicate()] * mesh.ndim)
+                else:
+                    # distribute non-scalar tensors
+                    r = distribute_tensor(t, mesh, placements)
+                if type(t) is nn.Parameter:
+                    r = nn.Parameter(  # type: ignore[assignment]
+                        r, requires_grad=r.requires_grad
+                    )
+                return r
+            else:
+                self.miss += 1
+                return t
+        elif torch.overrides.is_tensor_like(t):
+            # Blindly converting tensor subclasses to dist tensor can cause
+            # unpredictable problems, we explicitly disable this conversion
+            # for now (i.e. we don't support DTensor holding tensor subclass
+            # until there's a strong reason later).
+            self.miss += 1
+            return t
+        else:
+            raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}")
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d4d4a216270ce60e68b562c579eac68688e2a02
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/checkpoint_utils.py
@@ -0,0 +1,159 @@
+# mypy: allow-untyped-defs
+
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+import io
+import os
+import shutil
+import tempfile
+from functools import wraps
+from typing import Any, Callable, cast, IO, Optional
+
+# introduced as collections.abc.Buffer in Python 3.12
+from typing_extensions import Buffer
+
+import torch.distributed as dist
+from torch.distributed.checkpoint._extension import (
+    ExtensionRegistry,
+    StreamTransformExtension,
+)
+
+
+class Rot13Example(StreamTransformExtension):
+    """
+    This is an example stream transform extension which just does rot13 on each
+    alphanumeric character of the stream.  It is mainly intended as a demonstration
+    and for testing; there isn't a production use case for this.
+    """
+
+    def __init__(self, chunk_size: int = io.DEFAULT_BUFFER_SIZE) -> None:
+        super().__init__()
+        self._chunk_size = chunk_size
+
+    @staticmethod
+    def from_descriptor(version: str) -> "Rot13Example":
+        if version.partition(".")[0] != "1":
+            raise ValueError(f"Unknown extension {version=}")
+        return Rot13Example()
+
+    @staticmethod
+    def registry_name() -> str:
+        return "stream.rot13"
+
+    def get_descriptor(self) -> str:
+        return f"{self.registry_name()}/1"
+
+    @staticmethod
+    def _rot13bytes(b: Buffer, count: int) -> None:
+        b = memoryview(b)
+        for i in range(count):
+            ch = b[i]
+            if ch >= ord("A") and ch <= ord("Z"):
+                ch += ord("a") - ord("A")
+            elif ch >= ord("a") and ch <= ord("z"):
+                ch += ord("A") - ord("a")
+            b[i] = ch
+
+    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
+        class Writer(io.RawIOBase):
+            def __init__(self, output: IO[bytes]) -> None:
+                self.output = output
+
+            def writeable(self) -> bool:
+                return True
+
+            def write(self, b: Buffer) -> Optional[int]:
+                # Don't mutate the input
+                chunk = bytearray(b)
+                Rot13Example._rot13bytes(chunk, len(chunk))
+                return self.output.write(chunk)
+
+            def flush(self) -> None:
+                self.output.flush()
+
+        return cast(IO[bytes], Writer(output))
+
+    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
+        class Reader(io.RawIOBase):
+            def __init__(self, input: IO[bytes]) -> None:
+                self.input = input
+
+            def readable(self) -> bool:
+                return True
+
+            def readinto(self, b: Buffer) -> Optional[int]:
+                if hasattr(self.input, "readinto"):
+                    count = self.input.readinto(b)
+                else:
+                    # It's possible self.input is an IO[bytes] with no readinto method.
+                    # In that case, we emulate with a read and copy.  In practice,
+                    # all of the current concrete extensions have readinto.
+                    view = memoryview(b)
+                    r = self.input.read(len(view))
+                    if r is None:
+                        count = None
+                    else:
+                        count = len(r)
+                        view[:count] = r
+                if count == 0 or count is None:
+                    return count
+
+                Rot13Example._rot13bytes(b, count)
+                return count
+
+            def seekable(self) -> bool:
+                return self.input.seekable()
+
+            def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
+                return self.input.seek(offset, whence)
+
+            def tell(self) -> int:
+                return self.input.tell()
+
+        return cast(IO[bytes], Reader(input))
+
+
+def get_test_extension_registry() -> ExtensionRegistry:
+    registry = ExtensionRegistry()
+    registry.register(Rot13Example)
+    return registry
+
+
+def with_temp_dir(
+    func: Optional[Callable] = None,
+) -> Optional[Callable]:
+    """
+    Wrapper to initialize temp directory for distributed checkpoint.
+    """
+    assert func is not None
+
+    @wraps(func)
+    def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
+        if dist.is_initialized():
+            # Only create temp_dir when rank is 0
+            if dist.get_rank() == 0:
+                temp_dir = tempfile.mkdtemp()
+                print(f"Using temp directory: {temp_dir}")
+            else:
+                temp_dir = ""
+            object_list = [temp_dir]
+
+            # Broadcast temp_dir to all the other ranks
+            os.sync()
+            dist.broadcast_object_list(object_list)
+            self.temp_dir = object_list[0]
+            os.sync()
+        else:
+            temp_dir = tempfile.mkdtemp()
+            print(f"No process group initialized, using temp directory: {temp_dir}")
+            self.temp_dir = temp_dir
+
+        try:
+            func(self, *args, **kwargs)
+        finally:
+            if dist.is_initialized() and dist.get_rank() == 0:
+                shutil.rmtree(self.temp_dir, ignore_errors=True)
+            else:
+                shutil.rmtree(self.temp_dir, ignore_errors=True)
+
+    return wrapper
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d79907bdbed55911b5eb319efa69adc7a56dcb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/common_state_dict.py
@@ -0,0 +1,170 @@
+# mypy: allow-untyped-defs
+
+# Owner(s): ["oncall: distributed"]
+
+import copy
+from itertools import chain
+from typing import Any
+
+import torch
+import torch.nn as nn
+from torch.distributed._sharded_tensor import ShardedTensor
+from torch.distributed._state_dict_utils import _gather_state_dict
+from torch.distributed.checkpoint.state_dict import (
+    _PG,
+    _STATE,
+    set_state_dict,
+    StateDictOptions,
+)
+from torch.distributed.tensor import DTensor
+
+
+class VerifyStateDictMixin:
+    def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False):
+        if isinstance(dist_tensor, (DTensor, ShardedTensor)):
+            dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey")
+
+        if offload_to_cpu:
+            orig_tensor = orig_tensor.cpu()
+            dist_tensor = dist_tensor.cpu()
+        self.assertTrue(isinstance(dist_tensor, torch.Tensor))
+        self.assertTrue(torch.allclose(orig_tensor, dist_tensor))
+
+    def _verify_msd(
+        self,
+        msd: dict[str, Any],
+        dist_msd: dict[str, Any],
+        options: StateDictOptions = StateDictOptions(),
+        offload_to_cpu=False,
+    ) -> None:
+        if not options.ignore_frozen_params:
+            self.assertEqual(len(msd), len(dist_msd))
+        for fqn, param in msd.items():
+            dist_param = dist_msd.get(fqn, None)
+            if not options.ignore_frozen_params:
+                self.assertIsNotNone(dist_param, f"{fqn=}")
+                try:
+                    self._compare_tensor(param, dist_param, offload_to_cpu)
+                except AssertionError as e:
+                    raise AssertionError(
+                        f"{fqn} has mismatched value {param} {dist_param}"
+                    ) from e
+            elif dist_param is None:
+                self.assertFalse(param.requires_grad, f"{fqn=}")
+
+    def _verify_osd(
+        self,
+        model: nn.Module,
+        optim: torch.optim.Optimizer,
+        osd: dict[str, Any],
+        dist_osd: dict[str, Any],
+    ) -> None:
+        params = list(chain.from_iterable(g["params"] for g in optim.param_groups))
+        param_pid_mapping = dict(zip(params, range(len(params))))
+        fqn_pid_mapping = {}
+        for fqn, param in model.named_parameters():
+            pid = param_pid_mapping[param]
+            fqn_pid_mapping[fqn] = pid
+            fqn_pid_mapping[pid] = fqn
+        # Check optimizer_state_dict state
+
+        self.assertEqual(len(osd[_STATE]), len(dist_osd[_STATE]))
+        for pid, states in osd[_STATE].items():
+            fqn = fqn_pid_mapping[pid]
+            dist_states = dist_osd[_STATE].get(fqn, None)
+            self.assertIsNotNone(dist_states, fqn)
+            self.assertEqual(len(states), len(dist_states))
+            for key, state in states.items():
+                dist_state = states.get(key, None)
+                self.assertIsNotNone(dist_state)
+                self._compare_tensor(state, dist_state)
+
+        # Check optimizer_state_dict param_group
+        old_dist_osd_pg = dist_osd[_PG]
+        if len(osd[_PG]) != len(dist_osd[_PG]):
+            self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG]))
+            new_pg = copy.deepcopy(dist_osd[_PG][0])
+            new_pg["params"] = []
+            for dist_group in dist_osd[_PG]:
+                new_pg["params"].extend(dist_group["params"])
+            dist_osd[_PG] = [new_pg]
+
+        self.assertEqual(len(osd[_PG]), len(dist_osd[_PG]))
+        for group, dist_group in zip(osd[_PG], dist_osd[_PG]):
+            self.assertEqual(len(group), len(dist_group))
+            for key, value in group.items():
+                # Below doesn't work because param_groups can have None
+                # values.
+                # dist_value = dist_group.get(key, None)
+                # self.assertIsNotNone(dist_value, (dist_group, group))
+                dist_value = dist_group[key]
+                if key == "params":
+                    fqns = [fqn_pid_mapping[pid] for pid in value]
+                    self.assertEqual(sorted(fqns), sorted(dist_value))
+                else:
+                    self.assertEqual(value, dist_value)
+        dist_osd[_PG] = old_dist_osd_pg
+
+    def _verify_osd_by_load(
+        self,
+        model: nn.Module,
+        optim: torch.optim.Optimizer,
+        new_optim: torch.optim.Optimizer,
+        dist_osd: dict[str, Any],
+    ) -> None:
+        new_dist_osd = _gather_state_dict(dist_osd)
+        set_state_dict(
+            model,
+            optimizers=new_optim,
+            model_state_dict={},
+            optim_state_dict=new_dist_osd,
+        )
+        self.assertEqual(optim.state_dict(), new_optim.state_dict())
+
+
+class FusionEmbedding(nn.Module):
+    def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
+        super().__init__()
+        self.embedding = nn.Embedding(vocab_size, embed_dim)
+        self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
+
+
+class FusionEmbeddingWithHook(nn.Module):
+    def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None:
+        super().__init__()
+        self.embedding = nn.Embedding(vocab_size, embed_dim)
+        self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim)
+        self._register_state_dict_hook(FusionEmbeddingWithHook._state_dict_hook)
+        self._register_load_state_dict_pre_hook(
+            FusionEmbeddingWithHook._load_state_dict_hook, with_module=True
+        )
+
+    def _state_dict_hook(self, destination, prefix, keep_vars):
+        """Remove "embedding" from the original embedding in the state_dict
+        name. This keeps the original state dict name for the embedding
+        from before fusing with the FusionEmbedding.
+        """
+        key = prefix + "embedding.weight"
+        new_key = prefix + "weight"
+        destination[new_key] = destination[key]
+        del destination[key]
+
+    def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs):
+        """Apply extra "embedding" prefix to the state_dict key to
+        account for the FusionEmbedding wrapping.
+        """
+        if state_dict:
+            key = prefix + "weight"
+            new_key = prefix + "embedding.weight"
+            state_dict[new_key] = state_dict[key]
+            del state_dict[key]
+
+
+class FusionEmbeddingWithModifier(FusionEmbeddingWithHook):
+    # _fqn_modifiers is a private function as a contract between DSD. When users change the state_dict
+    # keys, they need to provide a mapping from the new key to the original key. This is used to ensure
+    # consistency between the state_dict keys and fqn.
+    def _fqn_modifiers(self) -> dict[str, str]:
+        return {
+            "weight": "embedding",
+        }
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac9252d498e06659c94b76fb265bb404b01f2b8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py
@@ -0,0 +1,743 @@
+# mypy: allow-untyped-defs
+
+import contextlib
+import enum
+import logging
+import os
+import threading
+from typing import NamedTuple
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.nn as nn
+from torch.distributed import rpc
+from torch.distributed.nn import RemoteModule
+from torch.nn.parallel import DistributedDataParallel
+from torch.testing._internal.common_distributed import (
+    requires_gloo,
+    requires_nccl,
+    skip_if_lt_x_gpu,
+    skip_if_rocm_multiprocess,
+)
+from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+NUM_EM_ROW = 2
+D_SPARSE = 3
+D_DENSE = 2
+D_HID = 3
+D_OUT = 1
+NUM_TRAINERS = 4
+# Trainers + the master + the remote worker
+WORLD_SIZE = NUM_TRAINERS + 2
+TRAINER_RANKS = list(range(NUM_TRAINERS))
+REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1
+MASTER_RANK = REMOTE_WORKER_RANK + 1
+
+
+class DdpMode(enum.Enum):
+    # Don't apply DDP
+    NONE = enum.auto()
+    # Apply DDP to the top level nn.Module
+    OUTSIDE = enum.auto()
+    # Embed DDP inside the top level nn.Module
+    INSIDE = enum.auto()
+
+
+def init_logger():
+    logger = logging.getLogger(__name__)
+    level = logging.DEBUG if "debug" in os.environ else logging.INFO
+    logger.setLevel(level)
+    console = logging.StreamHandler()
+    formatter = logging.Formatter(
+        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
+    )
+    console.setFormatter(formatter)
+    console.setLevel(level)
+    # add the handlers to the logger
+    logger.addHandler(console)
+    logger.propagate = False
+    return logger
+
+
+gLogger = init_logger()
+
+
+class FeatureSet(NamedTuple):
+    """A feature set has 2 types of features"""
+
+    dense_features: torch.Tensor
+    sparse_features: torch.LongTensor
+    values: torch.Tensor
+
+
+def _call_method(method, rref, *args, **kwargs):
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def _remote_method(method, rref, *args, **kwargs):
+    args_tup = tuple([method, rref] + list(args))
+    return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
+
+
+def _remote_method_async(method, rref, *args, **kwargs):
+    args_tup = tuple([method, rref] + list(args))
+    return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
+
+
+class RemoteEM(nn.Module):
+    def __init__(self, num_embeddings: int, embedding_dim: int):
+        gLogger.info("Initing RemoteEM with %s %s", num_embeddings, embedding_dim)
+        super().__init__()
+        init_em = [0.5] * embedding_dim
+        self.em = nn.EmbeddingBag(
+            num_embeddings,
+            embedding_dim,
+            _weight=torch.tensor([init_em] * num_embeddings),
+        )
+
+    def forward(self, input: torch.Tensor):
+        gLogger.debug("Running RemoteEM.forward() on: %s", input)
+        return self.em(input, offsets=torch.LongTensor(range(input.shape[0])))
+
+
+# Return a linear module with predefined parameters.
+def getLinear(d_in, d_out):
+    l = nn.Linear(d_in, d_out, bias=False)
+    w = torch.ones((d_out, d_in))
+    w[0][0] = -1
+    w.requires_grad_()
+    l.weight.data = w
+    return l
+
+
+class RemoteNet(nn.Module):
+    def __init__(self, d_in: int, d_out: int):
+        gLogger.info("Initing RemoteNet with %s %s", d_in, d_out)
+        super().__init__()
+        self.fc = getLinear(d_in, d_out)
+        self.relu = nn.ReLU()
+
+    def forward(self, input: torch.Tensor):
+        gLogger.debug("Running RemoteNet.forward() on: %s", input)
+        return self.relu(self.fc(input))
+
+
+class HybridModel(nn.Module):
+    def __init__(
+        self,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+        process_group_for_ddp: dist.ProcessGroup = None,
+    ):
+        super().__init__()
+        self.remote_em_rref = remote_em_rref
+        self.remote_net_rref = remote_net_rref
+        self.fc1 = getLinear(D_DENSE, D_DENSE)
+        self.fc2 = getLinear(D_HID, D_OUT)
+
+        self.non_ddp_params = tuple(self.fc1.parameters()) + tuple(
+            self.fc2.parameters()
+        )
+        self.ddp_params = ()
+
+        if process_group_for_ddp is not None:
+            self.non_ddp_params, self.ddp_params = (
+                tuple(self.fc1.parameters()),
+                tuple(self.fc2.parameters()),
+            )
+            gLogger.info("Use DDP for the second local net.")
+            self.fc2 = DistributedDataParallel(
+                self.fc2, check_reduction=True, process_group=process_group_for_ddp
+            )
+
+        gLogger.info(
+            "HybridModel has %s groups of parameters.", len(list(self.parameters()))
+        )
+
+    def forward(self, input: FeatureSet):
+        gLogger.debug("Running HybridModel.forward on %s", input)
+        sparse = _remote_method(
+            RemoteEM.forward, self.remote_em_rref, input.sparse_features
+        )
+        # The same size of mini batch.
+        assert sparse.shape[0] == input.dense_features.shape[0]
+        dense = self.fc1(input.dense_features)
+        x = torch.cat((dense, sparse), 1)
+        gLogger.debug("Concatenated feature: %s", x)
+        x = _remote_method(RemoteNet.forward, self.remote_net_rref, x)
+        return self.fc2(x)
+
+
+class Trainer:
+    def __init__(
+        self,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+        ddp_mode: DdpMode,
+        rank: int,
+    ):
+        self.rank = rank
+        self.trainer_group = (
+            dist.new_group(TRAINER_RANKS)
+            if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE)
+            else None
+        )
+        self.remote_em_rref = remote_em_rref
+        self.remote_net_rref = remote_net_rref
+        self.hybrid_module = HybridModel(
+            self.remote_em_rref,
+            self.remote_net_rref,
+            self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
+        )
+        self.ddp_params, self.non_ddp_params = (
+            self.hybrid_module.ddp_params,
+            self.hybrid_module.non_ddp_params,
+        )
+        if ddp_mode == DdpMode.OUTSIDE:
+            gLogger.info("Wrapping the whole hybrid module into DDP.")
+            self.ddp_params += self.non_ddp_params
+            self.non_ddp_params = ()
+            self.hybrid_module = DistributedDataParallel(
+                self.hybrid_module,
+                check_reduction=True,
+                process_group=self.trainer_group,
+            )
+        gLogger.info(
+            "Succeeded in creating a HybridModel instance with "
+            "%s ddp params and %s other local params.",
+            len(self.ddp_params),
+            len(self.non_ddp_params),
+        )
+
+    def destroy_pg(self):
+        if self.trainer_group:
+            dist.destroy_process_group(self.trainer_group)
+
+    def train_batch(
+        self,
+        mini_batch: FeatureSet,
+        trainer_has_less_inputs: bool,
+        simulate_uneven_inputs: bool,
+    ):
+        grads_dict = None
+
+        if not simulate_uneven_inputs:
+            input_batches = [mini_batch]
+        else:
+            # Split into microbatches, and trim to simulate uneven inputs.
+            dense_features = mini_batch.dense_features
+            sparse_features = mini_batch.sparse_features
+            values = mini_batch.values
+
+            dense_microbatch = torch.split(dense_features, 2)
+            sparse_microbatch = torch.split(sparse_features, 2)
+            values_microbatch = torch.split(values, 2)
+            batches = []
+            for d, s, v in zip(dense_microbatch, sparse_microbatch, values_microbatch):
+                feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v)
+                batches.append(feature_set)
+
+            if trainer_has_less_inputs:
+                input_batches = batches[: len(batches) // 2]
+                gLogger.info(
+                    "Trainer reduced input patches from %s "
+                    "to %s to simulate uneven inputs.",
+                    len(batches),
+                    len(input_batches),
+                )
+            else:
+                input_batches = batches
+
+        with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext():
+            for b in input_batches:
+                with dist_autograd.context() as context_id:
+                    output = self.hybrid_module.forward(b)
+                    loss = (output * mini_batch.values).sum()
+                    dist_autograd.backward(context_id, [loss])
+                    grads_dict = dist_autograd.get_gradients(context_id)
+                    gLogger.info(
+                        "Loss is %s for mini batch: %s. "
+                        "Grads dict has %s entries: %s",
+                        loss,
+                        mini_batch,
+                        len(grads_dict),
+                        grads_dict,
+                    )
+        return (
+            tuple(grads_dict[param] for param in self.ddp_params),
+            tuple(grads_dict[param] for param in self.non_ddp_params),
+        )
+
+
+def get_training_examples():
+    n = 16
+    training_examples = FeatureSet(
+        dense_features=torch.zeros((n, D_DENSE)),
+        sparse_features=torch.zeros(n, dtype=torch.long),
+        values=torch.zeros(n),
+    )
+    idx = 0
+    # Every example has another one that has exactly the same features but an
+    # opposite value. Therefore, their grads cancel each other in all-reduce.
+    for value in (-1, 1):
+        for x in (-1.0 * value, 1.0 * value):
+            for y in (1.0 * value, -1.0 * value):
+                for z in (0, 1):
+                    training_examples.dense_features[idx, :] = torch.tensor((x, y))
+                    training_examples.sparse_features[idx] = z
+                    training_examples.values[idx] = value
+                    idx += 1
+
+    # Split the examples among NUM_TRAINERS trainers
+    assert 0 == (n % NUM_TRAINERS)
+    examples_per_trainer = int(n / NUM_TRAINERS)
+    return [
+        FeatureSet(
+            dense_features=training_examples.dense_features[
+                start : start + examples_per_trainer, :
+            ],
+            sparse_features=training_examples.sparse_features[
+                start : start + examples_per_trainer
+            ],
+            values=training_examples.values[start : start + examples_per_trainer],
+        )
+        for start in range(0, n, examples_per_trainer)
+    ]
+
+
+shutdown_signal = threading.Condition()
+
+
+def set_shutdown_signal():
+    global shutdown_signal
+    with shutdown_signal:
+        shutdown_signal.notify()
+
+
+class DdpUnderDistAutogradTest(RpcAgentTestFixture):
+    @property
+    def world_size(self) -> int:
+        return WORLD_SIZE
+
+    def remote_worker_name(self) -> str:
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{REMOTE_WORKER_RANK}"
+
+    def trainer_name(self, rank):
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{rank}"
+
+    def _remote_worker_process(self, ddp_mode):
+        gLogger.info("The remote worker is running.")
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
+            # new_group needs to be called on ranks.
+            dist.new_group(TRAINER_RANKS)
+
+        global shutdown_signal
+        with shutdown_signal:
+            shutdown_signal.wait()
+        gLogger.info("Exiting remote worker.")
+        dist.destroy_process_group()
+
+    def _trainer_process(self, rank: int):
+        gLogger.info("Running the trainer #%s...", rank)
+        gLogger.info(
+            "Initing trainer process group by trainer #%s with ranks %s",
+            rank,
+            TRAINER_RANKS,
+        )
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        gLogger.info("Waiting for shutdown signal on trainer #%s...", rank)
+
+        global shutdown_signal
+        with shutdown_signal:
+            shutdown_signal.wait()
+        gLogger.info("Exiting the trainer #%s...", rank)
+        dist.destroy_process_group()
+
+    def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool):
+        gLogger.info("Running the master process...")
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        remote_em_rref = rpc.remote(
+            self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE)
+        )
+        remote_net_rref = rpc.remote(
+            self.remote_worker_name(), RemoteNet, args=(D_DENSE + D_SPARSE, D_HID)
+        )
+        gLogger.info("Created remote rrefs on master")
+        self.do_test_on_master(
+            ddp_mode, simulate_uneven_inputs, remote_em_rref, remote_net_rref
+        )
+
+    def do_test_on_master(
+        self,
+        ddp_mode: DdpMode,
+        simulate_uneven_inputs: bool,
+        remote_em_rref: rpc.RRef,
+        remote_net_rref: rpc.RRef,
+    ):
+        if simulate_uneven_inputs:
+            gLogger.info(
+                "Running DDP + RPC test with simulating uneven inputs across trainers."
+            )
+
+        trainer_rrefs = []
+        for rank in TRAINER_RANKS:
+            trainer = self.trainer_name(rank)
+            trainer_rrefs.append(
+                rpc.remote(
+                    trainer,
+                    Trainer,
+                    args=(remote_em_rref, remote_net_rref, ddp_mode, rank),
+                )
+            )
+
+        if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
+            # new_group needs to be called on ranks.
+            dist.new_group(TRAINER_RANKS)
+
+        training_examples = get_training_examples()
+        for _ in range(3):
+            futures = []
+            num_trainers = len(trainer_rrefs)
+            for idx, trainer_rref in enumerate(trainer_rrefs):
+                # Half the trainers will deplete inputs earlier than the rest.
+                trainer_has_less_inputs = (
+                    simulate_uneven_inputs and idx < num_trainers // 2
+                )
+                futures.append(
+                    _remote_method_async(
+                        Trainer.train_batch,
+                        trainer_rref,
+                        training_examples[idx],
+                        trainer_has_less_inputs,
+                        simulate_uneven_inputs,
+                    )
+                )
+
+            for future in futures:
+                ddp_grads, non_ddp_grads = future.wait()
+                # When there are uneven inputs, it is not necessary that grads
+                # cancel each other out, since some trainers contribute 0 grad.
+                if not simulate_uneven_inputs:
+                    for grad in ddp_grads:
+                        self.assertEqual(
+                            grad,
+                            torch.zeros_like(grad),
+                            msg=f"The grad for any ddp parameter should be zeros, because "
+                            "the training examples' grads cancel each other. Received "
+                            f"gradient {grad}",
+                        )
+                for grad in non_ddp_grads:
+                    self.assertNotEqual(
+                        grad,
+                        torch.zeros_like(grad),
+                        msg="The grad for any non-ddp parameter shouldn't be zeros",
+                    )
+
+        # Destroy process groups
+        for idx, trainer_rref in enumerate(trainer_rrefs):
+            _remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
+
+        # Send shutdown signals.
+        for rank in TRAINER_RANKS:
+            trainer = self.trainer_name(rank)
+            rpc.rpc_sync(trainer, set_shutdown_signal, args=())
+
+        rpc.rpc_sync(self.remote_worker_name(), set_shutdown_signal, args=())
+
+    def _do_test(self, ddp_mode, simulate_uneven_inputs=False):
+        if self.rank == MASTER_RANK:
+            self._master_process(ddp_mode, simulate_uneven_inputs)
+        elif self.rank == REMOTE_WORKER_RANK:
+            self._remote_worker_process(ddp_mode)
+        elif self.rank in TRAINER_RANKS:
+            self._trainer_process(self.rank)
+        else:
+            raise RuntimeError(f"Unknown process rank: {self.rank}")
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_no_ddp(self):
+        self._do_test(DdpMode.NONE)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_outside(self):
+        self._do_test(DdpMode.OUTSIDE)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_outside_uneven_inputs(self):
+        self._do_test(DdpMode.OUTSIDE, simulate_uneven_inputs=True)
+
+    @requires_gloo()
+    @dist_init
+    def test_backward_ddp_inside(self):
+        self._do_test(DdpMode.INSIDE)
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonDdpComparisonTest(RpcAgentTestFixture):
+    @property
+    def world_size(self) -> int:
+        return NUM_TRAINERS
+
+    def trainer_name(self, rank):
+        # The name has to be consistent with that in 'dist_init' decorator.
+        return f"worker{rank}"
+
+    @staticmethod
+    def get_remote_grads(rref, context_id):
+        return dist_autograd.get_gradients(context_id)[rref.local_value().weight]
+
+
+class DdpComparisonTest(CommonDdpComparisonTest):
+    def _run_test_ddp_comparision(self, simulate_uneven_inputs=False):
+        gLogger.info("Running trainer rank: %s", self.rank)
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            # Postfix file_name with "pg" since file_name is also used by RPC agent
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=f"{self.file_name}_pg"),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+        net = nn.Linear(2, 3)
+        ddp_net = DistributedDataParallel(net)
+
+        # Odd ranks join early if simulate_uneven_inputs.
+        num_inputs = 1
+        if simulate_uneven_inputs:
+            if self.rank % 2 == 0:
+                num_inputs += 2
+        inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]
+
+        if simulate_uneven_inputs:
+            gLogger.info(
+                "Rank %s training with %s inputs.", self.rank, len(inputs_list)
+            )
+
+        # Use distributed autograd. The gradients will be in RPC context map.
+        grads_dict = {}
+        with ddp_net.join(simulate_uneven_inputs):
+            for i, inputs in enumerate(inputs_list):
+                with dist_autograd.context() as context_id:
+                    loss = ddp_net(inputs).norm()
+                    dist_autograd.backward(context_id, [loss])
+                    grads_dict = dist_autograd.get_gradients(context_id)
+                gLogger.info("Trainer #%s got grad dict: %s", self.rank, grads_dict)
+
+                # Use local autograd. The gradients will be in each variable's '.grad'.
+                ddp_net.zero_grad()
+                loss = ddp_net(inputs).norm()
+                loss.backward()
+
+                # The gradients should be the same
+                for param in net.parameters():
+                    self.assertTrue(
+                        param in grads_dict,
+                        msg=f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}",
+                    )
+                    self.assertEqual(
+                        grads_dict[param],
+                        param.grad,
+                        msg=f"The grads for param {param} are different under local "
+                        f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}",
+                    )
+        dist.destroy_process_group()
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_comparison(self):
+        self._run_test_ddp_comparision()
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_comparison_uneven_inputs(self):
+        # test with simulating uneven inputs in DDP
+        self._run_test_ddp_comparision(simulate_uneven_inputs=True)
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_dist_autograd_sparse_grads(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        model = nn.EmbeddingBag(10, 3, sparse=True)
+        ddp_model = DistributedDataParallel(model)
+
+        # Different inputs for each
+        input = torch.LongTensor(10).random_(0, 10)
+        offsets = torch.LongTensor([0, 4])
+
+        # Run local.
+        loss = ddp_model(input, offsets).sum()
+        loss.backward()
+
+        with dist_autograd.context() as context_id:
+            loss = ddp_model(input, offsets).sum()
+            dist_autograd.backward(context_id, [loss])
+            grads_dict = dist_autograd.get_gradients(context_id)
+            self.assertEqual(1, len(grads_dict))
+            self.assertEqual(model.weight.grad, grads_dict[model.weight])
+
+    @requires_gloo()
+    @dist_init
+    def test_ddp_dist_autograd_local_vs_remote(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        # Use two different remote device input string, w/ and w/o the default
+        # device string "cpu", respectively.
+        for remote_device in ["worker0/cpu", "worker0"]:
+            remote_layer1 = RemoteModule(
+                remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False)
+            )
+            layer1 = nn.Linear(10, 5, False)
+            # Start with the same parameters for remote and local
+            layer1.weight = remote_layer1.module_rref.to_here().weight
+
+            # Run local case.
+            layer2 = nn.Linear(5, 1)
+            inputs = torch.rand((10, 10))
+            ddp_model = DistributedDataParallel(layer2)
+            loss = ddp_model(layer1(inputs)).sum()
+            loss.backward()
+
+            # Run remote case.
+            with dist_autograd.context() as context_id:
+                loss = ddp_model(remote_layer1(inputs)).sum()
+                dist_autograd.backward(context_id, [loss])
+                grads_dict = dist_autograd.get_gradients(context_id)
+                dist.barrier()
+                self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
+                self.assertEqual(
+                    layer1.weight.grad,
+                    rpc.rpc_sync(
+                        "worker0",
+                        CommonDdpComparisonTest.get_remote_grads,
+                        args=(remote_layer1.module_rref, context_id),
+                    ),
+                )
+
+
+class CudaDdpComparisonTest(CommonDdpComparisonTest):
+    @skip_if_lt_x_gpu(NUM_TRAINERS)
+    @requires_nccl()
+    @dist_init
+    @skip_if_rocm_multiprocess
+    def test_ddp_dist_autograd_local_vs_remote_gpu(self):
+        # Each trainer uses a different random seed. Otherwise, they are going
+        # to have exactly the same initial model parameters, input, and
+        # therefore grads. That means the grads will be the same before and
+        # after DDP's all-reduce.
+        torch.manual_seed(self.rank)
+        dist.init_process_group(
+            backend="gloo",
+            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
+            world_size=self.world_size,
+            rank=self.rank,
+        )
+
+        remote_layer1 = RemoteModule(
+            remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False)
+        )
+        layer1 = nn.Linear(10, 7, False)
+        # Start with the same parameters for remote and local
+        layer1.weight = remote_layer1.module_rref.to_here().weight
+
+        layer2 = nn.Linear(7, 5).cuda(self.rank)
+        ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])
+
+        remote_layer3 = RemoteModule(
+            remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False)
+        )
+        layer3 = nn.Linear(5, 3, False)
+        # Start with the same parameters for remote and local
+        layer3.weight = remote_layer3.module_rref.to_here().weight
+
+        layer4 = nn.Linear(3, 1).cuda(self.rank)
+        ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank])
+
+        # Run local case.
+        inputs = torch.rand((10, 10))
+        loss = ddp_layer4(
+            layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(self.rank)
+        ).sum()
+        loss.backward()
+
+        # Run remote case.
+        with dist_autograd.context() as context_id:
+            loss = ddp_layer4(
+                remote_layer3(
+                    ddp_layer2(remote_layer1(inputs).cuda(self.rank)).cpu()
+                ).cuda(self.rank)
+            ).sum()
+            dist_autograd.backward(context_id, [loss])
+            grads_dict = dist_autograd.get_gradients(context_id)
+            dist.barrier()
+            self.assertEqual(
+                layer1.weight.grad,
+                rpc.rpc_sync(
+                    "worker0",
+                    CommonDdpComparisonTest.get_remote_grads,
+                    args=(remote_layer1.module_rref, context_id),
+                ),
+            )
+            self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
+            self.assertEqual(
+                layer3.weight.grad,
+                rpc.rpc_sync(
+                    "worker0",
+                    CommonDdpComparisonTest.get_remote_grads,
+                    args=(remote_layer3.module_rref, context_id),
+                ),
+            )
+            self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight])
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b761a37d58cc4053a3a35de8a654d5c6098276
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_test.py
@@ -0,0 +1,10391 @@
+# mypy: allow-untyped-defs
+
+import copy
+import itertools
+import json
+import math
+import operator
+import os
+import random
+import re
+import sys
+import tempfile
+import time
+import unittest
+from collections import defaultdict, namedtuple, OrderedDict
+from contextlib import contextmanager, nullcontext
+from dataclasses import dataclass
+from datetime import timedelta
+from functools import reduce
+from typing import Any, Callable, NamedTuple, Union
+
+import numpy as np
+
+import torch
+import torch.cuda
+import torch.distributed as dist
+import torch.distributed.algorithms.model_averaging.averagers as averagers
+import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
+import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
+import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
+import torch.nn as nn
+import torch.nn.functional as F
+from torch._utils_internal import (
+    TEST_MASTER_ADDR as MASTER_ADDR,
+    TEST_MASTER_PORT as MASTER_PORT,
+)
+from torch.autograd import DeviceType
+from torch.cuda.amp import autocast, GradScaler
+from torch.distributed.algorithms.ddp_comm_hooks import (
+    default_hooks as default,
+    post_localSGD_hook as post_localSGD,
+    powerSGD_hook as powerSGD,
+    quantization as quantization_hooks,
+)
+from torch.distributed.distributed_c10d import (
+    _get_default_group,
+    _get_pg_config,
+    get_world_size,
+)
+from torch.distributed.optim import _apply_optimizer_in_backward
+from torch.distributed.utils import (
+    _sync_module_states,
+    _verify_param_shape_across_processes,
+)
+from torch.nn.parallel import DistributedDataParallel
+from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision
+from torch.profiler import ExecutionTraceObserver, ProfilerActivity
+from torch.testing._internal.common_distributed import (
+    captured_output,
+    cleanup_temp_dir,
+    DistTestCases,
+    init_multigpu_helper,
+    initialize_temp_directories,
+    MultiProcessTestCase,
+    nccl_skip_if_lt_x_gpu,
+    require_n_gpus_for_nccl_backend,
+    requires_nccl_version,
+    simple_sparse_reduce_tests,
+    skip_if_lt_x_gpu,
+    skip_if_no_gpu,
+    skip_if_odd_worldsize,
+    skip_if_rocm_multiprocess,
+    skip_if_small_worldsize,
+    TEST_SKIPS,
+    verify_ddp_error_logged,
+    with_dist_debug_levels,
+    with_nccl_blocking_wait,
+)
+from torch.testing._internal.common_utils import (
+    FILE_SCHEMA,
+    instantiate_parametrized_tests,
+    IS_FBCODE,
+    IS_MACOS,
+    IS_SANDCASTLE,
+    IS_WINDOWS,
+    skip_but_pass_in_sandcastle,
+    skip_but_pass_in_sandcastle_if,
+    skipIfRocm,
+)
+from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils.data.distributed import DistributedSampler
+
+
+try:
+    import torchvision
+
+    HAS_TORCHVISION = True
+except Exception:  # Covering both ImportError and RuntimeError
+    HAS_TORCHVISION = False
+
+if sys.platform == "win32":
+    import msvcrt
+else:
+    import fcntl
+
+
+class NetWithBuffers(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 1, bias=False)
+        self.register_buffer("buffer", torch.randn(1, 2))
+
+    def forward(self, x):
+        self.buffer.add_(1)
+        return self.b(self.a(x))
+
+
+class Foo:
+    def __init__(self, x):
+        # Can be tensor or int
+        self.x = x
+
+    def __eq__(self, other):
+        def eq(value, other):
+            if isinstance(value, torch.Tensor):
+                return torch.equal(value, other)
+            return value == other
+
+        for attr, value in self.__dict__.items():
+            other_value = other.__dict__[attr]
+            if not eq(value, other_value):
+                return False
+        return True
+
+
+f = Foo(10)
+f.bar = 1
+
+foo_cpu_tensor = Foo(torch.randn(3, 3))
+
+
+COLLECTIVES_OBJECT_TEST_LIST = [
+    {"key1": 3, "key2": 4, "key3": {"nested": True}},
+    f,
+    foo_cpu_tensor,
+    "foo",
+    [1, 2, True, "string", [4, 5, "nested"]],
+]
+
+# Allowlist of distributed backends where profiling collectives is supported.
+PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.NCCL,
+    dist.Backend.GLOO,
+    dist.Backend.MPI,
+    dist.Backend.UCC,
+]
+
+# Allowlist of distributed backends where profiling is supported with use_cuda=True
+CUDA_PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.GLOO,
+    dist.Backend.MPI,
+    dist.Backend.NCCL,
+    dist.Backend.UCC,
+]
+
+# Allowlist of distributed backends where profiling is supported for p2p ops
+SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [
+    dist.Backend.MPI,
+    dist.Backend.GLOO,
+    dist.Backend.NCCL,
+    dist.Backend.UCC,
+]
+
+# Dummy NamedTuple data structures to test DDP support for NamedTuple types.
+EXPECTED_FIELDS = ("a", "b")
+TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS)
+
+
+class TestNamedTupleInput_1(NamedTuple):
+    a: torch.tensor
+    b: torch.tensor
+
+
+skipIfNoTorchVision = skip_but_pass_in_sandcastle_if(
+    not HAS_TORCHVISION, "no torchvision"
+)
+
+BACKEND = os.environ["BACKEND"]
+INIT_METHOD = os.getenv("INIT_METHOD", "env://")
+
+DEFAULT_TIMEOUT = 300
+CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}
+
+
+def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
+    event_list = (
+        profiler.events()
+        if isinstance(profiler, torch.profiler.profile)
+        else profiler.function_events
+    )
+    return [
+        event
+        for event in event_list
+        if (
+            (event.name.endswith(event_name) or event.name.startswith(event_name))
+            and (not dedup_gpu_user_annotation or event.device_type != DeviceType.CUDA)
+        )
+    ]
+
+
+def get_profiler_nccl_meta(prof):
+    """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+    We will need to test metadata obtained from profiler here"""
+    tf = tempfile.NamedTemporaryFile(mode="w+t", suffix=".json", delete=False)
+    tf.close()
+    trace_file = tf.name
+
+    prof.export_chrome_trace(trace_file)
+    with open(trace_file) as f:
+        events = json.load(f)["traceEvents"]
+    print(f"Trace saved to {trace_file}")
+
+    # Comment to debug
+    os.remove(trace_file)
+
+    return [e for e in events if e.get("name") == "record_param_comms"]
+
+
+# Base error message substring on unfinished reductions.
+ddp_prev_reduction_unfinished_str = (
+    "Expected to have finished reduction in the prior iteration"
+)
+# Error message substring when find_unused_parameters=True has not been passed
+ddp_recommend_find_unused_params_str = (
+    "passing the keyword argument `find_unused_parameters=True`"
+)
+# Error message substring when find_unused_parameters=True is enabled
+ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled"
+# Error message substring for possibility of not all model outputs being used
+# in loss computation
+ddp_outputs_not_used_in_loss_str = (
+    "`forward` function outputs participate in calculating loss"
+)
+# Error message substring suggesting to use TORCH_DISTRIBUTED_DEBUG
+ddp_suggest_debug_mode_str = (
+    "set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL"
+)
+
+
+class DDPUnevenTestInput(NamedTuple):
+    name: str
+    model: nn.Module
+    inp: Union[torch.tensor, tuple]
+    sync_interval: int
+    throw_on_early_termination: bool = False
+    hook: Callable = None
+    state: Any = None
+
+
+class _FC2(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = nn.Linear(10, 50, bias=True)
+        self.fc.bias.requires_grad = False
+
+    def forward(self, x):
+        x = self.fc(x)
+        return x
+
+
+class Net(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = nn.Linear(2, 10, bias=False)
+        self.fc2 = _FC2()
+        self.fc3 = nn.Linear(50, 4, bias=False)
+        self.relu = nn.ReLU()
+        self.no_grad_param = nn.Parameter(
+            torch.tensor([2, 2]).long(), requires_grad=False
+        )
+
+    def forward(self, x):
+        x = self.relu(self.fc1(x))
+        x = self.relu(self.fc2(x))
+        x = self.fc3(x)
+        return F.softmax(x, dim=1)
+
+
+class LargeNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc1 = nn.Linear(1000, 2000, bias=False)
+        self.fc2 = nn.Linear(2000, 500, bias=False)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.fc2(x)
+        return x
+
+
+class Task(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.p = nn.Parameter(torch.ones(2, 2))
+
+    def forward(self, x):
+        return self.p + x
+
+
+class BatchNormNet(nn.Module):
+    def __init__(self, affine=True):
+        super().__init__()
+        self.fc1 = nn.Linear(2, 40, bias=False)
+        self.bn = nn.BatchNorm1d(4, affine=affine)
+        self.fc2 = nn.Linear(40, 4, bias=False)
+
+    def forward(self, x):
+        x = torch.reshape(self.fc1(x), (-1, 4, 10))
+        x = self.bn(x)
+        x = torch.reshape(x, (-1, 40))
+        x = self.fc2(x)
+        return F.softmax(x, dim=1)
+
+
+class UnusedParamTwoLinLayerNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 10, bias=False)
+        self.c = nn.Linear(5, 5, bias=False)
+
+    def forward(self, x):
+        a = self.a(x)
+        b = self.b(x)
+        return (a, b)
+
+
+class DictOutputModule(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.module = UnusedParamTwoLinLayerNet()
+
+    def forward(self, x):
+        predictions = self.module(x)
+        loss = (predictions[0] + predictions[1]).sum()
+        return {
+            "predictions": predictions,
+            "loss": loss,
+        }
+
+
+class TwoLinLayerNet(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.a = nn.Linear(10, 10, bias=False)
+        self.b = nn.Linear(10, 1, bias=False)
+
+    def forward(self, x):
+        a = self.a(x)
+        b = self.b(x)
+        return (a, b)
+
+
+class EmbeddingNetDifferentParams(nn.Module):
+    """
+    A module containing an embedding with different dimension or different # of
+    parameters depending on the rank.
+    """
+
+    def __init__(self, rank, diff_num_params=False):
+        super().__init__()
+        embedding_dim = 500 if diff_num_params or rank == 0 else 50
+        self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=embedding_dim)
+        self.lin = nn.Linear(embedding_dim, 1)
+        if diff_num_params:
+            self.lin2 = nn.Linear(1, 1, bias=False)
+
+    def forward(self, x):
+        x = self.embedding(x)
+        return self.lin(x)
+
+
+class ControlFlowToyModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.lin1 = nn.Linear(10, 10, bias=False)
+        self.lin2 = nn.Linear(10, 10, bias=False)
+
+    def forward(self, x):
+        # Second layer is used dependent on input x.
+        use_second_layer = torch.equal(x, torch.ones(20, 10, device=x.device))
+        if use_second_layer:
+            return self.lin2(F.relu(self.lin1(x)))
+        else:
+            return F.relu(self.lin1(x))
+
+
+DDP_NET = Net()
+BN_NET = BatchNormNet()
+BN_NET_NO_AFFINE = BatchNormNet(affine=False)
+ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
+
+
+def get_timeout(test_id):
+    test_name = test_id.split(".")[-1]
+    if test_name in CUSTOMIZED_TIMEOUT:
+        return CUSTOMIZED_TIMEOUT[test_name]
+    else:
+        return DEFAULT_TIMEOUT
+
+
+default_pg_timeout = 60
+
+CUSTOM_PG_TIMEOUT = {
+    # This test runs slowly and needs additional time to complete, otherwise can
+    # be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING
+    "test_ddp_uneven_inputs": 300,
+    # This test has a short timeout since it tests being taken down by
+    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
+    "test_ddp_model_diff_across_ranks": 5,
+    # This test has a short timeout since it tests being taken down by
+    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
+    "test_ddp_has_finalized": 5,
+}
+
+
+def require_backend_is_available(backends):
+    def check(backend):
+        if backend == dist.Backend.GLOO:
+            return dist.is_gloo_available()
+        if backend == dist.Backend.NCCL:
+            return dist.is_nccl_available()
+        if backend == dist.Backend.MPI:
+            return dist.is_mpi_available()
+        if backend == dist.Backend.UCC:
+            return dist.is_ucc_available()
+        if backend in DistTestCases.backend_feature["plugin"]:
+            return True
+        return False
+
+    if BACKEND not in backends:
+        return skip_but_pass_in_sandcastle(
+            f"Test requires backend {BACKEND} to be one of {backends}"
+        )
+
+    if not check(dist.Backend(BACKEND)):
+        return skip_but_pass_in_sandcastle(
+            f"Test requires backend {BACKEND} to be available"
+        )
+    return lambda func: func
+
+
+def require_world_size(world_size):
+    if int(os.environ["WORLD_SIZE"]) < world_size:
+        return skip_but_pass_in_sandcastle(
+            f"Test requires world size of {world_size:d}"
+        )
+    return lambda func: func
+
+
+@contextmanager
+def _lock():
+    TEMP_DIR = os.environ["TEMP_DIR"]
+    lockfile = os.path.join(TEMP_DIR, "lockfile")
+    with open(lockfile, "w") as lf:
+        try:
+            if sys.platform == "win32":
+                msvcrt.locking(lf.fileno(), msvcrt.LK_RLCK, 1)
+                yield
+            else:
+                fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
+                yield
+        finally:
+            if sys.platform == "win32":
+                msvcrt.locking(lf.fileno(), msvcrt.LK_UNLCK, 1)
+            else:
+                fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
+            lf.close()
+
+
+@contextmanager
+def _rank_temp_file():
+    if dist.get_rank() == 0:
+        fd, name = tempfile.mkstemp()
+        os.close(fd)
+    else:
+        name = None
+    object_list = [name]
+    dist.broadcast_object_list(object_list)
+    name = object_list[0]
+    try:
+        yield name
+    finally:
+        if dist.get_rank() == 0:
+            os.remove(name)
+
+
+def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
+    if value is None:
+        value = size
+    if device_id is None:
+        return torch.empty(size, size, size, dtype=dtype).fill_(value)
+    else:
+        return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id)
+
+
+def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float):
+    if value is None:
+        value = dim
+    return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value)
+
+
+def _create_autograd_profiler():
+    return torch.autograd.profiler.profile(record_shapes=True)
+
+
+def _create_torch_profiler():
+    return torch.profiler.profile(
+        activities=[
+            torch.profiler.ProfilerActivity.CPU,
+        ],
+        record_shapes=True,
+    )
+
+
+class Barrier:
+    barrier_id = 0
+
+    @classmethod
+    def init(cls):
+        cls.barrier_id = 0
+        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
+        for f_name in os.listdir(barrier_dir):
+            os.unlink(os.path.join(barrier_dir, f_name))
+
+    @classmethod
+    def sync(cls, wait_for=None, timeout=10):
+        if wait_for is None:
+            wait_for = dist.get_world_size()
+        cls.barrier_id += 1
+        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
+        pid = str(os.getpid())
+        barrier_file = os.path.join(barrier_dir, pid)
+        with _lock():
+            with open(barrier_file, "w") as f:
+                f.write(str(cls.barrier_id))
+
+        start_time = time.time()
+        while True:
+            arrived = 0
+            with _lock():
+                for f_name in os.listdir(barrier_dir):
+                    with open(os.path.join(barrier_dir, f_name)) as f:
+                        data = f.read()
+                        if int(data) >= cls.barrier_id:
+                            arrived += 1
+            if arrived == wait_for:
+                break
+
+            if time.time() - start_time > timeout:
+                raise RuntimeError("barrier timeout")
+            time.sleep(0.1)
+
+
+class TestDistBackend(MultiProcessTestCase):
+    @classmethod
+    def setUpClass(cls):
+        os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
+        # Not setting MASTER_PORT and get a random free port
+        super().setUpClass()
+
+    def setUp(self):
+        super().setUp()
+        # initialize temp directories
+        initialize_temp_directories()
+        # initialize Barrier
+        Barrier.init()
+        # Skip return code checking for following tests as they are expected to
+        # crash a process due to TORCH_NCCL_ASYNC_ERROR_HANDLING.
+        self.skip_return_code_checks = [self.test_ddp_has_finalized.__wrapped__]
+
+    def tearDown(self):
+        cleanup_temp_dir()
+        super().tearDown()
+
+    @property
+    def init_method(self):
+        return f"{FILE_SCHEMA}{self.file_name}"
+
+    @property
+    def destroy_pg_upon_exit(self) -> bool:
+        # Overriding base test class: do not auto destroy PG upon exit.
+        return False
+
+    @classmethod
+    def _run(cls, rank, test_name, file_name, pipe, **kwargs):
+        if BACKEND == "nccl" and not torch.cuda.is_available():
+            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
+        self = cls(test_name)
+        self.rank = rank
+        self.file_name = file_name
+
+        if torch.cuda.is_available() and torch.cuda.device_count() < int(
+            self.world_size
+        ):
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+        try:
+            pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
+            timeout = timedelta(seconds=pg_timeout_seconds)
+            dist.init_process_group(
+                init_method=self.init_method,
+                backend=BACKEND,
+                world_size=int(self.world_size),
+                rank=self.rank,
+                timeout=timeout,
+            )
+        except RuntimeError as e:
+            if "recompile" in e.args[0]:
+                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
+
+            raise
+
+        # Execute barrier prior to running test to ensure that every process
+        # has finished initialization and that the following test
+        # immediately exiting due to a skip doesn't cause flakiness.
+        self._barrier()
+
+        self.run_test(test_name, pipe)
+        self._barrier()
+        dist.destroy_process_group()
+        sys.exit(0)
+
+    # Needed since MultiProcessTestCase assumes a world_size of 4, but we
+    # run these tests under other various world_sizes.
+    @property
+    def world_size(self):
+        return os.environ["WORLD_SIZE"]
+
+
+class DistributedTest:
+    class _DistTestBase:
+        def _barrier(self, *args, **kwargs):
+            Barrier.sync(*args, **kwargs)
+
+        def _init_group_test(self, **kwargs):
+            group = [1, 2]
+            group_id = dist.new_group(group, **kwargs)
+            rank = dist.get_rank()
+            if rank not in group:
+                return ([], None, rank)
+
+            return (group, group_id, rank)
+
+        def _init_full_group_test(self, **kwargs):
+            group = list(range(0, dist.get_world_size()))
+            group_id = dist.new_group(**kwargs)
+            rank = dist.get_rank()
+            return (group, group_id, rank)
+
+        def _init_global_test(self):
+            group = list(range(0, dist.get_world_size()))
+            group_id = dist.group.WORLD
+            rank = dist.get_rank()
+            return (group, group_id, rank)
+
+        def _verify_buffers_equal(self, m1, m2):
+            # verify buffers across models
+            m1_buf_dict = dict(m1.module.named_buffers())
+            for name, buf in m2.module.named_buffers():
+                self.assertEqual(buf, m1_buf_dict[name])
+
+            # Verify buffers across ranks.
+            m1_buffers = list(m1.buffers())
+            m2_buffers = list(m2.buffers())
+            for buf1, buf2 in zip(m1_buffers, m2_buffers):
+                gathered_bufs = [
+                    torch.empty_like(buf1) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(gathered_bufs, buf1)
+                gathered_bufs_m2 = [
+                    torch.empty_like(buf2) for _ in range(dist.get_world_size())
+                ]
+                for b in gathered_bufs:
+                    self.assertEqual(b, buf1)
+                dist.all_gather(gathered_bufs_m2, buf2)
+                for b in gathered_bufs_m2:
+                    self.assertEqual(b, buf2)
+
+        def _sanity_check_profiler_nccl_meta(self, nccl_meta_events):
+            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+            We test for basic fields in this profiler event that correspond to the nccl communication
+            collectives"""
+            per_coll_meta = defaultdict(list)
+            for e in nccl_meta_events:
+                args = e.get("args", {})
+                collname = args.get("Collective name", "")
+                self.assertNotEqual(collname, "")
+                self.assertNotEqual(args.get("dtype", ""), "")
+
+                per_coll_meta[collname].append(args)
+                if collname in {"wait"}:
+                    continue
+
+                self.assertEqual(args["Process Group Description"], "default_pg")
+                self.assertNotEqual(args["Process Group Ranks"], "")
+
+                self.assertGreaterEqual(args.get("In msg nelems", -1), 0)
+                self.assertGreaterEqual(args.get("Out msg nelems", -1), 0)
+                self.assertGreaterEqual(args.get("Group size", -1), 0)
+                self.assertGreaterEqual(args.get("Global rank start", -1), 0)
+                self.assertGreaterEqual(args.get("Global rank stride", -1), 0)
+
+            # print(per_coll_meta)
+            return per_coll_meta
+
+        def test_dump_DDP_relevant_env_vars(self):
+            with captured_output() as (out, _):
+                _dump_DDP_relevant_env_vars()
+                lines = out.getvalue().splitlines()
+
+            def format_line(var):
+                return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"
+
+            # Check relevant env vars
+            vars = [
+                "MASTER_ADDR",
+                "MASTER_PORT",
+                "WORLD_SIZE",
+                "NCCL_TOPO_DUMP_FILE",  # N/A
+                "TORCH_NCCL_ASYNC_ERROR_HANDLING",
+            ]
+            for var in vars:
+                line = format_line(var)
+                self.assertIn(line, lines)
+            # Check irrelevant env vars
+            vars = [
+                "xxx",
+                "yyy",
+                "zzz",
+            ]
+            for var in vars:
+                line = format_line(var)
+                self.assertNotIn(line, lines)
+
+        # GET RANK
+        def test_get_rank(self):
+            test_dir = os.path.join(os.environ["TEMP_DIR"], "test_dir")
+            pid = str(os.getpid())
+            num_processes = dist.get_world_size()
+            with open(os.path.join(test_dir, pid), "w") as f:
+                f.write(str(dist.get_rank()))
+
+            self._barrier()
+
+            all_ranks = set()
+            for f_name in os.listdir(test_dir):
+                with open(os.path.join(test_dir, f_name)) as f:
+                    all_ranks.add(int(f.read()))
+            self.assertEqual(len(all_ranks), num_processes)
+
+            self._barrier()
+
+            if dist.get_rank() == 0:
+                for f_name in os.listdir(test_dir):
+                    os.unlink(os.path.join(test_dir, f_name))
+
+            self._barrier()
+
+        def test_get_backend(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            backend_str = BACKEND.lower()
+            self.assertEqual(dist.get_backend(), backend_str)
+            if dist.get_rank() in group:
+                self.assertEqual(dist.get_backend(group_id), backend_str)
+            else:
+                with self.assertRaisesRegex(
+                    ValueError, "Invalid process group specified"
+                ):
+                    dist.get_backend(group_id)
+
+        def test_Backend_enum_class(self):
+            # test parsing
+            backend = BACKEND.lower()
+            self.assertEqual(dist.Backend(BACKEND.upper()), backend)
+            self.assertEqual(dist.Backend(BACKEND), backend)
+            with self.assertRaises(ValueError):
+                dist.Backend(None)
+            with self.assertRaises(ValueError):
+                dist.Backend(3)
+            with self.assertRaises(ValueError):
+                dist.Backend(["gloo"])
+
+        # Test destroy
+        def test_destroy_group(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            self._barrier()
+            dist.destroy_process_group(group_id)
+
+        # Test get rank and size of group
+        def test_get_rank_size_group(self):
+            if dist.get_world_size() > 2:
+                group = [1, 2]
+            else:
+                group = [0, 1]
+            group_id = dist.new_group(group)
+            if dist.get_rank() in group:
+                self.assertEqual(dist.get_world_size(group_id), 2)
+                self.assertTrue(dist.get_rank(group_id) in list(range(2)))
+            else:
+                self.assertEqual(dist.get_world_size(group_id), -1)
+                self.assertEqual(dist.get_rank(group_id), -1)
+
+        # Test destroy full groups
+        def test_destroy_full_group(self):
+            _, group_id, _ = self._init_full_group_test()
+            self._barrier()
+            dist.destroy_process_group(group_id)
+
+        # Test get rank and size of full group
+        def test_get_rank_size_full_group(self):
+            _, group_id, _ = self._init_full_group_test()
+            self.assertEqual(dist.get_world_size(group_id), dist.get_world_size())
+            self.assertEqual(dist.get_rank(group_id), dist.get_rank())
+
+        def _test_barrier_timeout(self, group_id, timeout):
+            local_rank = dist.get_rank(group_id)
+
+            # Only execute barrier on rank == 0, causing it to timeout
+            if local_rank == 0:
+                expected_time = time.time() + timeout.total_seconds()
+                # In debug mode, we execute a monitored_barrier before the
+                # collective, so assert on that.
+                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
+                    exception_ctx = self.assertRaisesRegex(
+                        Exception, "failed to pass monitoredBarrier"
+                    )
+                else:
+                    exception_ctx = self.assertRaisesRegex(
+                        Exception, " (Timed out|closed|timeout) "
+                    )
+                with exception_ctx:
+                    dist.barrier(group_id)
+                self.assertGreaterAlmostEqual(time.time(), expected_time, delta=0.1)
+            else:
+                pass
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            not INIT_METHOD.startswith("file://"),
+            "Requires file:// initialization method. "
+            + "Both tcp:// and env:// rely on the TCP store for which "
+            "reinitialization has proven racy.",
+        )
+        def test_barrier_timeout_global(self):
+            dist.destroy_process_group()
+
+            # Explicitly pass world size to the barrier because we've
+            # just destroyed any state in torch.distributed.
+            self._barrier(wait_for=int(os.environ["WORLD_SIZE"]))
+
+            # Reinitialize global process group
+            timeout = timedelta(seconds=1)
+            dist.init_process_group(
+                init_method=INIT_METHOD,
+                backend=BACKEND,
+                world_size=int(os.environ["WORLD_SIZE"]),
+                rank=self.rank,
+                timeout=timeout,
+            )
+            self._test_barrier_timeout(dist.group.WORLD, timeout)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        def test_barrier_timeout_group(self):
+            timeout = timedelta(seconds=5)
+            _, group_id, _ = self._init_group_test(timeout=timeout)
+            if group_id is not None:
+                self._test_barrier_timeout(group_id, timeout)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only gloo backend supports timeouts"
+        )
+        def test_barrier_timeout_full_group(self):
+            timeout = timedelta(seconds=1)
+            _, group_id, _ = self._init_full_group_test(timeout=timeout)
+            if group_id is not None:
+                self._test_barrier_timeout(group_id, timeout)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(2)
+        def test_new_subgroups(self):
+            subgroup_size = 2
+            cur_subgroup, subgroups = dist.new_subgroups(subgroup_size)
+
+            world_size = dist.get_world_size()
+            self.assertEqual(cur_subgroup.size(), subgroup_size)
+            self.assertEqual(len(subgroups), world_size / subgroup_size)
+            self.assertFalse(dist._rank_not_in_group(cur_subgroup))
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_with_group_param(self):
+            # Initialize global test environment
+            self._init_global_test()
+            # Set up GPU devices for each rank
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            # Create two subgroups: one with ranks [0,2] and another with ranks [1,3]
+            cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
+                ranks_per_subgroup_list=[[0, 2], [1, 3]]
+            )
+
+            # Further divide the current subgroup into sub-subgroups of size 1
+            cur_sub_subgroup, sub_subgroups = dist.new_subgroups(
+                group_size=1, group=cur_subgroup
+            )
+            # Verify we have 2 sub-subgroups (one for each rank in the original subgroup)
+            self.assertEqual(len(sub_subgroups), 2)
+            # Verify the current process's sub-subgroup has size 1
+            self.assertEqual(cur_sub_subgroup.size(), 1)
+            # Verify the current process is in its assigned sub-subgroup
+            self.assertFalse(dist._rank_not_in_group(group=cur_sub_subgroup))
+
+            # Clean up by destroying all created process groups
+            for sub_subgroup in sub_subgroups:
+                dist.destroy_process_group(sub_subgroup)
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_no_gpu
+        def test_new_subgroups_group_size_exceeds_world_size(self):
+            with self.assertRaisesRegex(ValueError, "must not exceed"):
+                dist.new_subgroups(100)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_world_size_not_divisible_by_group_size(self):
+            with self.assertRaisesRegex(
+                ValueError,
+                re.escape("The world size (4) must be divisible by 'group_size=3'"),
+            ):
+                dist.new_subgroups(3)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_by_enumeration(self):
+            _group, _group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
+                ranks_per_subgroup_list=[[0, 2], [1, 3]]
+            )
+            if device_id >= 4:
+                self.assertIsNone(cur_subgroup)
+            else:
+                self.assertEqual(cur_subgroup.size(), 2)
+                self.assertEqual(len(subgroups), 2)
+                if device_id == 0 or device_id == 2:
+                    self.assertEqual(cur_subgroup, subgroups[0])
+                else:
+                    self.assertEqual(cur_subgroup, subgroups[1])
+
+            for subgroup in subgroups:
+                dist.destroy_process_group(subgroup)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self):
+            _group, group_id, _rank = self._init_global_test()
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            world_size = get_world_size(group_id)
+
+            with self.assertRaisesRegex(
+                ValueError,
+                "The new group's rank should be within the world_size set by init_process_group",
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[0, 1], [world_size, 2]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_no_gpu
+        def test_new_subgroups_by_enumeration_negative_input_rank(self):
+            self._init_global_test()
+
+            with self.assertRaisesRegex(
+                ValueError,
+                "The new group's rank should be within the world_size set by init_process_group",
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[-1, -2], [-3, -4]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_new_subgroups_overlap_not_allowed(self):
+            with self.assertRaisesRegex(
+                ValueError, "Rank 1 has appeared in both subgroup"
+            ):
+                dist.new_subgroups_by_enumeration(
+                    ranks_per_subgroup_list=[[0], [1, 2], [1, 3]]
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_average_parameters(self):
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Sequential(
+                nn.Conv2d(3, 3, kernel_size=3, padding=1),
+                nn.ReLU(),
+                nn.Linear(1, 5, bias=False),
+            ).cuda(device_id)
+            # Test global model averaging
+            for p in model.parameters():
+                p.data = torch.ones_like(p.data)
+            model_averaging_utils.average_parameters(
+                params=model.parameters(), process_group=None
+            )
+            # Every element will be the same as the input.
+            for p in model.parameters():
+                self.assertEqual(p.data, torch.ones_like(p.data))
+
+            # Test partial model averaging
+            for p in model.parameters():
+                p.data = torch.ones_like(p.data) * rank
+            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
+            model_averaging_utils.average_parameters(
+                params=model.parameters(), process_group=group_nccl
+            )
+            if not dist._rank_not_in_group(group_nccl):
+                # Every element on device 0 or 1 should be the average of 0 and 1, i.e., 0.5.
+                for p in model.parameters():
+                    self.assertEqual(p.data, torch.ones_like(p.data) * 0.5)
+            else:
+                # Every element on device not in the subgroup should remain the same.
+                for p in model.parameters():
+                    self.assertEqual(p.data, torch.ones_like(p.data) * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_periodic_model_averager(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            expected_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(0, 20):
+                    # Reset the parameters at every step.
+                    param.data = copy.deepcopy(tensor)
+                    for params in model.parameters():
+                        # mock grad
+                        params.grad = torch.ones_like(param.data)
+                    averager.average_parameters(model.parameters())
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        self.assertEqual(param.data, expected_avg_tensor)
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        self.assertEqual(param.data, tensor)
+
+        @skip_if_lt_x_gpu(2)
+        def test_periodic_model_averager_param_group(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            opt = torch.optim.SGD(model.parameters(), lr=0.1)
+
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(0, 20):
+                    # Reset the parameters at every step.
+                    for param_group in opt.param_groups:
+                        for params in param_group["params"]:
+                            # mock grad
+                            params.grad = torch.ones_like(param.data) * rank
+                            params.data = torch.ones_like(param.data) * rank
+                    averager.average_parameters(opt.param_groups)
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        for param_group in opt.param_groups:
+                            for params in param_group["params"]:
+                                if params.grad is None:
+                                    continue
+                                self.assertEqual(
+                                    param.data,
+                                    torch.ones_like(param.data)
+                                    * sum(range(world_size))
+                                    / world_size,
+                                )
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        for param_group in opt.param_groups:
+                            for params in param_group["params"]:
+                                if params.grad is None:
+                                    continue
+                                self.assertEqual(
+                                    param.data, torch.ones_like(param.data) * rank
+                                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(
+            self,
+        ):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            expected_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            period = 4
+            for warmup_steps in [12, 13, 14, 15]:
+                averager = hierarchicalSGD.HierarchicalModelAverager(
+                    # Run the global averaging at a period of 4,
+                    # which is equivalent to the above periodic model averaging test case.
+                    period_group_size_dict=OrderedDict([(period, world_size)]),
+                    warmup_steps=warmup_steps,
+                )
+
+                averager = averagers.PeriodicModelAverager(
+                    period=period, warmup_steps=warmup_steps
+                )
+                for step in range(0, 20):
+                    # Reset the parameters at every step.
+                    param.data = copy.deepcopy(tensor)
+                    for params in model.parameters():
+                        # mock grad
+                        params.grad = torch.ones_like(param.data)
+                    averager.average_parameters(model.parameters())
+                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
+                        self.assertEqual(param.data, expected_avg_tensor)
+                    else:
+                        # No model averaging, so the parameters are not updated.
+                        self.assertEqual(param.data, tensor)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["subgroup"],
+            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
+        )
+        @require_world_size(4)
+        @skip_if_lt_x_gpu(4)
+        def test_3_level_hierarchical_model_averager(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            model = nn.Linear(1, 5, bias=False).cuda(device_id)
+            param = next(model.parameters())
+            tensor = torch.ones_like(param.data) * rank
+            # Set up such a hierarchical model averaging as follows:
+            # after the first 10 warmup steps,
+            # run model averaging every 2 steps within each subgroup of size 2,
+            # run model averaging every 4 steps within each subgroup of size 3,
+            # and run the global model averaging every 8 steps.
+            # If there is a conflict in model averaging at a step, only run the highest-level model averaging.
+            warmup_steps = 10
+            subgroup_size1 = 2
+            subgroup_avg_period1 = 2
+            subgroup_size2 = 4
+            subgroup_avg_period2 = 4
+            global_avg_period = 8
+            period_group_size_dict = OrderedDict(
+                [
+                    (subgroup_avg_period1, subgroup_size1),
+                    (subgroup_avg_period2, subgroup_size2),
+                    (global_avg_period, world_size),
+                ]
+            )
+            averager = hierarchicalSGD.HierarchicalModelAverager(
+                period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
+            )
+            self.assertEqual(dist.get_pg_count(), len(period_group_size_dict))
+
+            subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
+            subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
+            real_group_ranks_res1 = _get_pg_config(subgroup1)["ranks"]
+            real_group_ranks_res2 = _get_pg_config(subgroup2)["ranks"]
+
+            expect_group_ranks_res1 = (
+                rank // subgroup_size1 * subgroup_size1
+                + np.array(list(range(subgroup_size1)))
+            ).tolist()
+            expect_group_ranks_res2 = (
+                rank // subgroup_size2 * subgroup_size2
+                + np.array(list(range(subgroup_size2)))
+            ).tolist()
+            self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1)
+            self.assertEqual(real_group_ranks_res2, expect_group_ranks_res2)
+
+            expected_avg_tensor_within_subgroup1 = (
+                torch.ones_like(param.data)
+                * sum(real_group_ranks_res1)
+                / subgroup_size1
+            )
+            expected_avg_tensor_within_subgroup2 = (
+                torch.ones_like(param.data)
+                * sum(real_group_ranks_res2)
+                / subgroup_size2
+            )
+            expected_global_avg_tensor = (
+                torch.ones_like(param.data) * sum(range(world_size)) / world_size
+            )
+            for step in range(0, 25):
+                # Reset the parameters at every step.
+                param.data = copy.deepcopy(tensor)
+                for params in model.parameters():
+                    # mock grad
+                    params.grad = torch.ones_like(param.data)
+                averager.average_parameters(model.parameters())
+                if step == 16 or step == 24:
+                    # Run global model averaging when `step` can be divided by 8.
+                    self.assertEqual(param.data, expected_global_avg_tensor)
+                elif step == 12 or step == 20:
+                    # Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
+                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup2)
+                elif step == 10 or step == 14 or step == 18 or step == 22:
+                    # Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
+                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
+                else:
+                    # No model averaging, so the parameters are not updated.
+                    self.assertEqual(param.data, tensor)
+
+        # Coalescing manager (sync mode)
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
+            "Coalescing manager currently tests with NCCL only; internal test flaky",
+        )
+        def test_coalescing_manager(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            num_colls = 2
+            size_per_coll = 8
+            small_tensors = [
+                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
+            ]
+
+            with dist._coalescing_manager():
+                for i in range(num_colls):
+                    dist.all_reduce(small_tensors[i])
+
+            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
+            dist.all_reduce(big_tensor)
+
+            for i in range(num_colls):
+                self.assertEqual(
+                    small_tensors[i],
+                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
+                )
+
+            self._barrier()
+
+        # Coalescing manager (async mode)
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
+            "Coalescing manager currently tests with NCCL only; internal test flaky",
+        )
+        def test_coalescing_manager_async(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            num_colls = 2
+            size_per_coll = 8
+            small_tensors = [
+                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
+            ]
+
+            with dist._coalescing_manager(async_ops=True) as cm:
+                for i in range(num_colls):
+                    dist.all_reduce(small_tensors[i])
+            cm.wait()
+
+            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
+            dist.all_reduce(big_tensor)
+
+            for i in range(num_colls):
+                self.assertEqual(
+                    small_tensors[i],
+                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
+                )
+
+            self._barrier()
+
+        # NCCL Batch SEND RECV
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_nccl(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            p2p_op_list = []
+            recv_tensors = [None for _ in range(world_size)]
+            expected_tensors = [None for _ in range(world_size)]
+
+            for val in ["1", "0"]:
+                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
+                for src in range(0, world_size):
+                    send_tensor = _build_tensor(rank + 1, device_id=device_id).fill_(
+                        src
+                    )
+                    recv_tensors[src] = _build_tensor(
+                        src + 1, value=-1, device_id=device_id
+                    ).fill_(-1)
+                    expected_tensors[src] = _build_tensor(
+                        src + 1, value=-1, device_id=device_id
+                    ).fill_(rank)
+                    recv_op = dist.P2POp(dist.irecv, recv_tensors[src], src)
+                    p2p_op_list.append(recv_op)
+                    send_op = dist.P2POp(dist.isend, send_tensor, src)
+                    p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+                for src in range(0, world_size):
+                    self.assertEqual(recv_tensors[src], expected_tensors[src])
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_ring_exchange_nccl(self):
+            self._barrier()
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            send_tensor = _build_tensor(world_size, device_id=device_id)
+            recv_tensor = _build_tensor(world_size, value=-1, device_id=device_id)
+            send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
+            recv_op = dist.P2POp(
+                dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
+            )
+            reqs = dist.batch_isend_irecv([send_op, recv_op])
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_self_nccl(self):
+            self._barrier()
+            # Ensure the process group has been fully initialized (needed by
+            # the first sub-group batch_isend_irecv call)
+            dist.barrier()
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            p2p_op_list = []
+
+            if rank == 0:
+                send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, 0)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, 0)
+                p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+            self._barrier()
+
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_no_rank_zero_nccl(self):
+            self._barrier()
+            # Ensure the process group has been fully initialized (needed by
+            # the first sub-group batch_isend_irecv call)
+            dist.barrier()
+            rank = dist.get_rank()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            p2p_op_list = []
+
+            if rank == 1:
+                peer = 2
+            elif rank == 2:
+                peer = 1
+
+            if rank in [1, 2]:
+                send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, peer)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, peer)
+                p2p_op_list.append(send_op)
+
+                reqs = dist.batch_isend_irecv(p2p_op_list)
+                for req in reqs:
+                    req.wait()
+
+            self._barrier()
+
+        # GLOO Batch SEND RECV CPU
+        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
+        def test_batch_isend_irecv_gloo(self):
+            self._barrier()
+            rank = dist.get_rank()
+            p2p_op_list = []
+
+            for src in range(0, dist.get_world_size()):
+                if src == rank:
+                    continue
+                send_tensor = _build_tensor(rank + 1)
+                recv_tensor = _build_tensor(src + 1, value=-1)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, src)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, src)
+                p2p_op_list.append(send_op)
+
+            reqs = dist.batch_isend_irecv(p2p_op_list)
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        # GLOO Batch SEND RECV CPU with provided tags
+        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
+        def test_batch_isend_irecv_gloo_tags(self):
+            self._barrier()
+            rank = dist.get_rank()
+            p2p_op_list = []
+
+            for src in range(0, dist.get_world_size()):
+                if src == rank:
+                    continue
+                send_tensor = _build_tensor(rank + 1)
+                recv_tensor = _build_tensor(src + 1, value=-1)
+                recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src)
+                p2p_op_list.append(recv_op)
+                send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank)
+                p2p_op_list.append(send_op)
+
+            reqs = dist.batch_isend_irecv(p2p_op_list)
+            for req in reqs:
+                req.wait()
+
+            self._barrier()
+
+        # NCCL Batch SEND RECV Op Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_op_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            if rank == 0:
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                device_id = rank_to_GPU[rank][0]
+                with self.assertRaisesRegex(ValueError, "^Invalid ``op``"):
+                    send_tensor = _build_tensor(rank + 1, device_id=device_id)
+                    send_op = dist.P2POp(dist.broadcast, send_tensor, 1)
+                    dist.batch_isend_irecv([send_op])
+
+        # NCCL Batch SEND RECV p2p_op_list Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_op_list_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            if rank == 0:
+                with self.assertRaisesRegex(ValueError, "^Invalid ``p2p_op_list``"):
+                    dist.batch_isend_irecv([1, 2])
+
+        # NCCL Batch SEND RECV Mixed Backend Error
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_batch_isend_irecv_mixed_backend_err(self):
+            self._barrier()
+            rank = dist.get_rank()
+            init_multigpu_helper(dist.get_world_size(), BACKEND)
+            group_gloo = dist.new_group(ranks=[0, 1], backend="gloo")
+            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
+            if rank == 0:
+                with self.assertRaisesRegex(
+                    ValueError, "All ops need to use the same group"
+                ):
+                    send_tensor = _build_tensor(rank + 1)
+                    send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo)
+                    send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl)
+                    dist.batch_isend_irecv([send_op_gloo, send_op_nccl])
+
+        # NCCL SEND RECV
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def _test_send_recv_nccl(self, profiler_ctx=None):
+            # TODO: now that nccl send/recv is supported, there does not seem to
+            # be a need to have nccl send/recv be tested separately.
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            tensor = _build_tensor(rank + 1, device_id=device_id)
+            profiler_cls = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with profiler_cls as prof:
+                for src in range(0, world_size):
+                    if src == rank:
+                        # Send mode
+                        for dst in range(0, world_size):
+                            if dst == rank:
+                                continue
+                            dist.send(tensor, dst)
+                    else:
+                        # Recv mode
+                        expected_tensor = _build_tensor(src + 1)
+                        output_tensor = _build_tensor(
+                            src + 1, value=-1, device_id=device_id
+                        )
+                        dist.recv(output_tensor, src)
+                        self.assertEqual(output_tensor, expected_tensor)
+
+                self._barrier()
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(
+                            event_name, prof, dedup_gpu_user_annotation=True
+                        )
+                        self.assertTrue(events)
+                        # Event order is not deterministic, so simply assert their shape
+                        # is found in the following list.
+                        expected_shapes = [
+                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
+                        ]
+                        for event in events:
+                            self.assertTrue(event.input_shapes in expected_shapes)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_send_recv_nccl(self):
+            self._test_send_recv_nccl()
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        def test_send_recv_nccl_autograd_profiler(self):
+            profiler_ctx = torch.autograd.profiler.profile(record_shapes=True)
+            self._test_send_recv_nccl(profiler_ctx)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
+        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_nccl_torch_profiler(self):
+            profiler_ctx = torch.profiler.profile(
+                activities=[
+                    torch.profiler.ProfilerActivity.CPU,
+                    torch.profiler.ProfilerActivity.CUDA,
+                ],
+                record_shapes=True,
+            )
+            self._test_send_recv_nccl(profiler_ctx)
+
+        # SEND RECV
+        def _test_send_recv(self, profiler_ctx):
+            rank = dist.get_rank()
+            send_size = rank + 1
+            tensor = _build_tensor(send_size)
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for src in range(0, dist.get_world_size()):
+                    if src == rank:
+                        # Send mode
+                        for dst in range(0, dist.get_world_size()):
+                            if dst == rank:
+                                continue
+                            dist.send(tensor, dst)
+                    else:
+                        # Recv mode
+                        recv_size = src + 1
+                        expected_tensor = _build_tensor(recv_size)
+                        output_tensor = _build_tensor(recv_size, value=-1)
+                        dist.recv(output_tensor, src)
+                        self.assertEqual(output_tensor, expected_tensor)
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from all other ranks.
+                        event_count = sum(e.count for e in events)
+                        expected_event_count = dist.get_world_size() - 1
+                        self.assertEqual(event_count, expected_event_count)
+                        # Event order is not deterministic, so simply assert their shape
+                        # is found in the following list.
+                        expected_shapes = [
+                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
+                        ]
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertTrue(event.input_shapes in expected_shapes)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv(self):
+            self._test_send_recv(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_send_recv(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv(profiler_ctx=torch_profiler_ctx)
+
+        # SEND RECV ANY SOURCE
+        def _test_send_recv_any_source(self, profiler_ctx):
+            rank = dist.get_rank()
+            send_recv_size = 10
+            tensor = _build_tensor(send_recv_size, value=rank)
+            recv_ranks = []
+            irecv_ranks = []
+
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for dst in range(0, dist.get_world_size()):
+                    if dst == rank:
+                        # Recv mode
+                        for dst in range(0, dist.get_world_size()):
+                            if dst == rank:
+                                continue
+
+                            for recv in ["recv", "irecv"]:
+                                output_tensor = _build_tensor(send_recv_size, value=-1)
+
+                                if recv == "recv":
+                                    sender = dist.recv(output_tensor)
+                                    recv_ranks.append(sender)
+                                elif recv == "irecv":
+                                    work = dist.irecv(output_tensor)
+                                    work.wait()
+                                    sender = work._source_rank()
+                                    irecv_ranks.append(sender)
+
+                                # Assert the scalar value "sender" that should be
+                                # equal to the rank of the sender is equal to all
+                                # values in the received tensor.
+                                self.assertTrue(output_tensor.eq(sender).all())
+                    else:
+                        # Send mode
+                        dist.send(tensor, dst)  # recv
+                        dist.send(tensor, dst)  # irecv
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recvAnySource"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from other rank twice.
+                        self.assertEqual(
+                            sum(event.count for event in events),
+                            2 * (dist.get_world_size() - 1),
+                        )
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
+
+                # Each rank would have 2 * (world_size - 1) sends, verify that
+                # globally we receive the same amount on the other end.
+                recv_ranks_tensor = torch.cat(
+                    (torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0
+                )
+                global_recv_ranks = [
+                    torch.empty_like(recv_ranks_tensor)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(global_recv_ranks, recv_ranks_tensor)
+                global_recv_ranks_list = []
+                for tensor in global_recv_ranks:
+                    global_recv_ranks_list += tensor.tolist()
+
+                from itertools import groupby
+
+                global_recv_ranks_list.sort()
+                frequency = [
+                    len(list(group)) for key, group in groupby(global_recv_ranks_list)
+                ]
+                self.assertEqual(dist.get_world_size(), len(frequency))
+                self.assertEqual(
+                    [2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        def test_send_recv_any_source(self):
+            self._test_send_recv_any_source(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        def test_send_recv_any_source_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_send_recv_any_source(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
+            f"{BACKEND} does not support send/recv from any source",
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_any_source_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv_any_source(profiler_ctx=torch_profiler_ctx)
+
+        # SEND RECV WITH TAG
+        def _test_send_recv_with_tag(self, profiler_ctx):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            send_recv_size = 10
+            tensor = _build_tensor(send_recv_size, value=rank)
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                for dst in range(0, world_size):
+                    if dst == rank:
+                        # Recv mode
+                        for src in range(0, world_size):
+                            if src == rank:
+                                continue
+                            output_tensor = _build_tensor(send_recv_size, value=-1)
+                            dist.recv(output_tensor, src, tag=src)
+                            self.assertTrue(output_tensor.eq(src).all())
+                    else:
+                        # Send mode
+                        dist.send(tensor, dst, tag=rank)
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
+                        events = get_profiling_event(event_name, prof)
+                        # Each rank sends/recvs from all other ranks
+                        event_count = sum(e.count for e in events)
+                        expected_event_count = dist.get_world_size() - 1
+                        self.assertEqual(event_count, expected_event_count)
+                        for event in events:
+                            self.assertTrue(event.is_async)
+                            self.assertEqual(event.name, event_name)
+                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_with_tag(self):
+            self._test_send_recv_with_tag(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        def test_send_recv_with_tag_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            return self._test_send_recv_with_tag(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_send_recv_with_tag_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            return self._test_send_recv_with_tag(profiler_ctx=torch_profiler_ctx)
+
+        # ISEND
+        def _test_isend(self, profiler_ctx):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
+            with ctx as prof:
+                if rank == 0:
+                    requests = [
+                        dist.isend(_build_tensor(dest, 10), dest)
+                        for dest in range(1, world_size)
+                    ]
+                    for request in requests:
+                        request.wait()
+                        self.assertTrue(request.is_completed())
+                else:
+                    tensor = _build_tensor(rank, -1)
+                    dist.recv(tensor, 0)
+                    self.assertEqual(tensor, _build_tensor(rank, 10))
+
+                self._barrier()
+
+            if profiler_ctx is not None:
+                backend = dist.get_backend()
+                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
+                    expected_event_name = (
+                        f"{backend}:send" if rank == 0 else f"{backend}:recv"
+                    )
+                    events = get_profiling_event(expected_event_name, prof)
+                    event_count = sum(e.count for e in events)
+                    expected_count = dist.get_world_size() - 1 if rank == 0 else 1
+                    self.assertEqual(expected_count, event_count)
+                    # Event ordering is not guaranteed, so simply ensure the shapes are
+                    # found in the following map.
+                    expected_shapes = {
+                        r: [[r] * 3] for r in range(1, dist.get_world_size())
+                    }
+                    for event in events:
+                        self.assertTrue(event.is_async)
+                        self.assertEqual(event.name, expected_event_name)
+                        if rank == 0:
+                            self.assertTrue(
+                                event.input_shapes in expected_shapes.values()
+                            )
+                        else:
+                            self.assertEqual(event.input_shapes, expected_shapes[rank])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        def test_isend(self):
+            self._test_isend(profiler_ctx=None)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        def test_isend_autograd_profiler(self):
+            autograd_profiler_ctx = _create_autograd_profiler()
+            self._test_isend(profiler_ctx=autograd_profiler_ctx)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support isend"
+        )
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_isend_torch_profiler(self):
+            torch_profiler_ctx = _create_torch_profiler()
+            self._test_isend(profiler_ctx=torch_profiler_ctx)
+
+        # IRECV
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support irecv"
+        )
+        def test_irecv(self):
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+
+            if rank == 0:
+                expected_tensors = [
+                    _build_tensor(src, -1) for src in range(1, world_size)
+                ]
+                requests = [
+                    dist.irecv(expected_tensors[src - 1], src)
+                    for src in range(1, world_size)
+                ]
+
+                for src in range(1, world_size):
+                    requests[src - 1].wait()
+                    self.assertTrue(requests[src - 1].is_completed())
+                    self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
+            else:
+                tensor = _build_tensor(rank, 10)
+                dist.send(tensor, 0)
+
+            self._barrier()
+
+        # BROADCAST
+        def _test_broadcast_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            cuda=False,
+            rank_to_GPU=None,
+            with_options=False,
+        ):
+            for dtype, value, requires_cuda in [
+                (torch.float, -1e-10, False),
+                (torch.double, -1e-100, False),
+                (torch.half, -0.1, True),
+                (torch.int8, -2, False),
+                (torch.uint8, 129, False),
+                (torch.int, -1e5, False),
+                (torch.long, -1e15, False),
+            ]:
+                if requires_cuda and not cuda:
+                    continue
+                for src in group:
+                    expected_tensor = _build_tensor(src + 1, value, dtype)
+                    if cuda:
+                        expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    if rank == src:
+                        if with_options:
+                            opts = dist.BroadcastOptions()
+                            opts.rootTensor = 0
+                            opts.rootRank = src
+                            self.call_dist_op(
+                                ":broadcast",
+                                True,
+                                group_id.broadcast,
+                                [expected_tensor],
+                                opts,
+                            )
+                        else:
+                            self.call_dist_op(
+                                ":broadcast",
+                                False,
+                                dist.broadcast,
+                                expected_tensor,
+                                src,
+                                group_id,
+                            )
+                    else:
+                        tensor = _build_tensor(src + 1, -1, dtype)
+                        if cuda:
+                            tensor = tensor.cuda(rank_to_GPU[rank][0])
+                        if with_options:
+                            opts = dist.BroadcastOptions()
+                            opts.rootTensor = 0
+                            opts.rootRank = src
+                            self.call_dist_op(
+                                ":broadcast", True, group_id.broadcast, [tensor], opts
+                            )
+                        else:
+                            self.call_dist_op(
+                                ":broadcast",
+                                False,
+                                dist.broadcast,
+                                tensor,
+                                src,
+                                group_id,
+                            )
+                        self.assertEqual(tensor.size(), expected_tensor.size())
+                        self.assertEqual(
+                            tensor.ne(expected_tensor).max(), torch.tensor(False)
+                        )
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and Nccl backend supports CUDA allReduce",
+        )
+        @skip_if_no_gpu
+        def test_broadcast_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_broadcast_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_broadcast_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl",
+            "Only NCCL backend supports high priority stream",
+        )
+        @skip_if_no_gpu
+        def test_nccl_high_priority_stream(self):
+            group, _, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+
+            new_port = str(MASTER_PORT + 1)
+            os.environ["MASTER_PORT"] = new_port
+            gen_iterator = dist.rendezvous("env://", rank, dist.get_world_size())
+            store, rank, size = next(gen_iterator)
+            store = dist.PrefixStore(new_port, store)
+
+            opts = dist.ProcessGroupNCCL.Options()
+            opts.is_high_priority_stream = False
+            group_id = dist.ProcessGroupNCCL(store, rank, size, opts)
+
+            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU, True)
+
+        # REDUCE
+        def _test_reduce_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            for src in group:
+                tensor = _build_tensor(src + 1).fill_(
+                    master_value if rank == src else worker_value
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                self.call_dist_op(
+                    ":reduce",
+                    False,
+                    dist.reduce,
+                    tensor,
+                    src,
+                    op,
+                    group_id,
+                    tensor_shapes=[tensor.shape],
+                )
+                if rank == src:
+                    self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_sum_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + 10 * (len(group) - 1),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_small_worldsize
+        def test_reduce_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        # REDUCE TWICE
+        def _test_reduce_twice_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            for src in group:
+                tensors = [
+                    _build_tensor(src + 1).fill_(
+                        master_value if rank == src else worker_value
+                    )
+                    for i in range(2)
+                ]
+                if cuda:
+                    for i in range(2):
+                        tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0])
+                self.call_dist_op(
+                    ":reduce",
+                    False,
+                    dist.reduce,
+                    tensors[0],
+                    src,
+                    op,
+                    group_id,
+                    secondary_op_call=lambda: dist.reduce(
+                        tensors[1], src, op, group_id
+                    ),
+                    tensor_shapes=[tensors[0].shape],
+                )
+                if rank == src:
+                    for tensor in tensors:
+                        self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        def test_reduce_sum_twice(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_reduce_twice_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_sum_cuda_twice(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+            torch.cuda.set_device(device_id)
+            self._test_reduce_twice_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + 10 * (len(group) - 1),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports reduce_scatter_v"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["reduce"],
+            f"{BACKEND} does not support reduce",
+        )
+        @skip_if_no_gpu
+        def test_reduce_scatter_v_cuda(self):
+            self._barrier()
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            input_split_sizes = [src + 1 for src in group]
+            start_len = sum(input_split_sizes[:rank])
+            end_len = start_len + input_split_sizes[rank]
+            sum_len = sum(input_split_sizes)
+            master_value = 2
+            worker_value = 10
+
+            for async_val in [True, False]:
+                tensor = _build_tensor(sum_len, worker_value, device_id=device_id)
+                tensor[start_len:end_len].fill_(master_value)
+                out_tensor = (
+                    torch.empty(
+                        input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                    )
+                    .fill_(-1)
+                    .cuda(device_id)
+                )
+
+                req = dist.reduce_scatter(
+                    out_tensor,
+                    list(torch.split(tensor, input_split_sizes)),
+                    dist.ReduceOp.SUM,
+                    group_id,
+                    async_val,
+                )
+                if async_val:
+                    req.wait()
+
+                expected_value = 2 + (10 * (len(group) - 1))
+                expected_tensor = torch.empty(
+                    input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                )
+                expected_tensor = expected_tensor.fill_(expected_value).cuda(device_id)
+
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        # Test reduce_scatter_tensor accepting single tensor as input
+        def _reduce_scatter_tensor_helper(
+            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
+        ):
+            if cuda:
+                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
+                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
+            tensor_shapes = [tensor_out.shape]
+            self.call_dist_op(
+                ":reduce_scatter_tensor",
+                False,
+                dist.reduce_scatter_tensor,
+                tensor_out,
+                tensor_in,
+                dist.ReduceOp.SUM,
+                group_id,
+                False,
+                expect_event=False,
+                tensor_shapes=tensor_shapes,
+            )
+            return tensor_out
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor"
+        )
+        @skip_if_no_gpu
+        def test_reduce_scatter_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_out = torch.zeros(size, dtype=torch.int64)
+
+            # Concatenated input
+            tensor_in = torch.arange(len(group) * size)
+            tensor_out = self._reduce_scatter_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+            # Check result
+            expected_tensor = torch.arange(rank * size, (rank + 1) * size) * len(group)
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+            # Stacked input
+            tensor_in = torch.reshape(tensor_in, (len(group), size))
+            tensor_out = self._reduce_scatter_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+            # Check result
+            # Should be the same as the result in concatenated case
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        def call_dist_op(
+            self,
+            profiling_title_postfix,
+            is_async,
+            op,
+            *args,
+            expect_event=True,
+            secondary_op_call=None,
+            profile_cuda=False,
+            tensor_shapes=None,
+            **kwargs,
+        ):
+            op_calls = [lambda: op(*args, **kwargs)]
+            if secondary_op_call is not None:
+                op_calls.append(secondary_op_call)
+
+            autograd_profiler_ctx = torch.autograd.profiler.profile(
+                use_cuda=profile_cuda, record_shapes=True
+            )
+
+            # TODO: move this test to use torch.profiler once kineto issues are
+            # fixed internally.
+            with autograd_profiler_ctx:
+                works = [op_call() for op_call in op_calls]
+                if is_async:
+                    for work in works:
+                        work.wait()
+
+            if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS:
+                # We are only interested in the backend's implementation not the dispatcher wrapper.
+                events = get_profiling_event(
+                    dist.get_backend() + profiling_title_postfix, autograd_profiler_ctx
+                )
+                # DETAIL debug mode can use a pg wrapper that issues more collectives
+                # under the hood
+                if dist.get_debug_level() != dist.DebugLevel.DETAIL:
+                    self.assertEqual(len(events), len(op_calls))
+                for e in events:
+                    self.assertTrue(e.is_async)
+                    self.assertEqual(e.count, 1)
+                    self.assertGreaterEqual(e.cpu_time, 0)
+                    # Verify tensor shapes if given
+                    # DETAIL debug mode can use a pg wrapper that issues more collectives
+                    # under the hood
+                    if (
+                        tensor_shapes is not None
+                        and dist.get_debug_level() != dist.DebugLevel.DETAIL
+                    ):
+                        self.assertEqual(
+                            e.input_shapes,
+                            tensor_shapes,
+                            f"event shape: {e.input_shapes} vs tensor {tensor_shapes}",
+                        )
+
+        # ALL REDUCE
+        def _test_all_reduce_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            master_value,
+            worker_value,
+            expected_value,
+            cuda=False,
+            rank_to_GPU=None,
+            dtype=torch.float,
+            async_op=False,
+        ):
+            for src in group:
+                curr_value = master_value if rank == src else worker_value
+
+                tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value)
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                if tensor.dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(tensor).shape]
+                else:
+                    tensor_shapes = [tensor.shape]
+                self.call_dist_op(
+                    ":all_reduce",
+                    async_op,
+                    dist.all_reduce,
+                    tensor,
+                    op,
+                    group_id,
+                    async_op=async_op,
+                    tensor_shapes=tensor_shapes,
+                )
+                # Currently, only Gloo backend has profiling tested with CUDA enabled.
+                # Only run cuda profiling test for one rank to speed up since
+                # running with different src_rank does not affect the correctness.
+                if (
+                    src == 0
+                    and cuda
+                    and dist.get_backend() in CUDA_PROFILING_SUPPORTED_BACKENDS
+                ):
+                    self.call_dist_op(
+                        ":all_reduce",
+                        async_op,
+                        dist.all_reduce,
+                        tensor,
+                        op,
+                        group_id,
+                        async_op=async_op,
+                        profile_cuda=True,
+                        tensor_shapes=tensor_shapes,
+                    )
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum_async(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                async_op=True,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda_async(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+                async_op=True,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_sum_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                complex(2, 3),
+                complex(10, 11),
+                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_complex_unsupported_ops(self):
+            unsupported_ops = [
+                dist.ReduceOp.MAX,
+                dist.ReduceOp.MIN,
+                dist.ReduceOp.PRODUCT,
+                dist.ReduceOp.BAND,
+                dist.ReduceOp.BOR,
+                dist.ReduceOp.BXOR,
+            ]
+            _group, group_id, _rank = self._init_global_test()
+            for unsupported_op in unsupported_ops:
+                with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
+                    dist.all_reduce(
+                        _build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo" and BACKEND != "nccl",
+            "Only Gloo and NCCL backends will have CUDA allReduce tested",
+        )
+        @skip_if_no_gpu
+        def test_all_reduce_sum_cuda_complex(self):
+            torch.cuda.set_device(self.rank)
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                complex(2, 3),
+                complex(10, 11),
+                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
+                True,
+                rank_to_GPU,
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                2,
+                10,
+                2 + (10 * (len(group) - 1)),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                2,
+                10,
+                reduce(operator.mul, [10] * (len(group) - 1), 2),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
+            )
+
+        # SPARSE ALL REDUCE
+        def _test_sparse_all_reduce_sum(self, fn):
+            _group, group_id, rank = self._init_global_test()
+
+            tests = simple_sparse_reduce_tests(
+                rank, dist.get_world_size(), num_inputs=1
+            )
+            for inputs, outputs in tests:
+                tensors = [fn(input) for input in inputs]
+                dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id)
+                self.assertEqual(tensors[0], outputs[0])
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
+        )
+        def test_sparse_all_reduce_sum(self):
+            self._test_sparse_all_reduce_sum(lambda t: t)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
+        )
+        @skip_if_no_gpu
+        def test_sparse_all_reduce_sum_cuda(self):
+            self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda())
+
+        # ALL REDUCE - COALESCED
+        @staticmethod
+        def _all_reduce_coalesced_sum_test_cases(group_size):
+            return (
+                [2, 3, complex(2, 3)],
+                [10, 11, complex(10, 11)],
+                [
+                    2 + 10 * (group_size - 1),
+                    3 + 11 * (group_size - 1),
+                    complex(2, 3) + complex(10, 11) * (group_size - 1),
+                ],
+                [torch.float, torch.float, torch.cfloat],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_product_test_cases(group_size):
+            return (
+                [1, 2],
+                [3, 4],
+                [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)],
+                [torch.float, torch.float],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_min_test_cases(group_size):
+            return (
+                [1, 4],
+                [2, 3],
+                [1, 3],
+                [torch.float, torch.float],
+            )
+
+        @staticmethod
+        def _all_reduce_coalesced_max_test_cases(group_size):
+            return (
+                [1, 4],
+                [2, 3],
+                [2, 4],
+                [torch.float, torch.float],
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_reduce_coalesced_max_complex_unsupported(self):
+            _group, group_id, _rank = self._init_global_test()
+            with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
+                dist.all_reduce_coalesced(
+                    [_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id
+                )
+
+        def _test_all_reduce_coalesced_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            op,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
+            test_case_func = {
+                dist.ReduceOp.SUM: self._all_reduce_coalesced_sum_test_cases,
+                dist.ReduceOp.PRODUCT: self._all_reduce_coalesced_product_test_cases,
+                dist.ReduceOp.MIN: self._all_reduce_coalesced_min_test_cases,
+                dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases,
+            }[op]
+
+            master_values, worker_values, expected_values, dtypes = test_case_func(
+                len(group)
+            )
+
+            for src in group:
+                curr_values = master_values if rank == src else worker_values
+                tensors = [
+                    _build_tensor(src + 1, val, dtype=dtype)
+                    for dtype, val in zip(dtypes, curr_values)
+                ]
+                if cuda:
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                tensor_shapes = []
+                for tensor in tensors:
+                    if tensor.dtype == torch.complex64:
+                        tensor_shapes.append(torch.view_as_real(tensor).shape)
+                    else:
+                        tensor_shapes.append(tensor.shape)
+                self.call_dist_op(
+                    ":all_reduce",
+                    False,
+                    dist.all_reduce_coalesced,
+                    tensors,
+                    op,
+                    group_id,
+                    tensor_shapes=tensor_shapes,
+                )
+                expected_tensors = [
+                    _build_tensor(src + 1, expected_value, dtype=dtype)
+                    for dtype, expected_value in zip(dtypes, expected_values)
+                ]
+                self.assertEqual(tensors, expected_tensors)
+
+            self._barrier()
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_sum(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.SUM,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_product(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_min(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.MIN,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_max(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_sum(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_product(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_min(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MIN, cuda=False, rank_to_GPU=None
+            )
+
+        @skip_if_small_worldsize
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_group_max(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_sum(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_product(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.PRODUCT,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_min(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group,
+                group_id,
+                rank,
+                dist.ReduceOp.MIN,
+                cuda=False,
+                rank_to_GPU=None,
+            )
+
+        @require_backend_is_available({"gloo"})
+        def test_all_reduce_coalesced_full_group_max(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_reduce_coalesced_helper(
+                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
+            )
+
+        # SCATTER
+        def _test_scatter_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, -1, dtype=dtype)
+                expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype)
+                tensors = (
+                    [_build_tensor(dest + 1, i, dtype=dtype) for i in group]
+                    if rank == dest
+                    else []
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                if dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(t).shape for t in tensors]
+                else:
+                    tensor_shapes = [t.shape for t in tensors]
+                self.call_dist_op(
+                    ":scatter",
+                    False,
+                    dist.scatter,
+                    tensor,
+                    src=dest,
+                    scatter_list=tensors,
+                    group=group_id,
+                    expect_event=False,
+                    tensor_shapes=tensor_shapes,
+                )
+                self.assertEqual(tensor, expected_tensor)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_checks(self):
+            group, _group_id, rank = self._init_global_test()
+            one = torch.ones([1])
+
+            # Specify scatter_list argument only on source rank.
+            output = one.clone() * -1
+            if rank == 0:
+                scatter_list = [one.clone() * i for i in group]
+                dist.scatter(output, src=0, scatter_list=scatter_list)
+            else:
+                dist.scatter(output, src=0)
+            self.assertEqual(output, one * rank)
+
+            # Don't specify src argument.
+            output = one.clone() * -1
+            if rank == 0:
+                scatter_list = [one.clone() * i for i in group]
+                dist.scatter(output, scatter_list=scatter_list)
+            else:
+                dist.scatter(output)
+            self.assertEqual(output, one * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_scatter_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_scatter_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_scatter_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @skip_if_small_worldsize
+        def test_scatter_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_scatter_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_scatter_helper(group, group_id, rank)
+
+        # GATHER
+        def _test_gather_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, rank)
+                tensors = (
+                    [_build_tensor(dest + 1, -1) for i in group] if rank == dest else []
+                )
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                self.call_dist_op(
+                    ":gather",
+                    False,
+                    dist.gather,
+                    tensor,
+                    dst=dest,
+                    gather_list=tensors,
+                    group=group_id,
+                    expect_event=False,
+                    tensor_shapes=[tensors[0].shape] if len(tensors) > 0 else None,
+                )
+                if rank == dest:
+                    expected_tensors = [_build_tensor(dest + 1, i) for i in group]
+                    for t1, t2 in zip(tensors, expected_tensors):
+                        self.assertEqual(t1, t2)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather_checks(self):
+            group, _group_id, rank = self._init_global_test()
+            one = torch.ones([1])
+
+            # Specify gather_list argument only on destination rank.
+            if rank == 0:
+                gather_list = [one.clone() for _ in group]
+                dist.gather(one * rank, dst=0, gather_list=gather_list)
+                for i in group:
+                    self.assertEqual(gather_list[i], one * i)
+            else:
+                dist.gather(one * rank, dst=0)
+
+            # Don't specify dst argument.
+            if rank == 0:
+                gather_list = [one.clone() for _ in group]
+                dist.gather(one * rank, gather_list=gather_list)
+                for i in group:
+                    self.assertEqual(gather_list[i], one * i)
+            else:
+                dist.gather(one * rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA gather"
+        )
+        @skip_if_no_gpu
+        def test_gather_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_gather_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @skip_if_small_worldsize
+        def test_gather_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        def test_gather_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_gather_helper(group, group_id, rank)
+
+        # ALL GATHER
+        def _test_all_gather_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            for dest in group:
+                tensor = _build_tensor(dest + 1, rank, dtype=dtype)
+                tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
+                allgather = dist.all_gather
+                if cuda:
+                    tensor = tensor.cuda(rank_to_GPU[rank][0])
+                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
+                if tensors[0].dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(tensors[0]).shape]
+                else:
+                    tensor_shapes = [tensors[0].shape]
+                self.call_dist_op(
+                    ":all_gather",
+                    False,
+                    allgather,
+                    tensors,
+                    tensor,
+                    group_id,
+                    False,
+                    tensor_shapes=tensor_shapes,
+                )
+
+                expected_tensors = [
+                    _build_tensor(dest + 1, i, dtype=dtype) for i in group
+                ]
+                for t1, t2 in zip(tensors, expected_tensors):
+                    self.assertEqual(t1, t2)
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_gather_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "Nccl does not support CPU tensors"
+        )
+        def test_all_gather_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_gather_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports all_gather_v"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_v_cuda(self):
+            self._barrier()
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            device_id = rank_to_GPU[rank][0]
+
+            output_split_sizes = [dst + 1 for dst in group]
+            sum_len = sum(output_split_sizes)
+            value = 2
+
+            for async_val in [True, False]:
+                tensor = (
+                    torch.empty(
+                        output_split_sizes[rank], sum_len, sum_len, dtype=torch.float
+                    )
+                    .fill_(value)
+                    .cuda(device_id)
+                )
+                out_tensor = _build_tensor(sum_len, -1, device_id=device_id)
+
+                req = dist.all_gather(
+                    list(torch.split(out_tensor, output_split_sizes)),
+                    tensor,
+                    group_id,
+                    async_val,
+                )
+                if async_val:
+                    req.wait()
+
+                expected_value = value
+                expected_tensor = _build_tensor(
+                    sum_len, expected_value, device_id=device_id
+                )
+
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        # Test all_gather accepting single tensor as output
+        def _all_gather_into_tensor_helper(
+            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
+        ):
+            if cuda:
+                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
+                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
+            if tensor_out.dtype == torch.complex64:
+                tensor_shapes = [torch.view_as_real(tensor_in).shape]
+            else:
+                tensor_shapes = [tensor_in.shape]
+            self.call_dist_op(
+                ":all_gather_into_tensor",
+                False,
+                dist.all_gather_into_tensor,
+                tensor_out,
+                tensor_in,
+                group_id,
+                False,
+                expect_event=False,
+                tensor_shapes=tensor_shapes,
+            )
+            return tensor_out
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_into_cat_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_in = torch.ones([size, size]) * rank
+            # Concatenated output
+            tensor_out = torch.ones([len(group) * size, size]) * (-1)
+            tensor_out = self._all_gather_into_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+
+            # Check result
+            # Concatenate all blocks into a bigger tensor
+            expected_tensor = torch.cat([torch.ones([size, size]) * i for i in group])
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
+        )
+        @skip_if_no_gpu
+        def test_all_gather_into_stack_tensor_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            size = 2
+            tensor_in = torch.ones([size, size]) * rank
+            # Stacked output
+            tensor_out = torch.ones([len(group), size, size]) * (-1)
+            tensor_out = self._all_gather_into_tensor_helper(
+                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
+            )
+
+            # Check result
+            # Stack all blocks into a bigger tensor
+            expected_tensor = torch.stack([torch.ones([size, size]) * i for i in group])
+            self.assertEqual(tensor_out, expected_tensor)
+            self._barrier()
+
+        def _run_all_gather_coalesced_and_verify(
+            self, output_tensor_lists, input_tensors, expected_tensors, group_id
+        ):
+            """
+            Helper that runs all_gather_coalesced and returns true if output
+            matches expectations.
+            """
+            tensor_shapes = []
+            for input_tensor in input_tensors:
+                if input_tensor.dtype == torch.complex64:
+                    tensor_shapes.append(torch.view_as_real(input_tensor).shape)
+                else:
+                    tensor_shapes.append(input_tensor.shape)
+            self.call_dist_op(
+                ":all_gather",
+                False,
+                dist.all_gather_coalesced,
+                output_tensor_lists,
+                input_tensors,
+                group_id,
+                tensor_shapes=tensor_shapes,
+            )
+
+            for l1, l2 in zip(output_tensor_lists, expected_tensors):
+                for t1, t2 in zip(l1, l2):
+                    if not torch.equal(t1, t2):
+                        return False
+            return True
+
+        def _test_all_gather_coalesced_helper(
+            self, group, group_id, rank, dtype=torch.float
+        ):
+            # TODO: Instead we should probably go through _rank_not_in_group
+            # mechanism to disable sending tensors
+            if group_id is not None:
+                for test_case_id in range(2, 5):
+                    # Make sure we create tensors of incompatible sizes, e.g.
+                    # [1], [2x2], [3x3x3] ... to be sent in one batch
+                    input_tensors = [
+                        _build_multidim_tensor(
+                            tensor_id, tensor_id, rank + tensor_id, dtype=dtype
+                        )
+                        for tensor_id in range(1, test_case_id)
+                    ]
+                    output_tensor_lists = [
+                        [
+                            _build_multidim_tensor(
+                                tensor_id, tensor_id, -1, dtype=dtype
+                            )
+                            for tensor_id in range(1, test_case_id)
+                        ]
+                        for _ in group
+                    ]
+                    expected_tensors = [
+                        [
+                            _build_multidim_tensor(
+                                tensor_id, tensor_id, rank_iter + tensor_id, dtype=dtype
+                            )
+                            for tensor_id in range(1, test_case_id)
+                        ]
+                        for rank_iter in group
+                    ]
+                    assert self._run_all_gather_coalesced_and_verify(
+                        output_tensor_lists, input_tensors, expected_tensors, group_id
+                    ), "output tensors do not match expected outputs"
+
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_simple(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_coalesced_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_gather_coalesced_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
+            f"{BACKEND} does not support all_gather_coalesced",
+        )
+        def test_all_gather_coalesced_with_empty(self):
+            group, group_id, rank = self._init_global_test()
+            input_tensors = [
+                rank * torch.ones([2, 2]),
+                torch.ones([0]),
+                (rank + 1) * torch.ones([3, 3]),
+                torch.ones([0]),
+                torch.ones([0]),
+            ]
+            output_tensors_lists = [
+                [
+                    -1 * torch.ones([2, 2]),
+                    -1 * torch.ones([0]),
+                    -1 * torch.ones([3, 3]),
+                    -1 * torch.ones([0]),
+                    -1 * torch.ones([0]),
+                ]
+                for _ in group
+            ]
+            expected_tensors = [
+                [
+                    r * torch.ones([2, 2]),
+                    torch.ones([0]),
+                    (r + 1) * torch.ones([3, 3]),
+                    torch.ones([0]),
+                    torch.ones([0]),
+                ]
+                for r in group
+            ]
+            assert self._run_all_gather_coalesced_and_verify(
+                output_tensors_lists, input_tensors, expected_tensors, group_id
+            )
+            self._barrier()
+
+        # AllToAll
+        def _test_all_to_all_single_equal_split_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_tensor = torch.ones([size, size], dtype=dtype) * rank
+                expected_tensor = torch.cat(
+                    [torch.ones([1, size], dtype=dtype) * i for i in group]
+                )
+                out_tensor = torch.ones([size, size], dtype=dtype) * -1
+                if cuda:
+                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
+                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
+                if dtype == torch.complex64:
+                    tensor_shapes = [torch.view_as_real(in_tensor).shape]
+                else:
+                    tensor_shapes = [in_tensor.shape]
+                self.call_dist_op(
+                    ":all_to_all",
+                    False,
+                    dist.all_to_all_single,
+                    out_tensor,
+                    in_tensor,
+                    group=group_id,
+                    tensor_shapes=tensor_shapes,
+                )
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        def _test_all_to_all_single_unequal_split_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_splits = [i + 1 for i in group]
+                out_splits = [rank + 1 for _ in group]
+                in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank
+                out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype)
+                expected_tensor = torch.cat(
+                    [torch.ones([rank + 1, size], dtype=dtype) * i for i in group]
+                )
+                if cuda:
+                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
+                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
+                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
+                dist.all_to_all_single(
+                    out_tensor, in_tensor, out_splits, in_splits, group=group_id
+                )
+                self.assertEqual(out_tensor, expected_tensor)
+            self._barrier()
+
+        def _test_all_to_all_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            cuda=False,
+            rank_to_GPU=None,
+            dtype=torch.float,
+        ):
+            if group_id is not None:
+                size = len(group)
+                in_splits = [i + 1 for i in group]
+                in_tensors = [
+                    torch.ones([in_splits[i], size], dtype=dtype) * rank
+                    for i, _ in enumerate(group)
+                ]
+                out_tensors = [
+                    torch.ones([(rank + 1), size], dtype=dtype) for _ in group
+                ]
+                expected_tensors = [
+                    torch.ones([rank + 1, size], dtype=dtype) * i for i in group
+                ]
+                if cuda:
+                    in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
+                    expected_tensors = [
+                        t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
+                    ]
+                    out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
+                dist.all_to_all(out_tensors, in_tensors, group=group_id)
+                for t1, t2 in zip(out_tensors, expected_tensors):
+                    self.assertEqual(t1, t2)
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_equal_split_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_single_unequal_split_helper(
+                group, group_id, rank, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+                dtype=torch.cfloat,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all_complex(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_to_all_helper(group, group_id, rank, dtype=torch.cfloat)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_cuda_complex(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(
+                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_single_equal_split_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        def test_all_to_all_single_equal_split_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_single_unequal_split_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        @skip_if_small_worldsize
+        def test_all_to_all_single_unequal_split_group_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        @skip_if_small_worldsize
+        def test_all_to_all_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_equal_split_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_equal_split_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_equal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
+        )
+        def test_all_to_all_single_unequal_split_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_no_gpu
+        def test_all_to_all_single_unequal_split_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_single_unequal_split_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi", "Only MPI supports all_to_all"
+        )
+        def test_all_to_all_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_all_to_all_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
+        )
+        @skip_if_rocm_multiprocess
+        def test_all_to_all_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
+        # BARRIER
+        def _test_barrier_helper(
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None
+        ):
+            WAIT_TIME = 0.3  # seconds
+
+            for dest in group:
+                expected_time = torch.DoubleTensor(1).fill_(0.0)
+                if cuda:
+                    expected_time = expected_time.cuda(rank_to_GPU[rank][0])
+                if dest == rank:
+                    expected_time.fill_(time.time() + WAIT_TIME)
+                    dist.broadcast(expected_time, dest, group_id)
+                    time.sleep(WAIT_TIME + 0.1)  # sleep a little bit longer
+                    dist.barrier(group_id)
+                else:
+                    dist.broadcast(expected_time, dest, group_id)
+                    dist.barrier(group_id)
+                    self.assertGreaterAlmostEqual(
+                        float(time.time()),
+                        float(expected_time[0]),
+                        msg=f"destination rank: {dest:d}, my rank: {rank:d}"
+                        + " (if you see this failure, please report in #14554)",
+                    )
+
+            # Use higher timeout for the instance where the test runs
+            # against a subgroup and uses a CUDA tensor for expected time.
+            # The CUDA initialization for the participating processes can
+            # take long enough for the barrier timeout to trigger on the
+            # process that doesn't participate in the group.
+            self._barrier(timeout=20)
+
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        def test_barrier_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        def test_barrier_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_if_small_worldsize
+        @skip_if_no_gpu
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
+        )
+        def test_barrier_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier_group(self):
+            group, group_id, rank = self._init_group_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND in DistTestCases.skip_collective["cpu barrier"],
+            f"{BACKEND} does not support CPU barrier",
+        )
+        def test_barrier_full_group(self):
+            group, group_id, rank = self._init_full_group_test()
+            self._test_barrier_helper(group, group_id, rank)
+
+        def _model_step(self, model):
+            for param in model.parameters():
+                if param.grad is not None:
+                    with torch.no_grad():
+                        param += param.grad
+                    param.grad = None
+
+        def _model_step_with_zero_grad(self, model):
+            for param in model.parameters():
+                if param.grad is not None:
+                    with torch.no_grad():
+                        param += param.grad
+                    param.grad.requires_grad_(False)
+                    param.grad.zero_()
+
+        def _prepare_dummy_data(self, local_bs):
+            # global_bs for DDP should be divisible by WORLD_SIZE
+            world_size = int(os.environ["WORLD_SIZE"])
+            global_bs = world_size * local_bs
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 4)
+            loss = nn.MSELoss()
+            return global_bs, input_cpu, target, loss
+
+        # END TO END TEST FOR DISTRIBUTEDDATAPARALLEL
+        def _test_DDP_helper(
+            self, model, input_var, target, loss, scale_factor=1.0, memory_format=None
+        ):
+            model.train()
+            output = model(input_var)
+            l = loss(output, target) * scale_factor
+            l.backward()
+            if memory_format is not None:
+                self.assertTrue(output.is_contiguous(memory_format=memory_format))
+
+        def _assert_equal_param(self, param_gpu, param_DDP):
+            self.assertEqual(len(param_gpu), len(param_DDP))
+            for p_gpu, p_DDP in zip(param_gpu, param_DDP):
+                self.assertEqual(p_gpu, p_DDP)
+
+        def _test_DDP_niter(
+            self,
+            model_base,
+            model_DDP,
+            input,
+            target,
+            loss,
+            local_bs,
+            rank,
+            batch_size,
+            test_save,
+            offset=None,
+            world_size=0,
+            zero_grad=False,
+            memory_format=None,
+            n_iter=5,
+        ):
+            for idx in range(n_iter):
+                # single cpu/gpu training
+                self._test_DDP_helper(
+                    model_base, input, target, loss, memory_format=memory_format
+                )
+
+                if offset is None:
+                    offset = rank * local_bs
+
+                # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    world_size * local_bs / batch_size if world_size != 0 else 1,
+                    memory_format=memory_format,
+                )
+
+                # Update weights and run a second iteration to shake out errors
+                if zero_grad:
+                    self._model_step_with_zero_grad(model_base)
+                    self._model_step_with_zero_grad(model_DDP)
+                else:
+                    self._model_step(model_base)
+                    self._model_step(model_DDP)
+                self._assert_equal_param(
+                    list(model_base.parameters()), list(model_DDP.module.parameters())
+                )
+
+                # Shuffle the input so that DDP input is different
+                input = input[torch.randperm(batch_size)]
+
+                # save the model in the middle and reload
+                if test_save and idx == 2 and INIT_METHOD.startswith("file://"):
+                    with tempfile.NamedTemporaryFile() as tmp:
+                        if sys.platform == "win32":
+                            torch.save(model_DDP, tmp)
+                            tmp.seek(0)
+                            # weights_only=False as this is legacy code that saves the model
+                            model_DDP = torch.load(tmp, weights_only=False)
+                        else:
+                            torch.save(model_DDP, tmp.name)
+                            # weights_only=False as this is legacy code that saves the model
+                            model_DDP = torch.load(tmp.name, weights_only=False)
+
+            with tempfile.TemporaryFile() as tmp_file:
+                torch.save(model_DDP, tmp_file)
+                tmp_file.seek(0)
+                # weights_only=False as this is legacy code that saves the model
+                saved_model = torch.load(tmp_file, weights_only=False)
+            for k in model_DDP.state_dict():
+                self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k])
+
+        def _test_DistributedDataParallel(
+            self,
+            gpu_subset,
+            rank,
+            output_device=None,
+            gradient_as_bucket_view=False,
+            static_graph=False,
+            set_static_graph_twice=False,
+        ):
+            # Run a simple end to end DDP model, use result of single node model
+            # as baseline
+
+            # cpu training setup
+            model = DDP_NET
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpu_subset[0])
+
+            # DDP training setup
+            model_DDP = copy.deepcopy(model)
+            model_DDP.cuda(gpu_subset[0])
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP,
+                device_ids=gpu_subset,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+                static_graph=static_graph,
+            )
+
+            if set_static_graph_twice:
+                model_DDP._set_static_graph()
+
+            # test serializable/unserializable
+            with tempfile.NamedTemporaryFile() as tmp:
+                if sys.platform == "win32":
+                    torch.save(model_DDP, tmp)
+                    tmp.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp, weights_only=False)
+                else:
+                    torch.save(model_DDP, tmp.name)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp.name, weights_only=False)
+
+            # dummy data initialization
+            local_bs = len(gpu_subset)
+            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_cpu.cuda(gpu_subset[0]),
+                target.cuda(gpu_subset[0]),
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+            )
+            self._barrier()
+
+        def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False):
+            # Run a simple end to end DDP-CPU model, use result of single node
+            # model as baseline
+            _group, _group_id, rank = self._init_global_test()
+
+            # cpu training setup
+            model_base = DDP_NET
+
+            # DDP-CPU training setup
+            model_DDP = copy.deepcopy(model_base)
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP, gradient_as_bucket_view=gradient_as_bucket_view
+            )
+
+            # dummy data initialization
+            local_bs = 2
+            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_base,
+                model_DDP,
+                input_cpu,
+                target,
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                False,
+                zero_grad=True,
+            )
+            self._barrier()
+
+            return model_DDP
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_DistributedDataParallelCPU(self):
+            self._test_DistributedDataParallelCPU()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_DistributedDataParallelCPU_grad_is_view(self):
+            self._test_DistributedDataParallelCPU(gradient_as_bucket_view=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_DistributedDataParallel_requires_grad(self):
+            # a module without gradients shouldn't be accepted
+            self.assertRaises(
+                RuntimeError, lambda: nn.parallel.DistributedDataParallel(nn.Module())
+            )
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_zero_output_features(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10)
+                    self.relu = nn.ReLU()
+                    self.net2 = nn.Linear(10, 0)
+
+            model = ToyModel().to(self.rank)
+            nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
+
+        @skip_but_pass_in_sandcastle_if(BACKEND == "nccl", "Gloo-only test")
+        def test_ddp_create_graph(self):
+            class Model(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.p = nn.Parameter(torch.tensor(1.0))
+
+                def forward(self):
+                    return self.p.pow(2)
+
+            model = Model()
+            ddp_model = torch.nn.parallel.DistributedDataParallel(model)
+            for _ in range(6):
+                # Verify DDP doesn't throw when ran with create_graph=True.
+                # Although we do warn about potential issues, please see
+                # https://github.com/pytorch/pytorch/issues/63929 for details.
+                ddp_model().backward(create_graph=True)
+                # grad tensors should require grad.
+                self.assertTrue(
+                    all(param.requires_grad for param in ddp_model.parameters())
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_DistributedDataParallel_non_default_stream(self):
+            stream = torch.cuda.Stream(self.rank)
+            rank = self.rank
+            with torch.cuda.stream(stream):
+                net = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank]
+                )
+                for i in range(1000):
+                    # Clear gradients manually
+                    grad = net.module.weight.grad
+                    if grad is not None:
+                        grad.requires_grad_(False)
+                        grad.zero_()
+                    # Forward + BW
+                    batch = torch.tensor([rank]).float().cuda(rank)
+                    loss = net(batch).sum()
+                    loss.backward()
+                    # For each worker, the gradient on the weight should be worker_rank.
+                    grad = net.module.weight.grad
+                    avg = grad.clone()
+                    # All-reducing the gradient averages should give us the gradient
+                    # average. If not, then one of the workers has not correctly
+                    # written back the averaged gradient before this all-reduce call.
+                    dist.all_reduce(avg)
+                    world_size = int(os.environ["WORLD_SIZE"])
+                    avg.div_(world_size)
+                    expected_grad = sum(i for i in range(world_size)) / world_size
+                    self.assertEqual(
+                        avg[0, 0],
+                        expected_grad,
+                        msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_comm_hook_logging(self):
+            hooks = [
+                default.allreduce_hook,
+                default.fp16_compress_hook,
+                powerSGD.powerSGD_hook,
+                powerSGD.batched_powerSGD_hook,
+                quantization_hooks.quantization_pertensor_hook,
+                quantization_hooks.quantization_perchannel_hook,
+            ]
+
+            cpp_builtin_hooks = [
+                dist.BuiltinCommHookType.ALLREDUCE,
+                dist.BuiltinCommHookType.FP16_COMPRESS,
+            ]
+
+            for hook in hooks:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                    device_ids=[self.rank],
+                )
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                # Hook not registered yet, so should be empty
+                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+                ddp_model.register_comm_hook(None, hook)
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                self.assertEqual(ddp_logging_data.get("comm_hook"), hook.__qualname__)
+
+            for hook in cpp_builtin_hooks:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                    device_ids=[self.rank],
+                )
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                # Hook not registered yet, so should be empty
+                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+                ddp_model._register_builtin_comm_hook(hook)
+                ddp_logging_data = ddp_model._get_ddp_logging_data()
+                self.assertEqual(ddp_logging_data.get("comm_hook"), str(hook))
+
+            # No hook registered
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = ddp_model._get_ddp_logging_data()
+            # Hook not registered yet, so should be empty
+            self.assertEqual(ddp_logging_data.get("comm_hook"), None)
+            # After second forward pass, hook should still be empty string
+            for _ in range(2):
+                inp = torch.ones(1, 1, device=self.rank)
+                loss = ddp_model(inp).sum()
+                loss.backward()
+
+            ddp_logging_data = ddp_model._get_ddp_logging_data()
+            # Note: DETAIL debug mode logs DDP logging data to stdout and
+            # thus accesses std::map, which fills in a default value for the
+            # type if it didn't exist.
+            self.assertEqual(ddp_logging_data.get("comm_hook", ""), "")
+
+        def _test_ddp_hook_with_optimizer_parity(
+            self,
+            grad_as_bucket_view,
+            static_graph,
+            optim_cls,
+            optimize_subset,
+            *functional_optim_args,
+            **functional_optim_kwargs,
+        ):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            models_to_test = [
+                (LargeNet(), torch.randn(1, 1000).cuda()),
+            ]
+            if HAS_TORCHVISION:
+                models_to_test.append(
+                    (torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda())
+                )
+            for model, inp in models_to_test:
+                # Enable determinism in cudnn operators
+                with torch.backends.cudnn.flags(
+                    enabled=True, deterministic=True, benchmark=False
+                ):
+                    # Create DDP model that runs optimizer in fused fashion.
+                    ddp_model_with_optimizer_hook = (
+                        torch.nn.parallel.DistributedDataParallel(
+                            copy.deepcopy(model).cuda(),
+                            device_ids=[self.rank],
+                            gradient_as_bucket_view=grad_as_bucket_view,
+                            static_graph=static_graph,
+                        )
+                    )
+
+                    # Create DDP model with no hook that does optimizer after
+                    # backward.
+                    ddp_model_with_no_hook = torch.nn.parallel.DistributedDataParallel(
+                        copy.deepcopy(model).cuda(),
+                        device_ids=[self.rank],
+                        gradient_as_bucket_view=grad_as_bucket_view,
+                        static_graph=static_graph,
+                    )
+                    hook_params = ddp_model_with_optimizer_hook.parameters()
+                    no_hook_params = ddp_model_with_no_hook.parameters()
+                    if optimize_subset:
+                        hook_params = list(hook_params)
+                        no_hook_params = list(no_hook_params)
+                        self.assertGreater(len(hook_params), 0)
+                        hook_params = [hook_params[0]]
+                        no_hook_params = [no_hook_params[0]]
+
+                    # Register a fused optimizer that will run optimizer in step
+                    # with allreduce.
+
+                    if optimize_subset:
+                        # API where optim_params is specified.
+                        ddp_model_with_optimizer_hook._register_fused_optim(
+                            optim_cls,
+                            *functional_optim_args,
+                            optim_params=hook_params,
+                            **functional_optim_kwargs,
+                        )
+                    else:
+                        # API where optim_params is omitted
+                        ddp_model_with_optimizer_hook._register_fused_optim(
+                            optim_cls,
+                            *functional_optim_args,
+                            **functional_optim_kwargs,
+                        )
+
+                    optimizer_no_hook = optim_cls(
+                        no_hook_params,
+                        *functional_optim_args,
+                        **functional_optim_kwargs,
+                    )
+
+                    # Verify parameters are equal initially.
+                    for hook_param, allreduce_param in zip(
+                        ddp_model_with_optimizer_hook.parameters(),
+                        ddp_model_with_no_hook.parameters(),
+                    ):
+                        self.assertEqual(hook_param, allreduce_param)
+
+                    # Save old parameters to later verify optimizer modified them.
+                    opt_hook_init_params = copy.deepcopy(
+                        list(ddp_model_with_optimizer_hook.parameters())
+                    )
+
+                    # Run optimizer with hook model.
+                    for _ in range(6):
+                        ddp_model_with_optimizer_hook.zero_grad()
+                        out = ddp_model_with_optimizer_hook(inp)
+                        loss = out.sum()
+                        loss.backward()
+
+                    dist.barrier()
+
+                    # Run regular model.
+                    for _ in range(6):
+                        ddp_model_with_no_hook.zero_grad()
+                        out = ddp_model_with_no_hook(inp)
+                        loss = out.sum()
+                        loss.backward()
+                        optimizer_no_hook.step()
+
+                    dist.barrier()
+
+                    # Now verify parameters are equal.
+                    for hook_param, allreduce_param in zip(
+                        ddp_model_with_optimizer_hook.parameters(),
+                        ddp_model_with_no_hook.parameters(),
+                    ):
+                        self.assertEqual(hook_param, allreduce_param)
+
+                    # Verify optimizer modified appropriate parameter set,
+                    # otherwise they'd be trivially equal above.
+                    if optimize_subset:
+                        self.assertNotEqual(
+                            opt_hook_init_params[0],
+                            next(iter(ddp_model_with_optimizer_hook.parameters())),
+                        )
+                        # Untouched params should be equal
+                        self.assertEqual(
+                            opt_hook_init_params[1:],
+                            list(ddp_model_with_optimizer_hook.parameters())[1:],
+                        )
+                    else:
+                        self.assertNotEqual(
+                            opt_hook_init_params,
+                            list(ddp_model_with_optimizer_hook.parameters()),
+                        )
+                    dist.barrier()
+
+        """
+        # Commenting out the following 3 tests as they cause Sandcastle jobs to fail
+        # Failure signature:
+        # AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adamw
+
+        from torch.testing._internal.common_utils import parametrize
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("grad_as_bucket_view", [True, False])
+        @parametrize("static_graph", [True, False])
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_adamw(
+            self,
+            grad_as_bucket_view,
+            static_graph,
+            optimize_subset,
+        ):
+            adamw_lr = 1e-2
+            adamw_betas = (0.9, 0.99)
+            adamw_eps = 1e-6
+            self._test_ddp_hook_with_optimizer_parity(
+                grad_as_bucket_view,
+                static_graph,
+                torch.optim.AdamW,
+                optimize_subset,
+                adamw_lr,
+                betas=adamw_betas,
+                eps=adamw_eps,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_adam(self, optimize_subset):
+            adam_lr = 1e-2
+            adam_betas = (0.9, 0.99)
+            adam_eps = 1e-6
+            self._test_ddp_hook_with_optimizer_parity(
+                True,  # grad as bucket view
+                False,  # static graph
+                torch.optim.Adam,
+                optimize_subset,
+                adam_lr,
+                betas=adam_betas,
+                eps=adam_eps,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl" or BACKEND == "ucc",
+            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
+        )
+        @skip_if_lt_x_gpu(2)
+        @parametrize("optimize_subset", [True, False])
+        def test_ddp_hook_with_optimizer_parity_sgd(self, optimize_subset):
+            sgd_lr = 1e-2
+            sgd_momentum = 0.9
+            sgd_weight_decay = 0.01
+            # Not testing grad_as_bucket_view and static_graph as they are
+            # tested in AdamW test above.
+            self._test_ddp_hook_with_optimizer_parity(
+                True,  # grad as bucket view
+                False,  # static_graph
+                torch.optim.SGD,
+                optimize_subset,
+                sgd_lr,
+                momentum=sgd_momentum,
+                weight_decay=sgd_weight_decay,
+            )
+        """
+
+        @skip_if_lt_x_gpu(2)
+        def test_get_data_parallel_params(self):
+            torch.cuda.set_device(self.rank)
+            model = TwoLinLayerNet().cuda()
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            params_to_ignore = ["a.weight"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model, params_to_ignore
+            )
+            torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
+            dp_params = (
+                torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
+                    model, named_params=True
+                )
+            )
+            for name, _ in dp_params:
+                self.assertNotEqual(f"module.{params_to_ignore[0]}", name)
+
+            # test named_params=False, just check if returns the expected
+            # no of parameters.
+            num_ddp_params = len(list(model.parameters())) - 1
+            count = 0
+            dp_params = (
+                torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
+                    model, named_params=False
+                )
+            )
+            for _ in dp_params:
+                count += 1
+            self.assertEqual(count, num_ddp_params)
+
+        def _test_ddp_apply_optim_in_backward(
+            self,
+            optim_cls,
+            optim_kwargs,
+            init_before,
+            gradient_as_bucket_view=True,
+        ):
+            # Need to seed to ensure inputs are unique across rank. Otherwise,
+            # allreduce won't have any effect.
+            torch.manual_seed(self.rank)
+            torch.cuda.manual_seed(self.rank)
+            torch.cuda.set_device(self.rank)
+
+            # Test a simple linear as well as a ResNet model.
+            models_to_test = [
+                nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda()
+            ]
+            if HAS_TORCHVISION:
+                models_to_test.append(torchvision.models.resnet50().cuda())
+
+            for j, model in enumerate(models_to_test):
+                model_optim_in_bwd = copy.deepcopy(model)
+                model = nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    gradient_as_bucket_view=gradient_as_bucket_view,
+                )
+                optim = optim_cls(model.parameters(), **optim_kwargs)
+                if init_before:
+                    _apply_optimizer_in_backward(
+                        optimizer_class=optim_cls,
+                        params=model_optim_in_bwd.parameters(),
+                        optimizer_kwargs=optim_kwargs,
+                    )
+                model_optim_in_bwd = nn.parallel.DistributedDataParallel(
+                    model_optim_in_bwd,
+                    device_ids=[self.rank],
+                    gradient_as_bucket_view=gradient_as_bucket_view,
+                )
+                if not init_before:
+                    _apply_optimizer_in_backward(
+                        optimizer_class=optim_cls,
+                        params=model_optim_in_bwd.parameters(),
+                        optimizer_kwargs=optim_kwargs,
+                    )
+
+                for p1, p2 in zip(model.parameters(), model_optim_in_bwd.parameters()):
+                    self.assertEqual(p1, p2, "Parameters not initially equal!")
+                # Enable determinism in cudnn operators
+                with torch.backends.cudnn.flags(
+                    enabled=True, deterministic=True, benchmark=False
+                ):
+                    for i in range(8):
+                        inp = (
+                            torch.randn(1, 3, 1000, 1000, device="cuda")
+                            if j == 1
+                            else torch.randn(10, 3, device="cuda")
+                        )
+                        model(inp).sum().backward()
+                        optim.step()
+                        model_optim_in_bwd(
+                            inp
+                        ).sum().backward()  # runs optimizer as well
+                        for p1, p2 in zip(
+                            model.parameters(), model_optim_in_bwd.parameters()
+                        ):
+                            self.assertEqual(
+                                p1, p2, f"Params not equal at iteration {i}"
+                            )
+                            self.assertTrue(
+                                p2.grad is None,
+                                f"Optim in backward grad is not None at {i}",
+                            )
+
+                        # set_to_none for regular optimizer to match in backward
+                        # case.
+                        optim.zero_grad(set_to_none=True)
+
+        @skipIfRocm
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward(self):
+            for optim_cls, init_before in itertools.product(
+                [torch.optim.SGD, torch.optim.Adam], [True, False]
+            ):
+                with self.subTest(optim_cls=optim_cls):
+                    self._test_ddp_apply_optim_in_backward(
+                        optim_cls=optim_cls,
+                        optim_kwargs={"lr": 0.03},
+                        init_before=init_before,
+                    )
+
+        @skipIfRocm
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self):
+            for init_before in [True, False]:
+                self._test_ddp_apply_optim_in_backward(
+                    optim_cls=torch.optim.SGD,
+                    optim_kwargs={"lr": 0.03},
+                    init_before=init_before,
+                    gradient_as_bucket_view=False,
+                )
+
+        @skipIfRocm
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_apply_optim_in_backward_ignored_params(self):
+            torch.cuda.set_device(self.rank)
+            for init_before in [True, False]:
+                with self.subTest(init_before=init_before):
+                    torch.manual_seed(self.rank)
+                    torch.cuda.manual_seed(self.rank)
+                    model = TwoLinLayerNet()
+                    # Parameters to ignore are in the format {module_name}.{param_name}
+                    params_to_ignore = ["a.weight"]
+                    torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                        model, params_to_ignore
+                    )
+                    if init_before:
+                        _apply_optimizer_in_backward(
+                            optimizer_class=torch.optim.SGD,
+                            params=model.parameters(),
+                            optimizer_kwargs={"lr": 0.03},
+                        )
+                    net = torch.nn.parallel.DistributedDataParallel(
+                        model.cuda(self.rank),
+                        device_ids=[self.rank],
+                    )
+                    if not init_before:
+                        _apply_optimizer_in_backward(
+                            optimizer_class=torch.optim.SGD,
+                            params=model.parameters(),
+                            optimizer_kwargs={"lr": 0.03},
+                        )
+                    inp = torch.randn(1, 10)
+                    a, b = net(inp)
+                    (a.transpose(0, 1) @ b).sum().backward()
+                    # a.weight did not go through allreduce, so optimizer acted on local
+                    # gradient, which should be different across ranks. Remaining params
+                    # should be equal.
+                    models = [None for _ in range(dist.get_world_size())]
+                    dist.all_gather_object(models, model)
+                    rank0_model, remainder = models[0], models[1:]
+                    for m in remainder:
+                        self.assertNotEqual(rank0_model.a.weight, m.a.weight)
+                        self.assertEqual(
+                            list(rank0_model.b.parameters()), list(m.b.parameters())
+                        )
+                        self.assertEqual(rank0_model.a.bias, m.a.bias)
+
+        def _get_fp16_config(self) -> _MixedPrecision:
+            return _MixedPrecision(
+                param_dtype=torch.float16,
+                reduce_dtype=torch.float16,
+                buffer_dtype=torch.float16,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_ignored_params(self):
+            rank = self.rank
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            torch.cuda.set_device(rank)
+            model = TwoLinLayerNet()
+            model.register_buffer("buffer", torch.ones(5))
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            to_ignore = ["a.weight", "buffer"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model,
+                to_ignore,
+            )
+            mp_config = self._get_fp16_config()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.to(rank),
+                device_ids=[rank],
+                mixed_precision=mp_config,
+                gradient_as_bucket_view=True,
+            )
+            to_ignore = [f"module.{name}" for name in to_ignore]
+            expected_ignored = len(to_ignore)
+            n_ignored = 0
+            # ignored params should not have _mp_param or _fp_param fields.
+            for n, p in itertools.chain(net.named_parameters(), net.named_buffers()):
+                if n in to_ignore:
+                    n_ignored += 1
+                    self.assertFalse(hasattr(p, "_mp_param"))
+                    self.assertFalse(hasattr(p, "_fp_param"))
+                else:
+                    self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
+                    self.assertEqual(torch.float32, p._fp_param.dtype)
+
+            self.assertEqual(expected_ignored, n_ignored)
+
+        def _test_ddp_native_mixed_precision(
+            self, gradient_as_bucket_view, set_grad_to_none
+        ):
+            rank = self.rank
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+            torch.cuda.set_device(rank)
+            inp = torch.randn(10, 1)
+            mp_config = self._get_fp16_config()
+
+            class MyModel(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.m = torch.nn.Linear(1, 5)
+                    self.register_buffer("buffer", torch.randn(1, 2))
+                    self.p = torch.nn.Parameter(torch.randn(10, 5), requires_grad=False)
+
+                def forward(self_, x):  # noqa: B902
+                    params = self_.m.parameters()
+                    for p in params:
+                        self.assertEqual(mp_config.param_dtype, p.dtype)
+
+                    self.assertEqual(self_.buffer.dtype, mp_config.buffer_dtype)
+
+                    self.assertEqual(mp_config.param_dtype, x.dtype)
+                    return self_.m(x) + self_.p
+
+            m = MyModel()
+
+            net = torch.nn.parallel.DistributedDataParallel(
+                m.to(rank),
+                device_ids=[rank],
+                mixed_precision=mp_config,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+            # Buffers are casted in constructor.
+            self.assertEqual(net.module.buffer.dtype, mp_config.buffer_dtype)
+            # Each param should have an mp_param in the lower precision, and
+            # an fp_param in the higher precision.
+            for p in net.parameters():
+                self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
+                self.assertEqual(torch.float32, p._fp_param.dtype)
+
+            for _ in range(6):
+                loss = net(inp).sum()
+                loss.backward()
+                # Verify gradient synchronization and params and grads are fp32.
+                for n, param in net.named_parameters():
+                    self.assertEqual(param.dtype, torch.float32)
+                    if param.grad is None:
+                        assert n == "module.p"  # Only param that doesn't require grad
+                    else:
+                        self.assertEqual(param.grad.dtype, torch.float32)
+                        tensor_list = [
+                            torch.zeros_like(param.grad)
+                            for _ in range(dist.get_world_size(net.process_group))
+                        ]
+                        dist.all_gather(tensor_list, param.grad)
+                        g, rest = tensor_list[0], tensor_list[1:]
+                        self.assertEqual(g.dtype, torch.float32)
+                        for g_ in rest:
+                            self.assertEqual(g_.dtype, torch.float32)
+                            self.assertEqual(g, g_)
+                net.zero_grad(set_to_none=set_grad_to_none)
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(
+            self,
+        ):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=False,
+                set_grad_to_none=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_grad_as_bucket_view_no_set_grad_none(self):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True,
+                set_grad_to_none=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_grad_as_bucket_view_set_grad_to_none(self):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True, set_grad_to_none=True
+            )
+
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(
+            self,
+        ):
+            self._test_ddp_native_mixed_precision(
+                gradient_as_bucket_view=True, set_grad_to_none=True
+            )
+
+        def _test_ddp_hook_parity(self, state, hook, num_validated_iters=100):
+            rank = self.rank
+            m = torch.nn.Linear(1, 5)
+            try:
+                process_group = state.process_group
+            except AttributeError:
+                process_group = state
+
+            net_with_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(m).to(rank),
+                device_ids=[rank],
+                process_group=process_group,
+            )
+            net_with_hook.register_comm_hook(state=state, hook=hook)
+            net_without_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(m).to(rank),
+                device_ids=[rank],
+                process_group=process_group,
+            )
+            for i in range(100):
+                # Clear gradients manually.
+                for g in [
+                    net_without_hook.module.weight.grad,
+                    net_with_hook.module.weight.grad,
+                ]:
+                    if g is not None:
+                        g.requires_grad_(False)
+                        g.zero_()
+                # Forward + BW
+                batch = torch.tensor([rank]).float().cuda(rank)
+                loss = net_without_hook(batch).sum()
+                loss.backward()
+                # For each worker, the gradient on the weight should be worker_rank.
+                grad = net_without_hook.module.weight.grad
+                avg = grad.clone()
+                expected_grad = (
+                    sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
+                )
+                loss_hook = net_with_hook(batch).sum()
+                loss_hook.backward()
+                grad_hook = net_with_hook.module.weight.grad
+                avg_hook = grad_hook.clone()
+
+                if i < num_validated_iters:
+                    # Verify hook grad with expected.
+                    self.assertEqual(
+                        avg_hook[0, 0].item(),
+                        expected_grad,
+                        msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}",
+                    )
+                    # Verify hook grad with vanilla allreduce
+                    self.assertEqual(
+                        avg_hook[0, 0],
+                        avg[0, 0],
+                        msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}",
+                    )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_allreduce(self):
+            self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_allreduce_process_group(self):
+            # process_group is passed in to both DDP and comm. hook
+            world_size = dist.get_world_size()
+            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
+            gpus = [rank_to_GPU[int(r)][0] for r in range(world_size)]
+            process_group = torch.distributed.new_group(gpus)
+            self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_powerSGD(self):
+            for warm_start in [True, False]:
+                powersgd_state = powerSGD.PowerSGDState(
+                    process_group=None,
+                    matrix_approximation_rank=1,
+                    start_powerSGD_iter=2,
+                    warm_start=warm_start,
+                )
+                self._test_ddp_hook_parity(
+                    state=powersgd_state, hook=powerSGD.powerSGD_hook
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_ddp_hook_parity_post_localSGD(self):
+            # Although we start run local SGD at iteration 10, since we still use the global process group to run it,
+            # the post-LocalSGD actually still allreduces gradients globally for the remaining iterations.
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None, subgroup=dist.group.WORLD, start_localSGD_iter=10
+            )
+            self._test_ddp_hook_parity(
+                state=state, hook=post_localSGD.post_localSGD_hook
+            )
+            # Only validate the warmup iterations before local SGD is applied,
+            # because when `post_local_gradient_allreduce` is disabled, the gradients will not be synchronized at all.
+            # Note that in practice a model averager has to be applied to run model averaging,
+            # so local gradient averaging is not necessary.
+            start_localSGD_iter = 10
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None,
+                subgroup=dist.group.WORLD,
+                start_localSGD_iter=start_localSGD_iter,
+                post_local_gradient_allreduce=False,
+            )
+            self._test_ddp_hook_parity(
+                state=state,
+                hook=post_localSGD.post_localSGD_hook,
+                num_validated_iters=start_localSGD_iter,
+            )
+
+            # When `subgroup` is None, it is equivalent to the subgroup on the each node.
+            # For this single-node test environment, the intra-node process group is equivalent to
+            # the global process group.
+            if self.world_size == dist.get_world_size():
+                state = post_localSGD.PostLocalSGDState(
+                    process_group=None, subgroup=None, start_localSGD_iter=10
+                )
+                self._test_ddp_hook_parity(
+                    state=state, hook=post_localSGD.post_localSGD_hook
+                )
+
+            # Since we start local SGD later than the total number of 100 iterations,
+            # no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
+            state = post_localSGD.PostLocalSGDState(
+                process_group=None, subgroup=None, start_localSGD_iter=1000
+            )
+            self._test_ddp_hook_parity(
+                state=state, hook=post_localSGD.post_localSGD_hook
+            )
+
+        def _prepare_single_device_module(
+            self,
+            rank,
+            process_group,
+            devices,
+            device_ids,
+            global_batch_size,
+            gradient_as_bucket_view=False,
+        ):
+            model = Net()
+            device = devices[0] if devices else torch.device(f"cuda:{rank:d}")
+            ddp_model = DistributedDataParallel(
+                copy.deepcopy(model).to(device),
+                device_ids=device_ids,
+                process_group=process_group,
+                bucket_cap_mb=0.001,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+
+            model.to(device)
+
+            input = torch.randn(global_batch_size, 2).to(device)
+            target = torch.randn(global_batch_size, 4).to(device)
+
+            return model, ddp_model, input, target
+
+        def _prepare_cpu_module(
+            self,
+            process_group,
+            global_batch_size,
+            gradient_as_bucket_view=False,
+        ):
+            model = Net()
+            ddp_model = DistributedDataParallel(
+                copy.deepcopy(model),
+                process_group=process_group,
+                bucket_cap_mb=0.001,
+                gradient_as_bucket_view=gradient_as_bucket_view,
+            )
+            input = torch.randn(global_batch_size, 2)
+            target = torch.randn(global_batch_size, 4)
+            return model, ddp_model, input, target
+
+        def _test_accumulate_gradients_no_sync(
+            self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False
+        ):
+            """
+            This is the recommended way to implement accumulate grads.
+            If ``ddp_comm_hook`` input was specified, it will also register that hook
+            to the ``ddp_model``. The hook fed into this function should not change
+            the resulting gradients.
+            """
+            _group, group_id, rank = self._init_global_test()
+            world_size = get_world_size()
+
+            # FIXME: Add testing for gloo/CUDA
+            if BACKEND == "mpi" or BACKEND == "gloo":
+                global_batch_size = world_size
+                local_batch_size = 1
+                model, ddp_model, input, target = self._prepare_cpu_module(
+                    group_id, global_batch_size, gradient_as_bucket_view
+                )
+
+            if BACKEND == "nccl":
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                int_devices = rank_to_GPU[rank][:1]
+                devices = [torch.device("cuda:" + str(i)) for i in int_devices]
+                global_batch_size = world_size
+                local_batch_size = len(devices)
+                model, ddp_model, input, target = self._prepare_single_device_module(
+                    rank,
+                    group_id,
+                    devices,
+                    devices,
+                    global_batch_size,
+                    gradient_as_bucket_view,
+                )
+
+            if ddp_comm_hook is not None:
+                ddp_model.register_comm_hook(group_id, ddp_comm_hook)
+
+            def step_model(model, input, target):
+                model.train()
+                output = model(input)
+                loss = F.mse_loss(output, target.to(output.device))
+                loss.backward()
+
+            # ensure accumulate grads works with no_grad => no grads are accumulated.
+            with torch.no_grad():
+                with ddp_model.no_sync():
+                    ddp_model.train()
+                    ddp_model(input)
+
+            # check two model parameters over num_iters iterations
+            for iteration in range(num_iters):
+                step_model(model, input, target)
+
+                ddp_input = input[
+                    rank * local_batch_size : (rank + 1) * local_batch_size
+                ]
+                ddp_target = target[
+                    rank * local_batch_size : (rank + 1) * local_batch_size
+                ]
+
+                if iteration % 2 == 0:
+                    # accumulate grads locally
+                    with ddp_model.no_sync():
+                        step_model(ddp_model, ddp_input, ddp_target)
+                else:
+                    # sync grads
+                    step_model(ddp_model, ddp_input, ddp_target)
+
+                for i, j in zip(model.parameters(), ddp_model.parameters()):
+                    if not i.requires_grad:
+                        continue
+                    if iteration % 2 == 0:
+                        self.assertNotEqual(i.grad, j.grad)
+                    else:
+                        self.assertEqual(i.grad, j.grad)
+
+                # Shuffle the input so that DDP input is different
+                torch.manual_seed(1337 + iteration)
+                input = input[torch.randperm(global_batch_size)]
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync(self):
+            """
+            Runs _test_accumulate_gradients_no_sync using default inputs
+            """
+            self._test_accumulate_gradients_no_sync()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_grad_is_view(self):
+            """
+            Runs _test_accumulate_gradients_no_sync using default inputs
+            """
+            self._test_accumulate_gradients_no_sync(gradient_as_bucket_view=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_allreduce_hook(self):
+            """
+            Runs multiple iterations on _test_accumulate_gradients_no_sync
+            using allreduce hook and validates whether future result was properly
+            passed as gradients in reducer.
+            """
+
+            world_size = get_world_size()
+
+            def allreduce_hook(
+                group_id: object, bucket: dist.GradBucket
+            ) -> torch.futures.Future[torch.Tensor]:
+                tensors = [bucket.buffer() / world_size]
+                return (
+                    group_id.allreduce(tensors)
+                    .get_future()
+                    .then(lambda fut: fut.value()[0])
+                )
+
+            self._test_accumulate_gradients_no_sync(
+                num_iters=4, ddp_comm_hook=allreduce_hook
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
+            """
+            Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
+            hook that also uses then callbacks. In first then callback result is multiplied
+            by 2, and the second callback divides the result by 2 * world_size. It validates
+            whether final result was properly passed as gradients in reducer.
+            """
+
+            world_size = get_world_size()
+
+            def allreduce_with_then_hook(
+                group_id: object, bucket: dist.GradBucket
+            ) -> torch.futures.Future[torch.Tensor]:
+                fut = group_id.allreduce([bucket.buffer()]).get_future()
+
+                def mult(fut):
+                    # Multiply the result by 2.
+                    return 2 * fut.wait()[0]
+
+                def div(fut):
+                    # Divide the result by 2 * world_size.
+                    return fut.wait() / (2 * world_size)
+
+                return fut.then(mult).then(div)
+
+            self._test_accumulate_gradients_no_sync(
+                num_iters=4, ddp_comm_hook=allreduce_with_then_hook
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
+            "get_future is only supported on mpi, nccl and gloo",
+        )
+        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
+        def test_get_future(self):
+            def mult(fut):
+                return [t * 3 for t in fut.wait()]
+
+            def add(fut):
+                return [t + 1 for t in fut.wait()]
+
+            group, group_id, rank = self._init_global_test()
+            input = _build_tensor(3, 2)
+            if BACKEND == "nccl":
+                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+                device_id = rank_to_GPU[rank][0]
+                input = input.to(device_id)
+            fut = group_id.allreduce([input]).get_future()
+            res = fut.then(mult).then(add).wait()
+            expected = _build_tensor(3, 2 * len(group) * 3 + 1)
+
+            self.assertEqual(res[0], expected)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel(self):
+            _group, _group_id, rank = self._init_global_test()
+            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
+            gpus = list(rank_to_GPU[rank])
+
+            for use_bucket_view, static_graph in itertools.product(
+                (False, True), (False, True)
+            ):
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+                # test set static graph twice
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                    set_static_graph_twice=True,
+                )
+
+                # test output_device
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus,
+                    rank=rank,
+                    output_device=torch.device("cuda"),
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+                # test device_ids
+                gpus_list = [torch.device("cuda:" + str(i)) for i in gpus]
+                self._test_DistributedDataParallel(
+                    gpu_subset=gpus_list,
+                    rank=rank,
+                    output_device=torch.device("cuda"),
+                    gradient_as_bucket_view=use_bucket_view,
+                    static_graph=static_graph,
+                )
+
+        def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
+            torch.manual_seed(31415)
+            # Creates model and optimizer in default precision
+            model = copy.deepcopy(DDP_NET).cuda()
+            optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
+
+            # Creates a GradScaler once at the beginning of training.
+            scaler = GradScaler()
+
+            ddp_model = nn.parallel.DistributedDataParallel(
+                model, device_ids=[self.rank], gradient_as_bucket_view=grad_is_view
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            # verify grads are none before training
+            for p in ddp_model.parameters():
+                self.assertTrue(p is not None)
+                self.assertTrue(p.grad is None)
+
+            for idx in range(20):
+                optimizer.zero_grad()
+                # Runs the forward pass with autocasting.
+                with autocast():
+                    output = ddp_model(input)
+                    loss = loss_fn(output, target)
+
+                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
+                # Backward passes under autocast are not recommended.
+                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
+                scaler.scale(loss).backward()
+
+                # verify grads are not none and are valid during training
+                for p in ddp_model.parameters():
+                    if p.requires_grad:
+                        self.assertTrue(p.grad is not None)
+                        self.assertFalse(p.grad.isnan().any())
+                        self.assertFalse(p.grad.isinf().any())
+
+                # scaler.step() first unscales the gradients of the optimizer's assigned params.
+                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
+                # otherwise, optimizer.step() is skipped.
+                scaler.step(optimizer)
+
+                # Updates the scale for next iteration.
+                scaler.update()
+
+                # Shuffle the input so that DDP input is different
+                torch.manual_seed(1337 + idx)
+                input = input[torch.randperm(dist.get_world_size() * 2)]
+
+            return ddp_model
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_with_amp_and_grad_is_view(self):
+            torch.cuda.set_device(self.rank)
+            ddp_model_grad_not_view = self._test_DistributedDataParallel_with_amp(
+                grad_is_view=False
+            )
+            ddp_model_grad_is_view = self._test_DistributedDataParallel_with_amp(
+                grad_is_view=True
+            )
+            for i, j in zip(
+                ddp_model_grad_not_view.parameters(),
+                ddp_model_grad_is_view.parameters(),
+            ):
+                self.assertEqual(i, j)
+
+        def _test_DistributedDataParallel_SyncBatchNorm(
+            self,
+            gpu_subset,
+            rank,
+            local_bs,
+            global_bs,
+            offset,
+            output_device=None,
+            affine=True,
+        ):
+            # Run a simple end to end DDP model, use result of single node model
+            # as baseline
+
+            # cpu training setup
+            model = BN_NET if affine else BN_NET_NO_AFFINE
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpu_subset[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpu_subset[0])
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_DDP, device_ids=gpu_subset
+            )
+
+            # test serializable/unserializable
+            with tempfile.NamedTemporaryFile() as tmp:
+                if sys.platform == "win32":
+                    torch.save(model_DDP, tmp)
+                    tmp.seek(0)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp, weights_only=False)
+                else:
+                    torch.save(model_DDP, tmp.name)
+                    # weights_only=False as this is legacy code that saves the model
+                    model_DDP = torch.load(tmp.name, weights_only=False)
+
+            # data initialization
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 4)
+            loss = nn.MSELoss()
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_cpu.cuda(gpu_subset[0]),
+                target.cuda(gpu_subset[0]),
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+                offset,
+                dist.get_world_size(),
+                5 if affine else 2,
+            )
+            self._barrier()
+
+        def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
+            learning_rate = 0.03
+
+            net = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(DDP_NET).cuda(),
+                device_ids=[self.rank],
+                gradient_as_bucket_view=grad_is_view,
+            )
+            averager = create_averager()
+            opt = torch.optim.SGD(net.parameters(), lr=learning_rate)
+
+            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(DDP_NET).cuda(),
+                device_ids=[self.rank],
+                gradient_as_bucket_view=grad_is_view,
+            )
+            # Process group cannot be pickled in some environments,
+            # so cannot deep copy an averager. See:
+            # https://github.com/pytorch/pytorch/pull/74737#pullrequestreview-922487496
+            averager2 = create_averager()
+            post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager2
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            for _ in range(20):
+                self._perform_a_train_step(opt, net, loss_fn, input, target)
+                averager.average_parameters(net.parameters())
+
+                self._perform_a_train_step(
+                    post_localSGD_opt,
+                    net_using_post_localSGD_opt,
+                    loss_fn,
+                    input,
+                    target,
+                )
+                for p1, p2 in zip(
+                    net.parameters(), net_using_post_localSGD_opt.parameters()
+                ):
+                    self.assertEqual(p1.data, p2.data)
+
+            # Also check if the built-in step counters are the same to prevent a bug like #74737.
+            self.assertEqual(averager.step, averager2.step)
+
+        def _create_periodic_model_averager(self):
+            return averagers.PeriodicModelAverager(period=4, warmup_steps=10)
+
+        def _create_post_localSGD_optimizer(self, net, learning_rate, averager):
+            return post_localSGD_optimizer.PostLocalSGDOptimizer(
+                optim=torch.optim.SGD(net.parameters(), lr=learning_rate),
+                averager=averager,
+            )
+
+        def _perform_a_train_step(self, optimizer, net, loss_fn, input, target):
+            optimizer.zero_grad()
+            output = net(input)
+            loss = loss_fn(output, target)
+            loss.backward()
+            optimizer.step()
+
+        def _test_post_localSGD_optimizer_step_reload(
+            self, create_averager, chkpt_file
+        ):
+            learning_rate = 0.03
+
+            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank]
+            )
+
+            averager = create_averager()
+            post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager
+            )
+
+            averager2 = create_averager()
+            dummy_post_localSGD_opt = self._create_post_localSGD_optimizer(
+                net_using_post_localSGD_opt, learning_rate, averager2
+            )
+
+            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
+            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
+            loss_fn = nn.MSELoss()
+
+            for _ in range(20):
+                self._perform_a_train_step(
+                    post_localSGD_opt,
+                    net_using_post_localSGD_opt,
+                    loss_fn,
+                    input,
+                    target,
+                )
+
+            if self.rank == 0:
+                torch.save(
+                    {"optimizer_state_dict": post_localSGD_opt.state_dict()}, chkpt_file
+                )
+
+            dist.barrier()
+            map_location = {"cuda:0": f"cuda:{self.rank:d}"}
+            checkpoint = torch.load(chkpt_file, map_location=map_location)
+            dummy_post_localSGD_opt.load_state_dict(checkpoint["optimizer_state_dict"])
+
+            # Check that we didn't hit the trivial case
+            self.assertNotEqual(averager2.step, 0)
+            # Check if dummy averager was initialized to a correct value
+            self.assertEqual(averager.step, averager2.step)
+
+            # Remove 'step' entry from a checkpoint.
+            # And make sure it is not in the state dictionary
+            del checkpoint["optimizer_state_dict"]["step"]
+            self.assertNotIn("step", checkpoint["optimizer_state_dict"])
+
+            # Check if checkpoint without a 'step' entry invokes a warning
+            with self.assertWarnsRegex(
+                expected_warning=UserWarning,
+                expected_regex="Loaded state dict does not contain a step counter for an averager. "
+                "Setting step counter to 0.",
+            ):
+                dummy_post_localSGD_opt.load_state_dict(
+                    checkpoint["optimizer_state_dict"]
+                )
+
+            self.assertEqual(averager2.step, 0)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_periodic_model_averager,
+                grad_is_view=False,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_grad_is_view(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_periodic_model_averager,
+                grad_is_view=True,
+            )
+
+        def _create_hierarchical_model_averager(self):
+            period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())])
+            return hierarchicalSGD.HierarchicalModelAverager(
+                period_group_size_dict=period_group_size_dict, warmup_steps=4
+            )
+
+        @skip_if_lt_x_gpu(4)
+        @skip_if_odd_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_hierarchical_model_averager,
+                grad_is_view=False,
+            )
+
+        @skip_if_lt_x_gpu(4)
+        @skip_if_odd_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view(
+            self,
+        ):
+            torch.cuda.set_device(self.rank)
+            self._test_post_localSGD_optimizer_parity(
+                self._create_hierarchical_model_averager,
+                grad_is_view=True,
+            )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_post_localSGD_optimizer_step_reload(self):
+            torch.cuda.set_device(self.rank)
+            with _rank_temp_file() as tmp_file:
+                self._test_post_localSGD_optimizer_step_reload(
+                    self._create_periodic_model_averager, tmp_file
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self):
+            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+                torch.channels_last
+            )
+            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+                torch.channels_last_3d
+            )
+
+        def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
+            self, memory_format
+        ):
+            _group, _group_id, rank = self._init_global_test()
+            num_processes = dist.get_world_size()
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(num_processes * 2)
+
+            model = ONLY_SBN_NET
+            model_gpu = copy.deepcopy(model).cuda(rank)
+            model_DDP = nn.parallel.DistributedDataParallel(
+                model_gpu, device_ids=[rank]
+            )
+
+            shapes = [global_bs, 2, 4, 4] + (
+                [] if memory_format is torch.channels_last else [4]
+            )
+
+            input_gpu = (
+                torch.randn(*shapes, dtype=torch.float)
+                .cuda(rank)
+                .to(memory_format=memory_format)
+            )
+            target_gpu = (
+                torch.randn(*shapes, dtype=torch.float)
+                .cuda(rank)
+                .to(memory_format=memory_format)
+            )
+            loss = nn.MSELoss()
+
+            # check two model parameters over 5 iterations
+            self._test_DDP_niter(
+                model_gpu,
+                model_DDP,
+                input_gpu,
+                target_gpu,
+                loss,
+                local_bs,
+                rank,
+                global_bs,
+                True,
+                bs_offset,
+                dist.get_world_size(),
+                memory_format=memory_format,
+            )
+            self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(world_size * 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+            )
+
+            # test output_device
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                output_device=torch.device("cuda"),
+            )
+
+            # test device_ids
+            gpus = [torch.device("cuda:" + str(i)) for i in gpus]
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                output_device=torch.device("cuda"),
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            local_bs = 2
+            bs_offset = int(rank * 2)
+            global_bs = int(world_size * 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+                affine=False,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self):
+            _group, _group_id, rank = self._init_global_test()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            model = nn.BatchNorm1d(2)
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpus[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpus[0])
+            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
+
+            local_bs = len(gpus) * 2
+            global_bs = dist.get_world_size() * local_bs
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 2)
+            loss = nn.MSELoss()
+
+            # disabling cudnn.
+            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
+            # numerical issue created by the divergent code path.
+            with torch.backends.cudnn.flags(False):
+                # check two model parameters over 5 iterations
+                self._test_DDP_niter(
+                    model_gpu,
+                    model_DDP,
+                    input_cpu.cuda(gpus[0]),
+                    target.cuda(gpus[0]),
+                    loss,
+                    local_bs,
+                    rank,
+                    global_bs,
+                    True,
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        @require_world_size(2)
+        def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self):
+            _group, _group_id, rank = self._init_global_test()
+            # DDP does not support replicating BN layers within a process, hence
+            # testing with one module replica per process
+            gpus = [rank]
+
+            model = nn.BatchNorm1d(2)
+
+            # single gpu training setup
+            model_gpu = copy.deepcopy(model)
+            model_gpu.cuda(gpus[0])
+
+            # DDP training setup
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
+            model_DDP.cuda(gpus[0])
+            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
+
+            local_bs = 1
+            global_bs = dist.get_world_size()
+            input_cpu = torch.randn(global_bs, 2)
+            target = torch.randn(global_bs, 2)
+            loss = nn.MSELoss()
+
+            # disabling cudnn.
+            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
+            # numerical issue created by the divergent code path.
+            with torch.backends.cudnn.flags(False):
+                # check two model parameters over 5 iterations
+                self._test_DDP_niter(
+                    model_gpu,
+                    model_DDP,
+                    input_cpu.cuda(gpus[0]),
+                    target.cuda(gpus[0]),
+                    loss,
+                    local_bs,
+                    rank,
+                    global_bs,
+                    True,
+                )
+                self._barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
+            self,
+        ):
+            _group, _group_id, rank = self._init_global_test()
+            model = nn.parallel.DistributedDataParallel(
+                ONLY_SBN_NET.cuda(rank), device_ids=[rank]
+            )
+
+            input_var = []
+            for i in range(dist.get_world_size()):
+                input_var_rank = torch.cat(
+                    [
+                        torch.ones(2, 1, 10 ** (i + 1)) * (0.1 ** (i - 1)),
+                        torch.ones(2, 1, 10 ** (i + 1)) * (0.3 ** (i - 1)),
+                    ],
+                    dim=1,
+                )
+                input_var.append(input_var_rank)
+
+            all_input_var = torch.cat(
+                [
+                    x.permute(1, 0, 2).contiguous().view(ONLY_SBN_NET.num_features, -1)
+                    for x in input_var
+                ],
+                dim=1,
+            ).cuda(rank)
+
+            for i in range(100):
+                y = model(input_var[rank].cuda(rank))
+                y.mean().backward()
+
+            running_mean, running_var = (
+                model.module.running_mean,
+                model.module.running_var,
+            )
+            torch.testing.assert_close(running_mean, all_input_var.mean(1))
+            torch.testing.assert_close(running_var, all_input_var.var(1))
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
+            _group, _group_id, rank = self._init_global_test()
+            # only do single GPU per process
+            gpus = [rank]
+
+            # cpu training setup
+            num_processes = dist.get_world_size()
+            local_bs = rank + 2
+            bs_offset = int((rank + 3) * rank / 2)
+            global_bs = int((num_processes + 3) * num_processes / 2)
+
+            self._test_DistributedDataParallel_SyncBatchNorm(
+                gpu_subset=gpus,
+                rank=rank,
+                local_bs=local_bs,
+                global_bs=global_bs,
+                offset=bs_offset,
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_DistributedDataParallel_SyncBatchNorm_half(self):
+            _group, _group_id, rank = self._init_global_test()
+
+            model = copy.deepcopy(BN_NET)
+            model = model.half()
+            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+            model = nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[rank]
+            )
+            inp = torch.randn(2, 2, dtype=torch.float16, device=torch.device(rank))
+            # Check that forward/backward do not error with dtype mismatch
+            out = model(inp)
+            self.assertEqual(out.dtype, torch.float16)
+            out.sum().backward()
+            for param in model.parameters():
+                self.assertEqual(param.grad.dtype, torch.float16)
+
+        def _test_ddp_logging_data(self, is_gpu):
+            rank = dist.get_rank()
+            model_DDP = copy.deepcopy(DDP_NET)
+            if is_gpu:
+                model_DDP = nn.parallel.DistributedDataParallel(
+                    model_DDP.cuda(rank), device_ids=[rank]
+                )
+            else:
+                model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
+
+            # dummy data initialization
+            local_bs = 2
+            batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
+            if is_gpu:
+                input = input.cuda(rank)
+                target = target.cuda(rank)
+
+            model_DDP._set_ddp_runtime_logging_sample_rate(2)
+
+            for idx in range(20):
+                offset = rank * local_bs
+
+                # DDP training, DDP scatters subsets of input to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    1,
+                )
+
+                self._model_step_with_zero_grad(model_DDP)
+
+                # Verify DDP logging data is sampled as expected
+                # If it has ran more than 10 iterations and this is
+                # the sampled iteration for measuring run time stats,
+                # the run time stats for this idx-th iteration will not
+                # be zeros.
+                ddp_logging_data = model_DDP._get_ddp_logging_data()
+                if idx > 0 and (idx < 10 or idx % 2 == 0):
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("forward_compute_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_compute_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_comm_time"), 1
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_compute_time"),
+                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
+                    )
+                    self.assertGreaterEqual(
+                        ddp_logging_data.get("backward_comm_time"),
+                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
+                    )
+                    self.assertEqual(ddp_logging_data.get("iteration"), idx)
+                elif idx > 0:
+                    # if the idx-th iteration is not sampled to set runtime stats,
+                    # ddp_logging_data.iteration will not be updated to current
+                    # iteration.
+                    self.assertNotEqual(ddp_logging_data.get("iteration"), idx)
+
+                # Shuffle the input so that DDP input is different
+                input = input[torch.randperm(batch_size)]
+
+            return model_DDP
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_ddp_logging_data_cpu(self):
+            def parse_env(var):
+                return os.environ[var] if var in os.environ else "N/A"
+
+            dist.set_debug_level(dist.DebugLevel.INFO)
+            _, group_id, _ = self._init_global_test()
+            model_DDP = self._test_ddp_logging_data(is_gpu=False)
+
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            self.assertEqual(ddp_logging_data.get("world_size"), dist.get_world_size())
+            self.assertEqual(ddp_logging_data.get("rank"), dist.get_rank())
+            self.assertEqual(ddp_logging_data.get("module_name"), "Net")
+            self.assertEqual(ddp_logging_data.get("device_ids"), "")
+            # output_device is -1 in default if it is not set, e.g.
+            # output_device of CPU training is -1.
+            self.assertEqual(ddp_logging_data.get("output_device"), -1)
+            self.assertEqual(ddp_logging_data.get("broadcast_buffers"), 1)
+            self.assertEqual(ddp_logging_data.get("bucket_cap_bytes"), 25 * 1024 * 1024)
+            self.assertEqual(ddp_logging_data.get("find_unused_parameters"), 0)
+            self.assertEqual(ddp_logging_data.get("gradient_as_bucket_view"), 0)
+            self.assertEqual(
+                ddp_logging_data.get("backend_name"), dist.get_backend(group_id)
+            )
+            self.assertEqual(ddp_logging_data.get("iteration"), 18)
+            params = list(model_DDP.parameters())
+            num_params = 0
+            param_size = 0
+            params = list(filter(lambda parameter: parameter.requires_grad, params))
+            for p in params:
+                num_params += 1
+                param_size += p.numel() * p.element_size()
+            self.assertEqual(ddp_logging_data.get("dtypes"), "float")
+            self.assertEqual(
+                ddp_logging_data.get("total_parameter_size_bytes"), param_size
+            )
+            self.assertEqual(ddp_logging_data.get("num_parameter_tensors"), num_params)
+            self.assertEqual(ddp_logging_data.get("bucket_sizes"), str(param_size))
+            self.assertEqual(
+                ddp_logging_data.get("master_port"), parse_env("MASTER_PORT")
+            )
+            self.assertEqual(
+                ddp_logging_data.get("master_addr"), parse_env("MASTER_ADDR")
+            )
+            self.assertEqual(
+                ddp_logging_data.get("torch_distributed_debug"),
+                parse_env("TORCH_DISTRIBUTED_DEBUG"),
+            )
+            self.assertEqual(
+                ddp_logging_data.get("cuda_visible_devices"),
+                parse_env("CUDA_VISIBLE_DEVICES"),
+            )
+            if ddp_logging_data.get("backend_name") == "gloo":
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_socket_ifname"),
+                    parse_env("GLOO_SOCKET_IFNAME"),
+                )
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_device_transport"),
+                    parse_env("GLOO_DEVICE_TRANSPORT"),
+                )
+                default_gloo_threads = 2
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_num_threads"),
+                    default_gloo_threads,
+                )
+
+            self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_debug"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_nthreads"), None)
+            self.assertEqual(ddp_logging_data.get("nccl_ib_timeout"), None)
+            # test runtime logging fields
+            # Note: DETAIL debug mode logs DDP logging data to stdout and
+            # thus accesses std::map, which fills in a default value for the
+            # type if it didn't exist.
+            self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
+            self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
+            self.assertEqual(
+                ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
+            )
+            grad_ready_order = ddp_logging_data.get(
+                "prev_iteration_grad_ready_order_indices"
+            )
+            expected_order = list(reversed([str(x) for x in range(3)]))
+            self.assertEqual(grad_ready_order, ", ".join(expected_order))
+            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
+            self.assertEqual(bucket_indices, " ".join(expected_order))
+            # It is hard to test accurate latency, but it can test whether the latency is
+            # a valid value and in the expected range.
+            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"), 1
+            )
+            self.assertGreaterEqual(ddp_logging_data.get("avg_backward_comm_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_comm_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            # Test host-side times are roughly in the order that we expect
+            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
+            bwd_comp_start_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_start"
+            )
+            bwd_comp_end_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_end"
+            )
+            bwd_comm_start_host_side_time = ddp_logging_data.get(
+                "backward_comm_time_start"
+            )
+            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
+            self.assertGreaterEqual(
+                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
+
+            # test larger net with mixed data types, verify multiple bucket sizes
+            model = LargeNet()
+            model.float()
+            model.fc1.double()
+            model_DDP = nn.parallel.DistributedDataParallel(model, bucket_cap_mb=1.5)
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            params = list(model_DDP.parameters())
+            self.assertEqual(
+                ddp_logging_data.get("bucket_cap_bytes"), int(1.5 * 1024 * 1024)
+            )
+            bucket_sizes = [
+                params[1].numel() * params[1].element_size(),
+                params[0].numel() * params[0].element_size(),
+            ]
+            self.assertEqual(
+                ddp_logging_data.get("bucket_sizes"),
+                ", ".join(str(x) for x in bucket_sizes),
+            )
+            self.assertEqual(ddp_logging_data.get("dtypes"), "double, float")
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_no_gpu
+        def test_ddp_logging_data_gpu(self):
+            _group, _group_id, rank = self._init_global_test()
+            model_DDP = self._test_ddp_logging_data(is_gpu=True)
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            self.assertEqual(ddp_logging_data.get("device_ids"), str(rank))
+            self.assertEqual(ddp_logging_data.get("output_device"), rank)
+            grad_ready_order = ddp_logging_data.get(
+                "prev_iteration_grad_ready_order_indices"
+            )
+            expected_order = list(reversed([str(x) for x in range(3)]))
+            self.assertEqual(grad_ready_order, ", ".join(expected_order))
+            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
+            self.assertEqual(bucket_indices, " ".join(expected_order))
+            # test runtime logging fields
+            # It is hard to test accurate latency, but it can test whether the latency is
+            # a valid value and in the expected range.
+            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), 1
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_compute_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            self.assertGreaterEqual(
+                ddp_logging_data.get("avg_backward_comm_time"),
+                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
+            )
+            # Test host-side times are roughly in the order that we expect
+            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
+            bwd_comp_start_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_start"
+            )
+            bwd_comp_end_host_side_time = ddp_logging_data.get(
+                "backward_compute_time_end"
+            )
+            bwd_comm_start_host_side_time = ddp_logging_data.get(
+                "backward_comm_time_start"
+            )
+            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
+            self.assertGreaterEqual(
+                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(
+                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
+            )
+            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "nccl", "nccl does not support DDP on CPU models"
+        )
+        def test_static_graph_api_cpu(self):
+            model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
+            expected_err = "should be called before training loop starts"
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                local_bs = 2
+                _batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
+                offset = dist.get_rank() * local_bs
+
+                # DDP training, DDP scatters subsets of input to nodes/GPUs
+                self._test_DDP_helper(
+                    model_DDP,
+                    input[offset : offset + local_bs],
+                    target[offset : offset + local_bs],
+                    loss,
+                    1,
+                )
+                model_DDP._set_static_graph()
+
+            # Verify error was logged in ddp_logging_data.
+            verify_ddp_error_logged(model_DDP, expected_err)
+
+        @skipIfNoTorchVision
+        def test_SyncBatchNorm_process_group(self):
+            # When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
+            # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
+            # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).
+
+            process_ids = 0
+            process_group = torch.distributed.new_group([process_ids])
+            res50_model = torchvision.models.resnet50()
+            res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(
+                copy.deepcopy(res50_model), process_group
+            )
+            process_group_sync = res50_model_sync.layer1[0].bn1.process_group
+            self.assertEqual(process_group_sync, process_group)
+
+        def _run_reduction_test(
+            self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None
+        ):
+            if reduction_fn != dist.all_reduce and dst is None:
+                raise ValueError(f"Reduction fn {reduction_fn} must specify dst!")
+            if dst is not None:
+                reduction_fn(tensor, dst, op)
+                # Only destination rank tensor is expected to have final result.
+                if dist.get_rank() == dst:
+                    self.assertEqual(tensor, expected_tensor)
+            else:
+                reduction_fn(tensor, op)
+                self.assertEqual(tensor, expected_tensor)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_allreduce(self):
+            torch.cuda.set_device(self.rank)
+            # Run all_reduce with PRODUCT
+            element = self.rank % 2 == 0
+            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
+                input_tensor = torch.tensor([element, element]).to(self.rank)
+                self._run_reduction_test(
+                    input_tensor, torch.tensor([False, False]).to(self.rank), op
+                )
+                # Ensure that all ranks contributing True (cast to 1) results in the
+                # correct reduction.
+                input_tensor = torch.tensor([True, True]).to(self.rank)
+                expected_tensor = input_tensor.clone()
+                self._run_reduction_test(input_tensor, expected_tensor, op)
+
+            # Run all_reduce with SUM
+            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
+                input_tensor = torch.tensor([element, element]).to(self.rank)
+                self._run_reduction_test(
+                    input_tensor, torch.tensor([True, True]).to(self.rank), op
+                )
+            # TODO: NCCL backend does not work correctly for bitwise reduction ops
+            # (see https://github.com/pytorch/pytorch/issues/41362). Add tests for
+            # these once it is supported.
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_allgather(self):
+            torch.cuda.set_device(self.rank)
+            inp = {0: [True, True], 1: [False, True]}
+            input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
+            # Preserve a copy of the tensor to compare against after allgather.
+            input_tensor_copy = input_tensor.clone()
+            tensor_list = [
+                torch.tensor([False, False]).to(self.rank)
+                for _ in range(dist.get_world_size())
+            ]
+            dist.all_gather(tensor_list, input_tensor)
+
+            self.assertEqual(len(tensor_list), dist.get_world_size())
+            for i, t in enumerate(tensor_list):
+                expected = torch.tensor(inp[i % 2]).to(self.rank)
+                self.assertEqual(t, expected)
+            # Ensure that the input tensor is not modified, since this collective
+            # does not modify its input.
+            self.assertEqual(input_tensor_copy, input_tensor)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_nccl_backend_bool_reduce(self):
+            torch.cuda.set_device(self.rank)
+            inp = {0: [True, True], 1: [False, False]}
+            # Run reduce() with product op
+            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
+                # make sure rank 0 gets False if WORLD_SIZE=1 to match expected tensor
+                input_tensor = torch.tensor(inp[(self.rank + 1) % 2]).to(self.rank)
+                expected = torch.tensor([False, False]).to(self.rank)
+                self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
+                # Ensure that all ranks contributing True (cast to 1) results in the
+                # correct reduction.
+                input_tensor = torch.tensor([True, True]).to(self.rank)
+                expected_tensor = input_tensor.clone()
+                self._run_reduction_test(
+                    input_tensor, expected_tensor, op, dist.reduce, dst=0
+                )
+
+            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
+                input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
+                expected = (
+                    torch.tensor([True, True]).to(self.rank)
+                    if self.rank == 0
+                    else input_tensor.clone()
+                )
+                self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
+
+        @require_backend_is_available({"nccl"})
+        @skip_if_lt_x_gpu(2)
+        def test_nccl_backend_bool_broadcast(self):
+            tensor_size = 10
+            bcast_tensor = torch.tensor(
+                [
+                    (random.random() < 0.5 if self.rank == 0 else False)
+                    for _ in range(tensor_size)
+                ]
+            ).to(self.rank)
+            dist.broadcast(bcast_tensor, src=0)
+            # Now allgather and ensure the tensors are equal.
+            tensor_list = [
+                torch.tensor([False for _ in range(tensor_size)]).to(self.rank)
+                for _ in range(dist.get_world_size())
+            ]
+            dist.all_gather(tensor_list, bcast_tensor)
+            expected = tensor_list[0]
+            for tensor in tensor_list[1:]:
+                self.assertEqual(tensor, expected)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_DistributedSampler_padding(self):
+            # Tests padding of distributed sampler.
+            world_size = dist.get_world_size()
+
+            # Simulates the 'casual' dataset size
+            dataset_size = 100 + world_size + 1
+            dataset = [torch.ones(1).to(self.rank) * i for i in range(dataset_size)]
+
+            # Simulates the 'tiny' dataset size
+            dataset_tiny_size = max(world_size // 2 - 1, 1)
+            dataset_tiny = [
+                torch.ones(1).to(self.rank) * i for i in range(dataset_tiny_size)
+            ]
+
+            # Specifying drop_last=True will cause the tail of the data to be dropped.
+            dist_sampler = DistributedSampler(dataset=dataset, drop_last=True)
+            local_num_samples, local_dataset_size = (
+                dist_sampler.num_samples,
+                dist_sampler.total_size,
+            )
+            # The effective dataset size should be the greatest integer that is <=
+            # dataset_size that is divisible by the world_size. This is to ensure each
+            # rank processes the same number of samples.
+            effective_dataset_size = (
+                math.ceil((dataset_size - world_size) / world_size)
+                if dataset_size % world_size != 0
+                else dataset_size / world_size
+            )
+            self.assertEqual(local_num_samples, effective_dataset_size)
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler))
+            self.assertEqual(len(indices_list), local_num_samples)
+
+            def validate_global_samples(local_num_samples):
+                # Ensure that each rank processes the same number of samples.
+                world_samples = [
+                    torch.LongTensor([0]).to(self.rank) for _ in range(world_size)
+                ]
+                dist.all_gather(
+                    world_samples, torch.tensor([local_num_samples]).to(self.rank)
+                )
+                world_samples = [sample.item() for sample in world_samples]
+                self.assertEqual(len(set(world_samples)), 1)
+
+            validate_global_samples(local_num_samples)
+
+            # drop_last=False is the default and will add additional indices to be sampled,
+            # increasing the effective dataset size.
+            dist_sampler_added_samples = DistributedSampler(dataset=dataset)
+            local_num_samples, local_dataset_size = (
+                dist_sampler_added_samples.num_samples,
+                dist_sampler_added_samples.total_size,
+            )
+            # The effective dataset size is the smallest integer that is >= dataset_size
+            # and divisible by the world size.
+            self.assertEqual(local_num_samples, math.ceil(dataset_size / world_size))
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler_added_samples))
+            self.assertEqual(len(indices_list), local_num_samples)
+
+            # Ensure that each rank processes the same number of samples.
+            validate_global_samples(local_num_samples)
+
+            # Ensure additional samples are padded even when
+            # the extremely small dataset is given.
+            dist_sampler_added_samples_tiny = DistributedSampler(dataset=dataset_tiny)
+            local_num_samples, local_dataset_size = (
+                dist_sampler_added_samples_tiny.num_samples,
+                dist_sampler_added_samples_tiny.total_size,
+            )
+            self.assertEqual(
+                local_num_samples, math.ceil(dataset_tiny_size / world_size)
+            )
+            self.assertEqual(local_dataset_size, local_num_samples * world_size)
+            indices_list = list(iter(dist_sampler_added_samples_tiny))
+            self.assertEqual(len(indices_list), local_num_samples)
+            validate_global_samples(local_num_samples)
+
+        def _test_allgather_object(self, subgroup=None):
+            # Only set device for NCCL backend since it must use GPUs.
+
+            gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
+
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                # Case where rank != GPU device.
+                next_rank = (self.rank + 1) % int(self.world_size)
+                torch.cuda.set_device(next_rank)
+
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=0)))
+
+            output_gathered = [None for _ in range(dist.get_world_size())]
+            dist.all_gather_object(
+                output_gathered,
+                gather_objects[self.rank % len(gather_objects)],
+                group=subgroup,
+            )
+
+            for i, val in enumerate(output_gathered):
+                expected = gather_objects[i % len(gather_objects)]
+                self.assertEqual(val, expected)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        def test_all_gather_object_default_pg(self):
+            return self._test_allgather_object()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        def test_all_gather_object_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_allgather_object(subgroup=subgroup)
+
+        def _test_gather_object(self, pg=None):
+            # Ensure stateful objects can be gathered
+            gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
+            my_rank = dist.get_rank(pg)
+
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                # Case where rank != GPU device.
+                next_rank = (self.rank + 1) % int(self.world_size)
+                torch.cuda.set_device(next_rank)
+
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=my_rank)))
+
+            output_gathered = [None for _ in range(dist.get_world_size(pg))]
+            gather_on_rank = 0
+            dist.gather_object(
+                gather_objects[self.rank % len(gather_objects)],
+                object_gather_list=output_gathered
+                if my_rank == gather_on_rank
+                else None,
+                dst=gather_on_rank,
+                group=pg,
+            )
+            if my_rank != gather_on_rank:
+                self.assertEqual(
+                    output_gathered, [None for _ in range(dist.get_world_size())]
+                )
+            else:
+                for i, val in enumerate(output_gathered):
+                    expected = gather_objects[i % len(gather_objects)]
+                    self.assertEqual(val, expected)
+
+            # Validate errors when objects can't be pickled.
+            class Bar:
+                pass
+
+            b = Bar()
+            gather_objects = [b for _ in range(dist.get_world_size())]
+            with self.assertRaises(AttributeError):
+                dist.all_gather_object(
+                    [None for _ in range(dist.get_world_size())],
+                    gather_objects[self.rank],
+                    group=pg,
+                )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        def test_gather_object(self):
+            return self._test_gather_object()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
+        )
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
+        def test_gather_object_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_gather_object(subgroup)
+
+        def validate_net_equivalence(self, net):
+            # Helper to validate synchronization of nets across ranks.
+            net_module_states = list(net.module.state_dict().values())
+            # Check that all tensors in module's state_dict() are equal.
+            for t in net_module_states:
+                tensor_list = [
+                    torch.zeros_like(t) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, t)
+                for tensor in tensor_list:
+                    self.assertEqual(tensor, t)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_sync_module_states(self):
+            # Test that after calling _sync_module_states, models across ranks
+            # are the same and are equal to the model on the input rank.
+            dim = 2
+            rank = self.rank
+            rank_to_broadcast = 1
+            # Seed to ensure that ranks are initialized with different initial models.
+            torch.manual_seed(rank)
+            model = nn.Linear(dim, dim, bias=False)
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
+            )
+            new_model = nn.Linear(dim, dim, bias=False).cuda(rank)
+            net.module = copy.deepcopy(new_model)
+            # Assert params are different
+            net_module_states = list(net.module.state_dict().values())
+            for t in net_module_states:
+                tensor_list = [
+                    torch.zeros_like(t) for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, t)
+                for i, tensor in enumerate(tensor_list):
+                    if i == rank:
+                        self.assertEqual(t, tensor)
+                    else:
+                        # tensor from another rank should be different.
+                        self.assertNotEqual(t, tensor)
+
+            _sync_module_states(
+                module=net.module,
+                process_group=net.process_group,
+                broadcast_bucket_size=net.broadcast_bucket_size,
+                src=rank_to_broadcast,
+                params_and_buffers_to_ignore=net.parameters_to_ignore,
+            )
+            # Now all model params should be the same.
+            self.validate_net_equivalence(net)
+            # Since the network params were broadcast from rank_to_broadcast, validate that
+            # they are the same as new_model on rank_to_broadcast.
+            if rank == rank_to_broadcast:
+                expected_states = new_model.state_dict().values()
+                for t, expected in zip(net_module_states, expected_states):
+                    self.assertEqual(t, expected)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_grad_div_uneven_inputs(self):
+            # Test gradient division during training with join() API. If
+            # divide_by_initial_world_size=False, we scale by the effective world
+            # size when allreducing grads.
+            dim = 5
+            batch = 1
+            grad_scale = 50
+            rank = self.rank
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.ones(batch, dim, device=self.rank) * grad_scale
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
+            )
+            n_iters = 3
+            if self.rank > 0:
+                n_iters += 2
+
+            with net.join(divide_by_initial_world_size=False):
+                for _ in range(n_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+                    # The grad is always expected_grad, since we divide by the number
+                    # of currently active processes and inactive processes contribute
+                    # zero gradient. If we kept dividing by static initial world
+                    # size as processes leave, the grad would be smaller.
+                    expected_grad = torch.ones(dim, dim, device=self.rank) * grad_scale
+                    param = next(iter(net.parameters()))
+                    self.assertEqual(expected_grad, param.grad)
+                    # Avoid accumulating grads so that it's the same every iteration
+                    net.zero_grad()
+                    torch.cuda.synchronize(device=self.rank)
+
+            # If divide_by_initial_world_size=True (default), we always scale grads
+            # by the initial world_size.
+            with net.join(divide_by_initial_world_size=True):
+                for i in range(n_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+                    effective_ws = dist.get_world_size()
+                    if i >= 3:
+                        effective_ws -= 1
+                    expected_grad = (
+                        torch.ones(dim, dim, device=self.rank)
+                        * grad_scale
+                        * effective_ws
+                    ) / dist.get_world_size()
+                    param = next(iter(net.parameters()))
+                    self.assertEqual(expected_grad, param.grad)
+                    # Avoid accumulating grad so that it's the same every iteration.
+                    net.zero_grad()
+                    torch.cuda.synchronize(device=self.rank)
+
+        def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None):
+            """Runs DDP based model training and captures profiles.
+            This test will do two profiler runs.
+            1. An initial basic run to check if profiler events are correctly captured.
+            2. A second profiling pass after running some iterations of DDP, to check robustness of thread local state.
+
+            args
+                profiler_ctx : Profiler context manager for pass 1
+                profiler_ctx2 : Profiler context manager for pass 2.
+                    This can be left out as None, in which case a deepcopy
+                    of profiler_ctx is used.
+            Returns:
+                prof: Instantiated profiler object that can be used for post analysis.
+            """
+            batch = 3
+            dim = 10
+            num_iters = 6
+            torch.cuda.set_device(self.rank)
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.rand(batch, dim, device=self.rank)
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            if profiler_ctx2 is None:
+                profiler_ctx2 = copy.deepcopy(profiler_ctx)
+
+            with profiler_ctx as prof:
+                for _ in range(num_iters):
+                    loss = net(inp).sum()
+                    loss.backward()
+
+            all_reduce_event_name = f"{dist.get_backend()}:all_reduce"
+            events = get_profiling_event(
+                all_reduce_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            event_count = sum(e.count for e in events)
+            self.assertEqual(event_count, num_iters)
+            for event in events:
+                self.assertTrue(event.is_async)
+                self.assertEqual(event.name, all_reduce_event_name)
+
+            broadcast_event_name = f"{dist.get_backend()}:broadcast"
+            broadcast_events = get_profiling_event(
+                broadcast_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            event_count = sum(e.count for e in broadcast_events)
+            # Broadcast is called during rebuild_buckets
+            self.assertGreaterEqual(event_count, 1)
+            for event in broadcast_events:
+                self.assertEqual(event.name, broadcast_event_name)
+
+            # Run DDP with profiling for a few iterations, then enable profiling
+            # for a single pass, and ensure it is recorded. This tests that the
+            # thread local state is correctly updated.
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            for _ in range(3):
+                loss = net(inp).sum()
+                loss.backward()
+            # Now enable the profiler.
+            with profiler_ctx2 as prof:
+                loss = net(inp).sum()
+                loss.backward()
+
+            events = get_profiling_event(
+                all_reduce_event_name, prof, dedup_gpu_user_annotation=True
+            )
+            self.assertGreaterEqual(len(events), 1)
+            self.assertGreaterEqual(events[0].count, 1)
+            self.assertEqual(events[0].name, all_reduce_event_name)
+            for event in events:
+                self.assertTrue(event.is_async)
+            # Ensure searching unused parameters was profiled
+            events = get_profiling_event("search_unused_parameters", prof)
+            self.assertEqual(len(events), 1)
+
+            return prof
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle("Currently failing in NVIDIA internal CI")
+        def test_ddp_profiling_autograd_profiler(self):
+            autograd_profiler_ctx = torch.autograd.profiler.profile()
+            return self._test_ddp_profiling(profiler_ctx=autograd_profiler_ctx)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        def test_ddp_profiling_torch_profiler(self):
+            cpu_act = torch.profiler.ProfilerActivity.CPU
+            cuda_act = torch.profiler.ProfilerActivity.CUDA
+            torch_profiler_ctx = torch.profiler.profile(activities=[cpu_act, cuda_act])
+            prof = self._test_ddp_profiling(profiler_ctx=torch_profiler_ctx)
+
+            if dist.get_backend() != "nccl":
+                return
+
+            # Note comment out the "os.remove(trace_file)" in `get_profiler_nccl_meta()`
+            # to debug any mismatches.
+            nccl_meta_events = get_profiler_nccl_meta(prof)
+            self.assertGreater(len(nccl_meta_events), 0)
+
+            nccl_meta = self._sanity_check_profiler_nccl_meta(nccl_meta_events)
+
+            # additionally check the specific collectives in this test case
+            self.assertEqual(len(nccl_meta["allreduce"]), 2)
+            self.assertEqual(len(nccl_meta["wait"]), 1)
+
+            # check allreduce message sizes
+            a0 = nccl_meta["allreduce"][0]
+            self.assertEqual(a0["Out msg nelems"], 100, msg=f"{a0}")
+            self.assertEqual(a0["dtype"], "Float", msg=f"{a0}")
+            a1 = nccl_meta["allreduce"][1]
+            self.assertEqual(a1["Out msg nelems"], 1, msg=f"{a1}")
+            self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
+
+        def _validate_execution_trace_nccl(self, et_file: str) -> None:
+            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
+            We test for basic fields in these nodes in the Execution Trace.
+            """
+            with open(et_file) as f:
+                et = json.load(f)
+            pg_cfg_node = [
+                n for n in et["nodes"] if n["name"] == "## process_group:init ##"
+            ]
+            self.assertGreaterEqual(len(pg_cfg_node), 1)
+            nccl_meta_nodes = [
+                n for n in et["nodes"] if n["name"] == "record_param_comms"
+            ]
+            self.assertEqual(len(nccl_meta_nodes), 3)
+            per_coll_meta = defaultdict(list)
+
+            # Sanity check NCCL metadata nodes
+            for n in nccl_meta_nodes:
+                attrs_list = n.get("attrs", [])
+                self.assertGreater(len(attrs_list), 0)
+                attrs = {a["name"]: a["value"] for a in attrs_list}
+
+                collname = attrs.get("collective_name", "")
+                self.assertNotEqual(collname, "")
+                self.assertNotEqual(attrs.get("dtype", ""), "")
+
+                per_coll_meta[collname].append(attrs)
+                if collname in {"wait"}:
+                    continue
+
+                self.assertEqual(attrs["pg_name"], "0")  # yes this is a string
+                self.assertEqual(attrs["pg_desc"], "default_pg")
+                self.assertEqual(attrs["pg_size"], 2)
+
+                self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0)
+                self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0)
+                self.assertTrue("in_split_size" in attrs.keys())
+                self.assertTrue("out_split_size" in attrs.keys())
+                self.assertEqual(attrs.get("global_rank_start", -1), 0)
+                self.assertEqual(attrs.get("global_rank_stride", -1), 1)
+
+            # print(per_coll_meta)
+            self.assertEqual(len(per_coll_meta["allreduce"]), 2)
+            self.assertEqual(len(per_coll_meta["wait"]), 1)
+
+            # check allreduce message sizes
+            a0 = per_coll_meta["allreduce"][0]
+            self.assertEqual(a0["out_msg_nelems"], 100, msg=f"{a0}")
+            self.assertEqual(a0["dtype"], "Float", msg=f"{a0}")
+            a1 = per_coll_meta["allreduce"][1]
+            self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}")
+            self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
+        )
+        @unittest.skipIf(BACKEND != "nccl", "Tests nccl metadata primarily.")
+        def test_ddp_profiling_execution_trace(self):
+            self.assertEqual(dist.get_backend(), "nccl")
+            # Create a temp file to save execution trace data
+            fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
+            fp.close()
+            et_file = fp.name
+            et = ExecutionTraceObserver().register_callback(et_file)
+
+            # first profiler context need not have ET
+            torch_profiler_ctx1 = torch.profiler.profile(
+                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+            )
+            # collect ET in second profiler pass
+            torch_profiler_ctx2 = torch.profiler.profile(
+                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+                execution_trace_observer=et,
+            )
+            self._test_ddp_profiling(
+                profiler_ctx=torch_profiler_ctx1,
+                profiler_ctx2=torch_profiler_ctx2,
+            )
+
+            print(f"Execution trace saved at {fp.name}")
+            self._validate_execution_trace_nccl(et_file)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_join_model_equivalence(self):
+            # Verifies equivalence with model training locally and with DDP under
+            # the join context manager.
+            batch = 3
+            dim = 10
+            learning_rate = 0.03
+            model = nn.Linear(dim, dim, bias=False)
+            inp = torch.rand(batch, dim, device=self.rank)
+            local_model = copy.deepcopy(model)
+            local_model = local_model.cuda(self.rank)
+            rank_to_iter_mapping = {
+                rank: 2 * (rank + 1) for rank in range(dist.get_world_size())
+            }
+            # run local model
+            local_iters = sum(rank_to_iter_mapping.values())
+            local_optim = torch.optim.SGD(local_model.parameters(), lr=learning_rate)
+            for _ in range(local_iters):
+                local_optim.zero_grad()
+                out = local_model(inp)
+                loss = out.sum()
+                loss.backward()
+                local_optim.step()
+
+            # run DDP model with join API
+            num_iters = rank_to_iter_mapping[self.rank]
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank), device_ids=[self.rank]
+            )
+            ddp_optim = torch.optim.SGD(
+                model.parameters(), lr=learning_rate * dist.get_world_size()
+            )
+            with net.join():
+                for _ in range(num_iters):
+                    ddp_optim.zero_grad()
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                    torch.cuda.synchronize(device=self.rank)
+                    ddp_optim.step()
+
+            # Validate model state dicts are equal
+            for (_, local_tensor), (_, dist_tensor) in zip(
+                local_model.state_dict().items(), net.module.state_dict().items()
+            ):
+                self.assertEqual(local_tensor, dist_tensor)
+
+        def _run_uneven_inputs_test(
+            self,
+            test_case,
+            iteration_mapping,
+            find_unused_params,
+        ):
+            model = test_case.model
+            inp = test_case.inp
+            rank = self.rank
+            sync_interval = test_case.sync_interval
+            torch.cuda.set_device(rank)
+            # Ensure all outstanding GPU work is completed so this test runs independently.
+            dist.barrier()
+            # Bucket_cap_mb is intentionally low to test allreduce scheduling when
+            # there are many buckets.
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(rank),
+                device_ids=[rank],
+                bucket_cap_mb=1,
+                find_unused_parameters=find_unused_params,
+            )
+            # Register hook if specified
+            if test_case.hook is not None:
+                net.register_comm_hook(test_case.state, test_case.hook)
+                print(f"registered hook {test_case.hook}")
+
+            # Determine num iters for this rank via the passed in mapping.
+            num_iters = iteration_mapping[rank]
+            # If we throw when earliest rank terminates, we should ensure
+            # that we iterate for that minimum number of times.
+            num_iters_tensor = torch.tensor(
+                [num_iters], device=torch.cuda.current_device()
+            )
+            dist.all_reduce(num_iters_tensor, op=dist.ReduceOp.MIN)
+            min_num_iters = num_iters_tensor.item()
+            total_iters = 0
+            if test_case.throw_on_early_termination:
+                if min_num_iters == num_iters:
+                    # Early termination rank(s)
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError, f"Rank {self.rank} exhausted all inputs"
+                    )
+                else:
+                    # Non early termination rank
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError,
+                        "Detected at least one rank that exhausted inputs.",
+                    )
+            else:
+                exception_ctx = nullcontext()
+            with exception_ctx:
+                with net.join(
+                    throw_on_early_termination=test_case.throw_on_early_termination
+                ):
+                    for i in range(num_iters):
+                        # Use model.no_sync() to disable grad synchronization every
+                        # sync_interval.
+                        if i % sync_interval != 0:
+                            context = net.no_sync()
+                        else:
+                            context = nullcontext()
+                        with context:
+                            if isinstance(inp, tuple):
+                                loss = net(*inp).sum()
+                            else:
+                                loss = net(inp).sum()
+                            loss.backward()
+                            self._model_step(net)
+                            # Ensure completion of GPU kernels (including allreduce). If the
+                            # join API is not properly implemented, then this should hang
+                            # since the allreduce will hang.
+                            torch.cuda.synchronize(device=rank)
+                        total_iters += 1
+            if test_case.throw_on_early_termination:
+                # Ensure we iterated min_num_iters times.
+                self.assertEqual(total_iters, min_num_iters)
+            else:
+                # Ensure we iterated at least min_num_iters times.
+                self.assertGreaterEqual(total_iters, min_num_iters)
+
+            # Ensure completion of all GPU kernels.
+            torch.cuda.synchronize(device=rank)
+            # When throwing on early rank termination, we do not
+            # broadcast model state from an authoritative rank. All models
+            # should already be in sync.
+            if not test_case.throw_on_early_termination:
+                self.assertTrue(net._authoritative_rank)
+                # All ranks should have agreed on the same authoritative_rank!
+                final_rank_tensor = torch.tensor(
+                    [net._authoritative_rank], device=self.rank
+                )
+                tensor_list = [
+                    torch.zeros_like(final_rank_tensor)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(tensor_list, final_rank_tensor)
+                max_rank = dist.get_world_size() - 1
+                self.assertSetEqual(
+                    {max_rank}, {tensor.item() for tensor in tensor_list}
+                )
+                # Ensure that all models are the same across ranks after all have joined.
+                self.validate_net_equivalence(net)
+                # Ensure that running with DDP uneven inputs was logged.
+                ddp_logging_data = net._get_ddp_logging_data()
+                self.assertTrue(ddp_logging_data.get("join_uneven_inputs"))
+                dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_inputs_stop_iteration_sync_bn(self):
+            # Tests that uneven inputs join handler correctly throws StopIteration
+            # for models with SyncBN or general collective comm when
+            # throw_on_early_termination=True.
+            class ModelWithComm(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(2, 40, bias=False)
+
+                def forward(self, x):
+                    x = self.lin(x)
+                    dist.all_reduce(x)
+                    return x
+
+            torch.cuda.set_device(self.rank)
+            model_bn = BN_NET
+            model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
+                copy.deepcopy(model_bn)
+            ).cuda(self.rank)
+            comm_model = ModelWithComm().cuda(self.rank)
+            model_input = torch.randn(10, 2).cuda(torch.cuda.current_device())
+
+            for model in [model_bn, comm_model]:
+                model = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                )
+                min_num_iters = 5
+                if self.rank != 0:
+                    # Early termination rank(s)
+                    num_iters = min_num_iters
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError, f"Rank {self.rank} exhausted all inputs"
+                    )
+                else:
+                    # Non early termination rank
+                    num_iters = min_num_iters * 2
+                    exception_ctx = self.assertRaisesRegex(
+                        RuntimeError,
+                        "Detected at least one rank that exhausted inputs.",
+                    )
+                n = 0
+                with exception_ctx:
+                    with model.join(throw_on_early_termination=True):
+                        for _ in range(num_iters):
+                            loss = model(model_input).sum()
+                            loss.backward()
+                            self._model_step(model)
+                            n += 1
+
+                self.assertEqual(n, min_num_iters)
+                # Verify model equivalence
+                self.validate_net_equivalence(model)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_inputs(self):
+            dim = 1000
+            batch = 1
+            # Create a variety of models to run uneven input tests on.
+            large_model = nn.Sequential(
+                nn.Conv2d(1, 20, 5),
+                nn.ReLU(),
+                nn.Conv2d(20, 32, 5),
+                nn.ReLU(),
+                nn.Conv2d(32, 256, 5),
+                nn.ReLU(),
+            )
+            small_model = nn.Linear(dim, dim, bias=False)
+            bn_net = BatchNormNet()
+
+            class UnusedParamModule(nn.Module):
+                def __init__(self, unused_params_rank):
+                    super().__init__()
+                    self.t0 = Task()
+                    self.t1 = Task()
+                    self.unused_params_rank = unused_params_rank
+
+                def task_parameters(self):
+                    return (self.t0.p, self.t1.p)
+
+                def forward(self, x, rank):
+                    return (
+                        self.t1(self.t0(x))
+                        if rank != self.unused_params_rank
+                        else self.t1(x)
+                    )
+
+            unjoined_rank_with_unused_params_model = UnusedParamModule(1)
+            joined_rank_with_unused_params_model = UnusedParamModule(0)
+
+            rank = self.rank
+            models_to_test = [
+                # Network with batchnorm
+                DDPUnevenTestInput(
+                    name="batch_norm_net",
+                    model=bn_net,
+                    inp=torch.ones(batch, 2, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="large_conv_model",
+                    model=large_model,
+                    inp=torch.ones(batch, batch, dim, dim, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="small_model",
+                    model=small_model,
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+                # Unused parameter test where rank that does not join early has unused params
+                DDPUnevenTestInput(
+                    name="unjoined_rank_with_unused_params_model",
+                    model=unjoined_rank_with_unused_params_model,
+                    inp=(torch.ones(batch, 2, device=rank), rank),
+                    sync_interval=1,
+                ),
+                # Unused parameter test where rank that does join early has unused params
+                DDPUnevenTestInput(
+                    name="joined_rank_with_unused_params_model",
+                    model=joined_rank_with_unused_params_model,
+                    inp=(torch.ones(batch, 2, device=rank), rank),
+                    sync_interval=1,
+                ),
+            ]
+
+            # Test models that have hook installed.
+            models_with_hook = [
+                DDPUnevenTestInput(
+                    name="small_model_allreduce_hook",
+                    model=small_model,
+                    hook=default.allreduce_hook,
+                    state=None,
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+                DDPUnevenTestInput(
+                    name="small_model_power_sgd_hook",
+                    model=small_model,
+                    hook=powerSGD.powerSGD_hook,
+                    state=powerSGD.PowerSGDState(
+                        process_group=None,
+                        matrix_approximation_rank=1,
+                        # Config so that powerSGD runs immediately instead of
+                        # allreduce.
+                        start_powerSGD_iter=1,
+                        warm_start=False,
+                        use_error_feedback=False,
+                    ),
+                    inp=torch.ones(batch, dim, device=rank),
+                    sync_interval=1,
+                ),
+            ]
+            models_to_test.extend(models_with_hook)
+
+            # Add resnet model if we have torchvision installed.
+            if HAS_TORCHVISION:
+                resnet_model = torchvision.models.resnet50()
+                models_to_test.append(
+                    DDPUnevenTestInput(
+                        name="resnet_model",
+                        model=resnet_model,
+                        inp=torch.ones(1, 3, 1000, 1000),
+                        sync_interval=1,
+                    )
+                )
+
+            # Test with no_sync every 2, 3, 4, ... iterations.
+            models_with_sync = []
+            for i, test_input in enumerate(models_to_test):
+                models_with_sync.append(
+                    DDPUnevenTestInput(
+                        name=test_input.name,
+                        model=test_input.model,
+                        inp=test_input.inp,
+                        sync_interval=i + 2,
+                    )
+                )
+
+            throw_on_early_term_tests = []
+            for test_input in models_to_test:
+                throw_on_early_term_tests.append(
+                    DDPUnevenTestInput(
+                        name=test_input.name,
+                        model=test_input.model,
+                        inp=test_input.inp,
+                        sync_interval=test_input.sync_interval,
+                        throw_on_early_termination=True,
+                    )
+                )
+
+            models_to_test.extend(models_with_sync)
+            models_to_test.extend(throw_on_early_term_tests)
+
+            # 0 iteration tests for when one process does not train model at all, so
+            # we must shadow the broadcast calls made when rebuilding buckets.
+            baseline_num_iters = [0, 5]
+            iteration_offsets = [2, 3, 10]
+            num_uneven_ranks = [1]
+            if dist.get_world_size() > 2:
+                num_uneven_ranks.append(2)
+            iteration_mappings = []
+            # Generate rank : num_iters mappings for various uneven input scenarios.
+            # This includes cases where rank 0 joins early and all other ranks join
+            # later, and scenarios where multiple ranks join early, but at different
+            # iterations, and later ranks join later.
+            for num_early_join_ranks in num_uneven_ranks:
+                for baseline_iter in baseline_num_iters:
+                    for offset in iteration_offsets:
+                        mapping = dict.fromkeys(
+                            range(0, num_early_join_ranks), baseline_iter
+                        )
+                        # if num_early_join_ranks > 1, ranks > 0 that will join early
+                        # iterate offset//2 more times than rank 0, to test nodes
+                        # depleting inputs at different times.
+                        if num_early_join_ranks > 1:
+                            for rank in mapping.keys():
+                                if rank > 0:
+                                    mapping[rank] += offset // 2
+                        mapping.update(
+                            dict.fromkeys(
+                                range(num_early_join_ranks, dist.get_world_size()),
+                                baseline_iter + offset,
+                            )
+                        )
+                        iteration_mappings.append(mapping)
+
+            for test_case, iteration_mapping in itertools.product(
+                models_to_test, iteration_mappings
+            ):
+                if self.rank == 0:
+                    print(
+                        f"""Running test: {test_case.name} sync interval
+                        {test_case.sync_interval} with iteration mapping
+                        {iteration_mapping}"""
+                    )
+                self._run_uneven_inputs_test(
+                    test_case,
+                    iteration_mapping,
+                    find_unused_params=("unused_params_model" in test_case.name),
+                )
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_input_join_disable(self):
+            # tests that if net.join() with enable=False is specified, DDP works as
+            # expected with even inputs.
+            torch.manual_seed(self.rank)
+            net = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(1, 1).cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = torch.ones(1) * self.rank
+            n_iters = 5
+            world_size = dist.get_world_size()
+            with net.join(enable=False):
+                for _ in range(n_iters):
+                    # Clear grads
+                    grad = net.module.weight.grad
+                    if grad is not None:
+                        grad.requires_grad_(False)
+                        grad.zero_()
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                    # Validate gradients to ensure that we divide by the correct
+                    # world_size when join mode is disabled.
+                    expected_grad = sum(i for i in range(world_size)) / world_size
+                    self.assertEqual(net.module.weight.grad.item(), expected_grad)
+
+            join_config = net._join_config
+            self.assertFalse(join_config.enable)
+            self.validate_net_equivalence(net)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_uneven_input_exception(self):
+            # Tests that exceptions during training are correctly propagated by the
+            # context manager.
+            error_str = "Intentional error"
+
+            class ExceptionModule(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.param = nn.Parameter(torch.ones(1, requires_grad=True))
+
+                def forward(self, _):
+                    raise ValueError(error_str)
+
+            exception_module = ExceptionModule()
+            net = torch.nn.parallel.DistributedDataParallel(
+                exception_module.cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = torch.ones(1)
+            with self.assertRaisesRegex(ValueError, error_str):
+                with net.join():
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+
+        def _test_broadcast_object_list(self, group=None):
+            gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
+
+            # Only set device for NCCL backend since it must use GPUs.
+            # Case where rank != GPU device.
+            next_rank = (self.rank + 1) % int(self.world_size)
+            backend = os.environ["BACKEND"]
+            if backend == "nccl":
+                torch.cuda.set_device(next_rank)
+
+            src_rank = 0
+            # If GPU test, add object with GPU tensor
+            if backend == "nccl":
+                gather_objects.append(Foo(torch.randn(3, 3, device=0)))
+
+            if IS_FBCODE:
+                # Create Tensor with > 2^31 Bytes storage requirements
+                # Only on FBCODE as testing OOMs in OSS
+                gather_objects.append(Foo(torch.randn(3, 178956971)))
+            objects = (
+                gather_objects
+                if self.rank == src_rank
+                else [None for _ in gather_objects]
+            )
+
+            # Single object test with device specified. Backend="gloo", device=cpu
+            if backend != "nccl":
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device("cpu")
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test with device specified. Backend="gloo", device=current_device+1
+            # The test is gated by the fact GPU count is the same as world size to avoid the case
+            # when backend is gloo but there is no multiple GPU devices.
+            if backend != "nccl" and torch.cuda.device_count() == int(self.world_size):
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device(next_rank)
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test with device specified. Backend="nccl", device=current_device+1
+            if backend == "nccl" and torch.cuda.device_count() == int(self.world_size):
+                single_obj_list = [objects[0]]
+                if self.rank != src_rank:
+                    self.assertNotEqual(single_obj_list[0], gather_objects[0])
+                dist.broadcast_object_list(
+                    single_obj_list, src=0, group=group, device=torch.device(next_rank)
+                )
+                self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Single object test: backward compatibility with device unspecified
+            single_obj_list = [objects[0]]
+            if self.rank != src_rank:
+                self.assertNotEqual(single_obj_list[0], gather_objects[0])
+            dist.broadcast_object_list(single_obj_list, src=0, group=group)
+            self.assertEqual(single_obj_list[0], gather_objects[0])
+
+            # Multiple input objects test
+            if self.rank != src_rank:
+                self.assertNotEqual(objects, gather_objects)
+            dist.broadcast_object_list(objects, src=0, group=group)
+            self.assertEqual(objects, gather_objects)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL"])
+        @unittest.skip(
+            "Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
+        )
+        def test_broadcast_object_list(self):
+            return self._test_broadcast_object_list()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @require_n_gpus_for_nccl_backend(
+            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
+        )
+        @with_dist_debug_levels(levels=["DETAIL"])
+        def _test_broadcast_object_list_subgroup(self):
+            default = _get_default_group()
+            backend = dist.get_backend(default)
+            subgroup = dist.new_group(backend=backend)
+            return self._test_broadcast_object_list(subgroup)
+
+        def _test_ddp_ignore_params_arg(self, static_graph=False):
+            class TestModel(nn.Module):
+                def __init__(self, rank):
+                    self.rank = rank
+                    super().__init__()
+                    self.fc1 = nn.Linear(1, 1, bias=False)
+                    # Proxy that will be materialized to another architecture later.
+                    # (after wrapping model with DDP)
+                    if self.rank == 0:
+                        self.fc2 = nn.Linear(1, 10, bias=False)
+                    else:
+                        self.fc2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.fc1(x)
+                    x = self.fc2(x)
+                    return x
+
+            device_id = self.rank
+            # Ensure the test works for both find_unused_parameter and broadcast_buffer settings.
+            for find_unused, broadcast_buffers in itertools.product(
+                [False, True], [False, True]
+            ):
+                model = TestModel(self.rank).float().to(device_id)
+                # Note that the model can have different shape buffers if we pass
+                # them in to be ignored as well.
+                model.fc2.register_buffer(
+                    "ignore_buffer", torch.zeros(5 + self.rank, device=self.rank)
+                )
+                proxy_params = list(model.fc2.parameters())
+                model_fc2_name = next(
+                    module_name
+                    for module_name, module in model.named_modules()
+                    if module is model.fc2
+                )
+                proxy_param_names = [
+                    f"{model_fc2_name}.{param_name}"
+                    for param_name, _ in model.fc2.named_parameters()
+                ]
+                proxy_buffer_names = [
+                    f"{model_fc2_name}.{buf_name}"
+                    for buf_name, _ in model.fc2.named_buffers()
+                ]
+                # Specify that we should ignore proxy_params since it will be
+                # materialized later.
+                torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                    model, proxy_param_names + proxy_buffer_names
+                )
+                ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[device_id],
+                    find_unused_parameters=find_unused,
+                    broadcast_buffers=broadcast_buffers,
+                    static_graph=static_graph,
+                )
+                # Materialize new params. These are not registered in DDP and thus
+                # don't have autograd hooks installed on them.
+                ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id)
+
+                # local model with the new materialized parameters.
+                local_model = copy.deepcopy(ddp.module).cuda(self.rank)
+
+                inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1)
+                for _ in range(6):
+                    ddp(inp).sum().backward()
+
+                    local_model(inp).sum().backward()
+                    # materialized param grad is not touched by DDP, so its grad should
+                    # be the same as if running locally.
+                    for materialized_param, local_param in zip(
+                        ddp.module.fc2.parameters(), local_model.fc2.parameters()
+                    ):
+                        self.assertEqual(materialized_param.grad, local_param.grad)
+
+                    # fc1 parameter grad should still be different, due to allreduce.
+                    for synced_param, local_param in zip(
+                        ddp.module.fc1.parameters(), local_model.fc1.parameters()
+                    ):
+                        self.assertFalse(synced_param.grad == local_param.grad)
+
+                    # Proxy module grad should not be touched
+                    for proxy_param in proxy_params:
+                        self.assertTrue(proxy_param.grad is None)
+
+                # Synchronize since we run multiple iterations of this test, to
+                # isolate failure hangs.
+                torch.cuda.synchronize(device=self.rank)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_ignore_params_arg(self):
+            self._test_ddp_ignore_params_arg(static_graph=False)
+            self._test_ddp_ignore_params_arg(static_graph=True)
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_unused_params_rebuild_buckets_exception(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10, bias=False)
+                    self.net2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    return self.net1(x)
+
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                ToyModel().cuda(self.rank), device_ids=[self.rank]
+            )
+            for i in range(2):
+                inp = torch.rand(1, 10)
+                if i > 0:
+                    # On 2nd iteration, this will fail during rebuild_buckets,
+                    # but we should report an error regarding unused parameters
+                    # since that is the underlying root cause.
+                    try:
+                        ddp(inp).sum().backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(ddp, msg)
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["net2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(
+                            True, "DDP unused parameters error not raised."
+                        )
+                else:
+                    ddp(inp).sum().backward()
+
+            dist.barrier()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_shared_grad_acc_unused_params(self):
+            # When find_unused_parameters=True, ensure we mark unused parameters
+            # even if they share gradient accumulators.
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    # net1, bias, and net1.bias are all unused params.
+                    self.net1 = nn.Linear(10, 5, bias=False)
+                    self.bias = nn.Parameter(torch.zeros(5))
+                    # net1.bias and self.bias are names for the same underlying
+                    # parameter, so they share the same grad acc. This caused
+                    # the bug reported in https://github.com/pytorch/pytorch/issues/41324.
+                    self.net1.bias = self.bias
+                    self.net2 = nn.Linear(10, 5)
+
+                def forward(self, x):
+                    return self.net2(x).sum()
+
+            torch.cuda.set_device(self.rank)
+            model = ToyModel().to(torch.cuda.current_device())
+            for static in [True, False]:
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    copy.deepcopy(model),
+                    device_ids=[self.rank],
+                    find_unused_parameters=True,
+                    static_graph=static,
+                )
+                inp = torch.randn(20, 10, device=self.rank)
+                for _ in range(6):
+                    loss = ddp_model(inp)
+                    # To test https://github.com/pytorch/pytorch/issues/61982
+                    loss /= 10
+                    loss.backward()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_device(self):
+            expected_len = 2
+
+            class TensorWrapper:
+                __slots__ = ["t", "moved_to_gpu"]
+
+                def __init__(self, t):
+                    self.t = t
+                    self.moved_to_gpu = False
+
+            # Handlers for specific types of validation we want to do based on
+            # the input type.
+
+            def tuple_and_list_validator(x):
+                self.assertTrue(len(x), expected_len)
+                self.assertEqual(1, len({t.device for t in x}))
+                self.assertEqual(x[0].device.index, self.rank)
+                return x[0] + x[1]
+
+            def namedtuple_validator(x):
+                self.assertEqual(x._fields, EXPECTED_FIELDS)
+                self.assertEqual(x.a.device.index, x.b.device.index)
+                self.assertEqual(x.a.device.index, self.rank)
+                return x.a + x.b
+
+            def custom_type_validator(x):
+                self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu"))
+                x.t = x.t.to(self.rank)
+                x.moved_to_gpu = True
+                return x.t
+
+            def dict_validator(x):
+                self.assertTrue(EXPECTED_FIELDS[0] in x.keys())
+                self.assertTrue(EXPECTED_FIELDS[1] in x.keys())
+                self.assertEqual(1, len({t.device for t in x.values()}))
+                self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank)
+                return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]]
+
+            validators = {
+                TensorWrapper: custom_type_validator,
+                tuple: tuple_and_list_validator,
+                list: tuple_and_list_validator,
+                TestNamedTupleInput_0: namedtuple_validator,
+                TestNamedTupleInput_1: namedtuple_validator,
+                dict: dict_validator,
+            }
+
+            class ToyModel(torch.nn.Module):
+                def __init__(self_):  # noqa: B902
+                    super().__init__()
+                    self_.lin = nn.Linear(10, 10, bias=False)
+
+                def forward(self_, x, expected_type):  # noqa: B902
+                    # Similar to scatter, the recursive to in the single-device
+                    # case does not move tensors if they are in a custom type.
+                    self.assertTrue(isinstance(x, expected_type))
+                    fwd_tensor = validators[expected_type](x)
+                    return self_.lin(fwd_tensor)
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel().to(self.rank), device_ids=[self.rank]
+            )
+
+            def train_iter(inp, input_type):
+                for _ in range(4):
+                    out = model(inp, input_type)
+                    out.sum().backward()
+
+            # CPU tuple input, should be moved to the proper device before call
+            # to forward.
+            inp = tuple(torch.randn(10, 10) for _ in range(expected_len))
+            train_iter(inp, tuple)
+
+            # List CPU input, should be moved to proper device before call to
+            # forward.
+            inp = [torch.randn(10, 10) for _ in range(expected_len)]
+            train_iter(inp, list)
+            # Custom type containing tensor. The type is maintained, but the
+            # device is not propagated (which is what happens with scatter too)
+            inp = TensorWrapper(torch.randn(10, 10))
+            train_iter(inp, TensorWrapper)
+            # NamedTuple input. The type should be maintained and tensor inputs
+            # should be moved to the correct device as in scatter.
+            batch = 5
+            dim = 10
+            a = torch.rand(batch, dim)
+            b = torch.rand(batch, dim)
+
+            inp = TestNamedTupleInput_0(a, b)
+            train_iter(inp, type(inp))
+
+            inp = TestNamedTupleInput_1(a, b)
+            train_iter(inp, type(inp))
+
+            # dictionary input.
+            inp = {
+                EXPECTED_FIELDS[0]: a,
+                EXPECTED_FIELDS[1]: b,
+            }
+            train_iter(inp, type(inp))
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_namedtuple(self):
+            batch = 5
+            dim = 10
+
+            a = torch.rand(batch, dim, device=self.rank)
+            b = torch.rand(batch, dim, device=self.rank)
+
+            class NamedTupleModule(torch.nn.Module):
+                def __init__(self_):  # noqa: B902
+                    super().__init__()
+                    self_.lin = nn.Linear(10, 1)
+
+                def forward(self_, input, expected_type):  # noqa: B902
+                    # Without NamedTuple support, this would be of type tuple.
+                    self.assertTrue(
+                        isinstance(input, expected_type),
+                        f"Expected type {expected_type} but got {type(input)}",
+                    )
+                    self.assertEqual(input._fields, EXPECTED_FIELDS)
+                    self.assertEqual(a, input.a)
+                    self.assertEqual(b, input.b)
+                    return self_.lin(torch.mul(input.a, input.b))
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                NamedTupleModule().cuda(self.rank), device_ids=[self.rank]
+            )
+            inp = TestNamedTupleInput_0(a, b)
+            # The following would fail if DDP does not propagate NamedTuples correctly.
+            model(inp, type(inp))
+
+            inp = TestNamedTupleInput_1(a, b)
+            model(inp, type(inp))
+
+        @require_backend_is_available({"gloo"})
+        def test_grads_same_across_ranks_with_no_sync(self):
+            _group, _group_id, rank = self._init_global_test()
+            world_size = dist.get_world_size()
+            if world_size < 2:
+                self.skipTest("This test requires at least two ranks.")
+
+            class SimpleConditionalModel(nn.Module):
+                # if rank is 0, uses nn1 on the first pass and nn2 on the second pass.
+                # else, uses nn3 on the first pass and nn4 on the second pass.
+
+                def __init__(self, rank):
+                    super().__init__()
+
+                    self.rank = rank
+                    self.nn1 = nn.Linear(1, 1)
+                    self.nn2 = nn.Linear(1, 1)
+                    self.nn3 = nn.Linear(1, 1)
+                    self.nn4 = nn.Linear(1, 1)
+                    self.state = 0
+
+                def forward(self, input):
+                    if self.state == 0:
+                        self.state = 1
+                        if self.rank == 0:
+                            return self.nn1(input)
+                        else:
+                            return self.nn3(input)
+                    else:
+                        self.state = 0
+                        if self.rank == 0:
+                            return self.nn2(input)
+                        else:
+                            return self.nn4(input)
+
+            model = torch.nn.parallel.DistributedDataParallel(
+                SimpleConditionalModel(rank), find_unused_parameters=True
+            )
+            mse_loss = nn.MSELoss()
+            grad_accumulation = 2
+
+            for microbatch_idx in range(grad_accumulation):
+                if microbatch_idx < grad_accumulation - 1:
+                    context = model.no_sync
+                else:
+                    context = nullcontext
+
+                with context():
+                    input = torch.rand((1,))
+                    output = model.forward(input)
+                    target = torch.rand((1,))
+
+                    loss = mse_loss(output, target)
+                    loss.backward()
+
+            self.assertTrue(
+                not any(p.grad is None for p in model.parameters()),
+                "Gradients can't be None for any model parameter.",
+            )
+            grads = torch.cat([p.grad.view(-1) for p in model.parameters()])
+
+            # Gather all gradients to rank 0.
+            if rank == 0:
+                gathered_grads = [torch.zeros_like(grads) for _ in range(world_size)]
+            else:
+                gathered_grads = []
+
+            dist.gather(grads, gather_list=gathered_grads, dst=0)
+            if rank == 0:
+                for g in gathered_grads[1:]:
+                    self.assertTrue(
+                        torch.allclose(gathered_grads[0], g),
+                        "Gradients are not the same for all ranks.",
+                    )
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_control_flow_same_across_ranks(self):
+            # Control flow that is the same across ranks.
+            batch = 20
+            dim = 10
+
+            world_size = dist.get_world_size()
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            random_input = torch.randn(batch, dim, device=self.rank)
+            ones_input = torch.ones(batch, dim, device=self.rank)
+            for i in range(6):
+                if i % 2 == 0:
+                    out = model(random_input)
+                else:
+                    out = model(ones_input)
+                loss = out.sum()
+                loss.backward()
+                # On even iterations, 2nd param goes unused, on odd iterations,
+                # it is used.
+                local_used_map = model.reducer._get_local_used_map()
+                if i % 2 == 0:
+                    expected = torch.tensor(
+                        [world_size, 0], device=self.rank, dtype=torch.int32
+                    )
+                else:
+                    expected = torch.tensor(
+                        [world_size, world_size], device=self.rank, dtype=torch.int32
+                    )
+
+                # Validate parameter usage.
+                variable_usage_tensor = local_used_map
+                self.assertEqual(variable_usage_tensor, expected)
+
+            # Validate appropriate error message when DDP is used with
+            # find_unused_parameters=False.
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            for i in range(2):
+                if i == 0:
+                    loss = model(random_input).sum()
+                    loss.backward()
+                else:
+                    try:
+                        loss = model(random_input).sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(model, msg)
+                        # 2nd linear layer is unused
+                        unused_param_index = 1
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                            f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["lin2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(True, "DDP error not raised")
+
+            dist.barrier()
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_invalid_static_graph(self):
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ControlFlowToyModel().cuda(self.rank),
+                device_ids=[self.rank],
+                static_graph=True,
+            )
+            random_input = torch.randn(20, 10, device=self.rank)
+            ones_input = torch.ones(20, 10, device=self.rank)
+            # unused parameter in the first iteration got used
+            # in second iteration.
+            expected_err = "Your training graph has changed in this iteration"
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                for i in range(2):
+                    if i % 2 == 0:
+                        out = model(random_input)
+                    else:
+                        out = model(ones_input)
+                    loss = out.sum()
+                    loss.backward()
+
+            verify_ddp_error_logged(model, expected_err)
+
+            # used parameter in the first iteration got unused
+            # in second iteration.
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "Expected to have finished reduction in the prior iteration "
+                "before starting a new one. This error indicates that your "
+                "training graph has changed in this iteration, "
+                "e.g., one parameter is used in first iteration, "
+                "but then got unused in the second iteration. "
+                "this is not compatible with static_graph set to True.\n"
+                "Parameter indices which did not receive grad for",
+            ):
+                for i in range(2):
+                    if i % 2 != 0:
+                        out = model(random_input)
+                    else:
+                        out = model(ones_input)
+                    loss = out.sum()
+                    loss.backward()
+
+            verify_ddp_error_logged(model, "Expected to have finished reduction")
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_control_flow_different_across_ranks(self):
+            # Control flow that is different across ranks.
+            batch = 20
+            dim = 10
+
+            class ToyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.lin1 = nn.Linear(10, 10, bias=False)
+                    self.lin2 = nn.Linear(10, 10, bias=False)
+                    self.rank = rank
+
+                def forward(self, x):
+                    # Control-flow that is rank and input dependent for the
+                    # model.
+                    use_second_layer = (
+                        torch.equal(x, torch.ones(batch, dim, device=x.device))
+                        and self.rank == 1
+                    )
+
+                    if use_second_layer:
+                        return self.lin2(F.relu(self.lin1(x)))
+                    else:
+                        return F.relu(self.lin1(x))
+
+            world_size = dist.get_world_size()
+            torch.cuda.set_device(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel(self.rank).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            random_input = torch.randn(batch, dim, device=self.rank)
+            ones_input = torch.ones(batch, dim, device=self.rank)
+            for i in range(6):
+                if i % 2 == 0:
+                    out = model(random_input)
+                else:
+                    out = model(ones_input)
+                loss = out.sum()
+                loss.backward()
+                # On even iterations, 2nd param goes unused, on odd iterations,
+                # it is used only on rank 1.
+                local_used_map = model.reducer._get_local_used_map()
+
+                if i % 2 == 0:
+                    expected = torch.tensor(
+                        [world_size, 0], device=self.rank, dtype=torch.int32
+                    )
+                else:
+                    expected = torch.tensor(
+                        [world_size, 1], device=self.rank, dtype=torch.int32
+                    )
+
+                variable_usage_tensor = local_used_map
+                # Validate parameter usage. On odd iterations, 2nd param is only
+                # used on rank 1.
+                self.assertEqual(variable_usage_tensor, expected)
+
+            # Validate appropriate error message when DDP is used with
+            # find_unused_parameters=False.
+            model = torch.nn.parallel.DistributedDataParallel(
+                ToyModel(self.rank).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            for i in range(2):
+                if i == 0:
+                    loss = model(random_input).sum()
+                    loss.backward()
+                else:
+                    try:
+                        loss = model(random_input).sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        msg = str(e)
+                        verify_ddp_error_logged(model, msg)
+                        unused_param_index = 1
+                        expected_strs = [
+                            ddp_prev_reduction_unfinished_str,
+                            ddp_recommend_find_unused_params_str,
+                            ddp_outputs_not_used_in_loss_str,
+                            f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
+                        ]
+                        # In debug mode, should show parameters that weren't reduced.
+                        # Without debug mode, should show suggestion to use debug mode.
+                        if dist.get_debug_level() == dist.DebugLevel.OFF:
+                            expected_strs.append(ddp_suggest_debug_mode_str)
+                        else:
+                            unreduced_params = ", ".join(["lin2.weight"])
+                            expected_strs.append(
+                                f"did not receive grad for rank {self.rank}: {unreduced_params}"
+                            )
+                        for s in expected_strs:
+                            self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
+                        self.assertFalse(ddp_find_unused_params_enabled_str in msg)
+                    else:
+                        self.assertFalse(True, "DDP error not raised")
+
+            dist.barrier()
+
+        @require_backend_is_available({"gloo"})
+        def test_scatter_object_list(self):
+            src_rank = 0
+            scatter_list = (
+                COLLECTIVES_OBJECT_TEST_LIST
+                if self.rank == src_rank
+                else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
+            )
+            world_size = dist.get_world_size()
+            scatter_list = scatter_list[:world_size]
+            i = 0
+            while len(scatter_list) < world_size:
+                scatter_list.append(scatter_list[i])
+                i += 1
+
+            output_obj_list = [None]
+            dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
+            self.assertEqual(
+                output_obj_list[0],
+                COLLECTIVES_OBJECT_TEST_LIST[
+                    self.rank % len(COLLECTIVES_OBJECT_TEST_LIST)
+                ],
+            )
+            # Ensure errors are raised upon incorrect arguments.
+            with self.assertRaisesRegex(
+                ValueError,
+                "Expected argument scatter_object_output_list to be a list of size at least 1.",
+            ):
+                dist.scatter_object_list([], scatter_list, src=src_rank)
+
+        def _generate_sparse_tensors_for_bucket_assignment_test(self):
+            tensors = [
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+            ]
+
+            tensors_sparse = [t.to_sparse() for t in tensors]
+            return tensors_sparse
+
+        def _test_compute_bucket_assignment_by_size(self, use_logger):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
+            # determinism.
+            os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=5)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Create a valid model. The constructor initializes the logger that we use later.
+            # We never actually use the rest of the model - we only need its logger.
+            net = EmbeddingNetDifferentParams(0)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net.to(self.rank),
+                device_ids=[self.rank],
+                process_group=group_to_use,
+            )
+
+            # if we don't pass a logger then we can only check that an exception was thrown.
+            expected_err = "No support for sparse tensors."
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                tensors_sparse = (
+                    self._generate_sparse_tensors_for_bucket_assignment_test()
+                )
+                if use_logger:
+                    dist._compute_bucket_assignment_by_size(
+                        tensors_sparse, [400], logger=net.logger
+                    )
+                else:
+                    dist._compute_bucket_assignment_by_size(tensors_sparse, [400])
+            if use_logger:
+                verify_ddp_error_logged(net, expected_err)
+
+            # Perform gloo-based barrier to ensure one rank doesn't exit test
+            # early which causes failure with Barrier.sync.
+            dist.barrier(group_gloo)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self):
+            self._test_compute_bucket_assignment_by_size(use_logger=False)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
+            self._test_compute_bucket_assignment_by_size(use_logger=True)
+
+        def _test_verify_model_across_rank(self, use_logger):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=5)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Create a valid model. The constructor initializes the logger that we use later.
+            net = EmbeddingNetDifferentParams(0)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net.to(self.rank),
+                device_ids=[self.rank],
+                process_group=group_to_use,
+            )
+
+            # Modify the model so that the number of parameters are different for each rank.
+            # This will cause a RuntimeError to be thrown below in _verify_param_shape_across_processes,
+            # so we can check if the correct error is thrown and is logged.
+            # We can't do this in the constructor above otherwise the logger will
+            # not be properly initialized.
+            net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
+
+            # if we pass a logger we can verify that it was logged
+            caught = 0
+            try:
+                if use_logger:
+                    _verify_param_shape_across_processes(
+                        net.process_group, list(net.parameters()), net.logger
+                    )
+                else:
+                    _verify_param_shape_across_processes(
+                        net.process_group, list(net.parameters())
+                    )
+            except Exception:
+                caught = 1
+
+            # As long as there is one rank catching the exception
+            t = torch.Tensor([caught])
+            dist.all_reduce(t, group=group_gloo)
+            self.assertGreater(t, 0)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_verify_model_across_rank_with_logger(self):
+            self._test_verify_model_across_rank(use_logger=True)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_verify_model_across_rank_without_logger(self):
+            self._test_verify_model_across_rank(use_logger=False)
+
+        def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo):
+            caught = 0
+            try:
+                net = torch.nn.parallel.DistributedDataParallel(
+                    net.to(self.rank), device_ids=[self.rank], process_group=ddp_group
+                )
+            except Exception:
+                caught = 1
+
+            # As long as there is one rank catching the exception
+            t = torch.Tensor([caught])
+            dist.all_reduce(t, group=group_gloo)
+            self.assertGreater(t, 0)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_model_diff_shape_across_ranks(self):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=10)
+            )
+            torch.cuda.set_device(self.rank)
+            # Creates network with different sized embedding table on different
+            # ranks. This should throw an error during DDP init.
+            net = EmbeddingNetDifferentParams(self.rank)
+            self._run_test_ddp_model_with_diff_params(net, group_to_use, group_gloo)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_model_diff_num_params_across_ranks(self):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=10)
+            )
+            torch.cuda.set_device(self.rank)
+
+            # Creates network with diff # of param across ranks, reducer should
+            # recognize this and throw appropriate error.
+            net = EmbeddingNetDifferentParams(
+                self.rank, diff_num_params=(self.rank == 1)
+            )
+
+            self._run_test_ddp_model_with_diff_params(
+                net,
+                group_to_use,
+                group_gloo,
+            )
+
+        def _test_output_unused_in_loss(self, module_cls, gradient_as_bucket_view):
+            model = module_cls()
+            local_net = copy.deepcopy(model)
+            net = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+
+            # Tests that certain parameters not getting gradient since the
+            # output is unused in loss computation is supported. Specifically,
+            # checks that the grads remain unchanged and are the same as local
+            # training.
+            inp = torch.randn(10, 10)
+
+            # Ensure that if a param is not used in loss computation, its
+            # gradient is untouched, i.e. if it is None before it is None after,
+            # not zero.
+            if module_cls == DictOutputModule:
+                a, b = local_net(inp)["predictions"]
+                a_dist, b_dist = net(inp)["predictions"]
+            else:
+                a, b = local_net(inp)
+                a_dist, b_dist = net(inp)
+
+            loss_dist = b_dist.sum()
+            loss_dist.backward()
+
+            # Ensure that gradient corresponding to parameter "a" was not
+            # touched, i.e. it is None and matches the local grad.
+            if module_cls == DictOutputModule:
+                self.assertTrue(net.module.module.a.weight.grad is None)
+                self.assertEqual(
+                    net.module.module.a.weight.grad, local_net.module.a.weight.grad
+                )
+            else:
+                self.assertTrue(net.module.a.weight.grad is None)
+                self.assertEqual(net.module.a.weight.grad, local_net.a.weight.grad)
+
+            saved_a_local_grad = None
+            saved_a_dist_grad = None
+            net.zero_grad()
+            local_net.zero_grad()
+            for i in range(6):
+                if module_cls == DictOutputModule:
+                    a, b = local_net(inp)["predictions"]
+                    a_dist, b_dist = net(inp)["predictions"]
+                else:
+                    a, b = local_net(inp)
+                    a_dist, b_dist = net(inp)
+                if i < 2:
+                    # Use both params in loss computation. Later, "a" will go
+                    # unused and we check to ensure DDP supports this and
+                    # gradients remain the same as local training.
+                    t = a @ b
+                    t_dist = a_dist @ b_dist
+                    loss = t.sum()
+                    loss_dist = t_dist.sum()
+                else:
+                    # Model output "a" unused in loss.
+                    loss = b.sum()
+                    loss_dist = b_dist.sum()
+                loss.backward()
+                loss_dist.backward()
+                if i == 1:
+                    # Save grads to compare with them in next iterations.
+                    if module_cls == DictOutputModule:
+                        saved_a_local_grad = local_net.module.a.weight.grad
+                        saved_a_dist_grad = net.module.module.a.weight.grad
+                    else:
+                        saved_a_local_grad = local_net.a.weight.grad
+                        saved_a_dist_grad = net.module.a.weight.grad
+                    self.assertEqual(saved_a_local_grad, saved_a_dist_grad)
+                elif i >= 2:
+                    # parameter "a" of both models should be the same and not change
+                    if module_cls == DictOutputModule:
+                        self.assertEqual(
+                            net.module.module.a.weight.grad, saved_a_dist_grad
+                        )
+                        self.assertEqual(
+                            local_net.module.a.weight.grad, saved_a_local_grad
+                        )
+                    else:
+                        self.assertEqual(net.module.a.weight.grad, saved_a_dist_grad)
+                        self.assertEqual(local_net.a.weight.grad, saved_a_local_grad)
+
+                # Verify grads are the same
+                for local_param, dist_param in zip(
+                    local_net.parameters(), net.parameters()
+                ):
+                    local_grad = local_param.grad
+                    dist_grad = dist_param.grad
+                    self.assertEqual(local_grad, dist_grad)
+
+            dist.barrier()
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_output_unused_in_loss_tuple_module(self):
+            module_cls = UnusedParamTwoLinLayerNet
+            for grad_as_bucket_view in [True, False]:
+                self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_output_unused_in_loss_dict_module(self):
+            module_cls = DictOutputModule
+            for grad_as_bucket_view in [True, False]:
+                self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_undefined_grad_parity_unused_parameters(self):
+            # TODO: enable this for general training use cases:
+            # https://github.com/pytorch/pytorch/issues/58511.
+            x = torch.ones(1, 2).to(self.rank)
+            net = Net().to(self.rank)
+            local_net = copy.deepcopy(net)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            out = net(x).sum()
+            local_out = local_net(x).sum()
+            # Simulates undefined gradients.
+            torch._C._functions.UndefinedGrad()(out).backward()
+            torch._C._functions.UndefinedGrad()(local_out).backward()
+            for (dist_param_name, dist_param), (local_param_name, local_param) in zip(
+                net.named_parameters(), local_net.named_parameters()
+            ):
+                dist_grad = dist_param.grad
+                local_grad = local_param.grad
+                self.assertEqual(
+                    dist_grad,
+                    local_grad,
+                    f"""DDP param {dist_param_name} with grad {dist_grad}
+                    does not match local param {local_param_name} with grad
+                    {local_grad}""",
+                )
+
+        def _test_different_graph_across_ranks(
+            self, find_unused_parameters=False, static_graph=False
+        ):
+            class ToyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.lin1 = nn.Linear(10, 10, bias=False)
+                    self.lin2 = nn.Linear(10, 10, bias=False)
+                    self.rank = rank
+
+                def forward(self, x):
+                    if self.rank == 0:
+                        return self.lin2(F.relu(self.lin1(x)))
+                    else:
+                        return F.relu(self.lin1(x))
+
+            torch.manual_seed(31415)
+            torch.cuda.set_device(self.rank)
+            model = ToyModel(self.rank).cuda(self.rank)
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=find_unused_parameters,
+                gradient_as_bucket_view=True,
+                static_graph=static_graph,
+            )
+            random_input = torch.randn(20, 10, device=self.rank)
+            for _ in range(10):
+                out = ddp_model(random_input)
+                loss = out.sum()
+                loss.backward()
+            return ddp_model
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_different_graph_across_ranks(self):
+            base_model = self._test_different_graph_across_ranks(
+                find_unused_parameters=True
+            )
+            self.assertFalse(
+                base_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
+            )
+            static_model = self._test_different_graph_across_ranks(static_graph=True)
+            self.assertTrue(
+                static_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
+            )
+            for i, j in zip(base_model.parameters(), static_model.parameters()):
+                self.assertEqual(i, j)
+
+        @require_backend_is_available({"gloo"})
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "MacOS uses uv transport which does not have as robust error handling as tcp transport",
+        )
+        def test_monitored_barrier_gloo(self):
+            tensors = [torch.ones(10) * self.rank]
+            # Kick off some allreduce work on all ranks
+            for _ in range(10):
+                dist.all_reduce(torch.cat(tensors))
+            # Run monitored barrier and ensure it passes
+            timeout = timedelta(seconds=2)
+            dist.monitored_barrier(timeout=timeout)
+            # Check monitored_barrier success with wait_all_ranks=True
+            for _ in range(10):
+                dist.all_reduce(torch.cat(tensors))
+            dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
+            # All ranks besides 1 call into barrier, rank 0 should report failure
+            # while others report gloo error.
+            failed_rank = 1
+            src_rank = 0
+            if self.rank == src_rank:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
+                ):
+                    dist.monitored_barrier(timeout=timeout)
+            elif self.rank != failed_rank:
+                # Other ranks should not pass barrier since rank 0 failed.
+                err_regex = (
+                    f"Rank {self.rank} successfully reached monitoredBarrier,"
+                    f" but received errors while waiting for send/recv from rank"
+                    f" {src_rank}"
+                )
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout)
+
+            # We need a barrier since otherwise failed_rank exits too early
+            # and cause a timeout.
+            self._barrier(timeout=30)
+
+        @require_backend_is_available({"gloo"})
+        def test_monitored_barrier_gloo_subgroup(self):
+            # Tests that monitored_barrier works as expected on non-default
+            # process groups.
+            failed_rank = 1
+            timeout = 0.1
+            subgroup = dist.new_group(ranks=[0, 1])
+
+            if self.rank == failed_rank:
+                return
+
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
+                ):
+                    dist.monitored_barrier(subgroup, timeout)
+            else:
+                # Other ranks call into monitored_barrier, but this should be a
+                # noop because they are not part of the subgroup. Verify that
+                # there are no errors here.
+                dist.monitored_barrier(subgroup, timeout)
+
+        def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks):
+            # tests expected behavior when nonzero rank hangs.
+            nccl_pg = dist.new_group(
+                ranks=list(range(int(self.world_size))),
+                # provide sufficient timeout so communicators
+                # can be initialized in ctor.
+                timeout=timedelta(seconds=15),
+                backend=dist.Backend.NCCL,
+            )
+            gloo_pg = dist.new_group(
+                ranks=list(range(int(self.world_size))),
+                backend=dist.Backend.GLOO,
+            )
+            tensors = [torch.ones(10, device=self.rank) * self.rank]
+            # Let all ranks call allreduce first to set up communicators etc.
+            # Directly simulating error here will run into store issue described
+            # in https://github.com/pytorch/pytorch/issues/54524.
+            nccl_pg.allreduce(tensors).wait(timedelta(seconds=5))
+            # All ranks besides 0 call into allreduce. This is to simulate a
+            # desync across the world, where some ranks call into
+            # monitored_barrier() and others are stuck in collective comm. In
+            # practice, we don't need TORCH_NCCL_BLOCKING_WAIT, but we use it in this
+            # test to ensure it exits cleanly.
+            if self.rank != 0:
+                # Can get different errors here depending on whether gloo-based
+                # wrapper PG is enabled or not, since with wrapper pg, it will
+                # fail in a collective synchronization check and not actually
+                # call into the nccl pg.
+                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
+                    err_regex = "Timed out waiting"
+                else:
+                    err_regex = "caught collective operation timeout"
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    nccl_pg.allreduce(tensors).wait(timedelta(seconds=0.1))
+            else:
+                # Rank 0 should report first (in order) timed out rank or all ranks
+                # depending on wait_all_ranks flag passed into monitored_barrier.
+                if wait_all_ranks:
+                    rank_str = ", ".join(
+                        [str(i) for i in range(1, int(self.world_size))]
+                    )
+                    err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
+                else:
+                    expected_first_fail_rank = 1
+                    err_regex = f"Rank {expected_first_fail_rank} failed to pass monitoredBarrier"
+                monitored_barrier_timeout_seconds = timedelta(seconds=0.1)
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    gloo_pg.monitored_barrier(
+                        monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
+                    )
+
+            self._barrier(timeout=30)
+
+        @with_nccl_blocking_wait
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_monitored_barrier_allreduce_hang(self):
+            # tests expected behavior when nonzero rank hangs and we want to
+            # report first timed out rank.
+            self._test_monitored_barrier_allreduce_hang(wait_all_ranks=False)
+
+        @with_nccl_blocking_wait
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        def test_monitored_barrier_allreduce_hang_wait_all_ranks(self):
+            # Need to disable TORCH_NCCL_DUMP_ON_TIMEOUT otherwise this test times out
+            os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "0"
+            # tests expected behavior when nonzero rank hangs and we want to
+            # report all timed out ranks.
+            self._test_monitored_barrier_allreduce_hang(wait_all_ranks=True)
+
+        @require_backend_is_available({"gloo"})
+        def test_monitored_barrier_gloo_rank_0_timeout(self):
+            # tests error when rank 0 exhausts its given timeout.
+            process_group = dist.new_group(ranks=list(range(int(self.world_size))))
+            timeout = timedelta(seconds=0)
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {self.rank} timed out in monitoredBarrier"
+                ):
+                    process_group.monitored_barrier(timeout)
+
+        @require_backend_is_available({"gloo"})
+        @skip_if_small_worldsize
+        @skip_but_pass_in_sandcastle_if(
+            IS_MACOS or IS_WINDOWS,
+            "MacOS uses uv transport which does not have as robust error handling as tcp transport",
+        )
+        def test_monitored_barrier_failure_order(self):
+            # Ensure that the first (in sorted order) rank is reported when
+            # multiple ranks fail to pass the monitored_barrier.
+            # TODO(#54879): Provide ability to wait and report all failed ranks
+            expected_first_failed_rank = 2
+            timeout = timedelta(seconds=2)
+            src_rank = 0
+            if self.rank == src_rank:
+                with self.assertRaisesRegex(
+                    RuntimeError, f"Rank {expected_first_failed_rank}"
+                ):
+                    dist.monitored_barrier(timeout=timeout)
+            elif self.rank == 1:
+                err_regex = (
+                    f"Rank {self.rank} successfully reached monitoredBarrier,"
+                    f" but received errors while waiting for send/recv from rank"
+                    f" {src_rank}"
+                )
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout)
+
+        @require_backend_is_available({"gloo"})
+        @skip_if_small_worldsize
+        def test_monitored_barrier_wait_all_ranks(self):
+            # Tests simple case where > 1 rank does not call into monitored
+            # barrier and verifies all ranks are reported by rank 0.
+            if self.rank == 0:
+                timeout = timedelta(seconds=0.1)
+                rank_str = ", ".join([str(i) for i in range(1, int(self.world_size))])
+                err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
+                with self.assertRaisesRegex(RuntimeError, err_regex):
+                    dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @with_dist_debug_levels(levels=["INFO"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_build_debug_param_to_name_mapping(self):
+            model = TwoLinLayerNet()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            expected_mapping = {0: "a.weight", 1: "b.weight"}
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertDictEqual(expected_mapping, param_to_name_mapping)
+
+            # Test when DDP is used with ignored parameters.
+            model = TwoLinLayerNet()
+            # Parameters to ignore are in the format {module_name}.{param_name}
+            params_to_ignore = ["a.weight"]
+            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                model, params_to_ignore
+            )
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            expected_mapping = {0: "b.weight"}
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertDictEqual(expected_mapping, param_to_name_mapping)
+
+            # Test errors are raised when DDP and module parameters mismatch.
+            # This generally indicates a bug with DDP and is not expected to
+            # happen in user applications.
+            model = TwoLinLayerNet()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            net_params, _ = net._build_params_for_reducer()
+            if self.rank == 0:
+                print(type(net_params[0]))
+
+            net_params.extend(
+                [
+                    torch.nn.Parameter(torch.ones(1)),
+                    torch.nn.Parameter(torch.ones(1)),
+                ]
+            )
+
+            with self.assertRaisesRegex(ValueError, "Expected param to name mapping"):
+                net._build_debug_param_to_name_mapping(net_params)
+
+            net_params = net_params[:-3]
+            with self.assertRaisesRegex(ValueError, "Param with name"):
+                net._build_debug_param_to_name_mapping(net_params)
+
+            net_params.extend(
+                [
+                    torch.nn.Parameter(torch.ones(1)),
+                    torch.nn.Parameter(torch.ones(1)),
+                ]
+            )
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @with_dist_debug_levels(levels=["INFO"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_build_debug_param_to_name_mapping_requires_grad(self):
+            class Net(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(10, 10)
+                    # Is not tracked by DDP and should not show up in param to
+                    # name mapping.
+                    self.lin.bias.requires_grad_(False)
+
+                def forward(self, x):
+                    return self.lin(x)
+
+            model = Net()
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank), device_ids=[self.rank]
+            )
+            expected_mapping = {
+                0: "lin.weight",
+            }
+            net_params, _ = net._build_params_for_reducer()
+            param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
+            self.assertEqual(param_to_name_mapping, expected_mapping)
+
+        def _test_ddp_multiple_nested_unused_params_error(self, ignore_sparse):
+            debug_mode_off = dist.get_debug_level() == dist.DebugLevel.OFF
+
+            class SubModule(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.embedding_net = EmbeddingNetDifferentParams(0)
+                    self.lin = TwoLinLayerNet()
+                    self.bn = BatchNormNet()
+                    self.lin_layer = nn.Linear(4, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.bn(x)
+                    x = self.lin_layer(x)
+                    x = self.lin.a(x)  # self.lin.b param unused
+                    # EmbeddingNetDifferentParams entirely unused: self.embedding_net.embedding and
+                    # self.embedding_net.lin unused.
+                    return x
+
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.sub_module = SubModule()
+
+                def forward(self, x):
+                    return self.sub_module(x)
+
+            model = MyModel()
+            sparse_embedding_fqns = []
+            if ignore_sparse:
+                for module_name, module in model.named_modules():
+                    if module == model.sub_module.embedding_net.embedding:
+                        for parameter_name, _param in module.named_parameters(
+                            recurse=False
+                        ):
+                            fqn = f"{module_name}.{parameter_name}"
+                            sparse_embedding_fqns.append(fqn)
+
+                torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
+                    model, sparse_embedding_fqns
+                )
+                unused_modules = [
+                    model.sub_module.embedding_net.lin,
+                    model.sub_module.lin.b,
+                ]
+            else:
+                unused_modules = list(model.sub_module.embedding_net.modules()) + [
+                    model.sub_module.lin.b,
+                ]
+
+            expected_unused_param_fqns = []
+            used_param_fqns = []  # Validate that these don't mistakenly show up.
+            fqn_to_param_index = {}
+            index = 0
+            for module_name, module in model.named_modules():
+                for parameter_name, _param in module.named_parameters(recurse=False):
+                    fqn = f"{module_name}.{parameter_name}"
+                    fqn_to_param_index[fqn] = index
+                    if fqn not in sparse_embedding_fqns:
+                        index += 1
+                    if module in unused_modules:
+                        expected_unused_param_fqns.append(fqn)
+                    else:
+                        if (
+                            not ignore_sparse
+                            or module != model.sub_module.embedding_net.embedding
+                        ):
+                            used_param_fqns.append(fqn)
+
+            net = torch.nn.parallel.DistributedDataParallel(
+                model.cuda(self.rank),
+                device_ids=[self.rank],
+            )
+            batch, dim = 10, 2
+            inp = torch.ones(batch, dim)
+            for i in range(2):
+                if i == 0:
+                    out = net(inp)
+                    loss = out.sum()
+                    loss.backward()
+                else:
+                    try:
+                        out = net(inp)
+                        loss = out.sum()
+                        loss.backward()
+                    except RuntimeError as e:
+                        e = str(e)
+
+                        unused_param_substr = e[e.find("did not receive grad") :]
+                        # Validate that each unused param fully qualified name
+                        # shows up in error logs. We do this instead of
+                        # constructing a joined string since order of parameters
+                        # can be different in Reducer. In addition, validate
+                        # param indices show up as well.
+                        for unused_param_fqn in expected_unused_param_fqns:
+                            self.assertTrue(
+                                unused_param_fqn in unused_param_substr
+                                or debug_mode_off
+                            )
+                            self.assertTrue(
+                                str(fqn_to_param_index[unused_param_fqn])
+                                in unused_param_substr,
+                                f"Did not find index {fqn_to_param_index[unused_param_fqn]} for {unused_param_fqn}",
+                            )
+
+                        # Validate that used param fqns don't show up in error
+                        # logs.
+                        for used_param_fqn in used_param_fqns:
+                            self.assertFalse(used_param_fqn in unused_param_substr)
+                        # Validate that ignored param fqns don't show up as unused
+                        # (since DDP does not track them)
+                        for sparse_param_fqn in sparse_embedding_fqns:
+                            self.assertFalse(sparse_param_fqn in unused_param_substr)
+                    else:
+                        self.assertTrue(False, "Expected error was not raised!")
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_multiple_nested_unused_params_error(self):
+            self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=False)
+
+        @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_multiple_nested_unused_params_err_ignore_params(self):
+            # Tests unused parameter reporting when DDP is configured to ignore
+            # certain parameters.
+            self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=True)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_inference(self):
+            # tests that DDP module can be run on a single node with no_grad
+            # or eval setting and there is no hang.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            model = Net().cuda()
+            local_model = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+            )
+            syncbn_model = nn.SyncBatchNorm(
+                2, momentum=0.99, track_running_stats=False
+            ).cuda()
+            local_syncbn_model = copy.deepcopy(syncbn_model)
+            syncbn_model = torch.nn.parallel.DistributedDataParallel(
+                syncbn_model, device_ids=[rank]
+            )
+            inp = torch.randn(10, 2, device=rank)
+            inp_syncbn = torch.randn(10, 2, 4, 4, device=rank)
+            tests = [
+                (model, local_model, inp),
+                (syncbn_model, local_syncbn_model, inp_syncbn),
+            ]
+            for test in tests:
+                test_model, test_local_model, test_inp = test
+                if self.rank == 0:
+                    test_model.eval()
+                    test_local_model.eval()
+                    for _ in range(6):
+                        self.assertEqual(
+                            test_model(test_inp), test_local_model(test_inp)
+                        )
+
+            # Barrier since only rank 0 runs inference. Test should be
+            # much faster than 30s, but this is to avoid flakiness.
+            self._barrier(timeout=30)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @skip_if_lt_x_gpu(2)
+        @unittest.skip(
+            "Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
+        )
+        def test_ddp_sync_bn_training_vs_eval(self):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            # Need to set track_running_stats=False, when track_running_stats=True,
+            # bn_training is False and sync could not occur in eval model.
+            model = nn.SyncBatchNorm(2, momentum=0.99, track_running_stats=False).cuda(
+                rank
+            )
+            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
+            # Test sync occurs in training mode.
+            with torch.autograd.profiler.profile() as prof:
+                for _ in range(6):
+                    inp = torch.randn(10, 2, 4, 4).cuda(rank)
+                    out = model(inp)
+                    loss = out.sum()
+                    loss.backward()
+
+            # SyncBN allgathers stats across all ranks, so verify call to
+            # all_gather in profiler.
+            if BACKEND == "nccl":
+                all_gather_calls = get_profiling_event("_all_gather_base", prof)
+            else:
+                all_gather_calls = get_profiling_event("all_gather", prof)
+            self.assertNotEqual([], all_gather_calls)
+
+            # Only do inference on one rank. If SyncBN did collective stats sync,
+            # this would hang/error.
+            model_inference = model.module
+            if self.rank == 0:
+                model_inference.eval()
+                with torch.autograd.profiler.profile() as prof:
+                    for _ in range(6):
+                        inp = torch.randn(10, 2, 4, 4).cuda(rank)
+                        out = model_inference(inp)
+                        loss = out.sum()
+                        loss.backward()
+
+                # Ensure sync does not occur in eval() mode.
+                if BACKEND == "nccl":
+                    all_gather_calls = get_profiling_event("_all_gather_base", prof)
+                else:
+                    all_gather_calls = get_profiling_event("all_gather", prof)
+                self.assertEqual([], all_gather_calls)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_python_error_logged(self):
+            # Most python exceptions in DDP are raised during init before
+            # reducer is constructed, so we don't have a logger in those cases.
+            # However, the below is one example where a python error is thrown
+            # after reducer is constructed.
+            model = TwoLinLayerNet().cuda(self.rank)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            expected_err = "must be callable"
+            with self.assertRaisesRegex(TypeError, expected_err):
+                model.register_comm_hook({}, {})
+
+            verify_ddp_error_logged(model, expected_err)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_static_graph_nested_types(self):
+            # Tests for static graph training when outputs are not just tensors
+            # but can be (nested) tuple, list, dict, etc.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+
+            class NestedOutputModule(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(100, 1, bias=False)
+
+                def forward(self, inp, output_type):
+                    if output_type == "tuple":
+                        return (
+                            self.lin(inp),
+                            (
+                                self.lin(inp),
+                                self.lin(inp),
+                            ),
+                        )
+                    elif output_type == "list":
+                        return [
+                            self.lin(inp),
+                            [
+                                self.lin(inp),
+                                self.lin(inp),
+                            ],
+                        ]
+                    elif output_type == "dict":
+                        return {
+                            "a": self.lin(inp),
+                            "b": {
+                                "c": self.lin(inp),
+                            },
+                        }
+
+            def get_loss(model_output):
+                loss = 0.0
+                if isinstance(model_output, torch.Tensor):
+                    return model_output.sum()
+                elif isinstance(model_output, dict):
+                    for value in model_output.values():
+                        loss += get_loss(value)
+                elif isinstance(model_output, (tuple, list)):
+                    for x in model_output:
+                        loss += get_loss(x)
+                else:
+                    raise ValueError(f"Unknown model output type {type(model_output)}")
+                return loss
+
+            model = NestedOutputModule().cuda(rank)
+            model_static_graph = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+            )
+            model_static_graph = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[rank],
+                static_graph=True,
+            )
+            inp = torch.randn(10, 100)
+            type_mapping = {
+                "list": list,
+                "tuple": tuple,
+                "dict": dict,
+            }
+            for output_type in type_mapping.keys():
+                for _ in range(6):
+                    out = model(inp, output_type=output_type)
+                    loss = get_loss(out)
+                    loss.backward()
+                    self._model_step(model)
+                    out_static = model_static_graph(inp, output_type=output_type)
+                    self.assertTrue(isinstance(out_static, type_mapping[output_type]))
+                    loss_static = get_loss(out_static)
+                    loss_static.backward()
+                    self._model_step(model_static_graph)
+                    for p, p_static in zip(
+                        model.parameters(), model_static_graph.parameters()
+                    ):
+                        self.assertEqual(p, p_static)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_returns_tensor_with_no_grad(self):
+            # Tests case where module returns tensor that does not require grad.
+            torch.cuda.set_device(self.rank)
+
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(10, 10, bias=False)
+                    self.fc2 = nn.Linear(10, 10, bias=False)
+
+                def forward(self, x):
+                    x = self.fc2(F.relu(self.fc1(x)))
+                    y = x.clone()
+                    x = x.detach()
+                    assert not x.requires_grad
+                    return (x, y)
+
+            model = MyModel().to(self.rank)
+            inp = torch.randn(1, 10, device=self.rank)
+            for find_unused, static_graph in itertools.product(
+                [True, False], [True, False]
+            ):
+                ddp = DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    output_device=self.rank,
+                    find_unused_parameters=find_unused,
+                    static_graph=static_graph,
+                )
+                for _ in range(6):
+                    out = ddp(inp)
+                    self.assertFalse(out[0].requires_grad)
+                    o = (out[0] + out[1]).sum()
+                    o.backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_detect_ddp_is_actually_static(self):
+            class ToyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.net1 = nn.Linear(10, 10, bias=False)
+                    self.net2 = nn.Linear(10, 10)
+
+                def forward(self, x, find_unused, dynamic):
+                    if find_unused:
+                        if dynamic:
+                            return self.net2(self.net1(x))
+                        else:
+                            return self.net2(x)
+                    else:
+                        return self.net2(self.net1(x))
+
+            # Set of unused parameters don't change across iterations
+            torch.cuda.set_device(self.rank)
+            model = ToyModel().cuda()
+            for find_unused in [True, False]:
+                ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    find_unused_parameters=find_unused,
+                )
+                inp = torch.randn(1, 10, device="cuda")
+                for _ in range(6):
+                    out = ddp(inp, find_unused=find_unused, dynamic=False)
+                    loss = out.sum()
+                    loss.backward()
+                    self.assertTrue(ddp.reducer._ddp_graph_static())
+
+            # Set of unused parameters dynamically change
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            inp = torch.randn(1, 10, device="cuda")
+            for i in range(6):
+                out = ddp(inp, find_unused=True, dynamic=i % 2 == 0)
+                loss = out.sum()
+                loss.backward()
+            self.assertFalse(ddp.reducer._ddp_graph_static())
+
+        def _test_ddp_new_tensor_in_fwd(self, static_graph):
+            # Test from https://github.com/pytorch/pytorch/issues/60733
+            class MyModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(10, 10, bias=False)
+                    self.fc2 = nn.Linear(10, 10, bias=False)
+                    self.device = self.fc1.weight.device
+
+                def __init_opt(self):
+                    opt = torch.randn(1, 10, device=self.device)
+                    return opt
+
+                def forward(self, x, opt_1, opt_2, opt_nested):
+                    x = F.relu(self.fc1(x))
+                    x = self.fc2(x)
+                    if opt_1 is None:
+                        opt_1 = self.__init_opt()
+                    if opt_2 is None:
+                        opt_2 = self.__init_opt()
+                    if opt_nested is None or not torch.is_tensor(opt_nested):
+                        opt_nested = self.__init_opt()
+                    # Test multiple tensors as well as newly created tensors
+                    # within a struct.
+                    return x, opt_1, opt_2, {"tensor": opt_nested}
+
+            model = MyModel().to(self.rank)
+            for find_unused in [True, False]:
+                ddp = DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                    output_device=self.rank,
+                    broadcast_buffers=False,
+                    find_unused_parameters=find_unused,
+                    static_graph=static_graph,
+                )
+
+                opt = [None for _ in range(3)]
+                for i in range(2):
+                    ddp.zero_grad()
+                    x = torch.randn(1, 10, device=self.rank)
+                    out, opt[0], opt[1], opt[2] = ddp(
+                        x, opt_1=opt[0], opt_2=opt[1], opt_nested=opt[2]
+                    )
+                    for i in range(len(opt)):
+                        if torch.is_tensor(opt[i]):
+                            self.assertEqual(opt[i].grad_fn, None)
+                        else:
+                            self.assertEqual(opt[i]["tensor"].grad_fn, None)
+                    out.mean().backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_new_tensor_in_fwd(self):
+            return self._test_ddp_new_tensor_in_fwd(static_graph=False)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_new_tensor_in_fwd_static_graph(self):
+            return self._test_ddp_new_tensor_in_fwd(static_graph=True)
+
+        def _test_ddp_buffer_hook_allreduce(self, return_futures):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            def buffer_comm_hook(ddp, named_buffers):
+                buffers = [buffer for (_, buffer) in named_buffers.items()]
+                futs = [
+                    dist.all_reduce(
+                        buffer, group=ddp.process_group, async_op=True
+                    ).get_future()
+                    for buffer in buffers
+                ]
+                if return_futures:
+                    return futs
+                else:
+                    torch.futures.collect_all(futs).wait()
+
+            hook_pre_fwd = (
+                torch.nn.parallel.distributed._BufferCommHookLocation.PRE_FORWARD
+            )
+            hook_post_fwd = (
+                torch.nn.parallel.distributed._BufferCommHookLocation.POST_FORWARD
+            )
+            for hook_run_location in [
+                hook_pre_fwd,
+                hook_post_fwd,
+            ]:
+                model = NetWithBuffers().cuda(rank)
+                model_ddp = torch.nn.parallel.DistributedDataParallel(
+                    model,
+                    device_ids=[self.rank],
+                )
+                model_ddp._register_buffer_comm_hook(
+                    model_ddp, buffer_comm_hook, hook_run_location
+                )
+                model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
+                    copy.deepcopy(model),
+                    device_ids=[self.rank],
+                    broadcast_buffers=False,
+                )
+                inp = torch.randn(2, 10, device=rank)
+                for _ in range(2):
+                    loss_hook = model_ddp(inp).sum()
+                    # Since buffer reduction is done pre-forward, simulate it for
+                    # no hook case here.
+                    # Simulate allreduce appropriately depending on hook location.
+                    if hook_run_location == hook_pre_fwd:
+                        model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
+                        for tensor in model_no_hook_buffers:
+                            dist.all_reduce(tensor)
+
+                    loss_no_hook = model_ddp_no_hook(inp).sum()
+                    if hook_run_location == hook_post_fwd:
+                        model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
+                        for tensor in model_no_hook_buffers:
+                            dist.all_reduce(tensor)
+                    torch.cuda.synchronize()
+
+                    # if return_futures, they are only awaited on by DDP
+                    # at the end of the backwards pass for maximum overlap.
+                    if not return_futures:
+                        self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                    loss_hook.backward()
+                    loss_no_hook.backward()
+                    # Note that when custom hooks return futures, this
+                    # comparison is not expected to work when hook run location
+                    # is pre-forward pass. This is because the hook does async
+                    # communication and forward pass modifies the buffer without
+                    # appropriate synchronization. Therefore, if returning
+                    # futures from custom buffer hooks, it is advised to set
+                    # hook run location to post forward.
+                    if return_futures and hook_run_location == hook_post_fwd:
+                        self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_buffer_hook_allreduce_return_future(self):
+            self._test_ddp_buffer_hook_allreduce(return_futures=True)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_buffer_hook_allreduce(self):
+            self._test_ddp_buffer_hook_allreduce(return_futures=False)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_broadcast_buffer_via_hook(self):
+            # test that _distributed_broadcast_coalesced via registered hook is
+            # equivalent to DDP's default broadcast coalesced.
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            def buffer_comm_hook(ddp, named_buffers):
+                # named_buffers is a Dict[str, Tensor] representing a mapping
+                # from buffer name to buffer.
+                buffers = [buffer for (_, buffer) in named_buffers.items()]
+                ddp._default_broadcast_coalesced(buffers)
+
+            model = NetWithBuffers().cuda(rank)
+            model_ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            model_ddp._register_buffer_comm_hook(model_ddp, buffer_comm_hook)
+            model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model),
+                device_ids=[self.rank],
+            )
+            inp = torch.randn(2, 10, device=rank)
+            for _ in range(2):
+                loss_hook = model_ddp(inp).sum()
+                loss_no_hook = model_ddp_no_hook(inp).sum()
+                self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
+                loss_hook.backward()
+                loss_no_hook.backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_remove_autograd_hooks(self):
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    self.error = True
+                    self.fc1 = nn.Linear(10, 10).cuda(device)
+
+                def forward(self, inp):
+                    if self.error:
+                        return self.fc1(SimulateError.apply(inp))
+                    else:
+                        return self.fc1(inp)
+
+            # Run with error to trigger backward pass that marks fc1 as being marked
+            # ready. If we don't remove autograd hooks before running below it would
+            # fail on the old autograd hook.
+            model = MyModel(self.rank)
+            input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
+            model_ddp1 = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+
+            with self.assertRaises(RuntimeError):
+                model_ddp1(input).sum().backward()
+
+            # Remove autograd hooks on old instance.
+            model_ddp1._remove_autograd_hooks()
+
+            # Try another DDP instance without error now.
+            model.error = False
+            model_ddp2 = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            model_ddp2(input).sum().backward()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        @unittest.skip(
+            "Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751"
+        )
+        def test_ddp_has_finalized(self):
+            @dataclass
+            class MyClass:
+                obj: torch.Tensor
+
+            class MyModel(nn.Module):
+                def __init__(self, rank):
+                    super().__init__()
+                    self.rank = rank
+                    self.fc1 = nn.Linear(1024, 1024).cuda(rank)
+                    self.fc2 = nn.Linear(1024, 2 * 1024).cuda(rank)
+
+                def forward(self, inp):
+                    if self.rank == 0:
+                        return self.fc1(inp), MyClass(self.fc2(inp))
+                    else:
+                        return self.fc1(inp), self.fc2(inp)
+
+            model = MyModel(self.rank)
+            input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=(1024 * 4 / 1024 / 1024),  # One bucket per parameter.
+            )
+
+            if self.rank == 0:
+                out1, _ = ddp(input)
+                out1.sum().backward()
+            else:
+                out1, out2 = ddp(input)
+                (out1.sum() + out2.sum()).backward()
+
+            if self.rank == 0:
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Expected to have finished reduction in the prior iteration",
+                ):
+                    ddp._check_reducer_finalized()
+
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Expected to have finished reduction in the prior iteration",
+                ):
+                    ddp(input)
+            else:
+                ddp._check_reducer_finalized()
+                ddp(input)
+
+        """
+        # The set of "test_ddp_update_process_group..." below failed after
+        # upgrading CI from 2 GPUs to 4 GPUs.
+        # Commented out for now.
+        # Test purpose needs better documentation.
+
+        def _run_ddp_update_process_group(self, new_pg):
+            def get_num_torch_recompiles():
+                guard_failures = torch._dynamo.utils.guard_failures
+                num_recompiles = [len(guard_failures[code]) for code in guard_failures]
+                return 0 if len(num_recompiles) == 0 else max(num_recompiles)
+
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(torch.nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    # 4MB for multiple buckets.
+                    self.fc1 = torch.nn.Linear(1024, 1024).cuda(device)
+                    self.fc2 = torch.nn.Linear(1024, 1024).cuda(device)
+                    self.fc3 = torch.nn.Linear(1024, 1024).cuda(device)
+
+                def forward(self, inp, error):
+                    if error:
+                        return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
+                    else:
+                        return self.fc3(self.fc2(self.fc1(inp)))
+
+
+            input = torch.rand(10, 1024, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                MyModel(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=1,
+            )
+            model = torch.compile(ddp)
+
+            def run_iteration():
+                # Run regular iteration.
+                out = model(input, error=False)
+                out.sum().backward()
+                torch.cuda.synchronize()
+
+                # Run with error.
+                with self.assertRaises(RuntimeError):
+                    out = model(input, error=True)
+                    out.sum().backward()
+                torch.cuda.synchronize()
+
+            run_iteration()
+            assert 0 == get_num_torch_recompiles()
+
+            if new_pg:
+                # Now reduce world_size and run iteration.
+                group_size_2 = dist.new_group(ranks=[0, 1])
+                ddp._update_process_group(group_size_2)
+                if self.rank in [0, 1]:
+                    run_iteration()
+
+                # Increase the world size and run iteration.
+                group_size_3 = dist.new_group(ranks=[1, 2, 3])
+                ddp._update_process_group(group_size_3)
+                if self.rank in [1, 2, 3]:
+                    run_iteration()
+
+                # Back to default size.
+                ddp._update_process_group(_get_default_group())
+                run_iteration()
+            else:
+                # Create default pg of smaller size.
+                dist.destroy_process_group()
+
+                if self.rank in [1, 2, 3]:
+                    dist.init_process_group(
+                        init_method=self.init_method,
+                        backend=BACKEND,
+                        world_size=3,
+                        rank=self.rank - 1,
+                        timeout=timedelta(seconds=default_pg_timeout),
+                    )
+                    ddp._update_process_group(_get_default_group())
+                    run_iteration()
+                    dist.destroy_process_group()
+
+                # Need a barrier here to ensure ranks 1, 2 and 3 are done.
+                self._barrier(wait_for=4)
+
+                # Need to init pg again for "_barrier" to succeed.
+                dist.init_process_group(
+                    init_method=self.init_method,
+                    backend=BACKEND,
+                    world_size=4,
+                    rank=self.rank,
+                    timeout=timedelta(seconds=default_pg_timeout),
+                )
+
+            # Validate no more recompiles.
+            assert 0 == get_num_torch_recompiles()
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_new_group(self):
+            self._run_ddp_update_process_group(new_pg=True)
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_default_group(self):
+            self._run_ddp_update_process_group(new_pg=False)
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_grad_undefined(self):
+            class SimulateError(torch.autograd.Function):
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    raise RuntimeError
+
+            class MyModel(torch.nn.Module):
+                def __init__(self, device):
+                    super().__init__()
+                    self.fc1 = torch.nn.Linear(10, 10).cuda(device)
+                    self.fc2 = torch.nn.Linear(10, 10).cuda(device)
+                    self.fc3 = torch.nn.Linear(10, 10).cuda(device)
+
+                def forward(self, inp, error):
+                    if error:
+                        return self.fc3(self.fc2(self.fc1(SimulateError.apply(inp))))
+                    else:
+                        return self.fc2(self.fc1(inp))
+
+
+            input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                MyModel(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+                bucket_cap_mb=1,
+            )
+
+            try:
+                ddp(input, True).sum().backward()
+            except RuntimeError:
+                ddp._update_process_group(_get_default_group())
+
+            # Reset grads.
+            for param in ddp.parameters():
+                param.grad = None
+
+            # Run ddp again.
+            ddp(input, False).sum().backward()
+
+        @skip_if_lt_x_gpu(4)
+        @require_world_size(4)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_update_process_group_no_find_unused(self):
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                torch.nn.Linear(10, 10).cuda(self.rank),
+                device_ids=[self.rank],
+                find_unused_parameters=False,
+            )
+            ddp._update_process_group(_get_default_group())
+        """
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_broadcast_buffer(self):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            class NetWithBuffers(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.a = nn.Linear(10, 10, bias=False)
+                    self.b = nn.Linear(10, 1, bias=False)
+                    self.register_buffer("buffer", torch.randn(1, 2))
+
+                def forward(self, x):
+                    return self.b(self.a(x))
+
+            model = NetWithBuffers().cuda(rank)
+            model_ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            inp = torch.randn(2, 10, device=rank)
+            for _ in range(2):
+                if rank == 0:
+                    model_ddp.module.buffer = model_ddp.module.buffer + 1
+                loss = model_ddp(inp).sum()
+                loss.backward()
+                # Ensure all buffers are synchronized.
+                bufs = [
+                    torch.empty_like(model_ddp.module.buffer)
+                    for _ in range(dist.get_world_size())
+                ]
+                dist.all_gather(bufs, model_ddp.module.buffer)
+                rank_0_buf = bufs[0]
+                for buf in bufs[1:]:
+                    self.assertEqual(rank_0_buf, buf)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" and BACKEND != "gloo",
+            "Only Nccl & Gloo backend support DistributedDataParallel",
+        )
+        def test_static_graph_multi_forward(self):
+            class Net(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.lin = nn.Linear(10, 10)
+                    self.relu = nn.ReLU()
+
+                def forward(self, x):
+                    return self.relu(self.lin(x))
+
+            torch.cuda.set_device(self.rank)
+            torch.manual_seed(42 << 1337 % (self.rank + 1))
+            model = Net().cuda(self.rank)
+            local_model = copy.deepcopy(model)
+            model = torch.nn.parallel.DistributedDataParallel(
+                model, device_ids=[self.rank], static_graph=True
+            )
+            inp = torch.ones(2, 10, device="cuda")
+            for _ in range(3):
+                model.zero_grad()
+                local_model.zero_grad()
+                a = model(inp)
+                b = model(inp)
+                loss = a.sum() + b.sum()
+                loss.backward()
+                # Grads should be equal to a local model that ran through inp
+                # `world_size` times and averaged grads
+                if self.rank == 0:
+                    inp_clone = inp.clone()
+                    iters = dist.get_world_size()
+                    for _ in range(iters):
+                        a = local_model(inp_clone)
+                        b = local_model(inp_clone)
+                        loss = a.sum() + b.sum()
+                        loss.backward()
+
+                    for p in local_model.parameters():
+                        p.grad.data = p.grad / iters
+
+                    for p_ddp, p_local in zip(
+                        model.parameters(), local_model.parameters()
+                    ):
+                        self.assertTrue(
+                            torch.allclose(p_ddp.grad, p_local.grad),
+                            f"{p_ddp.grad} vs {p_local.grad}",
+                        )
+
+            dist.barrier()
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND != "nccl" and BACKEND != "gloo",
+            "Only Nccl & Gloo backend support DistributedDataParallel",
+        )
+        def test_sync_bn_logged(self):
+            model = BN_NET
+            rank = self.rank
+            # single gpu training setup
+            model_gpu = model.cuda(rank)
+            no_sync_bn = torch.nn.parallel.DistributedDataParallel(
+                copy.deepcopy(model_gpu),
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = no_sync_bn._get_ddp_logging_data()
+            sync_bn_logged = ddp_logging_data.get("has_sync_bn", True)
+            self.assertFalse(sync_bn_logged)
+            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(model_gpu)
+            model_DDP = torch.nn.parallel.DistributedDataParallel(
+                model_DDP,
+                device_ids=[self.rank],
+            )
+            ddp_logging_data = model_DDP._get_ddp_logging_data()
+            sync_bn_logged = ddp_logging_data.get("has_sync_bn", False)
+            self.assertTrue(sync_bn_logged)
+
+        @skip_if_lt_x_gpu(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_stateless_api_with_ddp(self):
+            class MockModule(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.l1 = torch.nn.Linear(1, 1)
+                    buffer = torch.ones(1)
+                    self.register_buffer("buffer", buffer)
+
+                def forward(self, x):
+                    return self.l1(x) + self.buffer
+
+            device = self.rank
+            module = MockModule().to(device)
+            module = torch.nn.parallel.DistributedDataParallel(
+                module, device_ids=[device]
+            )
+            x = torch.rand((1, 1)).to(device)
+            weight = torch.tensor([[1.0]], device=device, requires_grad=True)
+            bias = torch.tensor([0.0], device=device, requires_grad=True)
+            buffer = torch.tensor([0.0], device=device)
+            parameters = {
+                "module.l1.weight": weight,
+                "module.l1.bias": bias,
+                "module.buffer": buffer,
+            }
+            prev_weight = module.module.l1.weight.clone()
+            prev_buffer = module.module.buffer.clone()
+
+            res = torch.func.functional_call(module, parameters, x)
+            self.assertEqual(x, res)
+            # check that the weight remain unmodified
+            cur_weight = module.module.l1.weight
+            cur_buffer = module.module.buffer
+            self.assertEqual(cur_weight, prev_weight)
+            self.assertEqual(cur_buffer, prev_buffer)
+            # run a backward pass and check the gradients
+            res.backward()
+            self.assertIsNotNone(weight.grad)
+            self.assertIsNotNone(bias.grad)
+            # Gradient was not calculated for the module stated and buffers
+            self.assertIsNone(buffer.grad)
+            self.assertIsNone(module.module.l1.weight.grad)
+            self.assertIsNone(module.module.l1.bias.grad)
+            self.assertIsNone(module.module.buffer.grad)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_forward_backward_hook(self):
+            class DummyTestModel(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    torch.manual_seed(0)
+                    self.fc = nn.Linear(2, 2)
+
+                def forward(self, x):
+                    return self.fc(x)
+
+            def relu_hook(module, input):
+                return nn.functional.relu(input[0])
+
+            def gelu_hook(module, _input, output):
+                return nn.functional.gelu(output)
+
+            def celu_hook(module, _input, output):
+                return (nn.functional.celu(output[0]),)
+
+            local_model = DummyTestModel()
+            ddp_model = DummyTestModel()
+            local_model.fc.register_forward_pre_hook(relu_hook)
+            local_model.fc.register_forward_hook(gelu_hook)
+            ddp_model.fc.register_forward_pre_hook(relu_hook)
+            ddp_model.fc.register_forward_hook(gelu_hook)
+            local_model.fc.register_backward_hook(celu_hook)
+            ddp_model.fc.register_backward_hook(celu_hook)
+            ddp_model = DistributedDataParallel(
+                ddp_model.to(self.rank), device_ids=[self.rank]
+            )
+            input_data = torch.rand(5, 2)
+            output_local = local_model(input_data)
+            output_ddp = ddp_model(input_data.to(self.rank))
+            self.assertEqual(output_local, output_ddp)
+            output_local.sum().backward()
+            output_ddp.sum().backward()
+            ddp_grads = [p.grad for p in ddp_model.parameters()]
+            self.assertEqual(ddp_grads[0], local_model.fc.weight.grad)
+            self.assertEqual(ddp_grads[1], local_model.fc.bias.grad)
+
+        def _test_hook_pickling(self, hook, hook_state):
+            torch.manual_seed(0)
+            learning_rate = 0.01
+            chkpt_file = tempfile.gettempdir() + "/checkpoint.pt"
+            rank = self.rank
+
+            input = torch.randn(7, 1, device=rank)
+            target = torch.randn(7, 5, device=rank)
+            net = torch.nn.Linear(1, 5).to(rank)
+            ddp_model = DistributedDataParallel(copy.deepcopy(net), device_ids=[rank])
+            dummy_ddp_model = DistributedDataParallel(
+                copy.deepcopy(net), device_ids=[rank]
+            )
+            optimizer = torch.optim.SGD(ddp_model.parameters(), lr=learning_rate)
+            ddp_model.register_comm_hook(hook_state, hook)
+            ddp_model.train()
+
+            for _ in range(10):
+                optimizer.zero_grad()
+                out = ddp_model(input)
+                loss = F.mse_loss(out, target)
+                loss.backward()
+                optimizer.step()
+
+            state = {
+                "state_dict": ddp_model.state_dict(),
+                "comm_hook": hook,
+                "comm_hook_state": hook_state,
+            }
+
+            if rank == 0:
+                with self.assertLogs("torch.distributed") as captured:
+                    torch.save(state, chkpt_file)
+
+                # Check that the logger has only one entry
+                self.assertEqual(len(captured.records), 1)
+                # Check that the logger has an expected entry
+                self.assertEqual(
+                    captured.records[0].getMessage(),
+                    "NOTE: Process group is not serializable and excluded from a saved state.",
+                )
+
+            dist.barrier()
+            map_location = {"cuda:0": f"cuda:{rank:d}"}
+            with self.assertLogs("torch.distributed") as captured:
+                checkpoint = torch.load(chkpt_file, map_location=map_location)
+
+            # Check that the logger has only one entry
+            self.assertEqual(len(captured.records), 1)
+            # Check that the logger has an expected entry
+            self.assertEqual(
+                captured.records[0].getMessage(),
+                "NOTE: Process group will be set to a default group (i.e. the world size).\
+                If a different group is desired, please set `self.process_group` after PowerSGD state is loaded.",
+            )
+
+            dummy_ddp_model.load_state_dict(checkpoint["state_dict"])
+            dummy_hook = checkpoint["comm_hook"]
+            dummy_hook_state = checkpoint["comm_hook_state"]
+            dummy_optimizer = torch.optim.SGD(
+                dummy_ddp_model.parameters(), lr=learning_rate
+            )
+
+            # Check that loaded function is correct
+            self.assertEqual(dummy_hook.__qualname__, hook.__qualname__)
+
+            # Check that all slots' keys were restored correctly
+            self.assertEqual(hook_state.__slots__, dummy_hook_state.__slots__)
+
+            # Check that all slots' attributes are restored correctly
+            # Excluding ``process_group`` and ``rng``.
+            for entry in dummy_hook_state.__slots__:
+                if entry != "process_group" and entry != "rng":
+                    self.assertEqual(
+                        getattr(dummy_hook_state, entry), getattr(hook_state, entry)
+                    )
+
+            # Check that ``process_group`` was set to default
+            self.assertEqual(dummy_hook_state.process_group, _get_default_group())
+
+            # Check that a random state was restored properly:
+            # ``np.random.RandomState.get_state`` returns a tuple with entries:
+            # ``bit_generator`` - str,
+            # ``state.key`` - ndarray dtype[uint32],
+            # ``state.pos`` - int,
+            # ``has_gauss`` - int,
+            # ``gauss`` - float
+            #  (refer to https://github.com/numpy/numpy/blob/266aad7478bc7fbcc55eea7f942a0d373b838396/numpy/random/mtrand.pyi)
+            # To make sure random state was restored properly, all entries should equal the original
+            for entry1, entry2 in zip(
+                hook_state.rng.get_state(), dummy_hook_state.rng.get_state()
+            ):
+                np.testing.assert_array_equal(entry1, entry2)
+
+            dummy_ddp_model.register_comm_hook(dummy_hook_state, dummy_hook)
+            dummy_ddp_model.train()
+
+            for _ in range(10):
+                optimizer.zero_grad()
+                dummy_optimizer.zero_grad()
+                out_origin = ddp_model(input)
+                out_dummy = dummy_ddp_model(input)
+                loss_origin = F.mse_loss(out_origin, target)
+                loss_dummy = F.mse_loss(out_dummy, target)
+                loss_origin.backward()
+                loss_dummy.backward()
+                optimizer.step()
+                dummy_optimizer.step()
+
+            # Check that gradients after 10 epochs are the same
+            for orig_param, dummy_param in zip(
+                ddp_model.parameters(), dummy_ddp_model.parameters()
+            ):
+                self.assertEqual(orig_param.grad, dummy_param.grad)
+
+            dist.barrier()
+            if rank == 0:
+                os.remove(chkpt_file)
+
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["cuda"],
+            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
+        )
+        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
+        @skip_but_pass_in_sandcastle_if(True, "Skipped due to flakiness")
+        def test_ddp_hook_pickling_powerSGD(self):
+            hook = powerSGD.powerSGD_hook
+            powersgd_state = powerSGD.PowerSGDState(
+                process_group=None,
+                matrix_approximation_rank=1,
+                start_powerSGD_iter=4,
+            )
+            self._test_hook_pickling(hook, powersgd_state)
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_ddp_device_mesh_initialization(self):
+            """
+            Test DDP with device_mesh initialization.
+            """
+            world_size = int(os.environ["WORLD_SIZE"])
+
+            from torch.distributed.device_mesh import init_device_mesh
+
+            device_mesh = init_device_mesh("cuda", (world_size,))
+
+            pg = _get_default_group()
+
+            torch.cuda.set_device(self.rank)
+            model = TwoLinLayerNet().cuda()
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model, device_mesh=device_mesh
+            )
+            self.assertEqual(ddp_model.device_mesh, device_mesh)
+
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "Cannot specify both process_group and device_mesh arguments.",
+            ):
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    model, process_group=pg, device_mesh=device_mesh
+                )
+
+            with self.assertRaisesRegex(
+                RuntimeError, "Only 1D device mesh is supported,"
+            ):
+                device_mesh = init_device_mesh("cuda", (2, world_size // 2))
+                ddp_model = torch.nn.parallel.DistributedDataParallel(
+                    model, device_mesh=device_mesh
+                )
+
+        @skip_if_lt_x_gpu(2)
+        @require_world_size(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_compile_static_graph(self):
+            "Tests that DDP works with torch compile when static_graph=True"
+            model = torch.nn.Linear(10, 10).cuda(self.rank)
+            model_clone = copy.deepcopy(model)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            ddp_static = torch.nn.parallel.DistributedDataParallel(
+                model_clone, device_ids=[self.rank], static_graph=True
+            )
+            ddp = torch.compile(ddp)
+            ddp_static = torch.compile(ddp_static)
+            input = torch.rand(10, 10).cuda(self.rank)
+            # verify output and gradient parity
+            for _ in range(6):
+                out_ddp = ddp(input).sum()
+                out_ddp_static = ddp_static(input).sum()
+                self.assertEqual(out_ddp, out_ddp_static)
+                out_ddp.backward()
+                out_ddp_static.backward()
+                for p1, p2 in zip(ddp.parameters(), ddp_static.parameters()):
+                    self.assertEqual(p1.grad, p2.grad)
+
+        @skip_if_lt_x_gpu(2)
+        @require_world_size(2)
+        @skip_but_pass_in_sandcastle_if(
+            BACKEND not in DistTestCases.backend_feature["ddp"],
+            f"The {BACKEND} backend does not support DistributedDataParallel",
+        )
+        def test_ddp_sink_noclone(self):
+            "Tests that we can configure DDP to avoid clone"
+
+            class OpPatcher(TorchDispatchMode):
+                def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+                    func_packet = func._overloadpacket
+                    if func_packet == torch.ops.aten.clone:
+                        raise RuntimeError("clone encountered!")
+                    kwargs = kwargs if kwargs else {}
+                    return func(*args, **kwargs)
+
+            class MyModel(torch.nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc = torch.nn.Linear(10, 10)
+
+                def forward(self, input):
+                    return self.fc(input)
+
+            model = MyModel().cuda(self.rank)
+            ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+                find_unused_parameters=True,
+            )
+            ddp._set_ddp_sink_clone(False)
+            input = torch.rand(10, 10).cuda(self.rank)
+
+            with OpPatcher():
+                ddp(input).sum().backward()
+
+        def _test_skip_all_reduce_unused_parameters(
+            self,
+            find_unused_parameters=False,
+            static_graph=False,
+            skip_all_reduce_unused_params=False,
+        ):
+            class LargeNet(nn.Module):
+                def __init__(self) -> None:
+                    super().__init__()
+                    self.fc1 = nn.Linear(100, 5000, bias=False)
+                    # fc2 is unused
+                    self.fc2 = nn.Linear(100, 100, bias=False)
+
+                def forward(self, x):
+                    y = self.fc1(x)
+                    return y
+
+            torch.manual_seed(31415)
+            torch.cuda.set_device(self.rank)
+            model = LargeNet().cuda(self.rank)
+            ddp_model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                find_unused_parameters=find_unused_parameters,
+                static_graph=static_graph,
+                bucket_cap_mb=1.5,
+                skip_all_reduce_unused_params=skip_all_reduce_unused_params,
+            )
+            random_input = torch.randn(20, 100, device=self.rank)
+            for _ in range(10):
+                out = ddp_model(random_input)
+                loss = out.sum()
+                loss.backward()
+            return ddp_model
+
+        @require_backend_is_available(DistTestCases.backend_feature["gpu"])
+        @skip_if_lt_x_gpu(2)
+        def test_skip_all_reduce_unused_parameters(self):
+            base_model = self._test_skip_all_reduce_unused_parameters(
+                find_unused_parameters=True, static_graph=False
+            )
+            test_model_1 = self._test_skip_all_reduce_unused_parameters(
+                find_unused_parameters=True,
+                static_graph=False,
+                skip_all_reduce_unused_params=True,
+            )
+
+            self.assertEqual(
+                base_model._get_ddp_logging_data().get("num_buckets_reduced"), 2
+            )
+            self.assertEqual(
+                test_model_1._get_ddp_logging_data().get("num_buckets_reduced"), 1
+            )
+
+            for i, j in zip(base_model.parameters(), test_model_1.parameters()):
+                self.assertEqual(i, j)
+
+
+instantiate_parametrized_tests(DistributedTest._DistTestBase)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..502a300493cea5aa7d1d244ecf67970faff9ece3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/distributed_utils.py
@@ -0,0 +1,70 @@
+# mypy: allow-untyped-defs
+
+from contextlib import contextmanager
+from datetime import timedelta
+from functools import partial, wraps
+
+import torch.distributed as dist
+import torch.distributed.distributed_c10d as c10d
+
+
+class MockProcessGroup(dist.ProcessGroup):
+    def __init__(self, rank, world):
+        super().__init__(rank, world)
+
+    def getBackendName(self):
+        return "mock_process_group"
+
+
+def create_mock_pg(prefix_store, rank, world_size, timeout):
+    return MockProcessGroup(rank, world_size)
+
+
+dist.Backend.register_backend("mock_process_group", create_mock_pg)
+
+
+def mock_init_dist(rank, world_size):
+    # !!! WARNING !!!
+    # Kids don't try this at home, this is a cute pile of hacks that
+    # depends on a small mountain of c10d internals
+    assert not dist.is_initialized()
+    store = dist.HashStore()
+    # Trick _store_based_barrier into believing everyone else already checked-in
+    # Zero is the group index
+    store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1)
+    dist.init_process_group(
+        backend="mock_process_group",
+        rank=rank,
+        world_size=world_size,
+        store=store,
+        group_name="fake",
+        timeout=timedelta(seconds=1),
+    )
+
+
+@contextmanager
+def with_dist(rank=0, world_size=2):
+    """
+    Context manager that initializer c10d with a fake process group.
+    """
+    mock_init_dist(rank=rank, world_size=world_size)
+    try:
+        yield
+    finally:
+        dist.destroy_process_group()
+
+
+def with_fake_comms(func=None, rank=0, world_size=2):
+    """
+    Function wrapper that inits a fake process group designed for testing.
+    Right now only querying for world size is available
+    """
+    if func is None:
+        return partial(with_fake_comms, rank=rank, world_size=world_size)
+
+    @wraps(func)
+    def wrapper(self, *args, **kwargs):
+        with with_dist(rank, world_size):
+            func(self, *args, **kwargs)
+
+    return wrapper
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..a34ee75cf600e49e4af8247445e303a669ac1c12
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/fake_pg.py
@@ -0,0 +1,28 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed as dist
+from torch._C._distributed_c10d import FakeProcessGroup
+
+
+class FakeStore(dist.Store):
+    """
+    A fake store is a fake Key-Value store simply for initialization usage
+    the of fake process group, one can either use FakeStore or HashStore.
+    """
+
+
+def _create_fake_pg(prefix_store, rank, world_size, timeout):
+    """
+    A fake process group (not related to FakeTensor) is a process group which
+    doesn't actually do any communication, it just hallucinates some
+    communication.  You can run a single rank with a fake process group
+    without needing multiple processes (simulates per-rank behavior)
+
+    NOTE: This is not a real process group, and it would produce wrong results
+    for every collective. It should be used as a convenient tool when playing
+    with distributed but don't care about the actual data.
+    """
+    return FakeProcessGroup(rank, world_size)
+
+
+dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda", "hpu"])
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..8de13414dd475f30d3c7f1b50cb00b1cd3322e52
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/multi_threaded_pg.py
@@ -0,0 +1,573 @@
+# mypy: allow-untyped-defs
+
+import sys
+import threading
+import weakref
+from dataclasses import dataclass
+from functools import partial, reduce
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch._C._distributed_c10d import (
+    _create_work_from_future,
+    AllgatherOptions,
+    AllreduceOptions,
+    AllToAllOptions,
+    BarrierOptions,
+    BroadcastOptions,
+    ReduceOp,
+    ReduceScatterOptions,
+    ScatterOptions,
+    Store,
+)
+from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp
+from torch.futures import Future
+from torch.utils import _pytree as pytree
+
+
+"""
+TODO:
+Lots of missing collectives.
+Collectives validation.
+Make timeout robust by making collectives respect the test deadline.
+Make tests robust by making collectives interruptible.
+We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures.
+
+"""
+
+
+def flatten_list(lst):
+    return pytree.tree_leaves(lst)
+
+
+def ret_work(ret):
+    fut = Future()
+    fut.set_result(ret)
+    return _create_work_from_future(fut)
+
+
+def binop_reduce(tensors, op):
+    res = op(torch.stack(tensors), dim=0)
+    if isinstance(res, torch.Tensor):
+        return res
+    # min/max return a namedtuple
+    return res.values
+
+
+def bitwise_reduce(tensors, op):
+    return reduce(op, tensors)
+
+
+_reduce_ops = {
+    ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
+    ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
+    ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod),
+    ReduceOp.MIN: partial(binop_reduce, op=torch.min),
+    ReduceOp.MAX: partial(binop_reduce, op=torch.max),
+    ReduceOp.BAND: partial(bitwise_reduce, op=torch.bitwise_and),
+    ReduceOp.BOR: partial(bitwise_reduce, op=torch.bitwise_or),
+    ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor),
+}
+
+
+class AllToAll:
+    @torch.no_grad()
+    def work(self, data):
+        world_size = len(data)
+        for dest_rank in range(world_size):
+            output_tensor_list, _ = data[dest_rank]
+            for src_rank in range(world_size):
+                _, input_tensor_list = data[src_rank]
+                output_tensor_list[src_rank].copy_(input_tensor_list[dest_rank])
+
+
+class AllToAllBase:
+    @torch.no_grad()
+    def work(self, data):
+        world_size = len(data)
+        for dest_rank in range(world_size):
+            output_buffer, _, output_split_sizes, _ = data[dest_rank]
+
+            output_indexes = self._size_cumsum(
+                output_buffer.size(0), output_split_sizes, world_size
+            )
+
+            for src_rank in range(world_size):
+                _, input_buffer, _, input_split_sizes = data[src_rank]
+                input_indexes = self._size_cumsum(
+                    input_buffer.size(0), input_split_sizes, world_size
+                )
+
+                output_buffer[
+                    output_indexes[src_rank] : output_indexes[src_rank + 1]
+                ].copy_(
+                    input_buffer[
+                        input_indexes[dest_rank] : input_indexes[dest_rank + 1]
+                    ]
+                )
+
+    def _size_cumsum(
+        self,
+        buf_size: int,
+        sizes: Union[torch.Tensor, list[int], None],
+        world_size: int,
+    ) -> torch.Tensor:
+        if sizes is None or len(sizes) == 0:
+            sizes = torch.full((world_size,), buf_size // world_size, dtype=torch.int64)
+        if not isinstance(sizes, torch.Tensor):
+            sizes = torch.tensor(sizes, dtype=torch.int64)
+        assert sizes.dtype == torch.int64
+        sizes = torch.cumsum(
+            torch.cat(
+                (torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes),
+                dim=0,
+            ),
+            dim=0,
+        )
+        return sizes
+
+
+class AllReduce:
+    def __init__(self, op):
+        if op.op not in _reduce_ops:
+            raise NotImplementedError(
+                f"AllReduce op {op.op} not supported on multithreaded pg for now."
+            )
+        self.op = op.op
+
+    @torch.no_grad()
+    def work(self, data):
+        for i in range(len(data[0])):
+            # use rank0 as the device for sum
+            rank_0_device = data[0][i].device
+            # collect all data to the list and make them
+            # all on rank 0 device
+            tensors = [
+                data[src_rank][i].to(rank_0_device) for src_rank in range(0, len(data))
+            ]
+
+            # now mimic reduce across all ranks
+            res = _reduce_ops[self.op](tensors)
+
+            # copy all the reduced value to each rank
+            for src_rank in range(len(data)):
+                data[src_rank][i].copy_(res.to(data[src_rank][i].device))
+
+
+class AllGather:
+    @torch.no_grad()
+    def work(self, data):
+        for src_rank in range(len(data)):
+            in_tensor_list = data[src_rank][1]
+            # Can't handle all_gather with multiple tensors
+            assert len(in_tensor_list) == 1
+            src_tensor = in_tensor_list[0]
+
+            for dest in data:
+                dest_tensor = dest[0][0][src_rank]
+                dest_tensor.copy_(src_tensor)
+
+
+class Scatter:
+    def __init__(self, src):
+        self.src = src
+
+    @torch.no_grad()
+    def work(self, data):
+        src_in_tensor_list = data[self.src][1]
+        # Can't handle scatter with multiple input tensor list
+        assert len(src_in_tensor_list) == 1
+        src_in_tensors = src_in_tensor_list[0]
+
+        for rank, each_rank_data in enumerate(data):
+            out_tensor_list = each_rank_data[0]
+            # Can't handle scatter with multiple output tensor
+            assert len(out_tensor_list) == 1
+            dest_tensor = out_tensor_list[0]
+            dest_tensor.copy_(src_in_tensors[rank])
+
+
+class Gather:
+    def __init__(self, dst):
+        self.dst = dst
+
+    @torch.no_grad()
+    def work(self, data):
+        # Can't handle gather with multiple tensor lists
+        assert len(data[self.dst][0]) == 1
+        out_tensor_list = data[self.dst][0][0]
+        for rank, each_rank_data in enumerate(data):
+            src_in_tensor_list = each_rank_data[1]
+            # Can't handle gather with multiple tensor lists
+            assert len(src_in_tensor_list) == 1
+            dest_tensor = out_tensor_list[rank]
+            dest_tensor.copy_(src_in_tensor_list[0])
+
+
+class ReduceScatter:
+    def __init__(self, op):
+        if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG:
+            raise NotImplementedError(f"ReduceScatter does not support {op}")
+        self.op = op
+
+    @torch.no_grad()
+    def work(self, data):
+        start_reduction = [False for _ in range(len(data))]
+        for each_rank_data in data:
+            # Can't handle reduce_scatter with multiple scatter list
+            assert len(each_rank_data[1]) == 1
+            to_scatter = each_rank_data[1][0]
+            for i in range(len(to_scatter)):
+                dest_tensor_on_rank_i = data[i][0]
+                # Can't handle reduce_scatter with multiple output tensor
+                assert len(dest_tensor_on_rank_i) == 1
+                dst_tensor_device = dest_tensor_on_rank_i[0].device
+                if not start_reduction[i]:
+                    dest_tensor_on_rank_i[0].copy_(to_scatter[i].to(dst_tensor_device))
+                    start_reduction[i] = True
+                else:
+                    dest_tensor_on_rank_i[0].add_(to_scatter[i].to(dst_tensor_device))
+        if self.op == dist.ReduceOp.AVG:
+            num_ranks = len(data)
+            for each_rank_data in data:
+                each_rank_data[0][0] /= num_ranks
+
+
+class Broadcast:
+    def __init__(self, src):
+        self.src = src
+
+    @torch.no_grad()
+    def work(self, data):
+        in_tensor_list = flatten_list(data[self.src])
+        for i in range(len(data)):
+            out_tensor_list = flatten_list(data[i])
+            for j in range(len(in_tensor_list)):
+                out_tensor_list[j].copy_(in_tensor_list[j])
+
+
+class Collective:
+    def __init__(self, world_size, collective, pg):
+        self._world_size = world_size
+        self._collective = collective
+
+        self._start_cond = threading.Condition()
+        self._done_cond = threading.Condition()
+
+        self._data = [None] * world_size
+        self._count = 0
+        self._done = False
+
+        self._pg = pg
+
+    def join(self, rank, data):
+        with self._start_cond:
+            self._data[rank] = data
+            self._count += 1
+
+            # notify rank 0
+            if self._count == self._world_size:
+                if rank > 0:
+                    self._start_cond.notify()
+
+            if rank == 0:
+                self._start_cond.wait_for(
+                    lambda: self._count == self._world_size
+                    or self._pg._terminate.is_set()
+                )
+                # SystemExit is not a subclass of Exception but BaseException
+                # and can be distinguished from normal exception raised from program errors
+                # so that we can hide it from the exception queue
+                if self._pg._terminate.is_set():
+                    sys.exit("Test termination event occurs.")
+
+        with self._done_cond:
+            # wait for rank 0 to finish
+            if rank > 0:
+                self._done_cond.wait_for(
+                    lambda: self._done or self._pg._terminate.is_set()
+                )
+                if self._pg._terminate.is_set():
+                    sys.exit("Test termination event occurs.")
+            else:
+                # copy data around
+                self._collective.work(self._data)
+                self._done = True
+                self._done_cond.notify_all()
+        return ret_work(data)
+
+
+class ProcessLocalGroup(dist.ProcessGroup):
+    _coll_lock = threading.Lock()
+    _cur_coll_on_pgs = {}
+
+    _terminate = threading.Event()
+
+    @classmethod
+    def _start_coll(cls, collective, pg):
+        with cls._coll_lock:
+            # pg_name is unique, we use that to record the mapping between pg and collective
+            if pg.pg_name not in cls._cur_coll_on_pgs:
+                cls._cur_coll_on_pgs[pg.pg_name] = Collective(
+                    pg.size(), collective, cls
+                )
+            return cls._cur_coll_on_pgs[pg.pg_name]
+
+    @classmethod
+    def _end_coll(cls, collective, pg):
+        # This is racily called by all ranks, so only one will work
+        with cls._coll_lock:
+            if (
+                pg.pg_name in cls._cur_coll_on_pgs
+                and cls._cur_coll_on_pgs[pg.pg_name] == collective
+            ):
+                cls._cur_coll_on_pgs.pop(pg.pg_name)
+
+    @classmethod
+    def exception_handle(cls, exc):
+        cls._terminate.set()
+        for coll in cls._cur_coll_on_pgs.values():
+            with coll._start_cond:
+                coll._start_cond.notify()
+            with coll._done_cond:
+                coll._done_cond.notify_all()
+
+    @classmethod
+    def reset(cls):
+        with cls._coll_lock:
+            cls._cur_coll_on_pgs = {}
+            cls._terminate.clear()
+
+    def alltoall_base(
+        self,
+        output_buffer: torch.Tensor,
+        input_buffer: torch.Tensor,
+        output_split_sizes: Optional[list[int]],
+        input_split_sizes: Optional[list[int]],
+        opts=AllToAllOptions(),
+    ) -> torch.Tensor:
+        coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
+        res = coll.join(
+            self._rank,
+            (output_buffer, input_buffer, output_split_sizes, input_split_sizes),
+        )
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def alltoall(self, output_tensor_list, input_tensor_list, opts=AllToAllOptions()):
+        coll = ProcessLocalGroup._start_coll(AllToAll(), self)
+        res = coll.join(self._rank, (output_tensor_list, input_tensor_list))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def allreduce(self, tensor_list, opts=AllreduceOptions()):
+        coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def allreduce_coalesced(self, tensor_list, opts=AllreduceOptions()):
+        coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def barrier(self, opts=BarrierOptions()):
+        return self.allreduce(tensor_list=[torch.ones(1)])
+
+    def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()):
+        coll = ProcessLocalGroup._start_coll(AllGather(), self)
+        res = coll.join(self._rank, (output_tensors, input_tensor))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def _allgather_base(self, output_tensor, input_tensor, opts=AllgatherOptions()):
+        tensor_list = list(torch.chunk(output_tensor, self._world_size))
+        return self.allgather([tensor_list], [input_tensor], opts)
+
+    def broadcast(self, tensor_list, opts=BroadcastOptions()):
+        coll = ProcessLocalGroup._start_coll(Broadcast(opts.rootRank), self)
+        res = coll.join(self._rank, tensor_list)
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(Scatter(opts.rootRank), self)
+        res = coll.join(self._rank, (output_tensors, input_tensors))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def gather(self, output_tensors, input_tensors, opts=ScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(Gather(opts.rootRank), self)
+        res = coll.join(self._rank, (output_tensors, input_tensors))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()):
+        coll = ProcessLocalGroup._start_coll(ReduceScatter(opts.reduceOp), self)
+        res = coll.join(self._rank, (output_tensor, scatter_list))
+        ProcessLocalGroup._end_coll(coll, self)
+        return res
+
+    def _reduce_scatter_base(
+        self, output_tensor, input_tensor, opts=ReduceScatterOptions()
+    ):
+        tensor_list = list(torch.chunk(input_tensor, self._world_size))
+        return self.reduce_scatter([output_tensor], [tensor_list], opts)
+
+    def reduce_scatter_tensor_coalesced(
+        self, output_tensors, input_tensors, opts=ReduceScatterOptions()
+    ):
+        works = [
+            self._reduce_scatter_base(output_tensor, input_tensor, opts)
+            for output_tensor, input_tensor in zip(output_tensors, input_tensors)
+        ]
+        for work in works[:-1]:
+            work.wait()
+        return works[-1]
+
+    def allgather_into_tensor_coalesced(
+        self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()
+    ):
+        res = None
+        for o_t, i_t in zip(output_tensor_list, input_tensor_list):
+            res = self._allgather_base(o_t, i_t)
+        return res
+
+    def __init__(self, rank, world_size):
+        super().__init__(rank, world_size)
+        self._rank = rank
+        self._world_size = world_size
+        world = dist.distributed_c10d._world
+        if isinstance(world, ThreadLocalWorld):
+            world = world._get_world()
+        self._world = weakref.ref(world)
+        self._ctx = torch.autograd.set_multithreading_enabled(False)
+
+    def size(self):
+        return self._world_size
+
+    @property
+    def pg_name(self):
+        """
+        return the global registered name of the current pg in the world
+        """
+        return self._world().pg_names[self]
+
+    @property
+    def group_name(self):
+        return self.pg_name
+
+    def getBackendName(self):
+        return "threaded"
+
+    def __repr__(self):
+        return f"ThreadedPG world_size:{self._world_size} rank:{self._rank}"
+
+
+def _create_threaded_pg(prefix_store, rank, world_size, timeout):
+    pg = ProcessLocalGroup(rank, world_size)
+    # https://github.com/pytorch/pytorch/pull/103033 changed store based barrier to optional
+    # When device mesh involves sub groups while store based barrier is not enabled in c10d,
+    # even though threaded pg actual collectives are assumed to be single threaded,
+    # different threads may be initializing different groups,
+    # leading to race conditions.
+    # For example, if we have a mesh of [[0, 1], [2, 3]], the sub groups
+    # (dim 0 and 1) would be initialized in different threads independently.
+    # In this case we can no longer rely on class or global variables
+    # but have to rely on store based barrier to make sure each group
+    # is ready separately before we can invoke collectives in any of the groups.
+
+    # the prefix store is already per group so we pass an empty name here
+    _store_based_barrier(rank, prefix_store, "", world_size, timeout)
+    return pg
+
+
+dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda"])
+
+
+@dataclass
+class WorldData:
+    default_pg: dist.ProcessGroup
+    pg_map: dict[dist.ProcessGroup, tuple[str, Optional[Store]]]
+    pg_names: dict[dist.ProcessGroup, str]
+    pg_group_ranks: dict[dist.ProcessGroup, dict[int, int]]
+    pg_backend_config: dict[dist.ProcessGroup, str]
+    group_count: int
+    tags_to_pg: dict[str, list[dist.ProcessGroup]]
+    pg_to_tag: dict[dist.ProcessGroup, str]
+    pg_coalesce_state: dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]
+
+
+class ThreadLocalWorld:
+    _world = threading.local()
+
+    def _get_world(self) -> WorldData:
+        if not hasattr(ThreadLocalWorld._world, "world"):
+            ThreadLocalWorld._world.world = WorldData(
+                None, {}, {}, {}, {}, 0, {}, {}, {}
+            )
+        return ThreadLocalWorld._world.world
+
+    @property
+    def default_pg(self):
+        return self._get_world().default_pg
+
+    @default_pg.setter
+    def default_pg(self, value):
+        self._get_world().default_pg = value
+
+    @property
+    def pg_map(self):
+        return self._get_world().pg_map
+
+    @property
+    def pg_names(self):
+        return self._get_world().pg_names
+
+    @property
+    def pg_group_ranks(self):
+        return self._get_world().pg_group_ranks
+
+    @property
+    def pg_backend_config(self):
+        return self._get_world().pg_backend_config
+
+    @property
+    def group_count(self) -> int:
+        return self._get_world().group_count
+
+    @group_count.setter
+    def group_count(self, value):
+        self._get_world().group_count = value
+
+    @property
+    def tags_to_pg(self):
+        return self._get_world().tags_to_pg
+
+    @property
+    def pg_to_tag(self):
+        return self._get_world().pg_to_tag
+
+    @property
+    def pg_coalesce_state(self) -> dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]:
+        return self._get_world().pg_coalesce_state
+
+
+_old_pg_world = None
+_ctx_manager = None
+
+
+def _install_threaded_pg():
+    global _old_pg_world
+    global _ctx_manager
+    _old_pg_world = dist.distributed_c10d._world
+    dist.distributed_c10d._world = ThreadLocalWorld()
+    _ctx_manager = torch.autograd.set_multithreading_enabled(False)
+
+    return dist.distributed_c10d._world
+
+
+def _uninstall_threaded_pg():
+    dist.distributed_c10d._world = _old_pg_world
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c55f5b8847b1d872985ab49c40cbf053955f86
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/nn/api/remote_module_test.py
@@ -0,0 +1,752 @@
+# mypy: allow-untyped-defs
+
+import enum
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.testing._internal.dist_utils as dist_utils
+from torch import nn, Tensor
+from torch._jit_internal import Future
+from torch.distributed.nn import RemoteModule
+from torch.distributed.nn.api.remote_module import (
+    _REMOTE_MODULE_PICKLED_ATTRIBUTES,
+    _RemoteModule,
+)
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+_PARAM_VAL = torch.nn.Parameter(torch.ones(1))
+
+
+# RPC handler for querying the device on the destination worker.
+def remote_device(module_rref):
+    for param in module_rref.local_value().parameters():
+        return param.device
+
+
+# RPC handler for querying __dict__ on the destination worker.
+def remote_module_attributes(remote_module):
+    return remote_module.__dict__
+
+
+# RPC handler for running forward on the destination worker.
+def remote_forward(remote_module, args):
+    return remote_module.forward(*args)
+
+
+# RPC handler for running forward_async on the destination worker.
+def remote_forward_async(remote_module, args):
+    # Since future cannot be pickled and sent over the RPC layer,
+    # have to wait and behave just like ``forward_sync``.
+    return remote_module.forward_async(*args).wait()
+
+
+# RPC handler for getting training mode on the destination worker.
+def get_remote_training_arg(module_rref):
+    return module_rref.local_value().training
+
+
+class ModuleCreationMode(enum.Enum):
+    MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
+    MODULE_CTOR = "module_ctor"
+
+
+@torch.jit.interface
+class MyModuleInterface:
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+
+@torch.jit.interface
+class RemoteMyModuleInterface:
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+    def forward_async(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> Future[tuple[str, int, Tensor]]:
+        pass
+
+
+class MyModule(nn.Module):
+    def __init__(self, first_arg, first_kwarg=-1):
+        super().__init__()
+        self.param1 = _PARAM_VAL
+
+    def forward(
+        self, tensor: Tensor, number: int, word: str = "default"
+    ) -> tuple[str, int, Tensor]:
+        return word, number, tensor
+
+
+class BadModule:
+    def __init__(self, first_arg, first_kwarg=-1):
+        pass
+
+
+def create_scripted_module(first_arg, first_kwarg=-1):
+    module = MyModule(first_arg, first_kwarg=first_kwarg)
+    scripted_module = torch.jit.script(module)
+    return scripted_module
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonRemoteModuleTest(RpcAgentTestFixture):
+    @property
+    def world_size(self):  # Override setting in RpcAgentTestFixture
+        return 2
+
+    @staticmethod
+    def _create_remote_module_iter(remote_device, modes=None):
+        if modes is None:
+            modes = ModuleCreationMode.__members__.values()
+
+        args = (1,)
+        kwargs = dict(first_kwarg=2)
+
+        if ModuleCreationMode.MODULE_CTOR in modes:
+            remote_module = RemoteModule(remote_device, MyModule, args, kwargs)
+            yield remote_module
+
+        if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
+            remote_module = _RemoteModule(
+                remote_device,
+                create_scripted_module,
+                args,
+                kwargs,
+                _module_interface_cls=MyModuleInterface,
+            )
+            scripted_remote_module = torch.jit.script(remote_module)
+            yield scripted_remote_module
+
+
+class RemoteModuleTest(CommonRemoteModuleTest):
+    @dist_utils.dist_init
+    def test_bad_module(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        remote_device = f"{dst_worker_name}/cpu"
+        args = (1,)
+        kwargs = dict(first_kwarg=2)
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,",
+        ):
+            RemoteModule(remote_device, BadModule, args, kwargs).forward()
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of ,",
+        ):
+            RemoteModule(remote_device, BadModule, args, kwargs).forward()
+
+    @dist_utils.dist_init
+    def test_forward_async(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2, "3")
+        for remote_module in self._create_remote_module_iter(dst_worker_name):
+            ret_fut = remote_module.forward_async(*args)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_forward_async_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            )
+        )
+
+        @torch.jit.script
+        def run_forward_async(scripted_remote_module: RemoteMyModuleInterface):
+            ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3")
+            ret = ret_fut.wait()
+            return ret
+
+        ret = run_forward_async(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+
+    @dist_utils.dist_init
+    def test_forward_sync(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2, "3")
+        for remote_module in self._create_remote_module_iter(dst_worker_name):
+            ret = remote_module.forward(*args)
+            self.assertEqual(ret, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_forward_sync_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            )
+        )
+
+        @torch.jit.script
+        def run_forward(scripted_remote_module: MyModuleInterface):
+            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
+            return ret
+
+        ret = run_forward(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+
+    @dist_utils.dist_init
+    def test_forward_with_kwargs(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        args = (torch.ones(1), 2)
+        kwargs = dict(word="3")
+        # Only test Python nn.Module, because script module methods don't support taking kwargs.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            ret_fut = remote_module.forward_async(*args, **kwargs)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args + ("3",))))
+
+            ret = remote_module.forward(*args, **kwargs)
+            self.assertEqual(ret, tuple(reversed(args + ("3",))))
+
+    @dist_utils.dist_init
+    def test_remote_parameters(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # Only test Python nn.Module, because script module methods don't support ``remote_parameters``.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            param_rrefs = remote_module.remote_parameters()
+            self.assertEqual(len(param_rrefs), 1)
+            self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL))
+
+    @dist_utils.dist_init
+    def test_get_module_rref(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # Only test Python nn.Module, because script module methods don't support ``get_module_rref``.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            rref = remote_module.get_module_rref()
+            self.assertEqual(rref, remote_module.module_rref)
+            for param in rref.to_here().parameters():
+                self.assertTrue(torch.equal(param, _PARAM_VAL))
+
+    @dist_utils.dist_init
+    def test_train_eval(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            remote_module.train()
+            ret1 = rpc.rpc_sync(
+                dst_worker_name,
+                get_remote_training_arg,
+                args=(remote_module.get_module_rref(),),
+            )
+            self.assertEqual(ret1, True)
+
+            remote_module.eval()
+            ret2 = rpc.rpc_sync(
+                dst_worker_name,
+                get_remote_training_arg,
+                args=(remote_module.get_module_rref(),),
+            )
+            self.assertEqual(ret2, False)
+
+    @dist_utils.dist_init
+    def test_unsupported_methods(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``register_buffer`` not supported for RemoteModule"
+            ):
+                remote_module.register_buffer("buffer", torch.ones(5))
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_parameter`` not supported for RemoteModule",
+            ):
+                remote_module.register_parameter(
+                    "param", torch.nn.Parameter(torch.ones(1))
+                )
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``add_module`` not supported for RemoteModule"
+            ):
+                remote_module.add_module("empty", None)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``apply`` not supported for RemoteModule"
+            ):
+                fn = torch.rand((3, 3), requires_grad=False)
+                remote_module.apply(fn)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``cuda`` not supported for RemoteModule"
+            ):
+                remote_module.cuda()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``cpu`` not supported for RemoteModule"
+            ):
+                remote_module.cpu()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``type`` not supported for RemoteModule"
+            ):
+                remote_module.type(torch.FloatTensor)
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``float`` not supported for RemoteModule"
+            ):
+                remote_module.float()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``double`` not supported for RemoteModule"
+            ):
+                remote_module.double()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``bfloat16`` not supported for RemoteModule"
+            ):
+                remote_module.bfloat16()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``to`` not supported for RemoteModule"
+            ):
+                remote_module.to("cpu", dtype=torch.int32)
+
+            def hook(module, grad_input, grad_output):
+                pass
+
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_backward_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_backward_hook(hook)
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_forward_pre_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_forward_pre_hook(hook)
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``register_forward_hook`` not supported for RemoteModule",
+            ):
+                remote_module.register_forward_hook(hook)
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``state_dict`` not supported for RemoteModule"
+            ):
+                remote_module.state_dict()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``load_state_dict`` not supported for RemoteModule"
+            ):
+                remote_module.load_state_dict({})
+
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead.",
+            ):
+                remote_module.parameters()
+            with self.assertRaisesRegex(
+                ValueError,
+                r"Method ``named_parameters`` not supported for RemoteModule",
+            ):
+                remote_module.named_parameters()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``buffers`` not supported for RemoteModule"
+            ):
+                remote_module.buffers()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_buffers`` not supported for RemoteModule"
+            ):
+                remote_module.named_buffers()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``children`` not supported for RemoteModule"
+            ):
+                remote_module.children()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_children`` not supported for RemoteModule"
+            ):
+                remote_module.named_children()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``modules`` not supported for RemoteModule"
+            ):
+                remote_module.modules()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``named_modules`` not supported for RemoteModule"
+            ):
+                remote_module.named_modules()
+
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``requires_grad_`` not supported for RemoteModule"
+            ):
+                remote_module.requires_grad_()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``zero_grad`` not supported for RemoteModule"
+            ):
+                remote_module.zero_grad()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``share_memory`` not supported for RemoteModule"
+            ):
+                remote_module.share_memory()
+            with self.assertRaisesRegex(
+                ValueError, r"Method ``extra_repr`` not supported for RemoteModule"
+            ):
+                remote_module.extra_repr()
+
+    @dist_utils.dist_init
+    def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # If a new attribute is added to this RemoteModule after the initialization,
+        # and it will be sent over the wire by RPC,
+        # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES.
+        # Note that adding a new attribute out of constructor should rarely happen.
+        # If a new attribute is added to RemoteModule constructor,
+        # there is a sanity check to enforce developers to add this attribute to either
+        # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            new_attr_name = "new_attr"
+            setattr(remote_module, new_attr_name, 1)
+
+            attrs = rpc.rpc_sync(
+                dst_worker_name, remote_module_attributes, (remote_module,)
+            )
+            self.assertNotIn(new_attr_name, attrs)
+
+    @dist_utils.dist_init
+    def test_remote_module_py_pickle_not_supported(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            with TemporaryFileName() as fname:
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC",
+                ):
+                    torch.save(remote_module, fname)
+
+    @dist_utils.dist_init
+    def test_remote_module_py_pickle_not_supported_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        for remote_module in self._create_remote_module_iter(
+            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+        ):
+            with TemporaryFileName() as fname:
+                with self.assertRaisesRegex(
+                    torch.jit.Error, "can only be pickled when using RPC"
+                ):
+                    torch.save(remote_module, fname)
+
+
+class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest):
+    @property
+    def world_size(self):  # Override setting in CommonRemoteModuleTest
+        return 3
+
+    @dist_utils.dist_init
+    def test_send_remote_module_over_the_wire(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Unpickled attributes include both the inherent attributes of RemoteModule
+        # (not inherited from the superclass) and two installed methods.
+        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
+        expected_unpickled_attrs.append("forward_async")
+        expected_unpickled_attrs.append("forward")
+
+        # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            # Test querying some simple attributes from worker2.
+            attrs = rpc.rpc_sync(
+                dst_worker2_name, remote_module_attributes, (remote_module,)
+            )
+            self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs)
+            self.assertEqual(attrs["on"], "worker1")
+            self.assertEqual(attrs["device"], "cpu")
+            self.assertFalse(attrs["is_device_map_set"])
+            self.assertFalse(attrs["is_scriptable"])
+
+            # Test the installed methods on worker1's can be initiated by worker2 over RPC layer.
+            # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``,
+            # not have another worker to initiate forward over the RPC layer.
+            args = (torch.ones(1), 2, "3")
+            ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward, (remote_module, args))
+            self.assertEqual(ret1, tuple(reversed(args)))
+            ret2 = rpc.rpc_sync(
+                dst_worker2_name, remote_forward_async, (remote_module, args)
+            )
+            self.assertEqual(ret2, tuple(reversed(args)))
+
+    @dist_utils.dist_init
+    def test_send_remote_module_over_the_wire_script_not_supported(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Unpickled attributes include both the inherent attributes of RemoteModule
+        # (not inherited from the superclass) and two installed methods.
+        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
+        expected_unpickled_attrs.append("forward_async")
+        expected_unpickled_attrs.append("forward")
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Passing a script RemoteModule over RPC is not supported."
+        ):
+            # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
+            for remote_module in self._create_remote_module_iter(
+                dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
+            ):
+                # Test querying some simple attributes from worker2.
+                rpc.rpc_sync(
+                    dst_worker2_name, remote_module_attributes, (remote_module,)
+                )
+
+    @dist_utils.dist_init
+    def test_create_remote_module_from_module_rref(self):
+        if self.rank != 0:
+            return
+        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
+
+        # Create a remote module on worker1 and then pass its `module_rref` to worker2 over the RPC layer.
+        for remote_module in self._create_remote_module_iter(
+            dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            remote_module2 = rpc.rpc_sync(
+                dst_worker2_name,
+                RemoteModule.init_from_module_rref,
+                (dst_worker2_name, remote_module.get_module_rref()),
+            )
+
+            args = (torch.ones(1), 2, "3")
+            ret1 = rpc.rpc_sync(dst_worker1_name, remote_forward, (remote_module, args))
+            ret2 = rpc.rpc_sync(
+                dst_worker2_name, remote_forward, (remote_module2, args)
+            )
+            self.assertEqual(ret1, ret2)
+
+
+class CudaRemoteModuleTest(CommonRemoteModuleTest):
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_valid_device(self):
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker_name = dist_utils.worker_name(dst_rank)
+
+        for remote_module in self._create_remote_module_iter(
+            f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            device = rpc.rpc_sync(
+                dst_worker_name, remote_device, (remote_module.module_rref,)
+            )
+            self.assertEqual(device.type, "cuda")
+            self.assertEqual(device.index, 0)
+
+        # Test rank works as well.
+        for remote_module in self._create_remote_module_iter(
+            f"rank:{dst_rank}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            device = rpc.rpc_sync(
+                dst_worker_name, remote_device, (remote_module.module_rref,)
+            )
+            self.assertEqual(device.type, "cuda")
+            self.assertEqual(device.index, 0)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_invalid_devices(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            r"Expected one of .+ device type at start of device string",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/foo",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        if TEST_WITH_ROCM:
+            errorString = (
+                r"HIP error: invalid device ordinal\n"
+                r"HIP kernel errors might be asynchronously reported at some other API call, "
+                r"so the stacktrace below might be incorrect.\n"
+                r"For debugging consider passing AMD_SERIALIZE_KERNEL=3"
+            )
+        else:
+            errorString = r"CUDA error: invalid device ordinal"
+        with self.assertRaisesRegex(RuntimeError, errorString):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cuda:100",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cpu2",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    f"{dst_worker_name}/cuda:0/cuda:1",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: /. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    "/",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Could not parse remote_device: /cuda:0. The valid format is '/'",
+        ):
+            [
+                m.forward()
+                for m in self._create_remote_module_iter(
+                    "/cuda:0",
+                    modes=[ModuleCreationMode.MODULE_CTOR],
+                )
+            ]
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_input_moved_to_cuda_device(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        # These two CPU tensors (in args and kwargs) should be implicitly moved to an appropriate cuda device.
+        t1 = torch.ones(1)
+        args = (t1, 2)
+        t2 = t1 * 2
+        kwargs = dict(word=t2)
+
+        # Only test Python nn.Module, because script module methods don't support taking kwargs.
+        for remote_module in self._create_remote_module_iter(
+            f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
+        ):
+            ret_fut = remote_module.forward_async(*args, **kwargs)
+            ret = ret_fut.wait()
+            self.assertEqual(ret, tuple(reversed(args + (t2,))))
+            # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+            self.assertEqual(ret[0].device.type, "cpu")
+            self.assertEqual(ret[2].device.type, "cpu")
+
+            ret = remote_module.forward(*args, **kwargs)
+            self.assertEqual(ret, tuple(reversed(args + (t2,))))
+            # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+            self.assertEqual(ret[0].device.type, "cpu")
+            self.assertEqual(ret[2].device.type, "cpu")
+
+    @skip_if_lt_x_gpu(1)
+    @dist_utils.dist_init
+    def test_input_moved_to_cuda_device_script(self):
+        if self.rank != 0:
+            return
+        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
+
+        scripted_remote_module = next(
+            self._create_remote_module_iter(
+                f"{dst_worker_name}/cuda:0",
+                modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE],
+            )
+        )
+
+        @torch.jit.script
+        def run_forward(scripted_remote_module: MyModuleInterface):
+            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
+            return ret
+
+        ret = run_forward(scripted_remote_module)
+
+        self.assertEqual(ret, ("3", 2, torch.ones(1)))
+        # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
+        self.assertEqual(ret[2].device.type, "cpu")
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7cb2075e373434441ef5b90a24700b3a864f483
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
@@ -0,0 +1,2755 @@
+# mypy: allow-untyped-defs
+
+import random
+import sys
+import threading
+import time
+from datetime import timedelta
+from enum import Enum
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+import torch.testing._internal.dist_utils
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.distributed.rpc import RRef
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import (
+    IS_MACOS,
+    skip_but_pass_in_sandcastle_if,
+)
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    initialize_pg,
+    wait_until_node_failure,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+# Right now we test up to 3-layer nested rpc calls.
+# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id
+# sent from prev rank respectively.
+# rpc_done[2] and ctx_ids[2] represents for prev of prev rank.
+# rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank.
+# rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used.
+rpc_done = [False, False, False, False]
+ctx_ids = [-1, -1, -1, -1]
+
+known_context_ids = set()
+
+requires_grad_tensor = torch.ones(3, 3, requires_grad=True)
+
+
+# Send rpc done info and context_id to
+# dst_rank = (self.rank + rank_distance) % self.world_size
+# we don't need a lock here since the GIL is held while executing remote
+# python UDFs, so access is serialized across several workers.
+def _set_rpc_done(ctx_id, rank_distance):
+    global rpc_done
+    global ctx_ids
+    global known_context_ids
+    rpc_done[rank_distance] = True
+    ctx_ids[rank_distance] = ctx_id
+    known_context_ids.add(ctx_id)
+
+
+def _check_rpc_done(rank_distance):
+    while not rpc_done[rank_distance]:
+        time.sleep(0.1)
+
+
+def _torch_ones(sizes, requires_grad=False):
+    return torch.ones(sizes, requires_grad=requires_grad)
+
+
+# This method must be called on the rref owner, and verifies that the grad of
+# rref tensor equals to the given grad.
+def _compare_owner_value(context_id, rref, grad):
+    grads = dist_autograd.get_gradients(context_id)
+    x = grads[rref.local_value()]
+    if x.is_sparse:
+        assert grad.is_sparse
+        x = x.to_dense()
+        grad = grad.to_dense()
+    else:
+        assert not grad.is_sparse
+    return torch.equal(x, grad)
+
+
+def create_tensor():
+    return torch.ones((3, 3), requires_grad=True)
+
+
+def build_sparse_tensor(coalesce=False, requires_grad=True, dtype=torch.float32):
+    i = [[0, 1, 1], [2, 0, 2]]
+    v = [3.2, 4.1, 5.3]
+    tensor = torch.sparse_coo_tensor(
+        i, v, (3, 3), requires_grad=requires_grad, dtype=dtype
+    )
+    if coalesce:
+        tensor = tensor.coalesce()
+    return tensor
+
+
+@torch.jit.script
+def create_torchscript_tensor() -> torch.Tensor:
+    return torch.ones((3, 3)).requires_grad_()
+
+
+def my_py_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+def my_scalar_add(a, b):
+    return a + b
+
+
+def my_rref_add(rref_t1, t2):
+    ret = torch.add(rref_t1.local_value(), t2)
+    return ret
+
+
+@torch.jit.script
+def my_script_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+@torch.jit.script
+def my_script_ref_add(ref_t1: RRef[torch.Tensor], t2: torch.Tensor) -> torch.Tensor:
+    t1 = ref_t1.to_here()
+    return torch.add(t1, t2)
+
+
+def my_nested_rref_add(dst, rref_t1, t2):
+    return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
+
+
+def ret_requires_grad():
+    return requires_grad_tensor
+
+
+def my_py_nested_call(t1, t2, dst, world_size, hops):
+    next_dst = (dst + 1) % world_size
+    if hops > 0:
+        return rpc.rpc_sync(
+            worker_name(next_dst),
+            my_py_nested_call,
+            args=(t1, t2, next_dst, world_size, hops - 1),
+        )
+    else:
+        return rpc.rpc_sync(worker_name(next_dst), my_py_add, args=(t1, t2))
+
+
+# after dist autograd context is cleaned up, it should be cleaned up on other
+# nodes. This helper allows timeout_seconds for those RPCs to be completed, and
+# ensures that all the contexts have been cleaned up in that timeframe.any
+def _all_contexts_cleaned_up(timeout_seconds=10):
+    global known_context_ids
+    start = time.time()
+    context_id_to_raised = set()
+    while (
+        time.time() - start < timeout_seconds
+        and context_id_to_raised != known_context_ids
+    ):
+        for context_id in known_context_ids:
+            try:
+                dist_autograd._retrieve_context(context_id)
+            except RuntimeError:
+                context_id_to_raised.add(context_id)
+    # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError.
+    success = context_id_to_raised == known_context_ids
+    return success
+
+
+# This function creates a dis autograd context, run rpc_sync on the given ps,
+# and then blocks until the ps has verified the grads are correctly accumulated.
+def _run_trainer(rref_t1, t2, ps, rank_diff, sparse):
+    with dist_autograd.context() as context_id:
+        ret = rpc.rpc_sync(ps, my_rref_add, args=(rref_t1, t2))
+        if sparse:
+            loss = torch.sparse.sum(ret)
+        else:
+            loss = ret.sum()
+        dist_autograd.backward(context_id, [loss])
+        # prevent deleting dist autograd context
+        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
+        rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
+
+
+# This function is the same as _run_trainer, except rpc calls torchscript
+# function "my_script_ref_add" instead of python function "my_rref_add"
+def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse):
+    with dist_autograd.context() as context_id:
+        ret = rpc.rpc_sync(ps, my_script_ref_add, args=(rref_t1, t2))
+        if sparse:
+            loss = torch.sparse.sum(ret)
+        else:
+            loss = ret.sum()
+        dist_autograd.backward(context_id, [loss])
+        # prevent deleting dist autograd context
+        rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
+        rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
+
+
+class SimulateBackwardError(Function):
+    _simulate_error = True
+
+    @staticmethod
+    def forward(ctx, input):
+        return input
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, input):
+        if SimulateBackwardError._simulate_error:
+            raise Exception("Simulate error on backward pass")  # noqa: TRY002
+        else:
+            return input
+
+
+class ExecMode(Enum):
+    LOCAL = 1  # Run the operation locally.
+    RPC_SYNC = 2  # Run the operation using rpc_sync
+    REMOTE = 3  # Run the operation using remote.
+    RPC_ASYNC = 4  # Run the operation using rpc_async
+
+
+# Common utils for both CPU and CUDA test suites
+class CommonDistAutogradTest(RpcAgentTestFixture):
+    def _exec_func_with_dst(self, dst, exec_mode, method, *args):
+        if ExecMode.LOCAL == exec_mode:
+            if len(args) == 1 and isinstance(args[0], list):
+                return method(*args[0])
+            return method(*args)
+        elif ExecMode.RPC_SYNC == exec_mode:
+            return rpc.rpc_sync(worker_name(dst), method, args=(args))
+        elif ExecMode.REMOTE == exec_mode:
+            return rpc.remote(worker_name(dst), method, args=(args)).to_here()
+        elif ExecMode.RPC_ASYNC == exec_mode:
+            fut = rpc.rpc_async(worker_name(dst), method, args=(args))
+            return fut.wait()
+        else:
+            raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+    def _exec_func(self, exec_mode, method, *args):
+        return self._exec_func_with_dst(self._next_rank(), exec_mode, method, *args)
+
+    def _next_rank(self):
+        if hasattr(self, "dst_rank"):
+            self.dst_rank = (self.dst_rank + 1) % self.world_size
+            if self.dst_rank == self.rank:
+                return self._next_rank()
+        else:
+            self.dst_rank = (self.rank + 1) % self.world_size
+        return self.dst_rank
+
+    def _check_rpc_done(self, rank_distance):
+        _check_rpc_done(rank_distance)
+
+    def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args):
+        if exec_mode == ExecMode.LOCAL:
+            torch.autograd.backward(tensors)
+            return [arg.grad for arg in args]
+        else:
+            self._verify_backwards_remote(tensors, context_id, local_grads, *args)
+
+    def _verify_backwards_remote(self, tensors, context_id, local_grads, *args):
+        dist_autograd.backward(context_id, tensors)
+
+        # Verify grads were accumulated appropriately.
+        grads = dist_autograd.get_gradients(context_id)
+        nargs = len(args)
+        ngrads = 0
+        for i in range(0, nargs):
+            if local_grads[i] is not None:
+                self.assertIn(args[i], grads)
+                self.assertEqual(local_grads[i], grads[args[i]])
+                ngrads += 1
+            else:
+                self.assertNotIn(args[i], grads)
+
+        self.assertEqual(ngrads, len(grads))
+
+    def _test_graph(self, fn, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor()
+                t2 = build_sparse_tensor()
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2))
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(worker_name(dst_rank), fn, args=(t1, t2)).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            # Verify graph for current context id.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(1, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                next(iter(recv_functions.values())),
+                t1,
+                t2,
+                ret,
+            )
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            # Verify graph for previous context id.
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values())))
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+        # autograd context should be cleaned up by now.
+        with self.assertRaises(RuntimeError):
+            ctx = dist_autograd._retrieve_context(context_id)
+
+        # No autograd context available.
+        with self.assertRaises(RuntimeError):
+            ctx = dist_autograd._current_context()
+
+    # 3-layer nested calls
+    def _test_graph_for_py_nested_call(self, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=True)
+                t2 = build_sparse_tensor(requires_grad=True)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(t1, t2, dst_rank, self.world_size, 1),
+                )
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(t1, t2, dst_rank, self.world_size, 1),
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            # Barrier to ensure all RPCs are done.
+            dist.barrier()
+
+            for rd in [1, 2, 3]:
+                rpc.rpc_sync(
+                    worker_name((self.rank + rd) % self.world_size),
+                    _set_rpc_done,
+                    args=(context_id, rd),
+                )
+
+            # Barrier to ensure all set_rpc_done have completed.
+            dist.barrier()
+
+            # For self.rank, it has 4 graphs to verify
+            # One is for current context id when this rank send first rpc call.
+            # Second one is for prev context id when this rank make 1st nested
+            # call.
+            # Third one is for prev prev context id when this rank make
+            # 2nd nested call.
+            # Last one is for prev prev prev context id when this rank
+            # execute the torch.add() operator.
+
+            # Verify first graph for current context id.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(1, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                next(iter(recv_functions.values())),
+                t1,
+                t2,
+                ret,
+            )
+
+            # Verify second graph for 1st nested call.
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            self._verify_graph_for_nested_rpc_call(ctx)
+
+            # Verify third graph for 2nd nested call.
+            ctx = dist_autograd._retrieve_context(ctx_ids[2])
+            self._verify_graph_for_nested_rpc_call(ctx)
+
+            # verify last graph for rpc call execution.
+            ctx = dist_autograd._retrieve_context(ctx_ids[3])
+            send_functions = ctx._send_functions()
+            self.assertEqual(1, len(send_functions))
+            self._verify_graph_for_rpc_call_exec(next(iter(send_functions.values())))
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+    # Rank0->Rank1->Rank0
+    def _test_graph_for_py_nested_call_itself(self, exec_mode, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=True)
+                t2 = build_sparse_tensor(requires_grad=True)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=True)
+                t2 = torch.zeros(3, 3, requires_grad=True)
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(
+                        t1,
+                        t2,
+                        (self.rank - 1 + self.world_size) % self.world_size,
+                        self.world_size,
+                        0,
+                    ),
+                )
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank),
+                    my_py_nested_call,
+                    args=(
+                        t1,
+                        t2,
+                        (self.rank - 1 + self.world_size) % self.world_size,
+                        self.world_size,
+                        0,
+                    ),
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(
+                worker_name((self.rank + 1) % self.world_size),
+                _set_rpc_done,
+                args=(context_id, 1),
+            )
+
+            # For self.rank, it has 2 graphs to verify.
+            # One is for current context id when this rank send first rpc
+            # call and execute the torch.add() operator.
+            # Another one is for prev context id when this rank make
+            # nested call.
+            ctx = dist_autograd._current_context()
+            self.assertEqual(context_id, ctx._context_id())
+            send_functions = ctx._send_functions()
+            self.assertEqual(2, len(send_functions))
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(2, len(recv_functions))
+            self._verify_graph_for_first_rpc_call(
+                next(iter(send_functions.values())),
+                list(recv_functions.values())[1],
+                t1,
+                t2,
+                ret,
+            )
+            self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1])
+
+            # Verify two pairs of send and recv functions for nested
+            # call
+            self._check_rpc_done(1)
+            ctx = dist_autograd._retrieve_context(ctx_ids[1])
+            self._verify_graph_for_nested_rpc_call(ctx)
+            # this barrier is needed so one worker does not clean up their
+            # autograd context before another worker tries to access it.
+            dist.barrier()
+
+    def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dst_rank = (self.rank + 1) % self.world_size
+        with dist_autograd.context() as context_id:
+            if sparse:
+                t1 = build_sparse_tensor(requires_grad=False)
+                t2 = build_sparse_tensor(requires_grad=False)
+            else:
+                t1 = torch.ones(3, 3, requires_grad=False)
+                t2 = torch.zeros(3, 3, requires_grad=False)
+            if ExecMode.RPC_SYNC == exec_mode:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+            elif ExecMode.REMOTE == exec_mode:
+                rpc.remote(worker_name(dst_rank), torch.add, args=(t1, t2)).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            ctx = dist_autograd._current_context()
+            send_functions = ctx._send_functions()
+            self.assertEqual(len(send_functions), 0)
+            recv_functions = ctx._recv_functions()
+            self.assertEqual(len(recv_functions), 0)
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            # NB: RRef.to_here() always passes the autograd context to the
+            # the callee, as the caller does not know whether the return
+            # value would contain a requires_grad tensor or not.
+            #
+            # rpc/remote with udf (_set_rpc_done here) also always passes the
+            # autograd context to the callee due to the same reason.
+            self.assertNotEqual(-1, dist_autograd._retrieve_context(ctx_ids[1]))
+            dist.barrier()
+
+    def _test_rpc_complex_args(self, exec_mode, sparse):
+        with dist_autograd.context():
+            num_tensors = 10
+            tensors = []
+            for i in range(num_tensors):
+                if sparse:
+                    tensor = build_sparse_tensor(requires_grad=(i % 2 == 0))
+                else:
+                    tensor = torch.ones(3, 3, requires_grad=(i % 2 == 0))
+                tensors.append(tensor)
+            dst_rank = self._next_rank()
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), torch.stack, args=(tensors,))
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(
+                    worker_name(dst_rank), torch.stack, args=(tensors,)
+                ).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            self.assertEqual(torch.stack(tensors), ret)
+
+            # Verify appropriate tensors have been attached the autograd graph.
+            next_funcs = next(
+                iter(dist_autograd._current_context()._send_functions().values())
+            ).next_functions
+            for i in range(len(next_funcs)):
+                self.assertEqual(
+                    "torch::autograd::AccumulateGrad", next_funcs[i][0].name()
+                )
+                self.assertEqual(tensors[i], next_funcs[i][0].variable)
+
+            # Verify that the worker id has been recorded in the context
+            ctx = dist_autograd._current_context()
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(len(worker_ids), 1)
+            self.assertEqual(worker_ids, {dst_rank})
+
+    def context_cleanup_test_helper(self, rpc_args, func, nested=False):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # test that in dist autograd, in the case that tensors communicated over RPC do
+        # NOT require grad, we still cleanup the dist autograd contexts created
+        # on other nodes. This is because the autograd context is still
+        # communicated over RPC even if tensor arguments do not require grad, as
+        #  it is possible that the response could.
+        if nested:
+            dst_rank = (self.rank + 1) % self.world_size
+            nested_dst_rank = (dst_rank + 1) % self.world_size
+            dst_ranks = {dst_rank}
+        else:
+            dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+
+        with dist_autograd.context() as context_id:
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+                if nested:
+                    rpc.rpc_sync(
+                        worker_name(nested_dst_rank),
+                        _set_rpc_done,
+                        args=(context_id, 2),
+                    )
+        # the thread's context id should be cleaned up
+        with self.assertRaises(RuntimeError):
+            dist_autograd._retrieve_context(context_id)
+        # Ensure all peers have finished mutating the
+        # `known_context_ids` set.
+        dist.barrier()
+        # check that all contexts have been cleaned up.
+        success = _all_contexts_cleaned_up()
+        self.assertTrue(success)
+
+    def _backward_no_grad_on_tensor(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            self.assertIsNone(t1.grad)
+            self.assertIsNone(t2.grad)
+
+            # Now populate .grad with local autograd engine and
+            # verify dist autograd doesn't mess with it.
+            loss_local = torch.add(t1, t2)
+            if sparse:
+                loss_local = torch.sparse.sum(loss_local)
+            else:
+                loss_local = loss_local.sum()
+            loss_local.backward()
+            self.assertIsNotNone(t1.grad)
+            self.assertIsNotNone(t2.grad)
+
+            t1_grad_before = t1.grad
+            t2_grad_before = t2.grad
+            dist_autograd.backward(context_id, [loss])
+            self.assertEqual(t1_grad_before, t1.grad)
+            self.assertEqual(t2_grad_before, t2.grad)
+
+    # The current rank first creates a tensor on the rref_owner, and then passes
+    # the rref with another tensor to the callee to run either my_rref_add or
+    # my_nested_rref_add, depending on whether the callee is the rref owner.
+    # The grad of tensor lives on the current rank, and the grad of the rref
+    # tensor lives on the rref owner.
+    def _backward_rref(self, callee, rref_owner, t1, t2, local_grads, sparse):
+        local_ret = torch.add(t1, t2)
+        if sparse:
+            local_ret = torch.sparse.sum(local_ret)
+        else:
+            local_ret = local_ret.sum()
+        local_ret.backward()
+        with dist_autograd.context() as context_id:
+            if sparse:
+                rref_t1 = rpc.remote(
+                    rref_owner,
+                    build_sparse_tensor,
+                    args=(
+                        False,
+                        True,
+                    ),
+                )
+            else:
+                rref_t1 = rpc.remote(
+                    rref_owner,
+                    _torch_ones,
+                    args=((3, 3),),
+                    kwargs={"requires_grad": True},
+                )
+            if callee == rref_owner:
+                rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
+            else:
+                rref = rpc.remote(
+                    callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
+                )
+            ret = rref.to_here()
+            if sparse:
+                ret = torch.sparse.sum(ret)
+            else:
+                ret = ret.sum()
+            dist_autograd.backward(context_id, [ret])
+
+            # verify grads on caller
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertIn(t2, grads)
+            self.assertEqual(grads[t2], t2.grad)
+
+            # verify grads on rref owner
+            self.assertTrue(
+                rpc.rpc_sync(
+                    rref_owner,
+                    _compare_owner_value,
+                    args=(context_id, rref_t1, t1.grad),
+                )
+            )
+
+    # In this test, every rank will serve as a parameter server (ps) and a
+    # driver, and then kicks off trainers on the other three ranks. So, we have:
+    # ps = rank0 with trainers = rank1/2/3
+    # ps = rank2 with trainers = rank2/3/0
+    # ps = rank3 with trainers = rank3/0/1
+    # ps = rank4 with trainers = rank0/1/2
+    #
+    # These four test ps-trainer groups run on completely separate autograd
+    # graphs, but they share the same set of underlying RpcAgents.
+    def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse):
+        if sparse:
+            t1 = build_sparse_tensor(requires_grad=True)
+            t2 = build_sparse_tensor(requires_grad=True)
+        else:
+            t1 = torch.ones((3, 3), requires_grad=True)
+            t2 = torch.zeros((3, 3), requires_grad=True)
+
+        local_ret = torch.add(t1, t2)
+        if sparse:
+            torch.sparse.sum(local_ret).backward()
+        else:
+            local_ret.sum().backward()
+
+        # create rref on self
+        rref_t1 = rpc.remote(worker_name(self.rank), create_ref_fn, args=())
+
+        # kick off forward and backward pass on three other workers (trainers)
+        rank_diffs = [1, 2, 3]
+        futures = [
+            rpc.rpc_async(
+                worker_name((self.rank + rank_diff) % self.world_size),
+                trainer_fn,
+                args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse),
+            )
+            for rank_diff in rank_diffs
+        ]
+
+        # check if the trainers have done with their backward pass
+        for rank_diff in rank_diffs:
+            self._check_rpc_done(rank_diff)
+
+        # trainers are done and holding the context for verification
+        for rank_diff in rank_diffs:
+            # make sure grads are accumulated for the same tensors and values
+            # are all correct
+            ctx_id = ctx_ids[rank_diff]
+            grads = dist_autograd.get_gradients(ctx_id)
+            local_t1 = rref_t1.to_here()
+            self.assertIn(local_t1, grads)
+            self.assertEqual(grads[local_t1], t1.grad)
+
+        # unblock trainers
+        _set_rpc_done(None, 0)
+
+        # wait until all trainers are done
+        torch.futures.wait_all(futures)
+
+    def _backward_multiple_round_trips(self, t1, t2, t3, t4, t5, local_grads, sparse):
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                # Multiple RPCs between different nodes.
+                val = self._exec_func(exec_mode, torch.add, t1, t2)
+                val = self._exec_func(exec_mode, torch.mul, t3, val)
+                s1 = self._exec_func(exec_mode, torch.stack, (t4, val))
+                s2 = self._exec_func(exec_mode, torch.stack, (t5, val))
+                if sparse:
+                    val = self._exec_func(exec_mode, torch.mul, s1, s2)
+                    val = self._exec_func(exec_mode, torch.mul, val, val)
+                    loss = torch.sparse.sum(val)
+                else:
+                    val = self._exec_func(exec_mode, torch.bmm, s1, s2)
+                    val = self._exec_func(exec_mode, torch.matmul, val, val)
+                    loss = val.sum()
+
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5
+                )
+                local_grads = ret if ret else local_grads
+
+    def _backward_different_dtypes(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                loss = self._exec_func(exec_mode, torch.add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(loss)
+                else:
+                    loss = loss.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple_python_udf(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(exec_mode, my_py_add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple_script_call(self, t1, t2, sparse):
+        local_grads = None
+        for exec_mode in [
+            ExecMode.LOCAL,
+            ExecMode.RPC_SYNC,
+            ExecMode.RPC_ASYNC,
+            ExecMode.REMOTE,
+        ]:
+            with dist_autograd.context() as context_id:
+                forward_ret = self._exec_func(exec_mode, my_script_add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(forward_ret)
+                else:
+                    loss = forward_ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    def _nested_backward_accumulate_grads(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            ret = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._test_nested_backward_accumulate_grads,
+                args=(t1, t2, self._next_rank()),
+            )
+            if sparse:
+                loss = torch.sparse.sum(ret)
+            else:
+                loss = ret.sum()
+            # Run backward twice.
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            dist_autograd.backward(context_id, [loss])
+
+    def _backwards_nested_python_udf(self, t1, t2, sparse):
+        t3 = t1 * t2
+        t4 = t1 + t2
+        res = t3 + t4
+        loss = t1 * t2 * t3 * t4 * res
+        if sparse:
+            loss = torch.sparse.sum(loss)
+        else:
+            loss = loss.sum()
+        torch.autograd.backward([loss])
+
+        # Now run distributed autograd.
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._nested_python_udf,
+                args=(t1, t2, self._next_rank()),
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            dist_autograd.backward(context_id, [loss])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    def _mixed_requires_grad(self, t1, t2, sparse):
+        for exec_mode in [ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(
+                    exec_mode, DistAutogradTest._mixed_requires_grad_operaton, t1, t2
+                )
+                self.assertEqual(t1 * t2, ret)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                dist_autograd.backward(context_id, [loss])
+                self.assertTrue(t1.requires_grad)
+                self.assertFalse(t2.requires_grad)
+                grads = dist_autograd.get_gradients(context_id)
+                self.assertIn(t1, grads)
+                self.assertNotIn(t2, grads)
+                self.assertEqual(t2, grads[t1])
+
+    def _multiple_backward(self, t1, t2, sparse):
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            )
+            if sparse:
+                loss = torch.sparse.sum(loss)
+            else:
+                loss = loss.sum()
+            # Run backward in a loop multiple times.
+            for _ in range(1000):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+    # For current context, this rank sends t1 and t2 tensors to dst_rank,
+    # then get t3 = torch.add(t1, t2) result tensor.
+    # For the current context in this rank, it expects graph like this:
+    #  send function:
+    #              rpcSendBackward
+    #                  /          \
+    #  t1.AccumulateGrad         t2.AccumulateGrad
+    #
+    #  recv function:
+    #
+    #            |
+    #          t3.rpcRecvBackward
+    #
+    def _verify_graph_for_first_rpc_call(
+        self, send_function, recv_function, t1, t2, ret
+    ):
+        # Retrieve the next functions in the graph.
+        next_funcs = send_function.next_functions
+        self.assertEqual(2, len(next_funcs))
+
+        # We should now hit t1 and t2 in the autograd graph.
+        self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name())
+        self.assertEqual(t1, next_funcs[0][0].variable)
+        self.assertEqual(0, next_funcs[0][1])
+        self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name())
+        self.assertEqual(t2, next_funcs[1][0].variable)
+        self.assertEqual(0, next_funcs[1][1])
+
+        # Test recv functions.
+        self.assertEqual(ret.grad_fn, recv_function)
+
+    # Run the same code locally and with dist autograd and verify gradients
+    # are same.
+    def _backward_simple(self, dst, t1, t2, local_grads, sparse):
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func_with_dst(dst, exec_mode, torch.add, t1, t2)
+                if sparse:
+                    loss = torch.sparse.sum(ret)
+                else:
+                    loss = ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    # For a context passed from previous nested chain calls, this rank
+    # receives two tensors t1 and t2, executes torch.add(t1, t2) and sends
+    # result tensor t3 back.
+    # For this context in this rank, it expects graph like this:
+    #  send and recv functions:
+    #       rpcSendBackward
+    #           |
+    #          t3.AddBackward0
+    #          /             \
+    # t1.recvRpcBackward    t2.recvRpcBackward
+    def _verify_graph_for_rpc_call_exec(self, send_function):
+        # Verify next function is AddBackward0
+        next_funcs = send_function.next_functions
+        self.assertEqual(1, len(next_funcs))
+        add_backward_fn = next_funcs[0][0]
+        self.assertEqual("AddBackward0", add_backward_fn.name())
+
+        # Verify the next two functions are the same recv backward function.
+        next_funcs = add_backward_fn.next_functions
+        self.assertEqual(2, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
+        )
+        self.assertEqual(next_funcs[0][0], next_funcs[1][0])
+
+    # For a context passed from previous nested chain calls, this rank
+    # receives two tensors t1 and t2, forwards t1 and t2 tensors using
+    # nested rpc call to next dst. In return route, receive result tensor t3
+    # from next dst and forwarding t3 back to previous calls.
+    # For this context in this rank, it expects graph like this:
+    #  send and recv functions for receiving and forwarding t1 and t2:
+    #       rpcSendBackward
+    #          /          \
+    # t1.recvRpcBackward    t2.recvRpcBackward
+    #  send and recv functions for receiving and forwarding t3:
+    #       rpcSendBackward
+    #             |
+    #           t3.recvRpcBackward
+    def _verify_graph_for_nested_rpc_call(self, ctx):
+        send_functions = ctx._send_functions()
+        self.assertEqual(2, len(send_functions))
+
+        # For send function when making nest rpc call,
+        # next functions of the send function are two recv functions
+        # for received two tensors from previous call
+        next_funcs = next(iter(send_functions.values())).next_functions
+        self.assertEqual(2, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name()
+        )
+        self.assertEqual(next_funcs[0][0], next_funcs[1][0])
+
+        # For send function when returning response to previous call
+        # next function of the send function is the recv function
+        # for received tensor result returned from nested call
+        next_funcs = list(send_functions.values())[1].next_functions
+        self.assertEqual(1, len(next_funcs))
+        self.assertEqual(
+            "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name()
+        )
+
+
+class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
+    # Sparse tests only work with TensorPipeAgent.
+    @dist_init
+    def test_graph_for_builtin_call_sparse(self):
+        self._test_graph(torch.add, ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_python_call_sparse(self):
+        self._test_graph(my_py_add, ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_builtin_remote_call_sparse(self):
+        self._test_graph(torch.add, ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_python_remote_call_sparse(self):
+        self._test_graph(my_py_add, ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_py_nested_call_sparse(self):
+        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_sparse(self):
+        self._test_graph_for_py_nested_call(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_graph_for_py_nested_call_itself_sparse(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_itself_sparse(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_sparse(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_remote_sparse(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_rpc_complex_args_sparse(self):
+        self._test_rpc_complex_args(ExecMode.RPC_SYNC, True)
+
+    @dist_init
+    def test_remote_complex_args_sparse(self):
+        self._test_rpc_complex_args(ExecMode.REMOTE, True)
+
+    @dist_init
+    def test_context_cleanup_tensor_with_grad_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=True)
+        t2 = build_sparse_tensor(requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_tensor_no_grad_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=False)
+        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_nested_rpc_sparse(self):
+        t1 = build_sparse_tensor(requires_grad=True)
+        t2 = build_sparse_tensor(requires_grad=True)
+        dst_rank = (self.rank + 1) % self.world_size
+        args = (t1, t2, dst_rank, self.world_size, 0)
+        self.context_cleanup_test_helper(
+            rpc_args=args, func=my_py_nested_call, nested=True
+        )
+
+    @dist_init
+    def test_backward_no_grad_on_tensor_sparse(self):
+        self._backward_no_grad_on_tensor(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_sparse(self):
+        self._backward_simple(
+            self._next_rank(),
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_self_sparse(self):
+        self._backward_simple(
+            self.rank,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_rref_multi_sparse(self):
+        if self.rank > 0:
+            callee = "worker0"
+            rref_owner = callee
+            self._backward_rref(
+                callee,
+                rref_owner,
+                build_sparse_tensor(requires_grad=True),
+                build_sparse_tensor(requires_grad=True),
+                None,
+                True,
+            )
+
+    @dist_init
+    def test_backward_rref_sparse(self):
+        callee = worker_name(self._next_rank())
+        rref_owner = callee
+        self._backward_rref(
+            callee,
+            rref_owner,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_rref_nested_sparse(self):
+        callee = worker_name((self.rank + 1) % self.world_size)
+        rref_owner = worker_name((self.rank + 2) % self.world_size)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_trainer_ps_sparse(self):
+        self._test_trainer_ps(build_sparse_tensor, _run_trainer, True)
+
+    @dist_init
+    def test_backward_multiple_round_trips_sparse(self):
+        self._backward_multiple_round_trips(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            build_sparse_tensor(requires_grad=True),
+            None,
+            True,
+        )
+
+    @dist_init
+    def test_backward_different_dtypes_sparse(self):
+        self._backward_different_dtypes(
+            build_sparse_tensor(requires_grad=True, dtype=torch.float32),
+            build_sparse_tensor(requires_grad=True, dtype=torch.float64),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_python_udf_sparse(self):
+        self._backward_simple_python_udf(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backward_simple_script_call_sparse(self):
+        self._backward_simple_script_call(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_nested_backward_accumulate_grads_sparse(self):
+        self._nested_backward_accumulate_grads(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_backwards_nested_python_udf_sparse(self):
+        # Run equivalent of _nested_python_udf locally.
+        self._backwards_nested_python_udf(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_mixed_requires_grad_sparse(self):
+        self._mixed_requires_grad(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=False),
+            True,
+        )
+
+    @dist_init
+    def test_multiple_backward_sparse(self):
+        self._multiple_backward(
+            build_sparse_tensor(requires_grad=True),
+            build_sparse_tensor(requires_grad=True),
+            True,
+        )
+
+    @dist_init
+    def test_embedding_bag_with_no_grad_tensors(self):
+        dst = self._next_rank()
+        remote_embedding = rpc.remote(
+            worker_name(dst),
+            torch.nn.EmbeddingBag,
+            args=(16, 16),
+            kwargs={"mode": "sum", "sparse": True},
+        )
+        local_embedding = torch.nn.EmbeddingBag(16, 16, mode="sum", sparse=True)
+
+        input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
+        # requires_grad = True to record send/recv functions
+        per_sample_weights = torch.rand((8), requires_grad=True)
+        offsets = torch.LongTensor([0, 4])
+
+        local_res = local_embedding(input, offsets, per_sample_weights)
+
+        # Run backward twice.
+        torch.autograd.backward([local_res.sum()], retain_graph=True)
+        torch.autograd.backward([local_res.sum()])
+        local_grad = local_embedding.weight.grad
+
+        with dist_autograd.context() as context_id:
+            res = rpc.rpc_sync(
+                worker_name(dst),
+                DistAutogradTest._call_remote_embedding,
+                args=(remote_embedding, input, offsets, per_sample_weights),
+            )
+
+            # Run backward twice to test accumulation of sparse gradients.
+            dist_autograd.backward(context_id, [res.sum()], retain_graph=True)
+            dist_autograd.backward(context_id, [res.sum()])
+
+            remote_grad = rpc.rpc_sync(
+                worker_name(dst),
+                DistAutogradTest._get_grad,
+                args=(remote_embedding, context_id),
+            )
+
+            self.assertEqual(local_grad, remote_grad)
+
+
+class DistAutogradTest(CommonDistAutogradTest):
+    @dist_init
+    def test_autograd_context(self):
+        # Verify max possible id.
+        max_auto_increment = 281474976710655
+        self.assertEqual(
+            max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id()
+        )
+
+        context_ids = []
+        for _ in range(200):
+            with dist_autograd.context() as context_id:
+                self.assertEqual(
+                    context_id,
+                    dist_autograd._retrieve_context(context_id)._context_id(),
+                )
+                # First 16 bits should be worker_id.
+                self.assertEqual(self.worker_id, context_id >> 48)
+                context_ids.append(context_id)
+
+        for context_id in context_ids:
+            with self.assertRaisesRegex(
+                RuntimeError,
+                f"Could not find autograd context with id: {context_id}",
+            ):
+                dist_autograd._retrieve_context(context_id)
+
+    @dist_init
+    def test_nested_context(self):
+        with dist_autograd.context():
+            # Nested contexts not supported.
+            with self.assertRaisesRegex(
+                RuntimeError, "Already have an autograd context id for this thread"
+            ):
+                with dist_autograd.context():
+                    pass
+
+    @dist_init
+    def test_graph_for_builtin_call(self):
+        self._test_graph(torch.add, ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_python_call(self):
+        self._test_graph(my_py_add, ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_builtin_remote_call(self):
+        self._test_graph(torch.add, ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_python_remote_call(self):
+        self._test_graph(my_py_add, ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_py_nested_call(self):
+        self._test_graph_for_py_nested_call(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call(self):
+        self._test_graph_for_py_nested_call(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_graph_for_py_nested_call_itself(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_graph_for_py_nested_remote_call_itself(self):
+        self._test_graph_for_py_nested_call_itself(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_no_graph_with_tensors_not_require_grad_remote(self):
+        self._test_no_graph_with_tensors_not_require_grad(ExecMode.REMOTE, False)
+
+    def _test_grad_only_on_return_value(self, exec_mode):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dst_rank = (self.rank + 1) % self.world_size
+        with dist_autograd.context() as context_id:
+            if ExecMode.RPC_SYNC == exec_mode:
+                ret = rpc.rpc_sync(worker_name(dst_rank), ret_requires_grad)
+            elif ExecMode.REMOTE == exec_mode:
+                ret = rpc.remote(worker_name(dst_rank), ret_requires_grad).to_here()
+            else:
+                raise ValueError(f"Unrecognized ExecMode {exec_mode}")
+
+            dist_autograd.backward(context_id, [ret.sum()])
+
+            rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+
+            # Wait for the prev rank to be done with rpc.
+            self._check_rpc_done(1)
+            grads = dist_autograd.get_gradients(ctx_ids[1])
+            self.assertEqual(1, len(grads))
+            self.assertIn(requires_grad_tensor, grads)
+            self.assertEqual(torch.ones_like(ret), grads[requires_grad_tensor])
+            # due to the above get_gradients call, ensure that dist autograd
+            # contexts aren't cleaned up until all workers exit context managers
+            dist.barrier()
+
+    @dist_init
+    def test_grad_only_on_return_value(self):
+        self._test_grad_only_on_return_value(ExecMode.RPC_SYNC)
+
+    @dist_init
+    def test_grad_only_on_return_value_remote(self):
+        self._test_grad_only_on_return_value(ExecMode.REMOTE)
+
+    @dist_init
+    def test_rpc_complex_args(self):
+        self._test_rpc_complex_args(ExecMode.RPC_SYNC, False)
+
+    @dist_init
+    def test_remote_complex_args(self):
+        self._test_rpc_complex_args(ExecMode.REMOTE, False)
+
+    @dist_init
+    def test_context_cleanup_tensor_with_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_tensor_no_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=False)
+        self.context_cleanup_test_helper(rpc_args=(t1, t1), func=torch.add)
+
+    @dist_init
+    def test_context_cleanup_no_tensors(self):
+        self.context_cleanup_test_helper(rpc_args=(1, 1), func=my_scalar_add)
+
+    @dist_init
+    def test_context_cleanup_nested_rpc(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        dst_rank = (self.rank + 1) % self.world_size
+        args = (t1, t2, dst_rank, self.world_size, 0)
+        self.context_cleanup_test_helper(
+            rpc_args=args, func=my_py_nested_call, nested=True
+        )
+
+    @dist_init
+    def test_worker_ids_recorded(self):
+        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+        with dist_autograd.context() as context_id:
+            # if no tensors require grad, we should still record worker_ids, as
+            # the autograd context ID is still passed to other workers.
+            t1 = torch.ones(3, 3, requires_grad=False)
+            t2 = torch.zeros(3, 3, requires_grad=False)
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+            # all worker_ids in dst_ranks should be recorded.
+            ctx = dist_autograd._current_context()
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(worker_ids, dst_ranks)
+
+            # worker_ids should be recorded when tensors do require grad
+            t1.requires_grad = True
+            t2.requires_grad = True
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+            # all worker_ids in dst_ranks should be recorded.
+            worker_ids = ctx._known_worker_ids()
+            self.assertEqual(worker_ids, dst_ranks)
+
+    @dist_init
+    def test_dist_autograd_profiling(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(3, 3, requires_grad=True)
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.add, args=(t1, t2)
+            ).sum()
+            with torch.autograd.profiler.profile() as p:
+                dist_autograd.backward(context_id, [loss])
+
+        function_events = p.function_events
+
+        def get_event(partial_key):
+            return next(event for event in function_events if partial_key in event.name)
+
+        send_event = get_event("SendRpcBackward")
+        recv_event = get_event("RecvRpcBackward")
+        backward_event = get_event("torch::distributed::autograd::backward")
+        # There should be at least 1 send and recv_events each, corresponding to send/recv functions executed.
+        self.assertEqual(send_event.count, 1)
+        self.assertEqual(recv_event.count, 1)
+        # The CPU total for backward event should be great than send and recv, since
+        # applying those functions in the backwards pass is a subset of the entire backward pass.
+        self.assertGreater(backward_event.cpu_time_total, send_event.cpu_time_total)
+        self.assertGreater(backward_event.cpu_time_total, recv_event.cpu_time_total)
+
+    @dist_init
+    def test_error_in_context(self):
+        with dist_autograd.context():
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(6, 6, requires_grad=True)
+
+            with self.assertRaises(RuntimeError):
+                # This should throw an error since matrix sizes don't match.
+                rpc.rpc_sync(
+                    worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
+                )
+
+    @dist_init
+    def test_backward_no_grad_on_tensor(self):
+        self._backward_no_grad_on_tensor(
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple(self):
+        self._backward_simple(
+            self._next_rank(),
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_self(self):
+        self._backward_simple(
+            self.rank,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_rref(self):
+        callee = worker_name(self._next_rank())
+        rref_owner = callee
+        self._backward_rref(
+            callee,
+            rref_owner,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_rref_multi(self):
+        if self.rank > 0:
+            callee = "worker0"
+            rref_owner = callee
+            self._backward_rref(
+                callee,
+                rref_owner,
+                torch.rand((3, 3), requires_grad=True),
+                torch.rand((3, 3), requires_grad=True),
+                None,
+                False,
+            )
+
+    @dist_init
+    def test_backward_rref_nested(self):
+        callee = worker_name((self.rank + 1) % self.world_size)
+        rref_owner = worker_name((self.rank + 2) % self.world_size)
+        self._backward_rref(
+            callee,
+            rref_owner,
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_trainer_ps(self):
+        self._test_trainer_ps(create_tensor, _run_trainer, False)
+
+    @dist_init
+    def test_trainer_ps_torchscript_functions(self):
+        # TODO, need more investigation
+        # there is rref leak when shutting down, suspect it is because
+        # ref as arg is passed to pybind boundary, and the ref is not garbage
+        # collected by python when calling shutdown()
+        import torch.distributed.rpc.api as api
+
+        api._ignore_rref_leak = True
+
+        self._test_trainer_ps(
+            create_torchscript_tensor, _run_trainer_torchscript, False
+        )
+
+    @dist_init
+    def test_backward_multiple_round_trips(self):
+        self._backward_multiple_round_trips(
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3)),
+            torch.rand((3, 3), requires_grad=True),
+            torch.rand((3, 3)),
+            torch.rand((3, 3), requires_grad=True),
+            None,
+            False,
+        )
+
+    @dist_init
+    def test_backward_different_tensor_dims(self):
+        local_grads = None
+        t1 = torch.rand((4, 6), requires_grad=True)
+        t2 = torch.rand((6, 5))
+        t3 = torch.rand((5, 7), requires_grad=True)
+        t4 = torch.rand((7, 9))
+
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                val = self._exec_func(exec_mode, torch.matmul, t1, t2)
+                val = self._exec_func(exec_mode, torch.linalg.multi_dot, (val, t3, t4))
+                loss = val.sum()
+
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_backward_unused_tensors(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        t3 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3))
+                val = self._exec_func(
+                    exec_mode,
+                    torch.matmul,
+                    torch.narrow(s, 0, 0, 1),
+                    torch.narrow(s, 0, 2, 1),
+                )
+
+                loss = val.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2, t3
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_backward_multiple_output_tensors(self):
+        local_grads = None
+        t = torch.rand((10, 2), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                tensor_list = self._exec_func(exec_mode, torch.split, t, 2)
+                t1 = tensor_list[0]
+                t2 = tensor_list[2]
+                t3 = tensor_list[4]
+
+                val = self._exec_func(exec_mode, torch.linalg.multi_dot, (t1, t2, t3))
+
+                loss = val.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t
+                )
+                local_grads = ret if ret else local_grads
+
+    def _run_test_backward_unused_send_function_in_thread(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+
+            # We don't use the result of an RPC function, as a result the
+            # backward pass would hang in the "FAST" mode.
+            rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+
+            val = torch.mul(t1, t2)
+
+            # Run backward, this would hang forever.
+            dist_autograd.backward(context_id, [val.sum()])
+
+    @dist_init
+    def test_backward_unused_send_function(self):
+        # Run the test in a thread which would never finish.
+        t = threading.Thread(
+            target=self._run_test_backward_unused_send_function_in_thread
+        )
+        t.daemon = True
+        t.start()
+        t.join(10)  # Wait for 10s.
+
+        # Verify thread is still alive (indicating backward hasn't completed yet).
+        self.assertTrue(t.is_alive())
+
+    @dist_init
+    def test_backward_autograd_engine_error(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            # Perform some ops before error simulation.
+            tmp = (t1 + t2) * (t1 + t2)
+            t3 = SimulateBackwardError.apply(tmp)
+
+            # Run multiple round trips across different nodes and verify the
+            # original node receives an error thrown on a node deep in the chain.
+            val = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t2, t3))
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.mul, args=(val, t2)
+            )
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.matmul, args=(val, t2)
+            )
+            val = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.div, args=(val, t2)
+            )
+
+            with self.assertRaisesRegex(
+                RuntimeError, "Error on Node [0-9]+: Simulate error on backward pass"
+            ):
+                # Run backwards, and validate we receive an error.
+                dist_autograd.backward(context_id, [val.sum()])
+
+    @dist_init(clean_shutdown=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_backward_node_failure(self):
+        rpc._set_rpc_timeout(5)  # 5 seconds
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+
+            # Wait for all RPCs to be done.
+            dist.barrier()
+
+            # Kill all odd rank nodes.
+            if self.rank % 2 == 0:
+                shutdown_error_regex = self.get_shutdown_error_regex()
+                # Wait for all other nodes to die.
+                for rank in range(self.world_size):
+                    if rank % 2 != 0:
+                        wait_until_node_failure(rank, shutdown_error_regex)
+
+                # Shutdown sequence is not very well defined and as a result
+                # we might see any error given by get_shutdown_error_regex()
+                with self.assertRaisesRegex(RuntimeError, shutdown_error_regex):
+                    # Run backwards, and validate we receive an error since all
+                    # other nodes are dead.
+                    dist_autograd.backward(context_id, [res.sum()])
+            else:
+                # Exit all other nodes.
+                pass
+
+    @dist_init
+    def test_backward_without_context(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+
+        context_id = 100  # dummy context_id
+        with self.assertRaisesRegex(
+            RuntimeError,
+            f"Could not find autograd context with id: {context_id}",
+        ):
+            res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
+            dist_autograd.backward(context_id, [res.sum()])
+
+    @dist_init
+    def test_backward_without_rpc(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            t3 = torch.add(t1, t2)
+
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(torch.ones(3, 3), grads[t1])
+            self.assertEqual(torch.ones(3, 3), grads[t2])
+
+    @dist_init
+    def test_backward_invalid_args(self):
+        with dist_autograd.context() as context_id:
+            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
+                dist_autograd.backward(context_id, None)
+
+            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
+                dist_autograd.backward(None, None)
+
+            with self.assertRaisesRegex(
+                RuntimeError, "No tensors provided for gradient computation"
+            ):
+                dist_autograd.backward(context_id, [])
+
+            with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"):
+                t = torch.rand(3, 3)
+                dist_autograd.backward(context_id, [t])
+
+            with self.assertRaisesRegex(
+                RuntimeError, "is not a scalar, all roots need to be scalar"
+            ):
+                t = torch.rand(3, 3, requires_grad=True)
+                dist_autograd.backward(context_id, [t])
+
+            with self.assertRaisesRegex(
+                RuntimeError, "does not have a valid gradient function"
+            ):
+                t = torch.rand(1, requires_grad=True)
+                dist_autograd.backward(context_id, [t])
+
+    @dist_init
+    def test_backward_multiple_roots(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+            with dist_autograd.context() as context_id:
+                r1 = self._exec_func(exec_mode, torch.add, t1, t2).sum()
+                r2 = self._exec_func(exec_mode, torch.mul, t1, t2).sum()
+                r3 = self._exec_func(exec_mode, torch.cos, t1).sum()
+                r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum()
+
+                local_grads = self._verify_backwards(
+                    exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2
+                )
+
+    @dist_init
+    def test_backward_different_dtypes(self):
+        self._backward_different_dtypes(
+            torch.rand((3, 3), requires_grad=True, dtype=torch.float32),
+            torch.rand((3, 3), requires_grad=True, dtype=torch.float64),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_python_udf(self):
+        self._backward_simple_python_udf(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_backward_simple_script_call(self):
+        self._backward_simple_script_call(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @staticmethod
+    def _complex_python_udf(t1, t2):
+        t3 = torch.nn.functional.linear(t1, t2)
+        t4 = torch.nn.functional.linear(t2, t3)
+        t5 = torch.nn.functional.linear(t3, t4)
+        return torch.linalg.multi_dot([t1, t2, t3, t4, t5])
+
+    @dist_init
+    def test_backward_complex_python_udf(self):
+        # Run the same code locally and with dist autograd and verify gradients
+        # are same.
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(
+                    exec_mode, DistAutogradTest._complex_python_udf, t1, t2
+                )
+                loss = ret.sum()
+                local_grads = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+
+    @staticmethod
+    def _python_udf_with_backward_error(t1, t2):
+        t3 = t1 + t2
+        t4 = SimulateBackwardError.apply(t3)
+        return torch.linalg.multi_dot([t1, t2, t3, t4])
+
+    @staticmethod
+    def _nested_rpc_call_backward_error(t1, t2, dst):
+        t1 = t1 * t2
+        t2 = t1 + t2
+        res = rpc.rpc_sync(
+            worker_name(dst),
+            DistAutogradTest._python_udf_with_backward_error,
+            args=(t1, t2),
+        )
+        return torch.linalg.multi_dot([t1, t2, res])
+
+    @dist_init
+    def test_backward_python_udf_error(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                worker_name(self._next_rank()),
+                DistAutogradTest._nested_rpc_call_backward_error,
+                args=(t1, t2, self._next_rank()),
+            )
+            with self.assertRaisesRegex(
+                RuntimeError, "Simulate error on backward pass"
+            ):
+                dist_autograd.backward(context_id, [loss.sum()])
+
+    _backward_done = False
+
+    @dist_init(clean_shutdown=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_backward_node_failure_python_udf(self):
+        # Set a short timeout to quickly time out failed RPCs.
+        rpc._set_rpc_timeout(5)  # 5 seconds
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+
+            dst = self._next_rank()
+            res = rpc.rpc_sync(
+                worker_name(dst),
+                my_py_nested_call,
+                args=(t1, t2, dst, self.world_size, 1),
+            )
+
+            dist.barrier()
+
+            # Kill rank 2 (last hop of nested rpc) and verify rank 0 receives an error.
+            if self.rank == 2:
+                return
+
+            store = dist.distributed_c10d._get_default_store()
+            if self.rank == 0:
+                # Wait for rank 2 to die.
+                shutdown_error_regex = self.get_shutdown_error_regex()
+                wait_until_node_failure(2, shutdown_error_regex)
+                # Shutdown sequence is not very well defined and as a result
+                # we might see any error given by get_shutdown_error_regex().
+                with self.assertRaisesRegex(RuntimeError, shutdown_error_regex):
+                    # Run backwards, and validate we receive an error since rank 2 is dead.
+                    dist_autograd.backward(context_id, [res.sum()])
+
+                # Mark rank 0 is done in the store, since the RPC framework on
+                # some nodes might be broken at this point.
+                store.set("test_backward_node_failure_python_udf_rank0_done", "True")
+            else:
+                # Wait for backward to finish on rank 0.
+                store.wait(
+                    ["test_backward_node_failure_python_udf_rank0_done"],
+                    timedelta(seconds=10),
+                )
+
+    @staticmethod
+    def _nested_python_udf(t1, t2, dst):
+        t3 = t1 * t2
+        t4 = t1 + t2
+        res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4))
+        return t1 * t2 * t3 * t4 * res
+
+    @dist_init
+    def test_backwards_nested_python_udf(self):
+        # Run equivalent of _nested_python_udf locally.
+        self._backwards_nested_python_udf(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    _test_clean_context_backward_context_id = None
+
+    class MyBackwardFunc(Function):
+        @staticmethod
+        def forward(ctx, input):
+            return input
+
+        @staticmethod
+        @once_differentiable
+        def backward(ctx, input):
+            assert DistAutogradTest._test_clean_context_backward_context_id is not None
+
+            # Release the context to simulate error (use barrier before releasing
+            # context to ensure all nodes execute the backward function).
+            dist.barrier()
+            dist_autograd._release_context(
+                DistAutogradTest._test_clean_context_backward_context_id
+            )
+
+            # Verify all contexts are cleaned up.
+            assert _all_contexts_cleaned_up()
+
+            return input
+
+    @dist_init
+    def test_clean_context_during_backward(self):
+        """
+        This test simulates the situation where the 'backward' call might throw
+        an exception locally which would lead to the autograd context being
+        cleaned up if we're using the context manager. As a result, the autograd
+        context might be cleaned up while some threads are still using the
+        autograd context.
+
+        It is fine for the 'backward' call to throw an exception in this test,
+        but the process should not crash.
+        """
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        context = dist_autograd._new_context()
+        context_id = context._context_id()
+        DistAutogradTest._test_clean_context_backward_context_id = context_id
+
+        # Send the context id to all nodes.
+        for i in range(0, self.world_size):
+            if i != self.rank:
+                rank_distance = (i - self.rank + self.world_size) % self.world_size
+                rpc.rpc_sync(
+                    worker_name(i),
+                    _set_rpc_done,
+                    args=(context_id, rank_distance),
+                )
+
+        dist.barrier()
+
+        # Verify all context ids have been received.
+        self.assertEqual(self.world_size - 1, len(known_context_ids))
+
+        t1 = torch.rand((3, 3), requires_grad=True)
+        for i in range(0, 100):
+            dst = self._next_rank()
+            t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1))
+
+        # Call MyBackwardFunc as the first op of the backward pass to
+        # ensure we release the context early in the backward pass.
+        t1 = DistAutogradTest.MyBackwardFunc.apply(t1)
+        self.assertEqual(100, len(context._send_functions()))
+
+        context_id = 100  # dummy context_id
+        with self.assertRaisesRegex(
+            RuntimeError,
+            f"Could not find autograd context with id: {context_id}",
+        ):
+            dist_autograd.backward(context_id, [t1.sum()])
+
+        # HACK: Killing workers since otherwise the autograd engine gets stuck on
+        # other nodes. The proper fix would be addressing:
+        # https://github.com/pytorch/pytorch/issues/27643, which would inform
+        # other nodes about the failure.
+        # The autograd engine gets stuck on other nodes since they're waiting to
+        # receive gradients from the node that received an error (and as a
+        # result it didn't execute the rest of the graph).
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+        sys.exit(0)
+
+    @classmethod
+    def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights):
+        embedding = embedding_rref.local_value()
+        return embedding(input, offsets, per_sample_weights)
+
+    @classmethod
+    def _get_grad(cls, embedding_rref, context_id):
+        embedding = embedding_rref.local_value()
+        grad_map = dist_autograd.get_gradients(context_id)
+        return grad_map[embedding.weight]
+
+    @classmethod
+    def _mixed_requires_grad_operaton(cls, t1, t2):
+        if t2.requires_grad:
+            return t1 - t2
+        else:
+            return t1 * t2
+
+    @dist_init
+    def test_mixed_requires_grad(self):
+        self._mixed_requires_grad(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=False),
+            False,
+        )
+
+    class TestDebugInfoFunc(Function):
+        @staticmethod
+        def forward(ctx, input):
+            return input
+
+        @staticmethod
+        @once_differentiable
+        def backward(ctx, input):
+            debug_info = dist_autograd._get_debug_info()
+            assert debug_info is not None
+            backward_passes = int(debug_info["num_current_backward_passes"])
+
+            # Hard to validate exact numbers because of the distributed nature.
+            # We can't use a barrier() here since that would block the single
+            # CPU thread available for autograd and can cause deadlocks.
+            assert backward_passes >= 1 and backward_passes <= 4
+            return input
+
+    @dist_init
+    def test_debug_info(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            i = 0
+            res = {}
+            res[i] = t1
+            for rank in range(self.world_size):
+                if rank != self.rank:
+                    res[i + 1] = rpc.rpc_sync(
+                        worker_name(rank), torch.add, args=(res[i], t2)
+                    )
+                    i += 1
+
+            # Call custom function in middle of backward pass to ensure all
+            # nodes are still waiting on a backward().
+            res[i + 1] = DistAutogradTest.TestDebugInfoFunc.apply(res[i])
+            i += 1
+
+            for rank in range(self.world_size):
+                if rank != self.rank:
+                    res[i + 1] = rpc.rpc_sync(
+                        worker_name(rank), torch.add, args=(res[i], t2)
+                    )
+                    i += 1
+
+            dist_autograd.backward(context_id, [res[i].sum()])
+
+            debug_info = dist_autograd._get_debug_info()
+            num_autograd_context = int(debug_info["num_autograd_contexts"])
+            # Need at least one context and not more than 4.
+            self.assertTrue(num_autograd_context >= 1 and num_autograd_context <= 4)
+
+        for rd in range(self.world_size - 1):
+            rpc.rpc_sync(
+                worker_name((self.rank + rd + 1) % self.world_size),
+                _set_rpc_done,
+                args=(context_id, rd + 1),
+            )
+
+        dist.barrier()
+
+        # Validate information
+        debug_info = dist_autograd._get_debug_info()
+        assert debug_info is not None
+        self.assertEqual(0, int(debug_info["num_current_backward_passes"]))
+        # only have `num_current_backward_passes` and `num_autograd contexts`
+        self.assertTrue(len(debug_info) == 2)
+
+        self.assertTrue(_all_contexts_cleaned_up())
+
+        # All contexts should be cleaned up.
+        debug_info = dist_autograd._get_debug_info()
+        self.assertEqual(0, int(debug_info["num_autograd_contexts"]))
+
+    @staticmethod
+    def _workload_thread():
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            t3 = rpc.rpc_sync("worker0", torch.add, args=(t1, t2))
+            t4 = rpc.rpc_sync("worker0", torch.mul, args=(t2, t3))
+            t5 = rpc.rpc_sync("worker0", torch.matmul, args=(t3, t4))
+            t6 = rpc.rpc_sync("worker0", torch.add, args=(t4, t5))
+
+            dist_autograd.backward(context_id, [t6.sum()])
+
+    @dist_init
+    def test_async_dist_autograd(self):
+        """
+        This test ensures async processing for distributed autograd works
+        appropriately. This is achieved by spawning multiple threads and
+        hammering a single node with a lot of backward() calls.
+        """
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        if self.rank != 0:
+            # All other ranks schedule work on rank 0.
+            threads = []
+            for _ in range(20):
+                t = threading.Thread(target=DistAutogradTest._workload_thread)
+                t.start()
+                threads.append(t)
+
+            for thread in threads:
+                thread.join()
+
+        dist.barrier()
+
+    @dist_init
+    def test_backward_accumulate_grads(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            t3 = torch.matmul(t1, t2)
+            # Run backward twice.
+            torch.autograd.backward([t3.sum()], retain_graph=True)
+            torch.autograd.backward([t3.sum()])
+
+            t3 = rpc.rpc_sync(
+                worker_name(self._next_rank()), torch.matmul, args=(t1, t2)
+            )
+            # Run backward twice.
+            dist_autograd.backward(context_id, [t3.sum()], retain_graph=True)
+            dist_autograd.backward(context_id, [t3.sum()])
+
+            # Verify the gradients are same for local and remote execution.
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    @staticmethod
+    def _test_nested_backward_accumulate_grads(t1, t2, dst_rank):
+        return rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
+
+    @dist_init
+    def test_nested_backward_accumulate_grads(self):
+        self._nested_backward_accumulate_grads(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init
+    def test_multiple_backward(self):
+        self._multiple_backward(
+            torch.rand(3, 3, requires_grad=True),
+            torch.rand(3, 3, requires_grad=True),
+            False,
+        )
+
+    @dist_init(clean_shutdown=False)
+    def test_multiple_backward_with_errors(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        with dist_autograd.context() as context_id:
+            loss = rpc.rpc_sync(
+                f"worker{self._next_rank()}",
+                DistAutogradTest._python_udf_with_backward_error,
+                args=(t1, t2),
+            ).sum()
+
+            try:
+                # Run backward in a loop multiple times.
+                for i in range(100):
+                    if i < 50:
+                        with self.assertRaisesRegex(
+                            RuntimeError, "Simulate error on backward pass"
+                        ):
+                            dist_autograd.backward(
+                                context_id, [loss], retain_graph=True
+                            )
+                    elif i > 50:
+                        # Recovered from error.
+                        dist_autograd.backward(context_id, [loss], retain_graph=True)
+                    else:
+                        dist.barrier()
+                        SimulateBackwardError._simulate_error = False
+                        dist.barrier()
+            finally:
+                # Sync before resetting flag.
+                dist.barrier()
+
+                # Reset the flag.
+                SimulateBackwardError._simulate_error = True
+
+    @dist_init
+    def test_backward_verify_hooks(self):
+        t1 = torch.ones((3, 3), requires_grad=True)
+        # Double the gradient.
+        t1.register_hook(lambda grad: grad * 2)
+        t2 = torch.ones((3, 3), requires_grad=True)
+        local_grads = None
+        for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(exec_mode, torch.matmul, t1, t2)
+                loss = ret.sum()
+                ret = self._verify_backwards(
+                    exec_mode, [loss], context_id, local_grads, t1, t2
+                )
+                local_grads = ret if ret else local_grads
+
+    @dist_init
+    def test_no_grad_copy(self):
+        """
+        Similar to test in test_autograd.py.
+        """
+
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp1, inp2):
+                return inp1 + inp2
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad.data_ptr()
+                return grad, grad
+
+        class MyFuncSingleGrad(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFuncSingleGrad.static_grad_ptr = grad.data_ptr()
+                return grad
+
+        class NonContGradFunc(Function):
+            @staticmethod
+            def forward(ctx, inp1):
+                ctx.size = inp1.size()
+                return torch.tensor([1.0])
+
+            @staticmethod
+            def backward(ctx, grad):
+                return torch.ones(1).expand(ctx.size)
+
+        a = torch.randn(5, 6, requires_grad=True)
+        b = torch.randn(5, 6, requires_grad=True)
+        # non-contiguous grad should be copied
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(
+                context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))]
+            )
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
+            self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
+
+        # test case that should trigger no copy for a
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(context_id, [MyFuncSingleGrad.apply(a)[1][0]])
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFuncSingleGrad.static_grad_ptr
+            p_a = grads[a].data_ptr()
+            # Verify there was no clone.
+            self.assertTrue(p_a == p_g)
+
+        # Test case that should trigger copy for both of a,b. This is
+        # different in the distributed autograd case since we hold
+        # a reference to all grads in a vector until all accumulation is done.
+        with dist_autograd.context() as context_id:
+            dist_autograd.backward(context_id, [MyFunc.apply(a, b)[1][0]])
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a].data_ptr()
+            p_b = grads[b].data_ptr()
+            # check a,b uses different grad buffer
+            self.assertFalse(p_a == p_b)
+            # both should be copied.
+            self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
+            self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
+
+    @dist_init
+    def test_no_grad_copy_sparse(self):
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad._values().data_ptr()
+                return grad
+
+        class NonContGradFunc(Function):
+            static_grad_ptr = None
+
+            @staticmethod
+            def forward(ctx, inp1, inp2):
+                return inp1 + inp2
+
+            @staticmethod
+            def backward(ctx, grad):
+                # Create a sparse tensor with non-contiguous indices and values
+                # and return as grad.
+                v = torch.rand(1, 3)
+                i = torch.ones(1, 1, dtype=torch.long)
+                nv = v.expand(8, 3)
+                ni = i.expand(1, 8)
+                ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32)
+                NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr()
+                return ngrad, ngrad
+
+        a = torch.randn(10, 3, requires_grad=True)
+        b = torch.randn(10, 3, requires_grad=True)
+        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
+        offsets = torch.tensor([0, 4])
+        import torch.nn.functional as F
+
+        # test case that should trigger no copy for a.
+        with dist_autograd.context() as context_id:
+            emb_matrix = MyFunc.apply(a)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            # check a uses the same buffer
+            self.assertTrue(p_a == p_g)
+
+            # Run backwards multiple times.
+            for _ in range(10):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+        # non-contiguous indices and value, we should trigger a copy.
+        with dist_autograd.context() as context_id:
+            emb_matrix = NonContGradFunc.apply(a, b)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = NonContGradFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            p_b = grads[b]._values().data_ptr()
+            # check a,b uses different grad buffer
+            self.assertFalse(p_a == p_b)
+            # Verify we cloned both grads.
+            self.assertFalse(p_a == p_g)
+            self.assertFalse(p_b == p_g)
+
+            # Run backwards multiple times to verify accumulation.
+            for _ in range(10):
+                dist_autograd.backward(context_id, [loss], retain_graph=True)
+
+    @dist_init
+    def test_grad_copy_sparse_indices_extra_ref(self):
+        # create autograd function that saves grad pointer as class static
+        class MyFunc(Function):
+            static_grad_ptr = None
+            static_grad_indices_ref = None
+            static_grad_values_ref = None
+
+            @staticmethod
+            def forward(ctx, inp):
+                return inp
+
+            @staticmethod
+            def backward(ctx, grad):
+                MyFunc.static_grad_ptr = grad._values().data_ptr()
+                # indices() and values() return views, so holding onto
+                # references of them would not increment refcount of indices
+                # and values inside the sparse tensor.
+                MyFunc.static_grad_indices_ref = grad._indices()
+                MyFunc.static_grad_values_ref = grad._values()
+                return grad
+
+        a = torch.randn(10, 3, requires_grad=True)
+        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
+        offsets = torch.tensor([0, 4])
+        import torch.nn.functional as F
+
+        with dist_autograd.context() as context_id:
+            emb_matrix = MyFunc.apply(a)
+            loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
+            dist_autograd.backward(context_id, [loss], retain_graph=True)
+            grads = dist_autograd.get_gradients(context_id)
+            p_g = MyFunc.static_grad_ptr
+            p_a = grads[a]._values().data_ptr()
+            self.assertIsNotNone(MyFunc.static_grad_indices_ref)
+            self.assertIsNotNone(MyFunc.static_grad_values_ref)
+            # grad would be stolen, since static_grad_indices_ref and
+            # static_grad_values_ref are holding onto views and don't bump the
+            # refcount.
+            self.assertTrue(p_g == p_a)
+
+    @dist_init
+    def test_post_hooks(self):
+        self.hook_called_times = 0
+
+        def post_hook_add_one(output_grads, input_grads):
+            self.hook_called_times += 1
+            return output_grads
+
+        def post_hook_add_two(output_grads, input_grads):
+            self.hook_called_times += 2
+            return output_grads
+
+        t = torch.rand(10, 10, requires_grad=True)
+        a = t + t
+
+        # Register post hooks
+        accumulate_grad_0 = a.grad_fn.next_functions[0][0]
+        accumulate_grad_0.register_hook(post_hook_add_one)
+        accumulate_grad_0.register_hook(post_hook_add_two)
+
+        accumulate_grad_1 = a.grad_fn.next_functions[1][0]
+        accumulate_grad_1.register_hook(post_hook_add_two)
+
+        with dist_autograd.context() as context_id:
+            loss = a.sum()
+            dist_autograd.backward(context_id, [loss])
+            self.assertEqual(5, self.hook_called_times)
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(1, len(grads))
+            self.assertTrue(t in grads)
+
+    @staticmethod
+    def _slow_add(t1, t2):
+        time.sleep(1)
+        t3 = t1 + t2
+        t3.requires_grad = True
+        return t3
+
+    @dist_init
+    def test_thread_local_context_id(self):
+        t1 = torch.rand((3, 3))
+        t2 = torch.rand((3, 3))
+
+        t3 = t1 + t2
+        t3.requires_grad = True
+        t3.sum().backward()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, DistAutogradTest._slow_add, args=(t1, t2))
+
+        with dist_autograd.context() as context_id:
+            loss = rref.to_here().sum()
+            # due to slow add, the continuation of this backward pass will be
+            # invoked by the previous rpc.remote thread which does not have a
+            # valid context_id. So, this can test whether we propagate
+            # thread_local states properly when jumping across threads on the
+            # server side.
+            dist_autograd.backward(context_id, [loss])
+            self.assertTrue(
+                rpc.rpc_sync(
+                    dst, _compare_owner_value, args=(context_id, rref, t3.grad)
+                )
+            )
+
+
+class CudaDistAutogradTest(CommonDistAutogradTest):
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_simple(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        (t1 + t2).sum().backward()
+        with dist_autograd.context() as context_id:
+            t3 = t1 + t2
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertEqual(t1.grad, grads[t1])
+            self.assertEqual(t2.grad, grads[t2])
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_to_cpu_continuation(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True)
+        # Run a few iterations.
+        for _ in range(3):
+            t1.grad = None
+            t2.grad = None
+            # Root is CPU
+            local_grads = None
+            for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+                with dist_autograd.context() as context_id:
+                    t3 = self._exec_func(exec_mode, torch.add, t2, t2)
+                    t4 = t3.cuda(0) + t1
+                    t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2)
+                    t6 = t5.cuda(0) + t4
+                    t7 = self._exec_func(exec_mode, torch.add, t6.cpu(), t5)
+                    # Autograd graph consists of CPU -> GPU -> CPU execution.
+                    ret = self._verify_backwards(
+                        exec_mode, [t7.sum()], context_id, local_grads, t1, t2
+                    )
+                    local_grads = ret if ret else local_grads
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_gpu_to_cpu_continuation_gpu_root(self):
+        t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0")
+        t2 = torch.rand(3, 3, requires_grad=True)
+        # Run a few iterations.
+        for _ in range(3):
+            t1.grad = None
+            t2.grad = None
+            # Root is CPU
+            local_grads = None
+            for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC]:
+                with dist_autograd.context() as context_id:
+                    t3 = self._exec_func(exec_mode, torch.add, t2, t2)
+                    t4 = t3.cuda(0) + t1
+                    t5 = self._exec_func(exec_mode, torch.add, t4.cpu(), t2)
+                    t6 = t5.cuda(0) + t4
+                    # Autograd graph consists of CPU -> GPU -> CPU execution.
+                    ret = self._verify_backwards(
+                        exec_mode, [t6.sum()], context_id, local_grads, t1, t2
+                    )
+                    local_grads = ret if ret else local_grads
+
+
+class FaultyAgentDistAutogradTest(RpcAgentTestFixture):
+    # Reusing a simplified helper function from DistAutogradTest to ensure
+    # autograd context is successfully cleaned up even when RPCs are failing.
+    def context_cleanup_test_helper(self, rpc_args, func):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # test that in dist autograd, in the case that tensors communicated over RPC do
+        # NOT require grad, we still cleanup the dist autograd contexts created
+        # on other nodes. This is because the autograd context is still
+        # communicated over RPC even if tensor arguments do not require grad, as
+        # it is possible that the response could.
+        dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank}
+
+        with dist_autograd.context() as context_id:
+            for dst_rank in dst_ranks:
+                rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
+                rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
+        # the thread's context id should be cleaned up
+        with self.assertRaises(RuntimeError):
+            dist_autograd._retrieve_context(context_id)
+        # Ensure all peers have finished mutating the
+        # `known_context_ids` set.
+        dist.barrier()
+        # check that all contexts have been cleaned up.
+        success = _all_contexts_cleaned_up()
+        self.assertTrue(success)
+
+    # no faulty_messages defined so this fails all retryable messages - see
+    # faulty_rpc_agent_test_fixture.py for the list of retryable messages.
+    @dist_init
+    def test_context_cleanup_tensor_with_grad(self):
+        t1 = torch.ones(3, 3, requires_grad=True)
+        t2 = torch.zeros(3, 3, requires_grad=True)
+        self.context_cleanup_test_helper(rpc_args=(t1, t2), func=torch.add)
+
+    @dist_init
+    def test_verify_backend_options(self):
+        self.assertEqual(
+            self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
+        )
+        self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
+        self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
+
+
+class WrapperModule(nn.Module):
+    def __init__(self, model, device):
+        super().__init__()
+        self.model = model.to(device)
+
+    def forward(self, *args):
+        return self.model(*args)
+
+    def gradients(self, ctx_id):
+        grads = dist_autograd.get_gradients(ctx_id)
+        return [grads[p] for p in self.model.parameters()]
+
+
+class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_backward_pass(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        # The reverse of this device mapping should be used for the backward pass.
+        options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        t1 = torch.rand(10, device=self.rank, requires_grad=True)
+        t2 = torch.rand(10, device=self.rank, requires_grad=True)
+        with dist_autograd.context() as context_id:
+            res = rpc.rpc_sync(dst, torch.add, args=(t1, t2))
+            dist_autograd.backward(context_id, [res.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(torch.ones(10), grads[t1])
+            self.assertEqual(torch.ones(10), grads[t2])
+            self.assertEqual(t1.device, grads[t1].device)
+            self.assertEqual(t2.device, grads[t2].device)
+
+        rpc.shutdown()
+
+    class MyRemoteCompute(torch.nn.Module):
+        def forward(self, input):
+            input = input * 2.0
+            return input
+
+    class MyLocalCompute(torch.nn.Module):
+        def __init__(self, next_stage):
+            super().__init__()
+            self.next_stage = next_stage
+
+        def forward(self, input):
+            return self.next_stage.rpc_sync().forward(input)
+
+    @skip_if_lt_x_gpu(4)
+    def test_dist_autograd_sync_streams(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        # The reverse of this device mapping should be used for the backward pass.
+        options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        remote_compute = rpc.remote(dst, TensorPipeCudaDistAutogradTest.MyRemoteCompute)
+        local_compute = TensorPipeCudaDistAutogradTest.MyLocalCompute(remote_compute)
+        for _ in range(10):
+            input = torch.rand([1000, 10000], device=self.rank, requires_grad=True)
+            # Run local autograd
+            result = input * 2.0
+            r = random.random()
+            loss = result.sum() * r
+            loss.backward()
+
+            # Run distributed autograd
+            with dist_autograd.context() as context_id:
+                result = local_compute(input)
+                loss = result.sum() * r
+                dist_autograd.backward(context_id, [loss])
+
+                # Compare grads.
+                grads = dist_autograd.get_gradients(context_id)
+                self.assertEqual(input.grad, grads[input])
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(4)
+    def test_gradients_synchronizations(self):
+        options = self.rpc_backend_options
+        for peer_rank in range(self.world_size):
+            options.set_device_map(worker_name(peer_rank), {self.rank: peer_rank})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 0:
+            # this is master
+            layers = [nn.Linear(2000, 2000) for _ in range(self.world_size - 1)]
+            local_layers = [l.to(0) for l in layers]
+            remote_layers = [
+                rpc.remote(
+                    worker_name(rank), WrapperModule, args=(layers[rank - 1], rank)
+                )
+                for rank in range(1, self.world_size)
+            ]
+
+            x = torch.randn(5000, 2000).to(0)
+            # local iteration
+            local_model = nn.Sequential(*local_layers)
+            local_model(x).sum().backward()
+
+            # remote iteration
+            with dist_autograd.context() as context_id:
+                for remote_layer in remote_layers:
+                    x = remote_layer.rpc_sync().forward(x)
+
+                dist_autograd.backward(context_id, [x.sum()])
+
+                futs = []
+                for remote_layer in remote_layers:
+                    futs.append(remote_layer.rpc_async().gradients(context_id))
+
+                for i in range(len(futs)):
+                    local_gradients = [p.grad for p in local_layers[i].parameters()]
+                    for g1, g2 in zip(futs[i].wait(), local_gradients):
+                        self.assertEqual(g1, g2)
+
+        rpc.shutdown()
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d335325f8364241dd14517da5c67c2a6e6a032b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
@@ -0,0 +1,281 @@
+# mypy: allow-untyped-defs
+
+
+import threading
+
+import torch
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+from torch import optim
+from torch.distributed.optim import DistributedOptimizer
+from torch.testing._internal.dist_utils import dist_init
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+class MyModule:
+    lock = threading.Lock()
+
+    def __init__(self, requires_grad=True):
+        # cannot directly use torch.manual_seed(0) as all threads share the same
+        # default generator. The race from multiple RPC threads could mess up
+        # the draw order from the default RNG instance, leading to
+        # non-deterministic behavior. Hence, create a dedicated RNG here.
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        self.w = torch.rand((3, 3), requires_grad=requires_grad, generator=g_cpu)
+
+    def forward(self, t1):
+        return torch.mm(self.w, t1)
+
+    def get_w(self):
+        return self.w
+
+
+class FailingOptimizer(optim.Optimizer):
+    def __init__(self, params):
+        super().__init__(params, {})
+
+    def step(self, closure=None):
+        raise ValueError("Error running optimizer.")
+
+
+class OptimizerFailingOnConstructor(optim.Optimizer):
+    def __init__(self, params):
+        super().__init__(params, {})
+        raise ValueError("Error creating optimizer.")
+
+    def step(self, closure=None):
+        raise NotImplementedError
+
+
+def _call_method(method, obj_rref, *args, **kwargs):
+    return method(obj_rref.local_value(), *args, **kwargs)
+
+
+def remote_method(method, obj_rref, *args, **kwargs):
+    """
+    Call rpc.remote on a method in a remote object.
+
+    Args:
+        method: the method (for example, Class.method)
+        obj_rref (RRef): remote reference to the object
+        args: positional arguments to pass to the method
+        kwargs: keyword arguments to pass to the method
+
+    Returns a RRef to the remote method call result.
+    """
+    return rpc.remote(
+        obj_rref.owner(),
+        _call_method,
+        args=[method, obj_rref] + list(args),
+        kwargs=kwargs,
+    )
+
+
+def rpc_async_method(method, obj_rref, *args, **kwargs):
+    """
+    Call rpc.rpc_async on a method in a remote object.
+
+    Args:
+        method: the method (for example, Class.method)
+        obj_rref (RRef): remote reference to the object
+        args: positional arguments to pass to the method
+        kwargs: keyword arguments to pass to the method
+
+    Returns a Future to the method call result.
+    """
+    return rpc.rpc_async(
+        obj_rref.owner(),
+        _call_method,
+        args=[method, obj_rref] + list(args),
+        kwargs=kwargs,
+    )
+
+
+class DistOptimizerTest(RpcAgentTestFixture):
+    @dist_init()
+    def test_dist_optim_exception(self):
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        dist_optim = DistributedOptimizer(
+            FailingOptimizer, [remote_param1, remote_param2]
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu = torch.Generator()
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
+            output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
+            loss = torch.add(output2.wait(), t1).sum()
+
+            dist_autograd.backward(context_id, [loss])
+            with self.assertRaisesRegex(Exception, "Error running optimizer"):
+                dist_optim.step(context_id)
+
+    @dist_init()
+    def test_dist_optim_exception_on_constructor(self):
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        with self.assertRaisesRegex(Exception, "Error creating optimizer."):
+            DistributedOptimizer(
+                OptimizerFailingOnConstructor, [remote_param1, remote_param2]
+            )
+
+    def _test_dist_optim_base(self, optim_cls, *args, **kwargs):
+        # local version
+        module1 = MyModule()
+        module2 = MyModule()
+        params = [module1.get_w(), module2.get_w()]
+        local_optim = optim_cls(params, *args, **kwargs)
+
+        old_w1 = module1.w.detach().clone()
+        old_w2 = module2.w.detach().clone()
+
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        output1 = module1.forward(t2)
+        output2 = module2.forward(output1)
+        loss = torch.add(output2, t1).sum()
+
+        loss.backward()
+        local_optim.step()
+
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule)
+        remote_param1 = remote_method(MyModule.get_w, remote_module1)
+        remote_param2 = remote_method(MyModule.get_w, remote_module2)
+
+        # sanity check: local and remote initial weights should match
+        self.assertEqual(old_w1, remote_param1.to_here())
+        self.assertEqual(old_w2, remote_param2.to_here())
+
+        dist_optim = DistributedOptimizer(
+            optim_cls, [remote_param1, remote_param2], *args, **kwargs
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
+            output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
+            loss = torch.add(output2.wait(), t1)
+
+            dist_autograd.backward(context_id, [loss.sum()])
+            dist_optim.step(context_id)
+
+            new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
+            new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
+
+            # ensure optimizer changed weights
+            self.assertNotEqual(old_w1, new_w1)
+            self.assertNotEqual(old_w2, new_w2)
+            # ensure local equals remote
+            self.assertEqual(new_w1, module1.get_w())
+            self.assertEqual(new_w2, module2.get_w())
+
+    @dist_init()
+    def test_dist_optim(self):
+        self._test_dist_optim_base(optim.Adagrad, lr=0.05)
+        self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True)
+        self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True)
+        self._test_dist_optim_base(optim.SGD, lr=0.05)
+        self._test_dist_optim_base(
+            optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True
+        )
+        self._test_dist_optim_base(optim.Adadelta, rho=0.95)
+        self._test_dist_optim_base(optim.RMSprop, lr=0.05)
+        self._test_dist_optim_base(optim.Adamax, lr=0.05)
+        self._test_dist_optim_base(optim.Rprop, lr=0.05)
+
+    def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs):
+        # local version
+        module1 = MyModule()
+        module2 = MyModule(requires_grad=False)
+        params = [module1.get_w(), module2.get_w()]
+        local_optim = optim_cls(params, *args, **kwargs)
+
+        old_w1 = module1.w.detach().clone()
+        old_w2 = module2.w.detach().clone()
+
+        g_cpu = torch.Generator()
+        g_cpu.manual_seed(0)
+        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+        output1 = module1.forward(t2)
+        output2 = module2.forward(output1)
+        loss = torch.add(output2, t1).sum()
+
+        loss.backward()
+        local_optim.step()
+
+        # distributed version
+        owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
+        owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
+
+        remote_module1 = rpc.remote(owner1, MyModule)
+        remote_module2 = rpc.remote(owner2, MyModule, args=(False,))
+        remote_param1 = remote_module1.remote().get_w()
+        remote_param2 = remote_module2.remote().get_w()
+
+        # sanity check: local and remote initial weights should match
+        self.assertEqual(old_w1, remote_param1.to_here())
+        self.assertEqual(old_w2, remote_param2.to_here())
+
+        dist_optim = DistributedOptimizer(
+            optim_cls, [remote_param1, remote_param2], *args, **kwargs
+        )
+
+        with dist_autograd.context() as context_id:
+            g_cpu.manual_seed(0)
+            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
+            output1 = remote_module1.rpc_async().forward(t2)
+            output2 = remote_module2.rpc_async().forward(output1.wait())
+            loss = torch.add(output2.wait(), t1)
+
+            dist_autograd.backward(context_id, [loss.sum()])
+            dist_optim.step(context_id)
+
+            new_w1 = remote_module1.rpc_async().get_w().wait()
+            new_w2 = remote_module2.rpc_async().get_w().wait()
+
+            # ensure optimizer changed weights for w1
+            self.assertNotEqual(old_w1, new_w1)
+
+            # ensure optimizer not changed weights for w2
+            self.assertEqual(old_w2, new_w2)
+            # ensure local equals remote
+            self.assertEqual(new_w1, module1.get_w())
+            self.assertEqual(new_w2, module2.get_w())
+
+    @dist_init()
+    def test_dist_optim_none_grads(self):
+        self._test_dist_optim_none_grads(optim.SGD, lr=0.05)
+        self._test_dist_optim_none_grads(optim.RMSprop, lr=0.05)
+        self._test_dist_optim_none_grads(optim.Rprop, lr=0.05)
+        self._test_dist_optim_none_grads(optim.Adadelta, rho=0.95)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f84ba5225c6e833a82d2d6e43f7934628a1dc589
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
@@ -0,0 +1,140 @@
+# mypy: allow-untyped-defs
+
+# If you need to modify this file to make this test pass, please also apply same edits accordingly to
+# https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py
+# and https://pytorch.org/tutorials/intermediate/rpc_async_execution.html#batch-updating-parameter-server
+
+import threading
+from datetime import datetime
+from time import perf_counter
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+from torch import optim
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+batch_size = 20
+in_features = 100
+out_features = 30
+num_batches = 4
+
+
+def timed_log(text):
+    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")
+
+
+class BatchUpdateParameterServer:
+    def __init__(self, batch_update_size):
+        self.model = nn.Linear(in_features, out_features)
+        self.lock = threading.Lock()
+        self.future_model = torch.futures.Future()
+        self.batch_update_size = batch_update_size
+        self.curr_update_size = 0
+        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
+        for p in self.model.parameters():
+            p.grad = torch.zeros_like(p)
+
+    def get_model(self):
+        return self.model
+
+    @staticmethod
+    @rpc.functions.async_execution
+    def update_and_fetch_model(ps_rref, grads):
+        self = ps_rref.local_value()
+        for p, g in zip(self.model.parameters(), grads):
+            if p.grad is None:
+                p.grad = g
+            else:
+                p.grad += g
+        with self.lock:
+            timed_log(
+                f"PS got {self.curr_update_size}/{self.batch_update_size} updates"
+            )
+            self.curr_update_size += 1
+            fut = self.future_model
+
+            if self.curr_update_size >= self.batch_update_size:
+                for p in self.model.parameters():
+                    p.grad /= self.batch_update_size
+                self.curr_update_size = 0
+                self.optimizer.step()
+                self.optimizer.zero_grad()
+                fut.set_result(self.model)
+                timed_log("PS updated model")
+                self.future_model = torch.futures.Future()
+
+        return fut
+
+
+class Trainer:
+    def __init__(self, ps_rref):
+        self.ps_rref = ps_rref
+        self.loss_fn = nn.L1Loss()
+
+    def get_next_batch(self):
+        for _ in range(num_batches):
+            inputs = torch.randn(batch_size, in_features)
+            labels = torch.zeros(batch_size, out_features)
+            yield inputs, labels
+
+    def train(self):
+        name = rpc.get_worker_info().name
+        m = self.ps_rref.rpc_sync().get_model()
+        for inputs, labels in self.get_next_batch():
+            timed_log(f"{name} processing one batch")
+            self.loss_fn(m(inputs), labels).backward()
+            timed_log(f"{name} reporting grads")
+            m = rpc.rpc_sync(
+                self.ps_rref.owner(),
+                BatchUpdateParameterServer.update_and_fetch_model,
+                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
+            )
+            timed_log(f"{name} got updated model")
+
+
+def run_trainer(ps_rref):
+    trainer = Trainer(ps_rref)
+    trainer.train()
+
+
+def run_ps(trainers):
+    timed_log("Start training")
+    start = perf_counter()
+    ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers)))
+    futs = [
+        rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) for trainer in trainers
+    ]
+
+    torch.futures.wait_all(futs)
+    stop = perf_counter()
+    timed_log("Finish training")
+    timed_log(f"Time spent training: {stop - start}s")
+
+
+class ParameterServerTest(RpcAgentTestFixture):
+    @dist_init(setup_rpc=False)
+    def test_batch_updating_parameter_server(self):
+        if self.rank != 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        else:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            run_ps([f"{worker_name(r)}" for r in range(1, self.world_size)])
+
+        rpc.shutdown()
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb08a25484de7d23fde9e314ec881056da72b30
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py
@@ -0,0 +1,265 @@
+# mypy: allow-untyped-defs
+
+# If you need to modify this file to make this test pass, please also apply same edits accordingly to
+# https://github.com/pytorch/examples/blob/master/distributed/rpc/rl/main.py
+# and https://pytorch.org/tutorials/intermediate/rpc_tutorial.html
+
+import numpy as np
+
+import torch
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.distributed.rpc import remote, rpc_async, rpc_sync, RRef
+from torch.distributions import Categorical
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+TOTAL_EPISODE_STEP = 5000
+GAMMA = 0.1
+SEED = 543
+
+
+def _call_method(method, rref, *args, **kwargs):
+    r"""
+    a helper function to call a method on the given RRef
+    """
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def _remote_method(method, rref, *args, **kwargs):
+    r"""
+    a helper function to run method on the owner of rref and fetch back the
+    result using RPC
+    """
+    args = [method, rref] + list(args)
+    return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)
+
+
+class Policy(nn.Module):
+    r"""
+    Borrowing the ``Policy`` class from the Reinforcement Learning example.
+    Copying the code to make these two examples independent.
+    See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.affine1 = nn.Linear(4, 128)
+        self.dropout = nn.Dropout(p=0.6)
+        self.affine2 = nn.Linear(128, 2)
+
+        self.saved_log_probs = []
+        self.rewards = []
+
+    def forward(self, x):
+        x = self.affine1(x)
+        x = self.dropout(x)
+        x = F.relu(x)
+        action_scores = self.affine2(x)
+        return F.softmax(action_scores, dim=1)
+
+
+class DummyEnv:
+    r"""
+    A dummy environment that implements the required subset of the OpenAI gym
+    interface. It exists only to avoid a dependency on gym for running the
+    tests in this file. It is designed to run for a set max number of iterations,
+    returning random states and rewards at each step.
+    """
+
+    def __init__(self, state_dim=4, num_iters=10, reward_threshold=475.0):
+        self.state_dim = state_dim
+        self.num_iters = num_iters
+        self.iter = 0
+        self.reward_threshold = reward_threshold
+
+    def seed(self, manual_seed):
+        torch.manual_seed(manual_seed)
+
+    def reset(self):
+        self.iter = 0
+        return torch.randn(self.state_dim)
+
+    def step(self, action):
+        self.iter += 1
+        state = torch.randn(self.state_dim)
+        reward = torch.rand(1).item() * self.reward_threshold
+        done = self.iter >= self.num_iters
+        info = {}
+        return state, reward, done, info
+
+
+class Observer:
+    r"""
+    An observer has exclusive access to its own environment. Each observer
+    captures the state from its environment, and send the state to the agent to
+    select an action. Then, the observer applies the action to its environment
+    and reports the reward to the agent.
+    """
+
+    def __init__(self) -> None:
+        self.id = rpc.get_worker_info().id
+        self.env = DummyEnv()
+        self.env.seed(SEED)
+
+    def run_episode(self, agent_rref, n_steps):
+        r"""
+        Run one episode of n_steps.
+        Arguments:
+            agent_rref (RRef): an RRef referencing the agent object.
+            n_steps (int): number of steps in this episode
+        """
+        state, _ep_reward = self.env.reset(), 0
+        for _ in range(n_steps):
+            # send the state to the agent to get an action
+            action = _remote_method(Agent.select_action, agent_rref, self.id, state)
+
+            # apply the action to the environment, and get the reward
+            state, reward, done, _ = self.env.step(action)
+
+            # report the reward to the agent for training purpose
+            _remote_method(Agent.report_reward, agent_rref, self.id, reward)
+
+            if done:
+                break
+
+
+class Agent:
+    def __init__(self, world_size):
+        self.ob_rrefs = []
+        self.agent_rref = RRef(self)
+        self.rewards = {}
+        self.saved_log_probs = {}
+        self.policy = Policy()
+        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
+        self.eps = np.finfo(np.float32).eps.item()
+        self.running_reward = 0
+        self.reward_threshold = DummyEnv().reward_threshold
+        for ob_rank in range(1, world_size):
+            ob_info = rpc.get_worker_info(worker_name(ob_rank))
+            self.ob_rrefs.append(remote(ob_info, Observer))
+            self.rewards[ob_info.id] = []
+            self.saved_log_probs[ob_info.id] = []
+
+    def select_action(self, ob_id, state):
+        r"""
+        This function is mostly borrowed from the Reinforcement Learning example.
+        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+        The main difference is that instead of keeping all probs in one list,
+        the agent keeps probs in a dictionary, one key per observer.
+
+        NB: no need to enforce thread-safety here as GIL will serialize
+        executions.
+        """
+        probs = self.policy(state.unsqueeze(0))
+        m = Categorical(probs)
+        action = m.sample()
+        self.saved_log_probs[ob_id].append(m.log_prob(action))
+        return action.item()
+
+    def report_reward(self, ob_id, reward):
+        r"""
+        Observers call this function to report rewards.
+        """
+        self.rewards[ob_id].append(reward)
+
+    def run_episode(self, n_steps=0):
+        r"""
+        Run one episode. The agent will tell each observer to run n_steps.
+        """
+        # make async RPC to kick off an episode on all observers
+        futs = [
+            rpc_async(
+                ob_rref.owner(),
+                _call_method,
+                args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps),
+            )
+            for ob_rref in self.ob_rrefs
+        ]
+
+        # wait until all observers have finished this episode
+        for fut in futs:
+            fut.wait()
+
+    def finish_episode(self):
+        r"""
+        This function is mostly borrowed from the Reinforcement Learning example.
+        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
+        The main difference is that it joins all probs and rewards from
+        different observers into one list, and uses the minimum observer rewards
+        as the reward of the current episode.
+        """
+
+        # joins probs and rewards from different observers into lists
+        R, probs, rewards = 0, [], []
+        for ob_id in self.rewards:
+            probs.extend(self.saved_log_probs[ob_id])
+            rewards.extend(self.rewards[ob_id])
+
+        # use the minimum observer reward to calculate the running reward
+        min_reward = min(sum(self.rewards[ob_id]) for ob_id in self.rewards)
+        self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
+
+        # clear saved probs and rewards
+        for ob_id in self.rewards:
+            self.rewards[ob_id] = []
+            self.saved_log_probs[ob_id] = []
+
+        policy_loss, returns = [], []
+        for r in rewards[::-1]:
+            R = r + GAMMA * R
+            returns.insert(0, R)
+        returns = torch.tensor(returns)
+        returns = (returns - returns.mean()) / (returns.std() + self.eps)
+        for log_prob, R in zip(probs, returns):
+            policy_loss.append(-log_prob * R)
+        self.optimizer.zero_grad()
+        policy_loss = torch.cat(policy_loss).sum()
+        policy_loss.backward()
+        self.optimizer.step()
+        return min_reward
+
+
+def run_agent(agent, n_steps):
+    while True:
+        agent.run_episode(n_steps=n_steps)
+        agent.finish_episode()
+
+        if agent.running_reward > agent.reward_threshold:
+            print(f"Solved! Running reward is now {agent.running_reward}!")
+            break
+
+
+class ReinforcementLearningRpcTest(RpcAgentTestFixture):
+    @dist_init(setup_rpc=False)
+    def test_rl_rpc(self):
+        if self.rank == 0:
+            # Rank 0 is the agent.
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            agent = Agent(self.world_size)
+            run_agent(agent, n_steps=int(TOTAL_EPISODE_STEP / (self.world_size - 1)))
+
+            # Ensure training was run. We don't really care about whether the task was learned,
+            # since the purpose of the test is to check the API calls.
+            self.assertGreater(agent.running_reward, 0.0)
+        else:
+            # Other ranks are observers that passively wait for instructions from the agent.
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        rpc.shutdown()
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..747155e3e1cbce8f8e8c14756fe3f98bf22a8987
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py
@@ -0,0 +1,337 @@
+# mypy: allow-untyped-defs
+
+import time
+
+import torch
+import torch.distributed.rpc as rpc
+from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    wait_until_owners_and_forks_on_rank,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def my_sleep_func(seconds=1):
+    time.sleep(seconds)
+    return torch.mul(torch.tensor(1), torch.tensor(1))
+
+
+@torch.jit.script
+def my_script_func(tensor):
+    return torch.add(tensor, tensor)
+
+
+def add_rref_to_value(rref, value):
+    return rref.to_here() + value
+
+
+class FaultyAgentRpcTest(RpcAgentTestFixture):
+    # no faulty_messages defined so this fails all retryable messages - see
+    # faulty_rpc_agent_test_fixture.py for the list of retryable messages.
+    @dist_init(messages_to_delay={})
+    def test_check_failed_messages(self):
+        if self.rank == 0:
+            dst_worker_b = worker_name((self.rank + 1) % self.world_size)
+            dst_worker_c = worker_name((self.rank + 2) % self.world_size)
+
+            # Worker0 sends RPC to Worker1 and creates an RRef there
+            rref = rpc.remote(
+                dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2))
+            )
+            # Worker0 sends an RPC to Worker2 with the RRef as an arg
+            rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2)))
+            # check if the output is as expected
+            self.assertEqual(
+                rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2))
+            )
+        # explicitly delete all User RRefs
+        _delete_all_user_and_unforked_owner_rrefs()
+
+    @dist_init
+    def test_verify_backend_options(self):
+        self.assertEqual(
+            self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
+        )
+        self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
+        self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
+        self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2)
+        self.assertEqual(
+            self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC
+        )
+
+    @dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"])
+    def test_custom_faulty_messages(self):
+        self.assertEqual(
+            {"RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"},
+            set(self.rpc_backend_options.messages_to_fail),
+        )
+
+    @dist_init(faulty_messages=[])
+    def test_no_faulty_messages(self):
+        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 0)
+
+    @dist_init(messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_custom_messages_to_delay(self):
+        self.assertEqual(
+            self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5}
+        )
+
+    def _test_remote_message_dropped_pickle(self, dst=None):
+        if self.rank != 0:
+            return
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # Since we fail python_remote_call messages synchronously, the future
+        # corresponding to this remote call will be marked with an error when
+        # this function returns.
+        rref = rpc.remote(dst_worker, my_sleep_func, args=(1,))
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref._serialize()
+        # Test that using RRef as arg over RPC (which forks) results in the same
+        # error
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1))
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_remote_message_dropped_pickle(self):
+        self._test_remote_message_dropped_pickle()
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_remote_message_dropped_pickle_to_self(self):
+        self._test_remote_message_dropped_pickle(self.rank)
+
+    def _test_remote_message_dropped_timeout(self, func, args, dst=None):
+        if self.rank != 0:
+            return
+
+        # test the case where rpc.remote() message creation is completely dropped.
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # Since we fail python_remote_call messages synchronously, the future
+        # corresponding to this remote call will be marked with an error when
+        # this function returns.
+        rref = rpc.remote(dst_worker, func, args=args)
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+        # Note: during shutdown, logs will indicate "Could not find OwnerRRef..."
+        # on the owning nodes, this is expected because the OwnerRRef was never
+        # successfully created. Therefore, delAllUsers will work as expected.
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_builtin_remote_message_dropped_timeout(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_dropped_timeout(func, args)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_builtin_remote_message_dropped_timeout_to_self(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_dropped_timeout(func, args, dst=0)
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_udf_remote_message_dropped_timeout(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_dropped_timeout(func, args)
+
+    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
+    def test_udf_remote_message_dropped_timeout_to_self(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_dropped_timeout(func, args, dst=0)
+
+    def _test_remote_message_delay_timeout(self, func, args, dst=None):
+        if self.rank != 0:
+            return
+        # Test the case where remote message is eventually processed on the owner,
+        # but the future on the creator times out before the response comes back.
+        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # 10 ms timeout
+        rref = rpc.remote(dst_worker, func, args=args, timeout=0.001)
+        # Future corresponding to the remote creation should time out.
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref._get_future().wait()
+
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # to_here() should now pick up that rpc.remote() creation has failed.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+
+        # Test the case where rpc.remote() times out, but to_here() has already
+        # started blocking before.
+        # NOTE: we only test this when not sending to self, as to_here() calls
+        # calls localValue(), which does not send an RPC and thus does not have
+        # a timeout. This can be supported by allowing future.wait() to
+        # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280)
+        if dst_rank != self.rank:
+            slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2)
+
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                # to_here() should raise timeout error, since it does not know about the
+                # status of rpc.remote().
+                slow_rref.to_here(0.001)
+        # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete
+        # but this can be a noop since it may not exist on the owner yet. Later,
+        # the owner can process the RRef creation and wait for the delete message,
+        # thus leading to a timeout.
+        # Therefore, we wait until we get notification that pending owners have
+        # been confirmed before sending out RRefUserDeletes.
+        if dst_rank != self.rank:
+            wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2})
+    def test_udf_remote_message_delay_timeout(self):
+        func = my_sleep_func
+        args = (2,)
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2})
+    def test_udf_remote_message_delay_timeout_to_self(self):
+        func = my_sleep_func
+        args = (1,)
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_builtin_delay_timeout(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_builtin_delay_timeout_to_self(self):
+        func = torch.add
+        args = (torch.tensor(1), torch.tensor(1))
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_script_delay_timeout(self):
+        func = my_script_func
+        args = (torch.tensor(1),)
+        self._test_remote_message_delay_timeout(func, args)
+
+    @dist_init(
+        faulty_messages=[],
+        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
+    )
+    def test_remote_message_script_delay_timeout_to_self(self):
+        func = my_script_func
+        args = (torch.tensor(1),)
+        self._test_remote_message_delay_timeout(func, args, dst=0)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
+    def test_rref_to_here_timeout(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref.to_here(0.01)
+
+        rref.to_here()
+
+    @dist_init(faulty_messages=[])
+    def test_rpc_builtin_timeout(self):
+        next_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(next_rank)
+        expected_error = self.get_timeout_error_regex()
+        # PYTHON_CALL message types which correspond to Python UDF over RPC
+        # by default get a delay (see faulty_rpc_agent_test_fixture)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(
+                dst_worker,
+                torch.add,
+                args=(torch.tensor(1), torch.tensor(1)),
+                timeout=1,
+            )
+
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure that the currently set default timeout is large enough such
+        # that RPCs with delays still complete.
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        fut.wait()
+
+        # Ensure timeout if we set a new default and don't override
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if we specify timeout of 0
+        fut = rpc.rpc_async(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0
+        )
+        fut.wait()
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_rpc_script_timeout(self):
+        next_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(next_rank)
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
+
+        fut = rpc.rpc_async(
+            dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1
+        )
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure that the currently set default timeout is large enough such
+        # that RPCs with delays still complete.
+        fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
+        fut.wait()
+
+        # Ensure timeout if we set a new default and don't override
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if we specify timeout of 0
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(
+            dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0
+        )
+        fut.wait()
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff7d556d10621e7290c07ecb433b865d7133bb2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py
@@ -0,0 +1,64 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed.rpc as rpc
+import torch.distributed.rpc._testing  # noqa: F401
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+# The following message types are currently retried in the RREF protocol and
+# distributed autograd. Thus only these messages should be tested with the
+# Faulty RPC Agent.
+retryable_message_types = [
+    "RREF_FORK_REQUEST",
+    "RREF_CHILD_ACCEPT",
+    "RREF_USER_DELETE",
+    "CLEANUP_AUTOGRAD_CONTEXT_REQ",
+]
+
+# The following messages incur the corresponding delay in seconds while being
+# processed in FaultyTensorPipeAgent's enqueueSend() function.
+default_messages_to_delay = {
+    "PYTHON_CALL": 1.5,  # Python UDF
+    "SCRIPT_CALL": 1.5,  # Script/Builtin
+}
+
+
+class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.messages_to_fail = retryable_message_types
+        self.messages_to_delay = default_messages_to_delay
+
+    @property
+    def rpc_backend(self):
+        return rpc.backend_registry.BackendType["FAULTY_TENSORPIPE"]
+
+    @property
+    def rpc_backend_options(self):
+        return rpc.backend_registry.construct_rpc_backend_options(
+            self.rpc_backend,
+            init_method=self.init_method,
+            num_worker_threads=8,
+            num_fail_sends=3,
+            messages_to_fail=self.messages_to_fail,
+            messages_to_delay=self.messages_to_delay,
+        )
+
+    def setup_fault_injection(self, faulty_messages, messages_to_delay):
+        if faulty_messages is not None:
+            self.messages_to_fail = faulty_messages
+        if messages_to_delay is not None:
+            self.messages_to_delay = messages_to_delay
+
+    def get_shutdown_error_regex(self):
+        error_regexes = [
+            "Exception in thread pool task",
+            "Connection reset by peer",
+            "Connection closed by peer",
+        ]
+        return "|".join([f"({error_str})" for error_str in error_regexes])
+
+    def get_timeout_error_regex(self):
+        return "RPC ran for more than"
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde1fe2355c2968e1b351b288d20c674835b0ca2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py
@@ -0,0 +1,113 @@
+# mypy: allow-untyped-defs
+
+
+import torch
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.distributed.rpc import rpc_async
+from torch.testing import FileCheck
+from torch.testing._internal.dist_utils import dist_init, worker_name
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+@torch.jit.script
+def local_add(t1, t2):
+    return torch.add(t1, t2)
+
+
+@torch.jit.script
+def remote_add(t1, t2, dst: str):  # noqa: E999
+    return rpc_async(dst, local_add, (t1, t2)).wait()
+
+
+@torch.jit.script
+def fork_add(t1, t2, dst: str):
+    fut = torch.jit._fork(remote_add, t1, t2, dst)
+    return torch.jit._wait(fut)
+
+
+class JitDistAutogradTest(RpcAgentTestFixture):
+    @dist_init
+    def test_get_gradients(self):
+        @torch.jit.script
+        def dist_get_gradients(context_id: int) -> dict[Tensor, Tensor]:
+            return dist_autograd.get_gradients(context_id)
+
+        FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            t3 = torch.add(t1, t2)
+
+            dist_autograd.backward(context_id, [t3.sum()])
+            grads = dist_get_gradients(context_id)
+
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(torch.ones(3, 3), grads[t1])
+            self.assertEqual(torch.ones(3, 3), grads[t2])
+
+    @dist_init
+    def test_dist_backward(self):
+        if self.rank != 0:
+            return
+
+        @torch.jit.script
+        def dist_backward_script(context_id: int, loss: torch.Tensor):
+            dist_autograd.backward(context_id, [loss])
+
+        FileCheck().check("dist_backward").run(str(dist_backward_script.graph))
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(3, 3, requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum()
+            dist_backward_script(context_id, loss)
+
+    @dist_init
+    def test_jit_fork_within_context(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            res = fork_add(t1, t2, dst_worker_name)
+            loss = res.sum()
+            dist_autograd.backward(context_id, [loss])
+
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+
+    @dist_init
+    def test_restore_context_after_swtich_to_jit_thread(self):
+        if self.rank != 0:
+            return
+
+        @torch.jit.script
+        def forward_script(
+            context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor
+        ) -> tuple[Tensor, Tensor]:
+            res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1))
+            res1 = res1_fut.wait()  # After this, the script runs in a new JIT thread.
+            loss1 = res1.sum()
+
+            # SendRpcBackward is not attached, since DistAutogradContext is lost here.
+            res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2))
+            res2 = res2_fut.wait()
+            loss2 = res2.sum()
+
+            return loss1, loss2
+
+        with dist_autograd.context() as context_id:
+            t1 = torch.ones((2, 3), requires_grad=True)
+            t2 = torch.ones((2, 3), requires_grad=True)
+            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+            loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2)
+            dist_autograd.backward(context_id, [loss0, loss1])
+            grad0, grad1 = dist_autograd.get_gradients(context_id)
+            self.assertEqual(grad0, grad1)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec2f2b9499074c4d6bc646e82fda06cc7d07b7f3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
@@ -0,0 +1,1383 @@
+# mypy: allow-untyped-defs
+
+import io
+import time
+from typing import Any
+
+import torch
+import torch.distributed as dist
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.autograd.profiler import record_function
+from torch.autograd.profiler_legacy import profile as _profile
+from torch.distributed.rpc import RRef
+from torch.distributed.rpc.internal import _build_rpc_profiling_key, RPCExecMode
+from torch.futures import Future
+from torch.testing._internal.common_utils import TemporaryFileName
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    get_function_event,
+    initialize_pg,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def rref_isinstance(rref, cls_to_check):
+    return isinstance(rref.local_value(), cls_to_check)
+
+
+def sleep(t):
+    time.sleep(t)
+
+
+def rpc_return_rref(dst):
+    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+@torch.jit.script
+def rref_local_value(rref: RRef[Tensor]) -> Tensor:
+    return rref.local_value()
+
+
+@torch.jit.script
+def list_create() -> list[int]:
+    global_list = [1, 2, 3]
+    return global_list
+
+
+@torch.jit.script
+def rref_list_mutate(rref: RRef[list[int]]) -> None:
+    rref.local_value().append(4)
+    rref.to_here().append(5)
+    rref.to_here(5.0).append(6)
+
+
+def return_value(value: int) -> int:
+    return value
+
+
+class RRefAPITest:
+    @dist_init
+    def test_rref_is_owner(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        rref_var = rpc_return_rref(dst_worker_name)
+
+        @torch.jit.script
+        def rref_tensor_is_owner(rref_var: RRef[Tensor]) -> bool:
+            return rref_var.is_owner()
+
+        res = rref_tensor_is_owner(rref_var)
+        self.assertEqual(res, False)
+
+    @dist_init
+    def test_rref_local_value(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc_return_rref(dst_worker_name)
+
+        with self.assertRaisesRegex(
+            RuntimeError, r"Can't call RRef.local_value\(\) on a non-owner RRef"
+        ):
+            rref_local_value(rref)
+
+        ret = ret = rpc.rpc_sync(dst_worker_name, rref_local_value, (rref,))
+        self.assertEqual(ret, torch.add(torch.ones(2, 2), 1))
+
+    @dist_init
+    def test_local_rref_local_value(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name(self.rank)
+        rref = rpc.remote(dst_worker_name, return_value, (5,), {})
+
+        ret = rref_local_value(rref)
+        self.assertEqual(ret, 5)
+
+    def _create_rref(self):
+        owner_rank = (self.rank + 2) % self.world_size
+        return rpc.remote(
+            worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
+        )
+
+    @dist_init
+    def test_user_rrefs_confirmed(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
+        )
+        self.assertEqual(ret, True)
+
+    @dist_init
+    def test_user_rrefs_confirmed_remote(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret_rref = rpc.remote(
+            worker_name(dst_rank), script_check_rref_confirmed, args=(rref,)
+        )
+        self.assertEqual(ret_rref.to_here(), True)
+
+    @dist_init
+    def test_rref_list_mutate(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        list_rref = rpc.remote(dst, list_create)
+
+        rpc.rpc_sync(dst, rref_list_mutate, args=(list_rref,))
+        self.assertEqual(list_rref.to_here(), [1, 2, 3, 4, 5, 6])
+
+
+@torch.jit.script
+def no_arg():
+    return 0
+
+
+@torch.jit.script
+def one_arg(value):
+    return value + 1
+
+
+@torch.jit.script
+def script_add_ones(x):
+    return torch.add(x, torch.ones(1))
+
+
+@torch.jit.script
+def script_add_ones_with_record_function(x, block: str):
+    with record_function(block):
+        return torch.add(x, torch.ones(1))
+
+
+@torch.jit.script
+def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
+    t: Tensor = torch.ones(1)
+    with record_function(block):
+        fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
+        # Extra operator call to avoid de-duplication of the next async call
+        # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279
+        zero = torch.zeros_like(t)
+        fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
+        res = fut1.wait() + fut2.wait() + zero
+    return res
+
+
+@torch.jit.script
+def script_fork_wait_udf(tensor):
+    fut = torch.jit._fork(script_add_ones, tensor)
+    x = torch.jit._wait(fut)
+    return x
+
+
+@torch.jit.script
+def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_var.to_here()
+
+
+@torch.jit.script
+def return_rref(rref_var: RRef[Tensor]) -> RRef[Tensor]:
+    return rref_var
+
+
+@torch.jit.script
+def script_raise_func(value):
+    if value.numel() == 2:
+        raise ValueError("Expected error")
+    return value + 1
+
+
+@torch.jit.script
+def script_fork_wait_throw(invalue):
+    fut = torch.jit._fork(script_raise_func, invalue)
+    value = torch.jit._wait(fut)
+    return value
+
+
+@torch.jit.script
+def call_rpc_with_profiling(
+    record: torch.classes.profiler._RecordFunction, dst_worker_name: str
+) -> Tensor:
+    # Call rpc_async from within ScriptFunction and ensure that we can attach
+    # profiling callbacks. Note that handle here is a Tensor representation of
+    # RecordFunction.
+    fut = rpc.rpc_async(dst_worker_name, one_arg, (torch.tensor(1),))
+    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def call_rpc_torchscript_with_record_function(
+    dst_worker_name: str, block: str
+) -> Tensor:
+    fut = rpc.rpc_async(
+        dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block)
+    )
+    return fut.wait()
+
+
+@torch.jit.script
+def call_fork_with_profiling(record: torch.classes.profiler._RecordFunction) -> Tensor:
+    # Call fork from within ScriptFunction and ensure that we can attach profiling
+    # callbacks to the resulting future. Note that handle here is a Tensor
+    # representation of RecordFunction.
+    fut = torch.jit._fork(one_arg, torch.tensor(1))
+    torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
+    ret = fut.wait()
+    return ret
+
+
+class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
+    def __init__(self, dst_worker):
+        super().__init__()
+        self.rrefs = []
+        for _ in range(4):
+            self.rrefs.append(rpc_return_rref(dst_worker))
+
+    @torch.jit.script_method
+    def forward(self) -> Tensor:
+        res_tensor = torch.ones(2, 2)
+        for rref in self.rrefs:
+            res_tensor += rref.to_here()
+
+        return res_tensor
+
+
+@torch.jit.ignore
+def rref_python_annotation(rref_var: RRef[Tensor]) -> RRef[Tensor]:
+    return rref_var
+
+
+@torch.jit.script
+def rref_script_annotation(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_python_annotation(rref_var).to_here()
+
+
+class RRefTypingTest:
+    @dist_init
+    def test_rref_as_arg_and_return(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        local_ret = one_arg(torch.ones(2, 2))
+
+        # create rref on current rank
+        rref = rpc.remote(worker_name(self.rank), one_arg, args=(torch.ones(2, 2),))
+
+        # pass rref to another user in rpc call
+        ret = rpc.rpc_sync(worker_name(dst_rank), rref_to_here, args=(rref,))
+        self.assertEqual(ret, local_ret)
+
+        # return rref in rpc call
+        rref1 = rpc.rpc_sync(worker_name(dst_rank), return_rref, args=(rref,))
+        self.assertEqual(rref1.to_here(), local_ret)
+
+        # pass rref to another user in remote call
+        rref2 = rpc.remote(worker_name(dst_rank), rref_to_here, args=(rref,))
+        self.assertEqual(rref2.to_here(), local_ret)
+
+        # return rref in remote call
+        rref3 = rpc.remote(worker_name(dst_rank), return_rref, args=(rref,))
+        self.assertEqual(rref3.to_here().to_here(), local_ret)
+
+    @dist_init
+    def test_my_script_module_with_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        module_with_rrefs = MyScriptModuleWithRRefs(worker_name(dst_rank))
+        res = module_with_rrefs()
+        self.assertEqual(res, torch.ones(2, 2) * 9)
+
+    @dist_init
+    def test_rref_python_annotation(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_var = rpc_return_rref(worker_name(dst_rank))
+
+        res = rref_script_annotation(rref_var)
+        self.assertEqual(res, torch.ones(2, 2) + 1)
+
+
+class FutureTypingTest:
+    @dist_init
+    def test_future_passed_between_python_and_jit(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        inputs = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        ret_fut = rpc.rpc_async(worker_name(dst_rank), two_args_two_kwargs, args=inputs)
+        expected_res = torch.tensor([10, 10])
+
+        @torch.jit.script
+        def future_wait_in_script(fut: Future[Tensor]) -> Tensor:
+            return fut.wait()
+
+        self.assertEqual(future_wait_in_script(ret_fut), expected_res)
+
+        @torch.jit.script
+        def future_return_to_python(
+            dst_rank: int, inputs: tuple[Tensor, Tensor]
+        ) -> Future[Tensor]:
+            return rpc.rpc_async(f"worker{dst_rank}", two_args_two_kwargs, inputs)
+
+        fut_res = future_return_to_python(dst_rank, inputs)
+        self.assertEqual(fut_res.wait(), expected_res)
+
+    @dist_init
+    def test_future_python_annotation(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        input_0 = torch.ones(2, 2)
+        input_1 = 1
+        expected_res = torch.add(input_0, input_1)
+
+        @torch.jit.ignore
+        def python_return_future() -> Future[Tensor]:
+            fut = rpc.rpc_async(dst_worker_name, torch.add, (input_0, input_1), {})
+            return fut
+
+        @torch.jit.script
+        def script_use_future() -> Tensor:
+            fut = python_return_future()
+            return fut.wait()
+
+        res = script_use_future()
+        self.assertEqual(res, expected_res)
+
+
+@torch.jit.script
+class MyScriptClass:
+    def __init__(self, a: int):
+        self.a = a
+
+    def get_value(self) -> int:
+        return self.a
+
+
+@torch.jit.interface
+class MyModuleInterface(torch.nn.Module):
+    def forward(self) -> Tensor:
+        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
+        pass
+
+
+class MyScriptModule(torch.jit.ScriptModule):
+    def __init__(self, rank):
+        super().__init__()
+        self.a = torch.ones(rank)
+
+    @torch.jit.script_method
+    def forward(self) -> Tensor:
+        return self.a
+
+    @torch.jit.script_method
+    def custom_func(self) -> Tensor:
+        return self.a
+
+
+def owner_create_rref_my_script_class(a):
+    return rpc.RRef(MyScriptClass(a))
+
+
+def owner_create_rref_my_script_module(a):
+    return rpc.RRef(MyScriptModule(a), type_hint=MyModuleInterface)
+
+
+@torch.jit.script
+def script_rref_get_value_my_script_class(rref: RRef[MyScriptClass]) -> int:
+    return rref.to_here().get_value()
+
+
+@torch.jit.script
+def script_rref_run_forward_my_script_module(rref: RRef[MyModuleInterface]) -> Tensor:
+    return rref.to_here().forward()
+
+
+class LocalRRefTest:
+    @dist_init
+    def test_create_local_script_class_rref_in_py(self):
+        if self.rank != 0:
+            return
+
+        # Create a local RRef.
+        rref_script_class = rpc.RRef(MyScriptClass(self.rank))
+        ret = rref_script_class.to_here().get_value()
+        self.assertEqual(ret, self.rank)
+
+    @dist_init
+    def test_create_local_script_module_rref_in_py(self):
+        if self.rank != 0:
+            return
+
+        # Create a local RRef.
+        rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
+        ret = rref_script_module.to_here().forward()
+        self.assertEqual(ret, torch.ones(self.rank))
+
+        # Create a local RRef without type hint.
+        with self.assertRaisesRegex(
+            RuntimeError,
+            (
+                "The RRef being created contains a ScriptModule, "
+                "must provide its ModuleInterface type hint."
+            ),
+        ):
+            rref_script_module = rpc.RRef(MyScriptModule(self.rank))
+
+    @dist_init
+    def test_return_local_script_class_rref_in_py_and_use_in_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Create a local RRef remotely in Python.
+        rref = rpc.rpc_sync(
+            dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,)
+        )
+
+        def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int:
+            args = (rref,)
+            kwargs: dict[str, Any] = {}
+            fut = rpc.rpc_async(
+                rref.owner(), script_rref_get_value_my_script_class, args, kwargs
+            )
+            ret = fut.wait()
+            return ret
+
+        # Use RRef in local Python RPC and remote Script run.
+        ret = use_rref_on_owner(rref)
+        self.assertEqual(ret, self.rank)
+
+        # Use RRef in local Script RPC and remote Script run.
+        use_rref_on_owner_script = torch.jit.script(use_rref_on_owner)
+        ret = use_rref_on_owner_script(rref)
+        self.assertEqual(ret, self.rank)
+
+    @dist_init
+    def test_return_local_script_module_rref_in_py_and_use_in_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Create a local RRef remotely in Python.
+        rref = rpc.rpc_sync(
+            dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,)
+        )
+
+        def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor:
+            args = (rref,)
+            kwargs: dict[str, Any] = {}
+            fut = rpc.rpc_async(
+                rref.owner_name(),
+                script_rref_run_forward_my_script_module,
+                args,
+                kwargs,
+            )
+            ret = fut.wait()
+            return ret
+
+        # Use RRef in local Python RPC and remote Script run.
+        ret = use_rref_on_owner(rref)
+        self.assertEqual(ret, torch.ones(self.rank))
+
+        # Use RRef in local Script RPC and remote Script run.
+        use_rref_on_owner_script = torch.jit.script(use_rref_on_owner)
+        ret = use_rref_on_owner_script(rref)
+        self.assertEqual(ret, torch.ones(self.rank))
+
+
+def python_function():
+    return 0
+
+
+@torch.jit.script
+def two_args_two_kwargs(
+    first_arg,
+    second_arg,
+    first_kwarg=torch.tensor([3, 3]),
+    second_kwarg=torch.tensor([4, 4]),
+):
+    return first_arg + second_arg + first_kwarg + second_kwarg
+
+
+@torch.jit.script
+def assorted_types_args_kwargs(
+    tensor_arg: Tensor,  # noqa: E999
+    str_arg: str,
+    int_arg: int,
+    tensor_kwarg: Tensor = torch.tensor([2, 2]),
+    str_kwarg: str = "str_kwarg",
+    int_kwarg: int = 2,
+):
+    return tensor_arg + tensor_kwarg, str_arg + str_kwarg, int_arg + int_kwarg
+
+
+@torch.jit.script
+def raise_script():
+    raise RuntimeError("Expected error")
+
+
+@torch.jit.script
+def script_rpc_async_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def script_rpc_sync_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return res
+
+
+@torch.jit.script
+def script_rpc_remote_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return rref_res.to_here()
+
+
+class JitRpcOpTest:
+    # Call functions remotely from Script.
+    @dist_init
+    def test_all_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {}
+
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([10, 10]))
+
+    @dist_init
+    def test_some_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {"first_kwarg": torch.tensor([2, 2])}
+
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([9, 9]))
+
+    @dist_init
+    def test_no_kwargs_are_populated_by_defaults(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        for script_op in [
+            script_rpc_async_call,
+            script_rpc_sync_call,
+            script_rpc_remote_call,
+        ]:
+            ret = script_op(dst_worker_name, args, kwargs)
+            self.assertEqual(ret, torch.tensor([8, 8]))
+
+    @dist_init
+    def test_args_and_kwargs_contain_different_types(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_with_assorted_types(
+            dst_worker_name: str,
+        ):
+            args = (torch.tensor([1, 1]), "str_arg", 1)
+            # Must annotate the value type as `Any`, because JIT type inference
+            # does not support multiple types when defining a Dict.
+            # The error JIT gives is,
+            # "Dict values must contain only a single type, "
+            # "expected: Tensor but found str instead."
+            kwargs: dict[str, Any] = {
+                "tensor_kwarg": torch.tensor([3, 3]),
+                "str_kwarg": "_str_kwarg",
+                "int_kwarg": 3,
+            }
+            fut = rpc.rpc_async(
+                dst_worker_name, assorted_types_args_kwargs, args, kwargs
+            )
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_with_assorted_types(dst_worker_name)
+        self.assertEqual(ret, (torch.tensor([4, 4]), "str_arg_str_kwarg", 4))
+
+    @dist_init
+    def test_kwargs_not_passed(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_without_kwargs_passed(
+            dst_worker_name: str,
+        ):
+            args = ()
+            fut = rpc.rpc_async(dst_worker_name, no_arg, args)
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_without_kwargs_passed(dst_worker_name)
+        self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_args_kwargs_are_neither_passed(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def script_rpc_async_call_without_args_kwargs_passed(
+            dst_worker_name: str,
+        ):
+            fut = rpc.rpc_async(dst_worker_name, no_arg)
+            ret = fut.wait()
+            return ret
+
+        ret = script_rpc_async_call_without_args_kwargs_passed(dst_worker_name)
+        self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_less_than_needed_args_are_specified(self):
+        if self.rank != 0:
+            return
+
+        # Notice, args matching happens during scripting.
+        with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"):
+
+            @torch.jit.script
+            def script_rpc_async_call_with_less_args(
+                dst_worker_name: str,  # noqa: E999
+            ):
+                args = (torch.tensor([1, 1]),)
+                kwargs = {}
+                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+                ret = fut.wait()
+                return ret
+
+    @dist_init
+    def test_more_than_needed_args_are_specified(self):
+        if self.rank != 0:
+            return
+
+        # Notice, args matching happens during scripting.
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Expected at most 4 arguments but found 5 positional arguments",
+        ):
+
+            @torch.jit.script
+            def script_rpc_async_call_with_more_args(
+                dst_worker_name: str,
+            ):
+                args = (
+                    torch.tensor([1, 1]),
+                    torch.tensor([2, 2]),
+                    torch.tensor([3, 3]),
+                    torch.tensor([4, 4]),
+                    torch.tensor([5, 5]),
+                )
+                kwargs = {}
+                fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+                ret = fut.wait()
+                return ret
+
+    @dist_init
+    def test_unexepected_kwarg_is_specified(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Notice, kwargs matching happens during execution.
+        @torch.jit.script
+        def script_rpc_async_call_with_unexpected_kwarg(
+            dst_worker_name: str,  # noqa: E999
+        ):
+            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+            kwargs = {"third_kwarg": torch.tensor([1, 1])}
+            fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Unknown keyword argument 'third_kwarg'"
+        ):
+            ret = script_rpc_async_call_with_unexpected_kwarg(dst_worker_name)
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_python_function_remotely_from_script_not_supported(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def rpc_async_call_remote_py_function_in_torchscript(dst_worker_name: str):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, python_function, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "attempted to get undefined function"
+        ):
+            ret = rpc_async_call_remote_py_function_in_torchscript(dst_worker_name)
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_script_function_that_raises_remotely_from_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        # Notice, TorchScript always translates(emits) Python `raise` statement,
+        # as the exception message string, "Exception",
+        # no matter what exception type and exception message are in the statement,
+        @torch.jit.script
+        def rpc_async_call_remote_raising_torchscript_in_torchscript(
+            dst_worker_name: str,
+        ):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, raise_script, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(RuntimeError, "Expected error"):
+            ret = rpc_async_call_remote_raising_torchscript_in_torchscript(
+                dst_worker_name
+            )
+            self.assertEqual(ret, 0)
+
+    @dist_init
+    def test_call_script_function_that_not_exists_remotely_from_script(self):
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        @torch.jit.script
+        def nonexisting_script():
+            return 0
+
+        @torch.jit.script
+        def rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
+            dst_worker_name: str,
+        ):
+            args = ()
+            kwargs = {}
+            fut = rpc.rpc_async(dst_worker_name, nonexisting_script, args, kwargs)
+            ret = fut.wait()
+            return ret
+
+        with self.assertRaisesRegex(
+            RuntimeError, "attempted to get undefined function nonexisting_script"
+        ):
+            ret = rpc_async_call_remote_nonexisting_torchscript_in_torchscript(
+                dst_worker_name
+            )
+            self.assertEqual(ret, 0)
+
+
+@torch.jit.ignore
+def my_script_module_init(rank: int) -> MyModuleInterface:
+    return MyScriptModule(rank)
+
+
+@torch.jit.script
+def construct_my_script_module(rank: int) -> MyModuleInterface:
+    return my_script_module_init(rank)
+
+
+@torch.jit.script
+def run_ref_script_module(
+    ref_script_module: RRef[MyModuleInterface], t: Tensor
+) -> Tensor:
+    module = ref_script_module.to_here()
+    return module.forward() + t
+
+
+@torch.jit.script
+def script_check_rref_confirmed(rref: RRef[Tensor]) -> bool:
+    return rref.confirmed_by_owner()
+
+
+@torch.jit.script
+def save_rref(rref_var: RRef[Tensor], fname: str) -> None:
+    torch.save(rref_var, fname)
+
+
+@torch.jit.script
+def script_add(x: Tensor, y: Tensor) -> Tensor:
+    return x + y
+
+
+@rpc.functions.async_execution
+@torch.jit.script
+def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
+    return rpc.rpc_async(to, script_add, (x, y))
+
+
+@rpc.functions.async_execution
+@torch.jit.script
+def async_wrong_type() -> Tensor:
+    return torch.zeros(2)
+
+
+def load_script_module_with_pickled_rref(pickled_script_module):
+    f = io.BytesIO(pickled_script_module)
+    m = torch.jit.load(f)
+    return m()
+
+
+class JitRpcTest(
+    RRefAPITest,
+    RRefTypingTest,
+    LocalRRefTest,
+    JitRpcOpTest,
+    FutureTypingTest,
+    RpcAgentTestFixture,
+):
+    @dist_init
+    def test_torchscript_function(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        local_ret = one_arg(torch.ones(2, 2))
+        ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
+        self.assertEqual(ret, local_ret)
+        rref = rpc.remote(dst_worker_name, one_arg, args=(torch.ones(2, 2),))
+        self.assertEqual(rref.to_here(), local_ret)
+        # create rref to itself
+        local_rref = rpc.remote(
+            worker_name(self.rank), one_arg, args=(torch.ones(2, 2),)
+        )
+        self.assertEqual(local_rref.to_here(), local_ret)
+
+    @dist_init
+    def test_torchscript_function_exception(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"):
+            rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20))
+
+        with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"):
+            rpc.remote(dst_worker_name, one_arg, args=(10, 20))
+
+    @dist_init
+    def test_torchscript_functions_not_supported(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        my_local_script_module = MyScriptModule(self.rank)
+
+        # It is not thread safe to instantiate MyScriptModule in multiple threads,
+        # wait for local MyScriptModule instantiation to finish,
+        # otherwise it could instantiate MyScriptModule in parallel with
+        # server thread in the below
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        # rpc_sync still accepts script class and run it in
+        # the same code path as python call.
+        rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,))
+
+        # rpc_sync does not accept script module method.
+        # Python 3.5 and Python 3.6 throw different error message, the only
+        # common word can be greped is "pickle".
+        with self.assertRaisesRegex(TypeError, "pickle"):
+            rpc.rpc_async(dst_worker_name, my_local_script_module.forward, args=())
+
+    @dist_init
+    def test_remote_script_module(self):
+        # TODO, need more investigation
+        # there is rref leak when shutting down, suspect it is because
+        # ref as arg is passed to pybind boundary, and the ref is not garbage
+        # collected by python when calling shutdown()
+        import torch.distributed.rpc.api as api
+
+        api._ignore_rref_leak = True
+
+        local_ret = torch.ones(self.rank) + torch.ones(self.rank)
+
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        remote_ref = rpc.remote(
+            worker_name(dst_rank), construct_my_script_module, args=(self.rank,)
+        )
+
+        # pass rref arg to owner
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            run_ref_script_module,
+            args=(remote_ref, torch.ones(self.rank)),
+        )
+        self.assertEqual(ret, local_ret)
+
+        # pass rref arg to self/user
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "is an RRef to a ScriptModule. It can't be sent through RPC from owner,",
+        ):
+            ret = rpc.rpc_sync(
+                worker_name(self.rank),
+                run_ref_script_module,
+                args=(remote_ref, torch.ones(self.rank)),
+            )
+
+    @dist_init
+    def test_create_script_module_on_remote(self):
+        dst_name = worker_name((self.rank + 1) % self.world_size)
+        # Construct on remote end with rpc_sync
+        created_script_module = rpc.rpc_sync(
+            dst_name, MyScriptModule, args=(self.rank,)
+        )
+        # Forward should output a ones tensor of self.rank.
+        self.assertTrue(isinstance(created_script_module, torch.jit.ScriptModule))
+        rank_ones_tensor = created_script_module()
+        self.assertEqual(torch.ones(self.rank), rank_ones_tensor)
+
+        # Construct ScriptModule with rpc.remote.
+        remote_script_module = rpc.remote(dst_name, MyScriptModule, args=(self.rank,))
+        # Verify it is an instance of ScriptModule on remote end.
+        remote_end_is_script = rpc.rpc_sync(
+            remote_script_module.owner(),
+            rref_isinstance,
+            args=(remote_script_module, torch.jit.ScriptModule),
+        )
+        self.assertTrue(remote_end_is_script)
+        # Run forward pass remotely.
+        remote_forward_output = remote_script_module.rpc_sync().forward()
+        self.assertEqual(remote_forward_output, torch.ones(self.rank))
+        # Run function defined on ScriptModule remotely.
+        remote_func_output = remote_script_module.rpc_sync().custom_func()
+        self.assertEqual(remote_func_output, torch.ones(self.rank))
+        # Ensure we can transfer ScriptModule RRef to this rank and run
+        # forward pass.
+        local_script_module = remote_script_module.to_here()
+        self.assertTrue(isinstance(local_script_module, torch.jit.ScriptModule))
+        rank_ones_tensor = local_script_module()
+        self.assertEqual(rank_ones_tensor, torch.ones(self.rank))
+        local_script_func_output = local_script_module.custom_func()
+        self.assertEqual(local_script_func_output, torch.ones(self.rank))
+
+    @dist_init
+    def test_load_script_module_with_pickled_rref(self):
+        dst_name = worker_name((self.rank + 1) % self.world_size)
+        m1 = MyScriptModuleWithRRefs(dst_name)
+        m2 = MyScriptModuleWithRRefs(dst_name)
+
+        f = io.BytesIO()
+
+        rpc._enable_jit_rref_pickle()
+        torch.jit.save(m1, f)
+        rpc._disable_jit_rref_pickle()
+
+        out1 = rpc.rpc_sync(
+            dst_name, load_script_module_with_pickled_rref, args=(f.getvalue(),)
+        )
+        out2 = m2()
+        self.assertEqual(out1, out2)
+
+    @dist_init
+    def test_rref_jit_pickle_not_supported(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_var = rpc_return_rref(worker_name(dst_rank))
+        with TemporaryFileName() as fname:
+            with self.assertRaisesRegex(
+                RuntimeError, "RRef jit pickling is only allowed inside RPC calls"
+            ):
+                save_rref(rref_var, fname)
+
+    @dist_init
+    def test_remote_script_throw(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            script_raise_func,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            rref.to_here()
+
+    @dist_init
+    def test_remote_script_udf(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2) * 2)
+
+    @dist_init
+    def test_async_script_udf(self):
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+        self.assertEqual(future.wait(), torch.ones(2) * 2)
+
+    @dist_init
+    def test_callback_simple(self):
+        def callback(fut):
+            return fut.wait() + 1
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        ).then(callback)
+        self.assertEqual(future.wait(), torch.ones(2) * 2 + 1)
+
+    @dist_init
+    def test_callback_chain(self):
+        n = self.rank + 1
+
+        def callback(fut):
+            return fut.wait() + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size), one_arg, args=(torch.ones(n, n),)
+        )
+
+        num_cbs = 20
+        for _ in range(num_cbs):
+            fut = fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
+
+    @dist_init
+    def test_add_done_callback(self):
+        callback_called = None
+
+        def callback(fut):
+            nonlocal callback_called
+            callback_called = fut.wait() * 2
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_udf,
+            args=(torch.ones(2),),
+        )
+
+        future.add_done_callback(callback)
+        future_then = future.then(lambda _: True)
+
+        self.assertEqual(future.wait(), torch.ones(2) * 2)
+
+        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
+        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
+        future_then.wait()
+        self.assertEqual(callback_called, torch.ones(2) * 4)
+
+    @dist_init
+    def test_async_script_throw(self):
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_throw,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            future.wait()
+
+    @dist_init
+    def test_callback_with_exception(self):
+        def callback(fut):
+            with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+                fut.wait()
+            raise RuntimeError("Another expected error")
+
+        future = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            script_fork_wait_throw,
+            args=(torch.ones(2),),
+        ).then(callback)
+
+        with self.assertRaisesRegex(RuntimeError, "Another expected error"):
+            future.wait()
+
+    @dist_init
+    def test_call_rpc_with_profiling(self):
+        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
+        # future from within a script function that calls rpc_async
+        if self.rank == 0:
+            with _profile() as prof:
+                prof_key = _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC,
+                    torch._jit_internal._qualified_name(one_arg),
+                    "worker0",
+                    "worker1",
+                )
+                with torch.autograd.profiler.record_function(prof_key) as rf:
+                    call_rpc_with_profiling(rf.record, "worker1")
+            # TODO: Can't get a reliable time for this profiling event since
+            # it's hard to estimate the execution time on the remote end for non-UDFs.
+            # This can be resolved by https://github.com/pytorch/pytorch/issues/36272.
+            # After that, this test should be modified to validate the function time.
+            events = prof.function_events
+            function_event = get_function_event(events, prof_key)
+            self.assertTrue(
+                torch._jit_internal._qualified_name(one_arg) in function_event.name
+            )
+
+    @dist_init
+    def test_rpc_async_jit_profiled(self):
+        # Tests that rpc_async calls made from within a TorchScript function are
+        # profiled.
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+            kwargs = {}
+            with _profile() as prof:
+                script_rpc_async_call(dst_worker_name, args, kwargs)
+
+            # Ensure rpc_async call is profiled
+            function_events = prof.function_events
+            qual_name = torch._jit_internal._qualified_name(two_args_two_kwargs)
+            rpc_async_jit_event = [
+                event
+                for event in function_events
+                if qual_name in event.name and event.node_id == self.rank
+            ]
+            self.assertEqual(len(rpc_async_jit_event), 1)
+            rpc_async_jit_event = rpc_async_jit_event[0]
+            profiled_name = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC_JIT,
+                qual_name,
+                worker_name(self.rank),
+                dst_worker_name,
+            )
+            self.assertEqual(profiled_name, rpc_async_jit_event.name)
+            remote_events = [event for event in function_events if event.is_remote]
+            # All remote events should have taken place on dst_rank
+            remote_event_node_ids = {
+                remote_event.node_id for remote_event in remote_events
+            }
+            self.assertEqual(remote_event_node_ids, {dst_rank})
+            # script_rpc_async_call invokes add operator
+            # so we should see this as a remote event.
+            remote_add = next(
+                remote_event
+                for remote_event in remote_events
+                if "aten::add" in remote_event.name
+            )
+            remote_add_profiled_name = f"{profiled_name}#remote_op: aten::add"
+            self.assertEqual(remote_add.name, remote_add_profiled_name)
+
+    @dist_init
+    def test_record_function_on_caller_rpc_async(self):
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            block_scope = "foo"
+            with _profile() as prof:
+                # Runs 2 rpc_async calls within JIT under record_function.
+                record_function_on_caller_rpc_async(dst_worker_name, block_scope)
+
+            # Ensure record_function event is profiled.
+            function_events = prof.function_events
+            record_function_scope_event = [
+                event for event in function_events if event.name == block_scope
+            ]
+            self.assertEqual(1, len(record_function_scope_event))
+            record_function_scope_event = record_function_scope_event[0]
+            # Ensure RPC future is profiled.
+            expected_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC_JIT,
+                torch._jit_internal._qualified_name(script_add_ones),
+                worker_name(self.rank),
+                dst_worker_name,
+            )
+            jit_rpc_events = [
+                event for event in function_events if event.name == expected_key
+            ]
+            self.assertEqual(2, len(jit_rpc_events))
+            # Validate that the record_function scope time is greater than both
+            # of the individual RPC async call times. The reason it is not necessarily
+            # greater than the sum is because the two can execute in parallel.
+            for jit_rpc_event in jit_rpc_events:
+                self.assertTrue(
+                    record_function_scope_event.cpu_time_total
+                    > jit_rpc_event.cpu_time_total
+                )
+
+    @dist_init
+    def test_rpc_torchscript_record_function(self):
+        # tests that torchscript functions can be profiled using with
+        # record_function(...) over RPC.
+        REMOTE_OP_STR = "#remote_op: "
+        if self.rank == 0:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker_name = worker_name(dst_rank)
+            block_scope = "foo"
+            with _profile() as prof:
+                call_rpc_torchscript_with_record_function(dst_worker_name, block_scope)
+
+            # Need to call below to populate CPU children.
+            prof.key_averages()
+            function_events = prof.function_events
+            expected_key = (
+                _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC_JIT,
+                    torch._jit_internal._qualified_name(
+                        script_add_ones_with_record_function
+                    ),
+                    worker_name(self.rank),
+                    dst_worker_name,
+                )
+                + REMOTE_OP_STR
+                + block_scope
+            )
+            remote_record_function_event = next(
+                evt for evt in function_events if evt.name == expected_key
+            )
+            self.assertTrue(block_scope in remote_record_function_event.name)
+            remote_children = remote_record_function_event.cpu_children
+            self.assertTrue("aten::add" in child.name for child in remote_children)
+
+    def test_record_function_jit_end_callbacks_with_fork(self):
+        # Ensures that we can call rf._call_end_callbacks_on_future on a jit
+        # future in python eager mode with torch.jit.fork
+        sleep_interval = 1
+        with _profile() as prof:
+            with torch.autograd.profiler.record_function("foo") as rf:
+                fut = torch.jit._fork(sleep, sleep_interval)
+                rf._call_end_callbacks_on_future(fut)
+            fut.wait()
+
+        function_events = prof.function_events
+        sleep_event = get_function_event(function_events, "foo")
+        self.assertEqual(sleep_event.name, "foo")
+        # Validate that callbacks were fired at the right time by checking the
+        # profiling event cpu time
+        self.assertGreaterAlmostEqual(sleep_event.cpu_time * 1e-6, sleep_interval)
+
+    def test_call_fork_in_jit_with_profiling(self):
+        # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit
+        # future from within a script function with torch.jit.fork
+        with _profile() as prof:
+            with torch.autograd.profiler.record_function("foo") as rf:
+                call_fork_with_profiling(rf.record)
+
+        events = prof.function_events
+        function_event = get_function_event(events, "foo")
+        self.assertEqual(function_event.name, "foo")
+
+    @dist_init
+    def test_async_function_simple(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(
+            dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2))
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_async_function_wrong_return_type(self):
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Async functions must return an IValue of Future type, but got Tensor",
+        ):
+            rpc.rpc_sync(
+                worker_name((self.rank + 1) % self.world_size), async_wrong_type
+            )
+
+    @dist_init
+    def test_async_function_wrong_decorator_order(self):
+        # @torch.jit.script complains about undefined value rpc. Error is shown
+        # below. The reason for not checking error string is to avoid making
+        # JIT error handling code depend on RPC tests, as we don't have any
+        # restrictions on the error message here.
+        #
+        # RuntimeError:
+        # undefined value rpc:
+        # def async_wrong_decorator_order(to, x, y):
+        #    # type: (str, Tensor, Tensor) -> Future[Tensor]
+        #    return rpc.rpc_async(to, script_add, (x, y))
+        #           ~~~ <--- HERE
+        with self.assertRaises(RuntimeError):
+
+            @torch.jit.script
+            @rpc.functions.async_execution
+            def async_wrong_decorator_order(
+                to: str, x: Tensor, y: Tensor
+            ) -> Future[Tensor]:
+                return rpc.rpc_async(to, script_add, (x, y))
+
+    @dist_init
+    def test_async_function_remote(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        rref = rpc.remote(
+            dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2))
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_async_function_remote_multi(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        num = 20
+        rrefs = [
+            rpc.remote(
+                dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i)
+            )
+            for i in range(num)
+        ]
+
+        for i in range(num):
+            self.assertEqual(rrefs[i].to_here(), torch.ones(2, 2) + i)
+
+    @dist_init
+    def test_async_function_wrong_return_type_remote(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size), async_wrong_type
+        )
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Async functions must return an IValue of Future type, but got Tensor",
+        ):
+            rref.to_here()
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bedaad32d0e904a9a7523f31eced9cef96e832d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py
@@ -0,0 +1,219 @@
+# mypy: allow-untyped-defs
+
+
+import torch
+import torch.distributed.rpc as rpc
+from torch import Tensor
+from torch.distributed.rpc import RRef
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+@torch.jit.script
+def two_args_two_kwargs(
+    first_arg,
+    second_arg,
+    first_kwarg=torch.tensor([3, 3]),
+    second_kwarg=torch.tensor([4, 4]),
+):
+    return first_arg + second_arg + first_kwarg + second_kwarg
+
+
+@torch.jit.script
+def script_rpc_async_call(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def rpc_async_call_with_timeout(
+    dst_worker_name: str,
+    args: tuple[Tensor, Tensor],
+    kwargs: dict[str, Tensor],
+    timeout: float,
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
+    ret = fut.wait()
+    return ret
+
+
+@torch.jit.script
+def rpc_async_call_with_timeout_future_ret(
+    dst_worker_name: str,
+    args: tuple[Tensor, Tensor],
+    kwargs: dict[str, Tensor],
+    timeout: float,
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
+    return fut
+
+
+@torch.jit.script
+def rpc_async_call_future_ret(
+    dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
+):
+    fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
+    return fut
+
+
+@torch.jit.script
+def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
+    return rref_var.to_here()
+
+
+@torch.jit.script
+def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor:
+    return rref_var.to_here(timeout)
+
+
+@torch.jit.script
+def rpc_async_with_rref_arg(dst_worker_name: str, args: tuple[RRef[Tensor]]) -> Tensor:
+    fut = rpc.rpc_async(dst_worker_name, rref_to_here, args)
+    ret = fut.wait()
+    return ret
+
+
+class JitFaultyAgentRpcTest(RpcAgentTestFixture):
+    """
+    Run tests for rpc_async in JIT under the faulty agent test fixture to test
+    arbitrary timeouts.
+    """
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_timeout_in_torchscript_function(self):
+        # Call rpc_async + fut.wait() in torchscript function and ensure that
+        # timeout is raised.
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        expected_error = self.get_timeout_error_regex()
+        # Ensure that we get a timeout if we override the default timeout and
+        # the RPC takes longer to execute.
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0.5)
+
+        # Ensure that we timeout if we don't specify a timeout but the default
+        # is less than the RPC takes to execute.
+        rpc._set_rpc_timeout(0.001)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            script_rpc_async_call(dst_worker_name, args, kwargs)
+
+        # Ensure that we run to completion if zero timeout is specified.
+        ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
+        self.assertEqual(ret, torch.tensor([8, 8]))
+        # reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
+    def test_timeout_in_python(self):
+        # Ensures timeouts are raised if we call rpc_async from within a
+        # torchscript function, but wait on the future in python.
+        if self.rank != 0:
+            return
+
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+        args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
+        kwargs = {
+            "first_kwarg": torch.tensor([2, 2]),
+            "second_kwarg": torch.tensor([3, 3]),
+        }
+        expected_error = self.get_timeout_error_regex()
+
+        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0.5)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure timeout if we don't specify but the default is less than the
+        # RPC takes to execute.
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc_async_call_future_ret(dst_worker_name, args, kwargs)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if zero timeout is specified
+        fut = rpc_async_call_with_timeout_future_ret(dst_worker_name, args, kwargs, 0)
+        result = fut.wait()
+        self.assertEqual(result, torch.tensor([8, 8]))
+        # reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_remote_timeout_to_here_in_jit(self):
+        # Test that calling to_here() in JIT will raise timeout error if
+        # rpc.remote failed.
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call to_here() within a ScriptFunction and ensure it raises
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref_to_here(rref)
+
+    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
+    def test_rref_to_here_timeout_in_jit(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref_to_here_with_timeout(rref, 0.01)
+
+        rref_to_here_with_timeout(rref, 100)
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_rref_timeout_pickle_in_jit(self):
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call RPC with RRef arg in JIT, which will go through JIT pickling and
+        # ensure error is raised.
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc_async_with_rref_arg(dst_worker, (rref,))
+
+    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
+    def test_rref_timeout_pickle_script_func(self):
+        # Similar to above test, but calls python rpc with script function.
+        if self.rank != 0:
+            return
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        rref = rpc.remote(
+            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
+        )
+        # Will ensure error handling callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        # Call RPC with script function that takes RRef, ensure timeout during pickling
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rpc.rpc_sync(dst_worker, rref_to_here, args=(rref,))
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a684b73d2f315a00465371fad3050a795251ddb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
@@ -0,0 +1,63 @@
+# mypy: allow-untyped-defs
+
+import os
+from abc import ABC, abstractmethod
+
+import torch.testing._internal.dist_utils
+
+
+class RpcAgentTestFixture(ABC):
+    @property
+    def world_size(self) -> int:
+        return 4
+
+    @property
+    def init_method(self):
+        use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+        if use_tcp_init == "1":
+            master_addr = os.environ["MASTER_ADDR"]
+            master_port = os.environ["MASTER_PORT"]
+            return f"tcp://{master_addr}:{master_port}"
+        else:
+            return self.file_init_method
+
+    @property
+    def file_init_method(self):
+        return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format(
+            file_name=self.file_name
+        )
+
+    @property
+    @abstractmethod
+    def rpc_backend(self):
+        pass
+
+    @property
+    @abstractmethod
+    def rpc_backend_options(self):
+        pass
+
+    def setup_fault_injection(self, faulty_messages, messages_to_delay):  # noqa: B027
+        """Method used by dist_init to prepare the faulty agent.
+
+        Does nothing for other agents.
+        """
+
+    # Shutdown sequence is not well defined, so we may see any of the following
+    # errors when running tests that simulate errors via a shutdown on the
+    # remote end.
+    @abstractmethod
+    def get_shutdown_error_regex(self):
+        """
+        Return various error message we may see from RPC agents while running
+        tests that check for failures. This function is used to match against
+        possible errors to ensure failures were raised properly.
+        """
+
+    @abstractmethod
+    def get_timeout_error_regex(self):
+        """
+        Returns a partial string indicating the error we should receive when an
+        RPC has timed out. Useful for use with assertRaisesRegex() to ensure we
+        have the right errors during timeout.
+        """
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..5149c66810f8d5b89746f4d40d4683ab51ef5e6b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -0,0 +1,6318 @@
+# mypy: allow-untyped-defs
+
+import concurrent.futures
+import contextlib
+import json
+import operator
+import os
+import sys
+import threading
+import time
+from collections import namedtuple
+from functools import partial
+from threading import Event, Lock
+from unittest import mock
+
+import torch
+import torch.distributed as dist
+import torch.distributed.autograd as dist_autograd
+import torch.distributed.rpc as rpc
+import torch.nn as nn
+from torch.autograd.profiler_legacy import profile as _profile
+from torch.distributed.rpc import (
+    _get_debug_info,
+    _rref_context_get_debug_info,
+    RRef,
+    WorkerInfo,
+)
+from torch.distributed.rpc.api import _thread_local_var, _use_rpc_pickler, _wait_all
+from torch.distributed.rpc.internal import (
+    _build_rpc_profiling_key,
+    _internal_rpc_pickler,
+    PythonUDF,
+    RPCExecMode,
+)
+from torch.futures import Future
+from torch.testing._internal.common_distributed import (
+    captured_output,
+    skip_if_lt_x_gpu,
+    tp_transports,
+)
+from torch.testing._internal.common_utils import (
+    get_cycles_per_ms,
+    IS_MACOS,
+    load_tests,
+    skip_but_pass_in_sandcastle_if,
+    TemporaryFileName,
+)
+from torch.testing._internal.dist_utils import (
+    dist_init,
+    get_function_event,
+    initialize_pg,
+    wait_until_node_failure,
+    wait_until_owners_and_forks_on_rank,
+    wait_until_pending_futures_and_users_flushed,
+    worker_name,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+def foo_add():
+    return torch.add(torch.ones(1), torch.ones(1))
+
+
+def udf_with_torch_ops(device=-1, use_record_function=False):
+    device_ctx = contextlib.nullcontext() if device == -1 else torch.cuda.device(device)
+    record_function_ctx = (
+        torch.autograd.profiler.record_function("##forward##")
+        if use_record_function
+        else contextlib.nullcontext()
+    )
+    with device_ctx, record_function_ctx:
+        t1, t2 = torch.ones(1), torch.ones(1)
+        t = torch.add(t1, t2)
+        t = torch.mul(t, t)
+        t = t.relu()
+        t = t.sigmoid()
+
+
+# Events (operator invocations) that are expected to be ran as part of the above
+# function.
+EXPECTED_REMOTE_EVENTS = [
+    "aten::ones",
+    "aten::ones",
+    "aten::add",
+    "aten::mul",
+    "aten::relu",
+    "aten::clamp_min",
+    "aten::sigmoid",
+]
+
+# Remote operations are prefixed with the following string for RPC profiling.
+REMOTE_OP_STR = "#remote_op: "
+
+
+VALUE_FUTURE = concurrent.futures.Future()
+DONE_FUTURE = concurrent.futures.Future()
+
+FIFTY_MIL_CYCLES = 50000000
+
+_rpc_barrier_count = 0
+
+
+def _increment_count():
+    global _rpc_barrier_count
+    _rpc_barrier_count += 1
+
+
+def _reset_count():
+    global _rpc_barrier_count
+    _rpc_barrier_count = 0
+
+
+class StubRpcAgent:
+    def __init__(self, world_size):
+        self.world_size = world_size
+
+    def get_worker_infos(self):
+        return {
+            WorkerInfo(name=worker_name(rank), id=rank)
+            for rank in range(self.world_size)
+        }
+
+
+def _stub_construct_rpc_backend_options_handler(**kwargs):
+    return mock.Mock()  # RpcBackendOptions.
+
+
+def _stub_init_rpc_backend_handler(store, name, rank, world_size, rpc_backend_options):
+    return StubRpcAgent(world_size=world_size)
+
+
+def set_value(value):
+    VALUE_FUTURE.set_result(value)
+
+
+def wait_for_value_future():
+    return VALUE_FUTURE.result()
+
+
+def set_and_check_done(value):
+    VALUE_FUTURE.set_result(value)
+    return DONE_FUTURE.result()
+
+
+# it is used to test python user defined function over rpc
+# classes and functions are used to test python user defined class and
+# methods over rpc
+TensorClass = namedtuple("TensorClass", ["tensors"])
+
+
+class MyPickleClass:
+    def __init__(self) -> None:
+        self.t = None
+
+    def __getstate__(self):
+        (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
+            PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
+        )
+        return (pickled_python_udf, tensors)
+
+    def __setstate__(self, obj):
+        python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
+        result = python_udf.func(python_udf.args[0], python_udf.args[1])
+        self.t = result
+
+    def set(self, val):
+        self.t = val
+
+
+class SlowPickleClass:
+    def __init__(self, t):
+        self.t = t
+
+    def __getstate__(self):
+        time.sleep(self.t)
+        return (self.t,)
+
+    def __setstate__(self, obj):
+        self.t = obj[0]
+        time.sleep(self.t)
+
+
+class MyClass:
+    def __init__(self, a, delay=False):
+        self.a = a
+        # delay initialization to simulate errors if specified
+        if delay:
+            time.sleep(2)
+
+    def my_instance_method(self, b):
+        return self.a + b
+
+    @classmethod
+    def my_class_method(cls, d, e):
+        return d + e
+
+    @staticmethod
+    def my_static_method(f):
+        return f > 10
+
+    def increment_value(self, increment):
+        self.a += increment
+
+    def get_value(self):
+        return self.a
+
+    def my_slow_method(self, my_tensor_arg):
+        time.sleep(5)
+        return torch.add(self.a, my_tensor_arg)
+
+
+def _call_method_on_rref(method, rref, *args, **kwargs):
+    return method(rref.local_value(), *args, **kwargs)
+
+
+def get_rref_list(values):
+    return [RRef(MyClass(a)) for a in values]
+
+
+def add_rref_to_value(rref, value):
+    return rref.to_here() + value
+
+
+def run_nested_pickle(pickle_cls_instance, tensor):
+    return pickle_cls_instance.t + tensor
+
+
+def build_sparse_tensor(coalesce=False):
+    i = [[0, 1, 1], [2, 0, 2]]
+    v = [3, 4, 5]
+    tensor = torch.sparse_coo_tensor(i, v, (2, 3))
+    if coalesce:
+        tensor = tensor.coalesce()
+    return tensor
+
+
+def build_complex_tensors():
+    a = torch.ones(3, 3)
+    b = [a, a]
+    c = [b, b]
+    d = [a, b]
+    e = {a: d}
+    return [a, b, c, d, e]
+
+
+def non_cont_test(t_view, t_cont):
+    if t_view.is_contiguous():
+        raise Exception("t_view is contiguous!")  # noqa: TRY002
+    if not t_cont.is_contiguous():
+        raise Exception("t_cont is not contiguous!")  # noqa: TRY002
+    if not torch.equal(t_view, t_cont):
+        raise Exception("t_view is not equal to t_cont!")  # noqa: TRY002
+    return t_view
+
+
+def my_function(a, b, c):
+    return a + b + c
+
+
+def my_tensor_function(a, b):
+    return a + b
+
+
+def my_container_sum(a):
+    result = a[0]
+    for tensor in a[1:]:
+        result += tensor
+    return result
+
+
+def my_sleep_func(seconds=1):
+    time.sleep(seconds)
+    return torch.mul(torch.tensor(1), torch.tensor(1))
+
+
+def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
+    res = list_input[0]
+    for t in list_input:
+        res += t
+    for v in dict_input.values():
+        res += v
+    complex_tensors = tensor_class_input.tensors
+    return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
+
+
+def my_rref_function(rref_a, rref_b):
+    return rref_a.to_here() + rref_b.to_here()
+
+
+def delayed_add(a, b, seconds=0.05):
+    time.sleep(seconds)
+    return a + b
+
+
+def identity(a):
+    return a
+
+
+def no_result():
+    print("do nothing")
+
+
+def raise_or_inc(value):
+    if value.numel() == 2:
+        raise ValueError("Expected error")
+    return value + 1
+
+
+def nested_rpc(dst):
+    return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+def nested_rpc_sparse(dst):
+    return rpc.rpc_sync(
+        dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())
+    )
+
+
+def multi_layer_nested_async_rpc(dst, world_size, ttl):
+    # this method returns immediately without blocking the callee, but will
+    # generate additional requests.
+    if ttl > 0:
+        current_dst = worker_name(dst)
+        next_dst = (dst + 1) % world_size
+        rpc.rpc_async(
+            current_dst,
+            multi_layer_nested_async_rpc,
+            args=(next_dst, world_size, ttl - 1),
+        )
+        return 0
+
+
+def nested_rref(dst):
+    return (
+        rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
+        rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
+    )
+
+
+def nested_rref_sparse(dst):
+    return (
+        rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())),
+        rpc.remote(dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())),
+    )
+
+
+def nested_remote(dst):
+    rref = rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
+    return rref.to_here()
+
+
+def nested_remote_sparse(dst):
+    rref = rpc.remote(
+        dst, torch.add, args=(build_sparse_tensor(), build_sparse_tensor())
+    )
+    return rref.to_here()
+
+
+def rref_forward_chain(dst, world_size, rref, ttl):
+    if ttl > 0:
+        current_dst = worker_name(dst)
+        next_dst = (dst + 1) % world_size
+        ret_rref = rpc.remote(
+            current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
+        )
+        return [ret_rref]
+    else:
+        return rref.to_here()
+
+
+def rpc_return_rref(dst):
+    return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
+
+
+def light_rpc():
+    return 0
+
+
+def heavy_rpc(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor /= i + 1
+    return 0
+
+
+def heavy_rpc_sparse(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor = tensor / (i + 1)
+    return 0
+
+
+@torch.jit.script
+def heavy_rpc_torchscript(tensor):
+    for i in range(1, 100):
+        tensor *= i
+        tensor /= i + 1
+    return 0
+
+
+@torch.jit.script
+def my_script_func(tensor):
+    return torch.add(tensor, tensor)
+
+
+expected_err = "Expected error"
+
+
+# Note that it needs to inherit from Exception, not BaseException. See comment
+# in rpc/internal.py
+class CustomException(Exception):
+    def __init__(self, bool, msg):
+        self.bool = bool
+        super().__init__(msg)
+
+
+def raise_func():
+    raise ValueError(expected_err)
+
+
+def custom_raise_func():
+    raise CustomException(True, "foo")
+
+
+@torch.jit.script
+def raise_func_script(expected_err: str) -> torch.Tensor:
+    raise ValueError(expected_err)
+
+
+expected_err_escape = (
+    "\nFirst line of error \n next line of error \n last line of error"
+)
+
+
+def raise_func_escape():
+    raise ValueError(expected_err_escape)
+
+
+global_rref = None
+
+
+def set_global_rref(rref):
+    global global_rref
+    global_rref = rref
+
+
+def clear_global_rref():
+    global global_rref
+    global_rref = None
+
+
+def check_rref_confirmed(rref):
+    return rref.confirmed_by_owner()
+
+
+def get_rref_debug_info():
+    return _rref_context_get_debug_info()
+
+
+def add_use_future_cb(to, x, y, z):
+    out = concurrent.futures.Future()
+
+    def callback(fut):
+        out.set_result(fut.wait() + z)
+
+    fut = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut.then(callback)
+    return out.result()
+
+
+def get_events_from_profile(profile_rref):
+    return profile_rref.local_value().process_global_function_events
+
+
+def add_use_future_set_result(to, x, y, z):
+    out = torch.futures.Future()
+    fut = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut.then(lambda fut: out.set_result(fut.wait() + z))
+    return out.wait()
+
+
+def add_use_future_nested_cb(to, x, y, z):
+    out = torch.futures.Future()
+
+    def callback(fut1):
+        fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z))
+        fut2.then(lambda fut2: out.set_result(fut2.wait()))
+
+    fut1 = rpc.rpc_async(to, torch.add, args=(x, y))
+    fut1.then(callback)
+    return out.wait()
+
+
+def fail_on_fut(fut):
+    pass
+
+
+@rpc.functions.async_execution
+def async_raise_func():
+    raise RuntimeError("Expected error")
+
+
+@rpc.functions.async_execution
+def async_wrong_type():
+    return torch.zeros(2, 2)
+
+
+@rpc.functions.async_execution
+def async_add(to, x, y):
+    return rpc.rpc_async(to, torch.add, args=(x, y))
+
+
+def slow_add(x, y, device="cpu"):
+    time.sleep(1)
+    x = x.to(device)
+    y = y.to(device)
+    return torch.add(x, y).cpu()
+
+
+@rpc.functions.async_execution
+def slow_async_add(to, x, y, device="cpu"):
+    return rpc.rpc_async(to, slow_add, args=(x, y, device))
+
+
+@rpc.functions.async_execution
+def async_add_with_future_ctor(to, x, y, z):
+    fut = torch.futures.Future()
+    rpc.rpc_async(to, torch.add, args=(x, y)).then(
+        lambda fut1: fut.set_result(fut1.wait() + z)
+    )
+    return fut
+
+
+@rpc.functions.async_execution
+def async_add_chained(to, x, y, z):
+    return rpc.rpc_async(to, torch.add, args=(x, y)).then(lambda fut: fut.wait() + z)
+
+
+@rpc.functions.async_execution
+def async_add_chained_multi(to, x, num, step):
+    fut = rpc.rpc_async(to, torch.add, args=(x, 0))
+    for _ in range(num):
+        fut = fut.then(lambda fut: fut.wait() + step)
+    return fut
+
+
+@rpc.functions.async_execution
+def async_add_nested(to, x, y, z):
+    return rpc.rpc_async(to, async_add, args=(to, x, y)).then(
+        lambda fut: fut.wait() + z
+    )
+
+
+@rpc.functions.async_execution
+def async_add_multi_fanout(to, x, num, step):
+    futs = []
+    for i in range(num):
+        if i == 0:
+            futs.append(rpc.rpc_async(to, torch.add, args=(x, step)))
+        else:
+            futs.append(rpc.rpc_async(to, torch.add, args=(0, step)))
+
+    # TODO: use torch.futures.collect_all
+    lock = Lock()
+    state = {"cnt": 0, "ret": torch.zeros_like(x)}
+    ret_future = torch.futures.Future()
+
+    def inc_and_set(fut):
+        with lock:
+            state["cnt"] += 1
+            state["ret"] += fut.wait()
+            if state["cnt"] >= len(futs):
+                ret_future.set_result(state["ret"])
+
+    for fut in futs:
+        fut.then(inc_and_set)
+
+    return ret_future
+
+
+@rpc.functions.async_execution
+def async_cuda_sleep_and_set_to_one(t):
+    device = t.device
+    original_stream = torch.cuda.current_stream(device)
+    new_stream = torch.cuda.Stream(device)
+    new_stream.wait_stream(original_stream)
+    with torch.cuda.stream(new_stream):
+        torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+        t.fill_(1)
+        fut = Future(devices=[device])
+        fut.set_result(t)
+        return fut
+
+
+@rpc.functions.async_execution
+def async_cuda_nested_add(to, x, y, z):
+    def cb(fut):
+        torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+        return fut.value() + z
+
+    return rpc.rpc_async(to, torch.add, args=(x, y)).then(cb)
+
+
+# A custom Python class that contains a tensor, needed to see if we correctly
+# use the Python pickler to extract tensors from non-IValue-convertible types.
+class TensorWrapper:
+    __slots__ = ("tensor", "lock", "event", "thread")
+
+    def __init__(self, t):
+        self.tensor = t
+        # Add one non-picklable field, to ensure it's ignored/skipped.
+        self.lock = Lock()
+        self.event = torch.cuda.Event(enable_timing=True)
+        self.thread = threading.Thread()
+        self.thread.start()
+
+    def increase(self, v):
+        with self.lock:
+            self.tensor += v
+
+    def sum(self):
+        with self.lock:
+            self.event.record()
+            return self.tensor.sum()
+
+
+class AsyncExecutionClass:
+    @staticmethod
+    @rpc.functions.async_execution
+    def static_async_add(to, x, y, z):
+        return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: fut.wait() + z
+        )
+
+    @classmethod
+    @rpc.functions.async_execution
+    def class_async_add(cls, to, x, y, z):
+        ret_fut = torch.futures.Future()
+        rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: ret_fut.set_result(fut.wait() + z)
+        )
+        return ret_fut
+
+    @rpc.functions.async_execution
+    def bound_async_add(self, to, x, y, z):
+        return rpc.rpc_async(to, torch.add, args=(x, y)).then(
+            lambda fut: fut.wait() + z
+        )
+
+
+def return_future():
+    return torch.futures.Future()
+
+
+class FooBackendOptions(rpc.RpcBackendOptions):
+    def __init__(self, init_method):
+        # Must call the __init__ of the superclass (and do so directly,
+        # without using super()) because... pybind.
+        rpc.RpcBackendOptions.__init__(self)
+        self.init_method = init_method
+
+
+# load_tests from common_utils is used to automatically filter tests for
+# sharding on sandcastle. This line silences flake warnings
+load_tests = load_tests
+
+
+class MyEmbeddingBagModel(torch.nn.Module):
+    def __init__(self, sparse):
+        super().__init__()
+        self.eb = torch.nn.EmbeddingBag(10, 10, sparse=sparse)
+
+    def forward(self, x):
+        return self.eb(x)
+
+
+class MyParameterServer:
+    def __init__(self, trainers):
+        self.lock = Lock()
+        self.trainers = trainers
+        self.iteration = 0
+        self.updates = 0
+        self.futures = []
+        self.total = None
+        self.gradient = None
+
+    @staticmethod
+    def get_gradient(rref):
+        return rref.local_value().gradient
+
+    @staticmethod
+    @rpc.functions.async_execution
+    def average(rref, riteration, tensor):
+        self = rref.local_value()
+        fut = torch.futures.Future()
+        with self.lock:
+            if riteration > self.iteration:
+                self.iteration = riteration
+                self.updates = 0
+                self.futures.clear()
+            self.futures.append(fut)
+            if self.total is None:
+                self.total = tensor
+            else:
+                self.total += tensor
+            self.updates += 1
+            if self.trainers == self.updates:
+                self.gradient = self.total / float(self.trainers)
+                for fut in self.futures:
+                    result = self.total / float(self.trainers)
+                    fut.set_result(result)
+        return fut
+
+
+class MyConvNetForMNIST(nn.Module):
+    def __init__(self, device):
+        super().__init__()
+        self.net = nn.Sequential(
+            nn.Conv2d(1, 16, 3, 1),
+            nn.ReLU(),
+            nn.Conv2d(16, 32, 3, 1),
+            nn.ReLU(),
+            nn.MaxPool2d(2),
+            nn.Flatten(1),
+            nn.Linear(4608, 128),
+            nn.ReLU(),
+            nn.Linear(128, 10),
+        ).to(device)
+        self.device = device
+
+    def forward(self, x, is_rref=False):
+        x = x.to_here() if is_rref else x
+        with torch.cuda.stream(torch.cuda.current_stream(self.device)):
+            # intentionally adding delay to current CUDA stream
+            torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+            return self.net(x)
+
+    def __getstate__(self):
+        # return an empty dict to avoid inspecting the model contents on the
+        # owner
+        return {}
+
+
+class RpcTestCommon:
+    def _run_func_in_mode(self, to, fn, mode, args=None, kwargs=None):
+        if mode == RPCExecMode.SYNC:
+            return rpc.rpc_sync(to, fn, args=args, kwargs=kwargs)
+        elif mode == RPCExecMode.ASYNC:
+            return rpc.rpc_async(to, fn, args=args, kwargs=kwargs).wait()
+        elif mode == RPCExecMode.REMOTE:
+            return rpc.remote(to, fn, args=args, kwargs=kwargs).to_here()
+
+    def _self_py_udf_remote(self, worker_info, x, y, z):
+        rref = rpc.remote(worker_info, my_function, args=(x, y, z))
+        self.assertEqual(rref.to_here(), x + y + z)
+
+    def _self_remote_rref_as_rpc_arg(self, dst, x, y, z):
+        self_worker_info = rpc.get_worker_info()
+        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+        fut = rpc.rpc_async(dst, add_rref_to_value, args=(rref, x))
+        ret = rpc.rpc_sync(dst, add_rref_to_value, args=(rref, x + y))
+        self.assertEqual(ret, x + y + z + x + y)
+        self.assertEqual(fut.wait(), x + y + z + x)
+
+    def _self_remote_rref_as_remote_arg(self, dst, x, y, z):
+        self_worker_info = rpc.get_worker_info()
+        rref = rpc.remote(self_worker_info, my_function, args=(x, y, z))
+        ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, x))
+        self.assertEqual(ret_rref.to_here(), x + y + z + x)
+
+    def _world_size_one(self, a, b):
+        if self.rank == 0:
+            rpc.init_rpc(
+                name="me",
+                backend=self.rpc_backend,
+                rank=0,
+                world_size=1,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+            def _rpc_sync(x, y):
+                expect = x * 2
+                result = rpc.rpc_sync("me", my_tensor_function, args=(x, y))
+                self.assertEqual(expect, result)
+
+            def _rpc_async(x, y):
+                expect = x * 2
+                result = rpc.rpc_async("me", my_tensor_function, args=(x, y)).wait()
+                self.assertEqual(expect, result)
+
+            def _remote(x, y):
+                expect = x * 2
+                result = rpc.remote("me", my_tensor_function, args=(x, y)).to_here()
+                self.assertEqual(expect, result)
+
+            _rpc_sync(a, b)
+            _rpc_async(a, b)
+            _remote(a, b)
+
+            rpc.shutdown()
+
+    def _multi_rpc(self, sparse):
+        dst_rank = (self.rank + 1) % self.world_size
+        for i in range(20):
+            n = i + self.rank + 1
+            if sparse:
+                x = build_sparse_tensor() * n
+                y = build_sparse_tensor() * n
+            else:
+                x = torch.ones(2, 2)
+                y = torch.ones(2, 2)
+            ret = rpc.rpc_sync(
+                worker_name(dst_rank),
+                torch.add,
+                args=(x, y),
+            )
+            self.assertEqual(ret, x * 2)
+
+    def _run_uneven_workload(self, f, x, num_repeat=30):
+        # worker0 drives and waits for worker1 and worker2
+        # throughout the test.
+        if self.rank == 0:
+            self.assertTrue(self.world_size >= 3)
+
+            # Phase 1: Only worker1 has workload.
+            dst = "worker1"
+            futs = []
+            for _ in range(num_repeat):
+                fut = rpc.rpc_async(dst, f, args=(x,))
+                futs.append(fut)
+
+            for fut in torch.futures.collect_all(futs).wait():
+                self.assertEqual(fut.wait(), 0)
+
+            # Phase 2: Only worker2 has workload.
+            # If join is not correctly implemented,
+            # worker2 should be closed by now.
+            dst = "worker2"
+            futs = []
+            for _ in range(num_repeat):
+                fut = rpc.rpc_async(dst, f, args=(x,))
+                futs.append(fut)
+
+            for val in torch.futures.wait_all(futs):
+                self.assertEqual(val, 0)
+
+    def _wait_all_workers(self, f, x):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        self._run_uneven_workload(f, x)
+
+        # worker0 calls this at the end after waiting for RPC responses.
+        # worker1/2 calls this immediately and has some works after it.
+        # worker3 calls this immediately and has no more work.
+        rpc.api._wait_all_workers()
+
+        # Wait before proceeding to shutdown to ensure worker0 RPCs make
+        # it through to other workers.
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+
+    def _wait_all_workers_twice(self, f, x):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        self._run_uneven_workload(f, x)
+
+        # worker0 calls this at the end after waiting for RPC responses.
+        # worker1/2 calls this immediately and has some works after it.
+        # worker3 calls this immediately and has no more work.
+        rpc.api._wait_all_workers()
+        rpc.api._wait_all_workers()
+
+        # Wait before proceeding to shutdown to ensure worker0 RPCs make
+        # it through to other workers.
+        dist.barrier()
+        rpc.shutdown(graceful=False)
+
+    def _nested_rpc(self, f, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            f,
+            args=(worker_name(self.rank),),
+        )
+        self.assertEqual(ret, expected)
+
+    def _stress_test_rpc(self, f, repeat=1000, args=()):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        futs = []
+        tik = time.time()
+        for _ in range(repeat):
+            fut = rpc.rpc_async(worker_name(dst_rank), f, args=args)
+            futs.append(fut)
+
+        for val in torch.futures.wait_all(futs):
+            self.assertEqual(val, 0)
+        tok = time.time()
+        print(
+            f"Rank {self.rank} finished testing {repeat} times in {tok - tik} seconds."
+        )
+
+    def _builtin_remote_ret(self, x, y, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            torch.add,
+            args=(x, y),
+        )
+        self.assertEqual(rref.to_here(), expected)
+
+    def _builtin_remote_self(self, x, y, expected):
+        rref = rpc.remote(
+            worker_name(self.rank),
+            torch.add,
+            args=(x, y),
+        )
+        self.assertEqual(rref.local_value(), expected)
+
+    def _test_multi_remote_call(
+        self, fn, sparse, args_fn=lambda x, y: (), kwargs_fn=lambda x, y: {}
+    ):
+        m = 10
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rrefs = []
+        expected = []
+        for i in range(m):
+            n = n + i
+            rrefs.append(
+                rpc.remote(
+                    worker_name(dst_rank),
+                    fn,
+                    args=args_fn(n, sparse),
+                    kwargs=kwargs_fn(n, sparse),
+                )
+            )
+            expected.append(fn(*args_fn(n, sparse), **kwargs_fn(n, sparse)))
+
+        for i in range(m):
+            self.assertEqual(rrefs[i].to_here(), expected[i])
+
+    def _py_rref_args(self, a, b, x, y, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(worker_name(dst_rank), torch.add, args=(a, b))
+        rref_b = rpc.remote(worker_name(dst_rank), torch.add, args=(x, y))
+        rref_c = rpc.remote(
+            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), expected)
+
+    def _py_rref_args_user_share(self, a, b, c, x, y, z, expected):
+        n = self.rank + 1
+        owner_rank = n % self.world_size
+        user_rank = (n + 1) % self.world_size
+        rref_a = rpc.remote(worker_name(owner_rank), my_function, args=(a, b, c))
+        rref_b = rpc.remote(worker_name(owner_rank), my_function, args=(x, y, z))
+        rref_c = rpc.remote(
+            worker_name(user_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), expected)
+
+    def _py_rpc_rref_args(self, a, b, c, x, y, z, expected):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(worker_name(dst_rank), my_function, args=(a, b, c))
+        rref_b = rpc.remote(worker_name(dst_rank), my_function, args=(x, y, z))
+
+        c = rpc.rpc_sync(worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b))
+        self.assertEqual(c, expected)
+
+    def _nested_remote(self, f, expected):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+
+        rref = rpc.remote(
+            worker_name(dst_rank1),
+            f,
+            args=(worker_name(dst_rank2),),
+        )
+        self.assertEqual(rref.to_here(), expected)
+
+    def _nested_rref(self, f, expected1, expected2):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        rref_of_rrefs = rpc.remote(
+            worker_name(dst_rank1),
+            f,
+            args=(worker_name(dst_rank2),),
+        )
+
+        # Say C has 2 OwnerRRefs.
+        # B has 2 UserRRefs to those 2 OwnerRRefs, respectively.
+        # This call is effectively A asking B to share its 2 UserRRefs.
+        rrefs = rref_of_rrefs.to_here()
+
+        self.assertEqual(len(rrefs), 2)
+        self.assertEqual(rrefs[0].to_here(), expected1)
+        self.assertEqual(rrefs[1].to_here(), expected2)
+
+    def _nested_rref_stress(self, f, expected1, expected2):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        all_rrefs = [
+            rpc.remote(
+                worker_name(dst_rank1),
+                f,
+                args=(worker_name(dst_rank2),),
+            )
+            for _ in range(20)
+        ]
+
+        for i in range(20):
+            rref_of_rrefs = all_rrefs[i]
+            rrefs = rref_of_rrefs.to_here()
+            self.assertEqual(len(rrefs), 2)
+            self.assertEqual(rrefs[0].to_here(), expected1)
+            self.assertEqual(rrefs[1].to_here(), expected2)
+
+    def _trainer_func(self, rref, sparse):
+        m = MyEmbeddingBagModel(sparse=sparse)
+        loss_fn = nn.MSELoss()
+        for i in range(10):
+            outputs = m(torch.rand(10, 10).long())
+            loss_fn(outputs, torch.rand(10, 10)).backward()
+            gradient = next(iter(m.parameters())).grad
+            fut = rref.rpc_async().average(rref, i, gradient)
+            gradient = fut.wait()
+            if gradient.is_sparse:
+                gradient = gradient.to_dense().double()
+            ps_gradient = rref.rpc_sync().get_gradient(rref)
+            if ps_gradient.is_sparse:
+                ps_gradient = ps_gradient.to_dense().double()
+            self.assertTrue(torch.equal(gradient, ps_gradient))
+
+    def _my_parameter_server(self, sparse):
+        ps_rref = RRef(MyParameterServer(self.world_size - 1))
+        futures = [
+            rpc.rpc_async(
+                worker_name((self.rank + index) % self.world_size),
+                self._trainer_func,
+                args=(ps_rref, sparse),
+            )
+            for index in range(1, self.world_size)
+        ]
+        torch.futures.wait_all(futures)
+
+    def _test_cuda_future_extraction(self, wrapper, unwrapper, sparse_tensor):
+        # We check proper CUDA stream synchronization by adding to the tensor
+        # in one stream to get the expected value, and reading it from another stream.
+        future = Future(devices=["cuda:0"])
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                if sparse_tensor:
+                    tensor = build_sparse_tensor().to("cuda:0")
+                    add_tensor = build_sparse_tensor().to("cuda:0")
+                    expected_tensor = (tensor + add_tensor).coalesce()
+                else:
+                    tensor = torch.zeros((100,), device="cuda:0")
+                    add_tensor = torch.ones((100,), device="cuda:0")
+                    expected_tensor = tensor + add_tensor
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor += add_tensor
+                if sparse_tensor:
+                    tensor = tensor.coalesce()
+                future.set_result(wrapper(tensor))
+            with torch.cuda.stream(another_stream):
+                tensor = unwrapper(future.wait())
+                if sparse_tensor:
+                    self.assertTrue(
+                        torch.eq(tensor.indices(), expected_tensor.indices())
+                        .all()
+                        .item()
+                    )
+                    self.assertTrue(
+                        torch.eq(tensor.values(), expected_tensor.values()).all().item()
+                    )
+                    self.assertEqual(tensor.size(), expected_tensor.size())
+                else:
+                    self.assertTrue(torch.eq(tensor, expected_tensor).all().item())
+
+
+class RpcTest(RpcAgentTestFixture, RpcTestCommon):
+    @dist_init
+    def test_worker_id(self):
+        n = self.rank + 1
+        peer_rank = n % self.world_size
+        self_worker_info = rpc.get_worker_info()
+        peer_worker_info = rpc.get_worker_info(worker_name(peer_rank))
+
+        self.assertEqual(self_worker_info.name, worker_name(self.rank))
+        self.assertEqual(peer_worker_info.name, worker_name(peer_rank))
+
+        with self.assertRaisesRegex(RuntimeError, "could not find destination"):
+            rpc.get_worker_info("WorkerUnknown")
+
+    @dist_init
+    def test_get_worker_infos(self):
+        worker_infos = rpc.api._get_current_rpc_agent().get_worker_infos()
+
+        worker_names = {worker_info.name for worker_info in worker_infos}
+        expected_worker_names = {worker_name(rank) for rank in range(self.world_size)}
+        self.assertEqual(worker_names, expected_worker_names)
+
+        worker_ids = {worker_info.id for worker_info in worker_infos}
+        expected_worker_ids = set(range(self.world_size))
+        self.assertEqual(worker_ids, expected_worker_ids)
+
+    @dist_init
+    def test_self_add(self):
+        self_worker_info = rpc.get_worker_info()
+        fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
+        ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
+        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_send_to_rank(self):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        # Test dense tensor
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            ret = self._run_func_in_mode(
+                dst_rank, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+            )
+            self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test invalid ranks
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(RuntimeError):
+                self._run_func_in_mode(
+                    self.world_size + 1,
+                    torch.add,
+                    exec_mode,
+                    args=(torch.ones(2, 2), 1),
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(RuntimeError):
+                self._run_func_in_mode(
+                    -1, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(ValueError):
+                self._run_func_in_mode(
+                    dst_rank + 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            with self.assertRaises(ValueError):
+                self._run_func_in_mode(
+                    dst_rank - 0.5, torch.add, exec_mode, args=(torch.ones(2, 2), 1)
+                )
+
+    @dist_init
+    def test_self_py_udf_remote(self):
+        self._self_py_udf_remote(rpc.get_worker_info(), torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_rpc_arg(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_rpc_arg(dst, torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_self_rpc_arg(self):
+        self._self_remote_rref_as_rpc_arg(rpc.get_worker_info(), torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_remote_arg(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_remote_arg(dst, torch.ones(2, 2), 1, 3)
+
+    @dist_init
+    def test_self_remote_rref_as_self_remote_arg(self):
+        self._self_remote_rref_as_remote_arg(
+            rpc.get_worker_info(), torch.ones(2, 2), 1, 3
+        )
+
+    @dist_init
+    def test_rref_proxy_non_exist(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
+        msg = "has no attribute 'non_exist'"
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.rpc_sync().non_exist()
+
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.rpc_async().non_exist().wait()
+
+        with self.assertRaisesRegex(AttributeError, msg):
+            rref.remote().non_exist()
+
+    def _test_rref_proxy_tensor(self, dst):
+        rref = rpc.remote(dst, my_function, args=(torch.ones(2, 2), 1, 3))
+
+        expected = torch.ones(2, 2) + 1 + 3
+        self.assertEqual(expected.size(), rref.rpc_sync().size())
+        self.assertEqual(expected + 1, rref.rpc_async().add(1).wait())
+        self.assertEqual(expected.view(1, 4), rref.remote().view(1, 4).to_here())
+
+    @dist_init
+    def test_rref_proxy_tensor(self):
+        self._test_rref_proxy_tensor(worker_name((self.rank + 1) % self.world_size))
+
+    @dist_init
+    def test_rref_proxy_tensor_self(self):
+        self._test_rref_proxy_tensor(rpc.get_worker_info())
+
+    @dist_init
+    def test_rref_proxy_reuse(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            my_function,
+            args=(torch.ones(2, 2), 1, 3),
+        )
+        expected = torch.ones(2, 2) + 1 + 3
+
+        proxy_rpc_sync = rref.rpc_sync()
+        proxy_rpc_async = rref.rpc_async()
+        proxy_remote = rref.remote()
+
+        self.assertEqual(expected.size(), proxy_rpc_sync.size())
+        self.assertEqual(expected + 1, proxy_rpc_sync.add(1))
+        self.assertEqual(expected.view(1, 4), proxy_rpc_sync.view(1, 4))
+
+        self.assertEqual(expected.size(), proxy_rpc_async.size().wait())
+        self.assertEqual(expected + 3, proxy_rpc_async.add(3).wait())
+        self.assertEqual(expected.view(4, 1), proxy_rpc_async.view(4, 1).wait())
+
+        self.assertEqual(expected.size(), proxy_remote.size().to_here())
+        self.assertEqual(expected + 5, proxy_remote.add(5).to_here())
+        self.assertEqual(expected.view(-1), proxy_remote.view(-1).to_here())
+
+    def _test_rref_proxy_class(self, dst):
+        rref = rpc.remote(dst, MyClass, args=(7,))
+        expected = MyClass(7)
+        self.assertEqual(expected.get_value(), rref.rpc_sync().get_value())
+        self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait())
+        self.assertEqual(expected.get_value(), rref.remote().get_value().to_here())
+
+        expected.increment_value(3)
+        self.assertEqual(None, rref.rpc_sync().increment_value(1))
+        self.assertEqual(None, rref.rpc_async().increment_value(1).wait())
+        self.assertEqual(None, rref.remote().increment_value(1).to_here())
+
+        self.assertEqual(expected.get_value(), rref.rpc_sync().get_value())
+        self.assertEqual(expected.get_value(), rref.rpc_async().get_value().wait())
+        self.assertEqual(expected.get_value(), rref.remote().get_value().to_here())
+
+        self.assertEqual(
+            expected.my_instance_method(2), rref.rpc_sync().my_instance_method(2)
+        )
+        self.assertEqual(
+            expected.my_instance_method(3),
+            rref.rpc_async().my_instance_method(3).wait(),
+        )
+        self.assertEqual(
+            expected.my_instance_method(4),
+            rref.remote().my_instance_method(4).to_here(),
+        )
+
+        self.assertEqual(
+            expected.my_static_method(9), rref.rpc_sync().my_static_method(9)
+        )
+        self.assertEqual(
+            expected.my_static_method(10), rref.rpc_async().my_static_method(10).wait()
+        )
+        self.assertEqual(
+            expected.my_static_method(11), rref.remote().my_static_method(11).to_here()
+        )
+
+        self.assertEqual(
+            expected.my_class_method(2, torch.zeros(2, 2)),
+            rref.rpc_sync().my_class_method(2, torch.zeros(2, 2)),
+        )
+        self.assertEqual(
+            expected.my_class_method(2, torch.ones(3, 3)),
+            rref.rpc_async().my_class_method(2, torch.ones(3, 3)).wait(),
+        )
+        self.assertEqual(
+            expected.my_class_method(2, torch.ones(4, 4)),
+            rref.remote().my_class_method(2, torch.ones(4, 4)).to_here(),
+        )
+
+    @dist_init
+    def test_rref_proxy_class(self):
+        self._test_rref_proxy_class(worker_name((self.rank + 1) % self.world_size))
+
+    @dist_init
+    def test_rref_proxy_class_self(self):
+        self._test_rref_proxy_class(rpc.get_worker_info())
+
+    @mock.patch.object(torch.distributed.autograd, "_init")
+    @mock.patch.object(torch.distributed.rpc.api, "_set_and_start_rpc_agent")
+    @dist_init(setup_rpc=False)
+    def test_register_rpc_backend_and_set_and_start_rpc_backend(
+        self, mock_rpc_agent, mock_dist_autograd_init
+    ):
+        backend_name = "stub_backend"
+
+        backend = rpc.backend_registry.register_backend(
+            backend_name,
+            _stub_construct_rpc_backend_options_handler,
+            _stub_init_rpc_backend_handler,
+        )
+
+        with self.assertRaisesRegex(
+            RuntimeError, "^RPC backend .+: already registered$"
+        ):
+            backend = rpc.backend_registry.register_backend(
+                backend_name,
+                _stub_construct_rpc_backend_options_handler,
+                _stub_init_rpc_backend_handler,
+            )
+
+        rpc.init_rpc(
+            name="worker1",
+            backend=backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+    @dist_init(setup_rpc=False)
+    def test_duplicate_name(self):
+        with self.assertRaisesRegex(RuntimeError, "is not unique"):
+            store, _, _ = next(
+                torch.distributed.rendezvous(
+                    self.init_method, rank=self.rank, world_size=self.world_size
+                )
+            )
+            rpc._init_rpc_backend(
+                backend=self.rpc_backend,
+                store=store,
+                name="duplicate_name",
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_duplicate_name_2(self):
+        with self.assertRaisesRegex(RuntimeError, "is not unique"):
+            rpc.init_rpc(
+                name=worker_name(self.rank % (self.world_size - 1)),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_reinit(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        # Wait for all init to complete.
+        dist.barrier()
+
+        # TODO: with TCP init, rank 0 raises Address already in use because
+        # rank 0 is the start daemon and the store is created before checking if
+        # RPC is already initialized in init_rpc.
+        if os.environ.get("RPC_INIT_WITH_TCP", None) == "1" and self.rank == 0:
+            expected_reinit_err = "Address already in use"
+        else:
+            expected_reinit_err = "is already initialized"
+
+        with self.assertRaisesRegex(RuntimeError, expected_reinit_err):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    def test_pg_init_no_rpc_init(self):
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.file_init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        class MyModel(torch.nn.Module):
+            def __init__(self) -> None:
+                super().__init__()
+                self.lin = torch.nn.Linear(3, 4)
+
+            def forward(self, x):
+                return self.lin(x)
+
+        model = MyModel()
+        model.train()
+        model = torch.nn.parallel.DistributedDataParallel(model)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Current RPC agent is not set! Did you initialize the RPC framework",
+        ):
+            [RRef(param) for param in model.parameters()]
+
+    def test_world_size_one(self):
+        self._world_size_one(torch.ones(2, 2), torch.ones(2, 2))
+
+    @dist_init(setup_rpc=False)
+    def test_invalid_names(self):
+        worker_id = 0
+        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
+            WorkerInfo("abc*", worker_id)
+
+        with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
+            WorkerInfo(" ", worker_id)
+
+        with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
+            WorkerInfo("", worker_id)
+
+        # If the number in the message does not match, it is likely that the
+        # value of MAX_NAME_LEN in RPC WorkerInfo has changed.
+        with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
+            WorkerInfo("".join(["a" for i in range(500)]), worker_id)
+
+    # Test that WorkerInfo can be pickled and sent in RPC call
+    @dist_init
+    def test_worker_info_pickle(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        worker_info = rpc.api.get_worker_info()
+        ret = rpc.rpc_sync(worker_name(dst_rank), identity, args=(worker_info,))
+        self.assertEqual(ret, worker_info)
+
+    @dist_init
+    def test_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+
+    @staticmethod
+    def return_callee_id():
+        return rpc.get_worker_info().id
+
+    @dist_init
+    def test_int_callee(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id)
+        self.assertEqual(ret, dst_rank)
+
+    @dist_init
+    def test_add_with_id(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        workder_info = rpc.get_worker_info(worker_name(dst_rank))
+
+        ret = rpc.rpc_sync(
+            workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_scalar_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(torch.ones(n, n), n))
+        self.assertEqual(ret, (torch.ones(n, n) + n))
+
+    @dist_init
+    def test_async_add(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        fut = rpc.rpc_async(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_nonzero(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        x = torch.ones(self.world_size, self.world_size)
+        x[self.rank][self.rank] = 0
+        ret = rpc.rpc_sync(worker_name(dst_rank), torch.nonzero, args=(x,))
+        self.assertEqual(ret, x.nonzero())
+
+    @dist_init
+    def test_multi_rpc(self):
+        self._multi_rpc(False)
+
+    @dist_init
+    def test_future_wait_twice(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [rpc.rpc_async(dst, raise_func) for _ in range(20)]
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+        for fut in futs:
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut.wait()
+
+    @dist_init(setup_rpc=False)
+    def test_wait_all_workers_timeout(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        og_func = rpc.api._wait_all_workers
+
+        def wait_all_workers_sleep(timeout):
+            rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout)
+
+        rpc.api._wait_all_workers = wait_all_workers_sleep
+
+        try:
+            with self.assertRaisesRegex(RuntimeError, ""):
+                rpc.shutdown(graceful=True, timeout=0.01)
+        finally:
+            rpc.api._wait_all_workers = og_func
+        dist.barrier()
+
+    def test_wait_all_workers_dense(self):
+        self._wait_all_workers(heavy_rpc, torch.ones(100, 100))
+
+    def test_wait_all_workers_twice_dense(self):
+        self._wait_all_workers_twice(heavy_rpc, torch.ones(100, 100))
+
+    @dist_init
+    def test_all_gather(self):
+        info = rpc.get_worker_info()
+        results = rpc.api._all_gather(info.id)
+        expected = {}
+        for info in rpc._get_current_rpc_agent().get_worker_infos():
+            expected[info.name] = info.id
+
+        self.assertEqual(expected, results)
+
+    @dist_init
+    def test_all_gather_timeout(self):
+        rpc._set_rpc_timeout(0.1)
+
+        if self.rank == 0:
+            with self.assertRaisesRegex(
+                RuntimeError, "timed out in _all_gather after 0\\.10 seconds"
+            ):
+                rpc.api._all_gather(SlowPickleClass(0.5))
+        else:
+            expected_error = self.get_timeout_error_regex()
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                rpc.api._all_gather(SlowPickleClass(0.5))
+
+    def _test_barrier_helper(self, info, names, multi_threaded=False):
+        names = sorted(names)
+        leader = names[0]
+        rpc.rpc_sync(leader, _reset_count)
+        if not multi_threaded and info.name == leader:
+            self.assertEqual(_rpc_barrier_count, 0)
+        rpc.api._barrier(names)
+        rpc.rpc_sync(leader, _increment_count)
+        rpc.api._barrier(names)
+        if not multi_threaded and info.name == leader:
+            self.assertEqual(_rpc_barrier_count, len(names))
+
+    @dist_init
+    def test_rpc_barrier_all(self):
+        # Test rpc barrier when called with full list of workers
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        names = [worker.name for worker in all_worker_info]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_subset(self):
+        # Test rpc barrier when processes are called with different subsets of the full list
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        if info.id % 2:
+            names = [worker.name for worker in all_worker_info if worker.id % 2]
+        else:
+            names = [worker.name for worker in all_worker_info if not worker.id % 2]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_partial_subset(self):
+        # Test rpc barrier when some processes are not involved in the barrier
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        if info.id % 2:
+            names = [worker.name for worker in all_worker_info if worker.id % 2]
+        else:
+            names = [f"worker{info.id}"]
+        self._test_barrier_helper(info, names)
+
+    @dist_init
+    def test_rpc_barrier_multithreaded(self):
+        # This tests validates the implementation of barrier when multiple threads call into it
+        # We only need to check that it does not hang in this case
+        info = rpc.get_worker_info()
+        all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
+        names = [worker.name for worker in all_worker_info]
+        threads = []
+        for _ in range(3):
+            th = threading.Thread(
+                target=self._test_barrier_helper, args=(info, names, True)
+            )
+            threads.append(th)
+            th.start()
+        for th in threads:
+            th.join()
+
+    @dist_init
+    def test_graceful_shutdown_with_uneven_workload(self):
+        """Test graceful termination."""
+        self._run_uneven_workload(heavy_rpc, torch.ones(100, 100))
+
+    @dist_init(setup_rpc=False)
+    def test_shutdown_followed_by_rpc(self):
+        # Initialize RPC.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, torch.ones(n, n) * 2)
+        rpc.shutdown()
+
+        with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
+            rpc.rpc_sync(
+                worker_name(dst_rank),
+                torch.add,
+                args=(torch.ones(n, n), torch.ones(n, n)),
+            )
+
+    @dist_init
+    def test_expected_src(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        expected_src_rank = (self.rank - 1) % self.world_size
+        rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,))
+        value = VALUE_FUTURE.result()
+        self.assertEqual(value, expected_src_rank)
+
+    @dist_init
+    def test_py_built_in(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), min, args=(n, n + 1, n + 2))
+        self.assertEqual(ret, min(n, n + 1, n + 2))
+
+    @dist_init
+    def test_py_user_defined(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            my_function,
+            kwargs={"a": n, "b": n + 1, "c": n + 2},
+        )
+        self.assertEqual(ret, my_function(n, n + 1, n + 2))
+
+    def test_build_rpc_profiling_key(self):
+        # Tests that the name that shows up as an Event in profiling RPCs has all
+        # the necessary information.
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            rpc_profiling_key = _build_rpc_profiling_key(
+                exec_mode, "foo", "worker0", "worker1"
+            )
+            self.assertIn(exec_mode.value, rpc_profiling_key)
+            self.assertIn("foo", rpc_profiling_key)
+            self.assertIn("worker0", rpc_profiling_key)
+            self.assertIn("worker1", rpc_profiling_key)
+
+    def check_profiling_info(
+        self, self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode
+    ):
+        self.assertTrue(self_worker_name in rpc_event.name)
+        self.assertTrue(dst_worker_name in rpc_event.name)
+        if isinstance(func, torch.jit.ScriptFunction):
+            self.assertTrue(torch._jit_internal._qualified_name(func) in rpc_event.name)
+        else:
+            self.assertTrue(func.__name__ in rpc_event.name)
+        self.assertTrue(rpc_exec_mode.value in rpc_event.name)
+        self.assertEqual(rpc_event.count, 1)
+
+    @dist_init
+    def test_profiler_rpc_record_shapes(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        t1, t2 = torch.ones(100), torch.ones(100)
+        with _profile(record_shapes=True) as prof:
+            rpc.rpc_sync(dst_worker, torch.add, args=(t1, t2))
+
+        function_events = prof.function_events
+        remote_events = [event for event in function_events if event.is_remote]
+        remote_add_event = next(
+            event for event in remote_events if "aten::add" in event.name
+        )
+        remote_add_input_shapes = remote_add_event.input_shapes
+        # Run profiler on equivalent local op and validate shapes are the same.
+        with _profile(record_shapes=True) as prof:
+            torch.add(t1, t2)
+
+        local_function_events = prof.function_events
+        local_add_event = next(
+            event for event in local_function_events if "aten::add" in event.name
+        )
+        local_add_input_shapes = local_add_event.input_shapes
+        self.assertEqual(remote_add_input_shapes, local_add_input_shapes)
+
+    @dist_init
+    def test_profiler_rpc_memory(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        with _profile(profile_memory=True) as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        function_events = p.function_events
+        event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events}
+        # if cpu_memory_usage was not propagated over the wire, this set would
+        # only contain 0 (indicates no memory being profiled)
+        self.assertNotEqual({0}, event_cpu_mem_usages)
+        # No memory profiled if profile_memory=False
+        with _profile(profile_memory=False) as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        function_events = p.function_events
+        event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events}
+        self.assertEqual({0}, event_cpu_mem_usages)
+
+    @dist_init
+    def test_profiler_export_trace(self):
+        if self.rank != 1:
+            return
+        dst = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst)
+        with _profile() as p:
+            fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+            fut.wait()
+
+        with TemporaryFileName() as fname:
+            path = fname
+            p.export_chrome_trace(path)
+            with open(path) as f:
+                trace = json.load(f)
+                event_names = [event["name"] for event in trace]
+                for expected_event_name in EXPECTED_REMOTE_EVENTS + [
+                    RPCExecMode.ASYNC.value
+                ]:
+                    event_exists = any(
+                        expected_event_name in event_name for event_name in event_names
+                    )
+                    self.assertTrue(event_exists)
+
+    @dist_init
+    def test_profiler_rpc_key_names(self):
+        # tests that remote events are properly prefixed with the RPC profiling key.
+        if self.rank != 1:
+            return
+
+        # Spawn multiple threads that send RPCs to ensure keys are correctly
+        # prefixed when there are multiple RPCs being created/in flight at the
+        # same time.
+        dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank]
+
+        def rpc_with_profiling(dst_worker):
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+                fut.wait()
+
+            events = prof.function_events
+            remote_event_names = {
+                event.name: event for event in events if event.is_remote
+            }
+            rpc_profiling_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC,
+                udf_with_torch_ops.__qualname__,
+                worker_name(self.rank),
+                dst_worker,
+            )
+
+            remote_event_name_set = set(EXPECTED_REMOTE_EVENTS)
+            for name, event in remote_event_names.items():
+                # Ensure that we have the expected key as part of the remote
+                # event.
+                self.assertTrue(name.startswith(rpc_profiling_key))
+                self.assertTrue(event.is_remote)
+                self.assertTrue(event.node_id == rpc.get_worker_info(dst_worker).id)
+                # Ensure that the remote event name also contains the operator.
+                operator_name_substr = name[len(rpc_profiling_key) :]
+                # Note: we don't assert that every remote event needs to be
+                # in the above set, the set is just a representative set of
+                # what we expect to see. The profiler can change and add more
+                # events, but we should always expect to see this representative
+                # set.
+                matching_event = {
+                    remote_event_name
+                    for remote_event_name in remote_event_name_set
+                    if remote_event_name in operator_name_substr
+                }
+                remote_event_name_set -= matching_event
+
+            # The set should be empty, otherwise its contained elements did
+            # not show up in the remote profiler output.
+            self.assertTrue(
+                remote_event_name_set == set(),
+                f"Expected {remote_event_name_set} to be included in remote profiler output.",
+            )
+
+        for dst in dst_ranks:
+            dst_worker = worker_name(dst)
+            num_parallel_rpcs = 2
+            with concurrent.futures.ThreadPoolExecutor(
+                max_workers=num_parallel_rpcs
+            ) as executor:
+                futs = [
+                    executor.submit(rpc_with_profiling, dst_worker)
+                    for _ in range(num_parallel_rpcs)
+                ]
+                # Wait for workers to finish test
+                for fut in futs:
+                    fut.result()
+
+    def _run_test_profiler_remote_events_profiled(self):
+        # Tests that we can successfully invoke the profiler on a remote node,
+        # and collect the remote events back in the local profiler.
+        if self.rank != 1:
+            return
+
+        dst_ranks = [rank for rank in range(0, self.world_size) if rank != self.rank]
+        for dst in dst_ranks:
+            dst_worker = worker_name(dst)
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=())
+                fut.wait()
+
+            events = prof.function_events
+
+            rpc_event = get_function_event(events, RPCExecMode.ASYNC.value)
+            self.check_profiling_info(
+                worker_name(self.rank),
+                dst_worker,
+                udf_with_torch_ops,
+                rpc_event,
+                RPCExecMode.ASYNC,
+            )
+
+            remote_events = {event.name: event for event in events if event.is_remote}
+            rpc_profiling_key = _build_rpc_profiling_key(
+                RPCExecMode.ASYNC,
+                udf_with_torch_ops.__qualname__,
+                worker_name(self.rank),
+                worker_name(dst),
+            )
+
+            for expected_remote_event_name in EXPECTED_REMOTE_EVENTS:
+                expected_key = (
+                    rpc_profiling_key + REMOTE_OP_STR + expected_remote_event_name
+                )
+                self.assertTrue(expected_key in remote_events)
+                remote_event = remote_events[expected_key]
+                # Remote event should have a node ID corresponding to the worker
+                # it ran on.
+                self.assertEqual(remote_event.node_id, dst)
+
+            # Validate order remote events show up in profiling output.
+            def convert_remote_to_local(event_name):
+                remote_op_key = rpc_profiling_key + REMOTE_OP_STR
+                return event_name[event_name.find(remote_op_key) + len(remote_op_key) :]
+
+            remote_events_list = [
+                convert_remote_to_local(event.name)
+                for event in events
+                if convert_remote_to_local(event.name) in EXPECTED_REMOTE_EVENTS
+            ]
+            self.assertEqual(
+                set(remote_events_list),
+                set(EXPECTED_REMOTE_EVENTS),
+                f"Mismatch between profiled events: {set(remote_events_list)} and expected events: {set(EXPECTED_REMOTE_EVENTS)}",
+            )
+
+    @dist_init
+    def test_profiler_remote_events_profiled(self):
+        self._run_test_profiler_remote_events_profiled()
+
+    @dist_init
+    def test_profiler_remote_events_profiled_single_threaded(self):
+        self._run_test_profiler_remote_events_profiled()
+
+    def run_profiling_workload(self, dst):
+        fut = rpc.rpc_async(
+            worker_name(dst),
+            torch.mul,
+            args=(
+                torch.tensor(1.0, requires_grad=True),
+                torch.tensor(1.0, requires_grad=True),
+            ),
+        )
+        fut.wait()
+
+    def _run_rpc_profiling_async_function(self, device="cpu"):
+        if self.rank != 1:
+            return
+
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+        x = torch.ones(2)
+        y = torch.ones(2)
+        with _profile() as prof:
+            ret = rpc.rpc_async(
+                dst1, slow_async_add, args=(dst2, x, y, device), timeout=20
+            )
+            ret.wait()
+
+        function_events = prof.function_events
+        # slow_async_add resulted in an RPC from dst1 -> dst2, so this should be
+        # recorded.
+        key_prefix = _build_rpc_profiling_key(
+            RPCExecMode.ASYNC, slow_async_add.__qualname__, worker_name(self.rank), dst1
+        )
+
+        nested_rpc_key_prefix = _build_rpc_profiling_key(
+            RPCExecMode.ASYNC, slow_add.__qualname__, dst1, dst2
+        )
+        expected_key = key_prefix + REMOTE_OP_STR + nested_rpc_key_prefix
+        remote_events = [event for event in function_events if event.is_remote]
+        rpc_remote_event = [
+            event for event in remote_events if event.name == expected_key
+        ]
+        self.assertEqual(1, len(rpc_remote_event))
+        rpc_remote_event = rpc_remote_event[0]
+        self.assertEqual(rpc_remote_event.node_id, (self.rank + 1) % self.world_size)
+        # slow_async_add's RPC does an add on dst2, which should be reflected as well.
+        remote_add_key = (
+            expected_key + REMOTE_OP_STR + torch.jit._builtins._find_builtin(torch.add)
+        )
+        remote_add_event = [
+            event for event in remote_events if event.name == remote_add_key
+        ]
+        self.assertEqual(1, len(remote_add_event))
+        remote_add_event = remote_add_event[0]
+        # Validate that node_id is dst2.
+        self.assertEqual(remote_add_event.node_id, (self.rank + 2) % self.world_size)
+
+    @dist_init
+    def test_rpc_profiling_async_function(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        self._run_rpc_profiling_async_function()
+        if torch.cuda.is_available():
+            dist.barrier()
+            self._run_rpc_profiling_async_function(device="cuda:0")
+
+    @dist_init
+    def test_rpc_profiling_async_function_single_threaded(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        self._run_rpc_profiling_async_function()
+        if torch.cuda.is_available():
+            dist.barrier()
+            self._run_rpc_profiling_async_function(device="cuda:0")
+
+    @dist_init
+    def test_rpc_profiling_remote_record_function(self):
+        # test that functions run over RPC with record_function show the expected
+        # profiled block.
+        if self.rank != 1:
+            return
+        dst_ranks = [i for i in range(self.world_size) if i != self.rank]
+        for dst_rank in dst_ranks:
+            dst_worker = worker_name(dst_rank)
+            with _profile() as prof:
+                fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=(-1, True))
+                fut.wait()
+
+            function_events = prof.function_events
+            record_function_remote_event = [
+                evt for evt in function_events if "##forward##" in evt.name
+            ]
+            self.assertEqual(1, len(record_function_remote_event))
+            record_function_remote_event = record_function_remote_event[0]
+            self.assertEqual(record_function_remote_event.node_id, dst_rank)
+            # cpu_children only returns direct children, so here we get all
+            # children recursively.
+
+            def get_cpu_children(event):
+                if not event.cpu_children:
+                    return []
+                cpu_children = event.cpu_children
+                for e in event.cpu_children:
+                    cpu_children.extend(get_cpu_children(e))
+                return cpu_children
+
+            remote_children = get_cpu_children(record_function_remote_event)
+            # Get local children and verify parity.
+            with _profile() as prof:
+                udf_with_torch_ops(-1, True)
+
+            local_function_events = prof.function_events
+            local_record_function_event = next(
+                evt for evt in local_function_events if "##forward##" in evt.name
+            )
+            local_children = get_cpu_children(local_record_function_event)
+            local_children_names = [evt.name for evt in local_children]
+
+            REMOTE_OP_STR = "#remote_op: "
+
+            def convert_remote_to_local(event_name):
+                remote_op_key = REMOTE_OP_STR
+                return event_name[event_name.find(remote_op_key) + len(remote_op_key) :]
+
+            for evt in remote_children:
+                local_name = convert_remote_to_local(evt.name)
+                self.assertTrue(local_name in local_children_names)
+
+    def validate_profiling_workload(self, dst, prof):
+        def convert_remote_to_local(event_name):
+            return event_name[event_name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :]
+
+        events = prof.function_events
+        remote_events = {
+            convert_remote_to_local(event.name): event
+            for event in events
+            if event.is_remote
+        }
+        self.assertTrue("aten::mul" in remote_events)
+        remote_mul_event = remote_events["aten::mul"]
+        self.assertEqual(remote_mul_event.node_id, dst)
+        self.check_profiling_info(
+            worker_name(self.rank),
+            worker_name(dst),
+            torch.mul,
+            remote_mul_event,
+            RPCExecMode.ASYNC,
+        )
+
+    def _run_test_profiler_with_autograd_context(self):
+        dst = (self.rank + 1) % self.world_size
+        if self.rank == 1:
+            # Cases where we can double wrap messages with profiling information and autograd info.
+            with dist_autograd.context():
+                with _profile() as prof:
+                    self.run_profiling_workload(dst)
+
+            self.validate_profiling_workload(dst, prof)
+
+            # Ensure that flipped order of ctx managers results in events being
+            # recorded as expected.
+            with _profile() as prof:
+                with dist_autograd.context():
+                    self.run_profiling_workload(dst)
+
+            self.validate_profiling_workload(dst, prof)
+
+    @dist_init
+    def test_profiler_with_autograd_context_single_threaded(self):
+        self._run_test_profiler_with_autograd_context()
+
+    @dist_init
+    def test_profiler_with_autograd_context(self):
+        self._run_test_profiler_with_autograd_context()
+
+    def _profiler_test_with_rpc(
+        self,
+        rpc_exec_mode,
+        func,
+        args,
+        use_record_function=False,
+        dst=None,
+        kineto_profile=False,
+    ):
+        dst = dst if dst is not None else (self.rank + 1) % self.world_size
+
+        # only run profiler on rank 1.
+        p = _profile if not kineto_profile else torch.profiler.profile  # kineto
+        if self.rank == 1:
+            with p() as prof:
+                record_function_ctx_mgr = (
+                    contextlib.nullcontext()
+                    if not use_record_function
+                    else torch.autograd.profiler.record_function("foo")
+                )
+                with record_function_ctx_mgr:
+                    if rpc_exec_mode == RPCExecMode.SYNC:
+                        rpc.rpc_sync(worker_name(dst), func, args=args)
+                    elif rpc_exec_mode == RPCExecMode.ASYNC:
+                        fut = rpc.rpc_async(worker_name(dst), func, args=args)
+                        if kineto_profile:
+                            # Ensure multiple async RPCs don't cause issues.
+                            # Would have raised
+                            # "RuntimeError: Cannot call
+                            # RemoteProfilerManager::setCurrentKey when current
+                            # key is already set." error if RPC profiling was
+                            # not disabled properly for kineto.
+                            fut2 = rpc.rpc_async(worker_name(dst), func, args=args)
+                            fut2.wait()
+                        fut.wait()
+                    else:
+                        self.assertTrue(rpc_exec_mode == RPCExecMode.REMOTE)
+                        rref = rpc.remote(worker_name(dst), func, args=args)
+                        rref.to_here()
+                        # To avoid flakiness, wait for the RRef to be profiled. This
+                        # means that we received the acknowledgement of successful
+                        # creation on the owner and ran the callbacks responsible
+                        # for recording the profiling event.
+                        rref._get_profiling_future().wait()
+
+            events = prof.function_events if not kineto_profile else prof.events()
+            if kineto_profile:
+                # RPC profiling is disabled so there should be no rpc related
+                # events.
+                with self.assertRaises(IndexError):
+                    get_function_event(events, rpc_exec_mode.value)
+
+                return
+
+            rpc_event = get_function_event(events, rpc_exec_mode.value)
+            # verify Node ID for this rpc event.
+            self.assertEqual(rpc_event.node_id, self.rank)
+            # Ensure recording of remote events.
+            remote_events = {event for event in events if event.node_id == dst} - {
+                rpc_event
+            }
+            self.assertGreaterEqual(len(remote_events), 1)
+            for remote_event in remote_events:
+                self.assertEqual(remote_event.node_id, dst)
+
+            if use_record_function:
+                scope_event = get_function_event(events, "foo")
+                # Since RPC call is within the scope, its CPU interval should be
+                # contained within foo's interval.
+                self.assertLessEqual(
+                    scope_event.time_range.start, rpc_event.time_range.start
+                )
+                self.assertGreaterEqual(
+                    scope_event.time_range.end, rpc_event.time_range.end
+                )
+            # the sender, dest worker, function run, and type of RPC should all
+            # be recorded.
+            self_worker_name = worker_name(self.rank)
+            dst_worker_name = worker_name(dst)
+            self.check_profiling_info(
+                self_worker_name, dst_worker_name, func, rpc_event, rpc_exec_mode
+            )
+            if use_record_function:
+                # verify order by ensuring that the outer context comes
+                # before the rpc event.
+                foo_event_ix = next(
+                    i for i, event in enumerate(events) if "foo" in event.name
+                )
+                rpc_event_idx = next(
+                    i
+                    for i, event in enumerate(events)
+                    if rpc_exec_mode.value in event.name
+                )
+                self.assertLess(foo_event_ix, rpc_event_idx)
+
+    def _run_test_profiler_with_sync_rpc_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.SYNC, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, my_sleep_func, args=(1,), use_record_function=True
+        )
+
+    @dist_init
+    def test_profiler_with_sync_rpc_udf(self):
+        self._run_test_profiler_with_sync_rpc_udf()
+
+    @dist_init
+    def test_profiler_with_sync_rpc_udf_single_threaded(self):
+        self._run_test_profiler_with_sync_rpc_udf()
+
+    def _run_test_profiler_with_sync_rpc_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_sync_rpc_builtin(self):
+        self._run_test_profiler_with_sync_rpc_builtin()
+
+    @dist_init
+    def test_profiler_with_sync_rpc_builtin_single_threaded(self):
+        self._run_test_profiler_with_sync_rpc_builtin()
+
+    def _run_test_profiler_with_async_rpc_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.ASYNC, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_sleep_func, args=(1,), use_record_function=True
+        )
+        # Test to ensure that kineto profiler enabled in RPC does not enable
+        # RPC profiling (it is unsupported) and does not result in issues.
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_sleep_func, args=(1,), kineto_profile=True
+        )
+
+    @dist_init
+    def test_profiler_with_async_rpc_udf(self):
+        self._run_test_profiler_with_async_rpc_udf()
+
+    @dist_init
+    def test_profiler_with_async_rpc_udf_single_threaded(self):
+        self._run_test_profiler_with_async_rpc_udf()
+
+    def _run_test_profiler_with_async_rpc_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_async_rpc_builtin(self):
+        self._run_test_profiler_with_async_rpc_builtin()
+
+    @dist_init
+    def test_profiler_with_async_rpc_builtin_single_threaded(self):
+        self._run_test_profiler_with_async_rpc_builtin()
+
+    def _run_test_profiler_with_remote_udf(self):
+        self._profiler_test_with_rpc(RPCExecMode.REMOTE, my_sleep_func, args=(1,))
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_sleep_func, args=(1,), use_record_function=True
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_sleep_func, args=(1,), dst=self.rank
+        )
+
+    @dist_init
+    def test_profiler_with_remote_udf(self):
+        self._run_test_profiler_with_remote_udf()
+
+    @dist_init
+    def test_profiler_with_remote_udf_single_threaded(self):
+        self._run_test_profiler_with_remote_udf()
+
+    def _run_test_profiler_with_remote_builtin(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, torch.mul, args=(torch.ones(1), torch.ones(1))
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            use_record_function=True,
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            torch.mul,
+            args=(torch.ones(1), torch.ones(1)),
+            dst=self.rank,
+        )
+
+    @dist_init
+    def test_profiler_with_remote_builtin(self):
+        self._run_test_profiler_with_remote_builtin()
+
+    @dist_init
+    def test_profiler_with_remote_builtin_single_threaded(self):
+        self._run_test_profiler_with_remote_builtin()
+
+    def _run_test_profiler_with_script_async_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.ASYNC,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_script_async_rpc(self):
+        self._run_test_profiler_with_script_async_rpc()
+
+    @dist_init
+    def test_profiler_with_script_async_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_async_rpc()
+
+    def _run_test_profiler_with_script_sync_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.SYNC,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+
+    @dist_init
+    def test_profiler_with_script_sync_rpc(self):
+        self._run_test_profiler_with_script_sync_rpc()
+
+    @dist_init
+    def test_profiler_with_script_sync_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_sync_rpc()
+
+    def _run_test_profiler_with_script_remote_rpc(self):
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),)
+        )
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE,
+            my_script_func,
+            args=(torch.tensor(1),),
+            use_record_function=True,
+        )
+        # test remote to self
+        self._profiler_test_with_rpc(
+            RPCExecMode.REMOTE, my_script_func, args=(torch.tensor(1),), dst=self.rank
+        )
+
+    @dist_init
+    def test_profiler_with_script_remote_rpc(self):
+        self._run_test_profiler_with_script_remote_rpc()
+
+    @dist_init
+    def test_profiler_with_script_remote_rpc_single_threaded(self):
+        self._run_test_profiler_with_script_remote_rpc()
+
+    def _assert_top_level_events(
+        self, process_global_events, expected_top_level_event_names
+    ):
+        top_level_event_names = []
+        for thread_local_events in process_global_events:
+            # Get top-level events from all events happened on a thread.
+            last_end_time = 0
+            for event in thread_local_events:
+                event_name = event.name
+                time_range = event.time_range
+                if time_range.start > last_end_time:
+                    top_level_event_names.append(event_name)
+                    last_end_time = time_range.end
+        top_level_event_names = sorted(top_level_event_names)
+        expected_top_level_event_names = sorted(expected_top_level_event_names)
+        self.assertEqual(
+            top_level_event_names,
+            expected_top_level_event_names,
+            f"Expected events {expected_top_level_event_names}, but got {top_level_event_names}",
+        )
+
+    @dist_init
+    def test_server_process_global_profiler(self):
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker_name = worker_name(dst_rank)
+
+        x = torch.tensor(1)
+        y = torch.tensor(2)
+
+        outer_profile_rref = rpc.remote(
+            dst_worker_name, rpc._server_process_global_profile
+        )
+        outer_profile_rref.rpc_sync().__enter__()
+        rpc.rpc_sync(dst_worker_name, torch.add, (x, y))
+        inner_profile_rref = rpc.remote(
+            dst_worker_name, rpc._server_process_global_profile
+        )
+        inner_profile_rref.rpc_sync().__enter__()
+        rpc.rpc_sync(dst_worker_name, torch.sub, (x, y))
+        inner_profile_rref.rpc_sync().__exit__(None, None, None)
+        outer_profile_rref.rpc_sync().__exit__(None, None, None)
+
+        inner_events = rpc.rpc_sync(
+            dst_worker_name, get_events_from_profile, (inner_profile_rref,)
+        )
+        expected_inner_events = ["aten::sub"]
+        expected_outer_events = expected_inner_events + ["aten::add"]
+
+        self._assert_top_level_events(inner_events, expected_inner_events)
+        outer_events = rpc.rpc_sync(
+            dst_worker_name, get_events_from_profile, (outer_profile_rref,)
+        )
+        self._assert_top_level_events(outer_events, expected_outer_events)
+
+        inner_profile_rref.rpc_sync().key_averages()
+        outer_profile_rref.rpc_sync().key_averages()
+
+    @dist_init
+    def test_async_record_function_double_end_callbacks(self):
+        num_sleep_seconds = 1
+        if self.rank == 1:
+            # Validate that calling the function twice results in an error.
+            with _profile():
+                with torch.autograd.profiler.record_function("foo") as rf:
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
+                    )
+                    rf._call_end_callbacks_on_future(fut)
+                    with self.assertRaisesRegex(
+                        RuntimeError, "can only be called once."
+                    ):
+                        rf._call_end_callbacks_on_future(fut)
+                fut.wait()
+
+    @dist_init
+    def test_async_record_function_legacy(self):
+        # Test the legacy _record_function ops work
+        # Note: These exist for backward compatibility with TorchScript
+        num_sleep_seconds = 1
+        if self.rank == 1:
+            with _profile():
+                try:
+                    handle = torch.ops.profiler._record_function_enter("foo", None)
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_sleep_func, args=(num_sleep_seconds,)
+                    )
+                    torch.ops.profiler._call_end_callbacks_on_jit_fut(handle, fut)
+                finally:
+                    torch.ops.profiler._record_function_exit(handle)
+
+                fut.wait()
+
+    @dist_init
+    def test_async_record_function_cbs_jit_call(self):
+        if self.rank == 1:
+            with _profile() as pf:
+                key = _build_rpc_profiling_key(
+                    RPCExecMode.ASYNC,
+                    torch._jit_internal._qualified_name(my_script_func),
+                    "worker1",
+                    "worker0",
+                )
+                with torch.autograd.profiler.record_function(key) as rf:
+                    fut = rpc.rpc_async(
+                        worker_name(0), my_script_func, args=(torch.tensor(1),)
+                    )
+                    # Intentionally calling record_function internals
+                    fut = torch.ops.profiler._call_end_callbacks_on_jit_fut(
+                        rf.record, fut
+                    )
+                result = fut.wait()
+                # Validate that the profiling future returns the same value as the RPC
+                # future.
+                expected = torch.add(torch.tensor(1), torch.tensor(1))
+                self.assertEqual(result, expected)
+            events = pf.function_events
+            rpc_event = get_function_event(
+                events, torch._jit_internal._qualified_name(my_script_func)
+            )
+            self.assertTrue(
+                torch._jit_internal._qualified_name(my_script_func) in rpc_event.name
+            )
+
+    @dist_init
+    def test_py_class_constructor(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), MyClass, args=(n,))
+        self.assertEqual(ret.a, n)
+
+    @dist_init
+    def test_py_class_instance_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass(2).my_instance_method, args=(n,)
+        )
+        self.assertEqual(ret, MyClass(2).my_instance_method(n))
+
+    @dist_init
+    def test_py_class_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass.my_class_method, args=(n, n + 1)
+        )
+        self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
+
+    @dist_init
+    def test_py_class_static_method(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), MyClass.my_static_method, args=(n + 10,)
+        )
+        self.assertEqual(ret, MyClass.my_static_method(n + 10))
+
+    @dist_init
+    def test_py_multi_async_call(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        dst_worker_info = rpc.get_worker_info(worker_name(dst_rank))
+        fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
+        fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
+        self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
+        self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
+
+    @dist_init
+    def test_py_no_return_result(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(worker_name(dst_rank), no_result)
+        self.assertEqual(ret, no_result())
+
+    @dist_init
+    def test_py_tensors(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            my_tensor_function,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
+
+    @dist_init
+    def test_py_tensors_multi_async_call(self):
+        futs = []
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        for i in range(100):
+            fut = rpc.rpc_async(
+                worker_name(dst_rank),
+                my_tensor_function,
+                args=(torch.ones(i, i), torch.ones(i, i)),
+            )
+            futs.append(fut)
+
+        for j, val in enumerate(torch.futures.wait_all(futs)):
+            self.assertEqual(
+                val, my_tensor_function(torch.ones(j, j), torch.ones(j, j))
+            )
+
+    @dist_init
+    def test_py_tensors_in_container(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        a = [torch.ones(n, n), torch.ones(n, n)]
+        b = TensorClass(build_complex_tensors())
+        c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank), my_complex_tensor_function, args=(a, b, c)
+        )
+        self.assertEqual(ret, my_complex_tensor_function(a, b, c))
+
+    @dist_init
+    def test_py_nested_pickle(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        ret = rpc.rpc_sync(
+            worker_name(dst_rank),
+            run_nested_pickle,
+            args=(MyPickleClass(), torch.ones(2, 2)),
+        )
+
+        m = MyPickleClass()
+        m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
+        self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
+
+    @dist_init
+    def test_py_function_exception(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        with self.assertRaises(TypeError):
+            rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,))
+
+    @dist_init
+    def test_py_raise_in_user_func(self):
+        with captured_output() as (_, err):
+            # This barrier prevents a race condition where the main thread has
+            # not entered the context manager when the remote function runs.
+            initialize_pg(self.file_init_method, self.rank, self.world_size)
+            dist.barrier()
+            n = self.rank + 1
+            dst_rank = n % self.world_size
+            fut = rpc.rpc_async(worker_name(dst_rank), raise_func)
+            with self.assertRaisesRegex(ValueError, expected_err):
+                fut.wait()
+            # This barrier prevents a race condition where the main thread exits
+            # context manager before the remote function has ran.
+            dist.barrier()
+
+        # Validate that trainers log errors when running functions.
+        stderr_lines = err.getvalue()
+        self.assertTrue(expected_err in stderr_lines)
+
+    @dist_init
+    def test_py_raise_in_user_func_escaped_str(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape)
+        try:
+            fut.wait()
+        except ValueError as e:
+            msg = str(e)
+            # Ensure newlines are unescaped to provide a better repr of error.
+            self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape"))
+        else:
+            self.assertTrue(False, "expected raise_func_escape to raise ValueError.")
+
+    @dist_init
+    def test_nested_rpc(self):
+        self._nested_rpc(nested_rpc, torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_stress_light_rpc(self):
+        self._stress_test_rpc(light_rpc)
+
+    @dist_init
+    def test_stress_heavy_rpc(self):
+        self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
+
+    @dist_init
+    def test_stress_heavy_rpc_torchscript(self):
+        self._stress_test_rpc(
+            heavy_rpc_torchscript, repeat=20, args=(torch.ones(100, 100),)
+        )
+
+    @dist_init
+    def test_builtin_remote_ret(self):
+        self._builtin_remote_ret(
+            torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2) * 2
+        )
+
+    @dist_init
+    def test_builtin_remote_self(self):
+        self._builtin_remote_self(
+            torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2) * 2
+        )
+
+    @staticmethod
+    def _multi_args_fn(n, sparse=False):
+        if sparse:
+            return (build_sparse_tensor(), build_sparse_tensor())
+        else:
+            return (torch.ones(n, n), torch.ones(n, n))
+
+    @dist_init
+    def test_multi_builtin_remote_ret(self):
+        self._test_multi_remote_call(torch.add, False, args_fn=RpcTest._multi_args_fn)
+
+    @dist_init
+    def test_py_udf_remote(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            my_function,
+            kwargs={"a": n, "b": n + 1, "c": n + 2},
+        )
+        self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
+
+    @staticmethod
+    def _multi_kwargs_fn(n, sparse=False):
+        if sparse:
+            return {
+                "a": build_sparse_tensor(),
+                "b": build_sparse_tensor(),
+                "c": build_sparse_tensor(),
+            }
+        else:
+            return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
+
+    @dist_init
+    def test_multi_py_udf_remote(self):
+        self._test_multi_remote_call(
+            my_function, False, kwargs_fn=RpcTest._multi_kwargs_fn
+        )
+
+    @dist_init
+    def test_py_rref_args(self):
+        self._py_rref_args(
+            torch.ones(2, 2), 1, torch.ones(2, 2), 2, torch.ones(2, 2) * 2 + 3
+        )
+
+    @dist_init
+    def test_py_rref_args_user_share(self):
+        self._py_rref_args_user_share(
+            torch.ones(2, 2), 1, 2, torch.ones(2, 2), 3, 4, torch.ones(2, 2) * 2 + 10
+        )
+
+    @dist_init
+    def test_py_rpc_rref_args(self):
+        self._py_rpc_rref_args(
+            torch.ones(2, 2), 1, 2, torch.ones(2, 2), 3, 4, torch.ones(2, 2) * 2 + 10
+        )
+
+    @dist_init
+    def test_nested_remote(self):
+        self._nested_remote(nested_remote, torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_nested_rref(self):
+        self._nested_rref(nested_rref, torch.ones(2, 2) + 1, torch.ones(2, 2) + 2)
+
+    @dist_init
+    def test_nested_rref_stress(self):
+        self._nested_rref_stress(
+            nested_rref, torch.ones(2, 2) + 1, torch.ones(2, 2) + 2
+        )
+
+    @dist_init
+    def test_multi_layer_nested_async_rpc(self):
+        # This test will exit right away, but there will be a chain of async
+        # RPCs. The termination algorithm should detect those messages properly.
+        # Otherwise, some peer could exit early, leaving others to timeout
+        # errors or connection closed errors.
+        ttl = 20
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
+
+    @dist_init
+    def test_remote_with_exception(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        # check ref to other workers
+        rref = rpc.remote(worker_name(dst_rank), raise_func)
+        with self.assertRaises(ValueError):
+            rref.to_here()
+        # check ref to itself
+        rref = rpc.remote(worker_name(self.rank), no_result, args=(10,))
+        with self.assertRaises(TypeError):
+            rref.to_here()
+
+    @dist_init
+    def test_rpc_return_rref(self):
+        n = self.rank + 1
+        dst_rank1 = n % self.world_size
+        dst_rank2 = (n + 1) % self.world_size
+        rref = rpc.rpc_sync(
+            worker_name(dst_rank1),
+            rpc_return_rref,
+            args=(worker_name(dst_rank2),),
+        )
+        self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
+
+    @dist_init
+    def test_rref_forward_chain(self):
+        ttl = 8
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        rref = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1))
+
+        ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
+
+        for _ in range(ttl):
+            self.assertEqual(len(ret_rref), 1)
+            ret_rref = ret_rref[0].to_here()
+
+        ret = ret_rref
+        self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
+
+    @dist_init
+    def test_local_rref_no_fork(self):
+        local_rref = RRef(35)
+        self.assertEqual(local_rref.local_value(), 35)
+
+    @dist_init
+    def test_local_value_not_on_owner(self):
+        # ensure that an error message is thrown if a user tries to call
+        # local_value() on a non-owning node.
+        next_rank = (self.rank + 1) % self.world_size
+        rref = rpc.remote(
+            worker_name(next_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        with self.assertRaisesRegex(
+            RuntimeError,
+            (
+                rf"For UserRRef\(rref_id=GloballyUniqueId\(created_on={self.rank}, local_id=0\), "
+                rf"fork_id=GloballyUniqueId\(created_on={self.rank}, local_id=1\)\), "
+                r"can't call localValue\(\) on user "
+                rf"WorkerInfo\(id={self.rank}, name={worker_name(self.rank)}\). "
+                rf"Call it on owner WorkerInfo\(id={next_rank}, name={worker_name(next_rank)}\)"
+            ),
+        ):
+            rref.local_value()
+
+    @dist_init
+    def test_return_local_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+
+        rref_list = rpc.rpc_sync(
+            worker_name(dst_rank), get_rref_list, args=([1, 2, 3],)
+        )
+
+        for rref in rref_list:
+            rpc.rpc_sync(
+                rref.owner(),
+                _call_method_on_rref,
+                args=(MyClass.increment_value, rref, 10),
+            )
+
+        rets = [
+            rpc.rpc_sync(
+                rref.owner(), _call_method_on_rref, args=(MyClass.get_value, rref)
+            )
+            for rref in rref_list
+        ]
+
+        self.assertEqual(rets, [11, 12, 13])
+
+    @dist_init
+    def _test_rref_type(self, blocking):
+        def launched_rpc(events):
+            expected_name = f"rpc_{RPCExecMode.ASYNC.value}#_rref_typeof_on_owner"
+            return any(e.name.startswith(expected_name) for e in events)
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, torch.add, args=(torch.ones(2), 1))
+
+        with _profile() as p:
+            t = rref._get_type(blocking=blocking)
+            if not blocking:
+                t = t.wait()
+
+        self.assertTrue(launched_rpc(p.function_events))
+        expected_type = type(torch.ones(2))
+        self.assertEqual(t, expected_type)
+
+        futs = []
+
+        def verify(fut):
+            self.assertEqual(fut.value(), expected_type)
+
+        with _profile() as p:
+            for _ in range(10):
+                t = rref._get_type(blocking=blocking)
+                if not blocking:
+                    futs.append(t)
+                    t.add_done_callback(verify)
+                    t = t.wait()
+                self.assertEqual(t, expected_type)
+
+        if not blocking:
+            # Note that cached calls with blocking=False all return the same
+            # cached original future.
+            first_fut = futs[0]
+            for f in futs[1:]:
+                self.assertTrue(f is first_fut)
+        # Ensure we never launch another RPC, other than for the very
+        # first call.
+        self.assertFalse(launched_rpc(p.function_events))
+        self.assertEqual(t, type(torch.ones(2)))
+
+        rref = rpc.remote(dst, MyClass, args=(0,))
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, MyClass)
+
+    def test_rref_type_blocking(self):
+        self._test_rref_type(blocking=True)
+
+    def test_rref_type_non_blocking(self):
+        self._test_rref_type(blocking=False)
+
+    @dist_init
+    def _test_rref_type_with_error(self, blocking):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        # 10 ms timeout
+        rref = rpc.remote(dst, raise_func)
+        # Blocking: error raised inline
+        if blocking:
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                rref._get_type(blocking=blocking)
+        else:
+            # Non-blocking: Immediately return future, block on wait
+            fut = rref._get_type(blocking=blocking)
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut.wait()
+
+    def test_rref_type_with_error_blocking(self):
+        self._test_rref_type_with_error(blocking=True)
+
+    def test_rref_type_with_error_non_blocking(self):
+        self._test_rref_type_with_error(blocking=False)
+
+    @dist_init
+    def _test_rref_type_owner(self, blocking):
+        rref = RRef(torch.ones(2) + 1)
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, type(torch.ones(2)))
+
+        rref = RRef(MyClass(0))
+        rref_type = rref._get_type(blocking=blocking)
+        if not blocking:
+            rref_type = rref_type.wait()
+        self.assertEqual(rref_type, MyClass)
+
+    def test_rref_type_owner_blocking(self):
+        self._test_rref_type_owner(blocking=True)
+
+    def test_rref_type_owner_non_blocking(self):
+        self._test_rref_type_owner(blocking=False)
+
+    @staticmethod
+    def _slow_add(x, y):
+        time.sleep(1)
+        return x + y
+
+    @dist_init
+    def test_rref_type_slow_init(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rref = rpc.remote(dst, RpcTest._slow_add, args=(torch.ones(2), 1))
+        self.assertEqual(rref._get_type(), type(torch.ones(2)))
+
+    @dist_init
+    def test_owner_equality(self):
+        a = RRef(40)
+        b = RRef(50)
+
+        other_rank = (self.rank + 1) % self.world_size
+        other_a = rpc.remote(
+            worker_name(other_rank), torch.add, args=(torch.ones(1), 1)
+        )
+        other_b = rpc.remote(
+            worker_name(other_rank), torch.add, args=(torch.ones(1), 1)
+        )
+        other_a.to_here()  # to ensure clean termination
+        other_b.to_here()
+
+        self.assertNotEqual(a.owner(), 23)
+        self.assertEqual(other_a.owner(), other_b.owner())
+        self.assertNotEqual(a.owner(), other_a.owner())
+        self.assertEqual(other_a.owner(), other_a.owner())
+        self.assertEqual(other_a.owner(), other_b.owner())
+        self.assertEqual(a.owner(), a.owner())
+        self.assertEqual(a.owner(), b.owner())
+        self.assertEqual(a.owner(), rpc.get_worker_info())
+        x = {}
+        x[a.owner()] = a
+        x[other_a.owner()] = other_a
+        self.assertEqual(x[a.owner()], a)
+        self.assertEqual(x[b.owner()], a)
+        self.assertEqual(x[other_a.owner()], other_a)
+        self.assertEqual(x[other_b.owner()], other_a)
+        self.assertEqual(len(x), 2)
+
+    @dist_init
+    def test_pass_local_rrefs(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        dst_worker = worker_name(dst_rank)
+
+        rref = RRef(40)
+        self.assertEqual(
+            rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90
+        )
+        self.assertEqual(
+            rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90
+        )
+        self.assertEqual(
+            rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90
+        )
+
+    @dist_init
+    def test_remote_same_worker(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref_a = rpc.remote(
+            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 2)
+        )
+        rref_b = rpc.remote(
+            worker_name(dst_rank), torch.add, args=(torch.ones(n, n), 1)
+        )
+        rref_c = rpc.remote(
+            worker_name(dst_rank), my_rref_function, args=(rref_a, rref_b)
+        )
+        self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
+
+    @dist_init(setup_rpc=True)
+    def test_call_method_on_rref(self):
+        """
+        Tests that it is possible to call an instance method on a remote object
+        by using rref.owner() as destination of the call.
+        """
+        vals = [10, 2, 5, 7]
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst_rank)
+
+        # creates a remote object
+        rref = rpc.remote(dst_worker, MyClass, args=(vals[0],))
+
+        # modifies state of the remote object
+        rpc.rpc_sync(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[1]),
+        )
+        rpc.rpc_async(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[2]),
+        ).wait()
+        rpc.remote(
+            rref.owner(),
+            _call_method_on_rref,
+            args=(MyClass.increment_value, rref, vals[3]),
+        ).to_here()
+
+        # queries state of the remote object
+        result = rpc.rpc_sync(
+            dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref)
+        )
+
+        self.assertEqual(result, sum(vals))
+
+    # Notice `rpc.api.shutdown()` accesses
+    # `_delete_all_user_and_unforked_owner_rrefs` through
+    # `torch.distributed.rpc.api`, so patching
+    # `torch.distributed.rpc._delete_all_user_and_unforked_owner_rrefs` will
+    # not help.
+    @mock.patch.object(
+        torch.distributed.rpc.api, "_delete_all_user_and_unforked_owner_rrefs"
+    )
+    def _test_rref_leak(
+        self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore_leak
+    ):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        # Wait for all init to complete.
+        dist.barrier()
+
+        rref = rpc.remote(  # noqa: F841
+            worker_name((self.rank + 1) % self.world_size),
+            torch.add,
+            args=(torch.ones(2, 2), 1),
+        )
+
+        import torch.distributed.rpc.api as api
+
+        if ignore_leak:
+            api._ignore_rref_leak = True
+            rpc.shutdown(graceful=True)
+        else:
+            api._ignore_rref_leak = False
+            with self.assertRaisesRegex(RuntimeError, "Leaking RRef"):
+                rpc.shutdown(graceful=True)
+
+    @dist_init(setup_rpc=False)
+    def test_rref_leak(self):
+        self._test_rref_leak(ignore_leak=False)
+
+    @dist_init(setup_rpc=False)
+    def test_ignore_rref_leak(self):
+        self._test_rref_leak(ignore_leak=True)
+
+    @dist_init
+    def test_rref_str(self):
+        rref1 = RRef(self.rank)
+        id_class = "GloballyUniqueId"
+        self.assertEqual(
+            f"OwnerRRef({id_class}(created_on={self.rank}, local_id=0))",
+            rref1.__str__(),
+        )
+
+        dst_rank = (self.rank + 1) % self.world_size
+        rref2 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        self.assertEqual(
+            rref2.__str__(),
+            f"UserRRef(RRefId = {id_class}(created_on={self.rank}, local_id=1), "
+            f"ForkId = {id_class}(created_on={self.rank}, local_id=2))",
+        )
+
+    @dist_init
+    def test_rref_get_future(self):
+        # Tests that we can obtain the future corresponding to the creation of
+        # the RRef on remote end
+        if self.rank == 0:
+            # Builtin
+            rref = rpc.remote(worker_name(1), torch.add, args=(1, 1))
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+            # UDF
+            rref = rpc.remote(worker_name(1), foo_add, args=())
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+            # Script
+            rref = rpc.remote(worker_name(1), my_script_func, args=(torch.tensor(1),))
+            rref.to_here()
+            fut = rref._get_future()
+            self.assertIsInstance(fut, torch._C.Future)
+
+    @dist_init
+    def test_rref_context_debug_info(self):
+        # This test checks local states that are modified by remote workers.
+        # This means that we would need barrier before and after every check.
+        # The barrier before the check makes sure that all previous states are
+        # cleared globally, the barrier after ensures that no following states
+        # change gets into the current check.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        # Check 1: local RRef does not update owners_ map or add a pending user.
+        #################################################
+
+        rref1 = RRef(self.rank)
+
+        # don't need a barrier here as local RRef is handled by this thread
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertIn("num_pending_users", info)
+        # RRef on local value is not added to context until shared across RPC
+        self.assertEqual(0, int(info["num_owner_rrefs"]))
+        self.assertEqual(0, int(info["num_pending_users"]))
+        # barrier after the check 1
+        dist.barrier()
+
+        # Check 2: Sharing RRef as an arg should update owners_ map
+        ###########################################################
+
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc.rpc_sync(worker_name(dst_rank), set_global_rref, args=(rref1,))
+
+        # barrier before check 2
+        wait_until_pending_futures_and_users_flushed()
+        dist.barrier()
+
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertEqual(1, int(info["num_owner_rrefs"]))
+        # no pending users since the fork is finished
+        self.assertEqual(0, int(info["num_pending_users"]))
+        # barrier after check 2
+        dist.barrier()
+
+        # clear states for check 2
+        rpc.rpc_sync(worker_name(dst_rank), clear_global_rref)
+
+        # Wait for owner rref to be cleared.
+        while int(info["num_owner_rrefs"]) != 0:
+            info = _rref_context_get_debug_info()
+            time.sleep(0.1)
+        dist.barrier()
+
+        # Check 3: rpc.remote call should update owners_ map
+        ####################################################
+        rref2 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        rref3 = rpc.remote(worker_name(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+        rref2.to_here()
+        rref3.to_here()
+
+        # barrier before check 3
+        wait_until_pending_futures_and_users_flushed()
+        dist.barrier()
+
+        info = _rref_context_get_debug_info()
+        self.assertIn("num_owner_rrefs", info)
+        self.assertEqual(2, int(info["num_owner_rrefs"]))
+        # no pending users since the fork is finished
+        self.assertEqual(0, int(info["num_pending_users"]))
+
+        # barrier after check 3
+        dist.barrier()
+
+    @dist_init
+    def test_disable_gil_profiling(self):
+        # test that rpc.enable_gil_profiling(false) will result in
+        # GIL wait time not being recorded.
+
+        # GIL profiling should be disabled by default.
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc.rpc_sync(
+            worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertRaises(KeyError, lambda: info["agent.gil_average_wait_time_us"])
+        rpc.enable_gil_profiling(True)
+        rpc.rpc_sync(
+            worker_name(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+        )
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertIn("agent.gil_average_wait_time_us", info)
+
+    @dist_init(setup_rpc=False)
+    def test_local_shutdown(self):
+        # test that we can start RPC and then immediately locally shutdown
+        # without sending any messages.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        # pass in graceful=False to ensure that we don't wait for other workers.
+        rpc.shutdown(graceful=False)
+
+    @dist_init
+    def test_debug_info(self):
+        # only test keys in this test case. Values should be covered by
+        # individual module debug info tests
+        import torch.distributed.autograd as dist_autograd
+
+        info = _get_debug_info()
+        rref_info = _rref_context_get_debug_info()
+        agent_info = rpc.api._get_current_rpc_agent().get_debug_info()
+        autograd_info = dist_autograd._get_debug_info()
+        common_keys = rref_info.keys() & agent_info.keys() & autograd_info.keys()
+        self.assertEqual(0, len(common_keys))
+        expected = {}
+        expected.update(rref_info)
+        expected.update(agent_info)
+        expected.update(autograd_info)
+        # NB: Key ordering is only preserved in python 3.6+. So here, we
+        # manually check keys are equal.
+        for key in expected.keys():
+            self.assertIn(key, info.keys())
+
+        for key in info.keys():
+            self.assertIn(key, expected.keys())
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        IS_MACOS,
+        "Test is flaky on MacOS since libuv error handling is not as robust as TCP",
+    )
+    def test_handle_send_exceptions(self):
+        # test that if a callee node has gone down, we raise an appropriate
+        # exception instead of just crashing.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc._set_rpc_timeout(10)
+        # This barrier is needed to ensure that some workers do not exit before
+        # others have been brought up.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+        if self.rank == 1:
+            dst_rank = (self.rank + 1) % self.world_size
+            dst_worker = worker_name(dst_rank)
+            # allow destination worker to exit without joining
+            error_str = self.get_shutdown_error_regex()
+            wait_until_node_failure(dst_rank, error_str)
+            fut = rpc.rpc_async(dst_worker, torch.add, args=(torch.ones(1), 3))
+            # Shutdown sequence is not very well defined and as a result
+            # we can see any of the error messages defined in get_shutdown_error_regex.
+            with self.assertRaisesRegex(RuntimeError, error_str):
+                fut.wait()
+        # exit all workers non-gracefully.
+        rpc.shutdown(graceful=False)
+
+    @dist_init
+    def test_deadlock(self):
+        # this test is copied from https://github.com/pytorch/pytorch/issues/45089
+        if self.rank == 1:
+            dst1 = worker_name((self.rank + 1) % self.world_size)
+            x = torch.ones(2)
+            y = torch.ones(2)
+            rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait()
+
+        dist_initialized = dist.is_initialized()
+        if not dist_initialized:
+            dist.init_process_group(
+                backend="gloo",
+                init_method=self.file_init_method,
+                rank=self.rank,
+                world_size=self.world_size,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_local_shutdown_with_rpc(self):
+        # test that we can start RPC, send RPCs, and then run local shutdown.
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rpc.rpc_sync(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        # A barrier is needed to ensure that all RPCs are processed.
+        # Otherwise, some RPCs can timeout since the receiving end
+        # has terminated.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+        # pass in graceful=False to ensure that we don't wait for other workers.
+        rpc.shutdown(graceful=False)
+
+    @dist_init(setup_rpc=False)
+    def test_set_and_get_default_rpc_timeout(self):
+        timeout = 0.5
+
+        # A new `RpcBackendOptions` is constructed
+        # when accessing `self.rpc_backend_options`.
+        rpc_backend_options = self.rpc_backend_options
+        rpc_backend_options.rpc_timeout = timeout
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+        set_timeout = rpc.get_rpc_timeout()
+        self.assertEqual(timeout, set_timeout)
+        rpc.shutdown()
+
+    @dist_init
+    def test_default_timeout_used(self):
+        """
+        Tests that if no timeout is passed into rpc_async and rpc_sync, then the
+        default timeout is used.
+        """
+        dst_rank = (self.rank + 1) % self.world_size
+        rpc._set_rpc_timeout(0.001)  # 1 ms
+        # futures should time out and be marked with an exception indicating it as such.
+        futs = [
+            rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=())
+            for _ in range(10)
+        ]
+        expected_error = self.get_timeout_error_regex()
+        for fut in futs:
+            with self.assertRaisesRegex(RuntimeError, expected_error):
+                fut.wait()
+
+        # ensure that if a new timeout is set old futures don't time out but new ones do.
+        rpc._set_rpc_timeout(200)  # 200 seconds
+        # create a longstanding RPC.
+        fut1 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,))
+        # now, set a short timeout.
+        rpc._set_rpc_timeout(0.001)
+        # fut2 should time out, fut1 should not.
+        fut2 = rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=(1,))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut2.wait()
+        fut1.wait()
+
+        # Zero timeout means infinity, so future should run to completion.
+        rpc._set_rpc_timeout(0)
+        rpc.rpc_async(worker_name(dst_rank), my_sleep_func, args=()).wait()
+
+        # reset to default timeout so shutdown messages can process cleanly.
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    @dist_init
+    def test_rpc_timeouts(self):
+        # TODO: enable timeouts for rpc.remote/RRef (https://github.com/pytorch/pytorch/issues/33803)
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = worker_name(dst_rank)
+        timeout = 0.1  # 100 ms
+        expected_error = self.get_timeout_error_regex()
+        # Test async UDF
+        fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=timeout)
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+
+        # Ensure run to completion if there is no timeout and we use the default
+        # RPC timeout.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,)).wait()
+
+        # Test sync UDF
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=timeout)
+
+        # Ensure run to completion if there is no timeout and we use the default
+        # RPC timeout.
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,))
+
+        # If we set a default timeout for RPCs, it should be respected, though
+        # still overridden if we pass in a different timeout to the APIs.
+        rpc._set_rpc_timeout(0.001)
+        fut = rpc.rpc_async(dst_worker, my_sleep_func, args=(1,))
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            fut.wait()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,))
+
+        # The RPCs should run to completion since we override the timeout.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=5).wait()
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=5)
+        # Passing in a zero timeout should ensure that the RPC won't time out.
+        rpc.rpc_async(dst_worker, my_sleep_func, args=(1,), timeout=0).wait()
+        rpc.rpc_sync(dst_worker, my_sleep_func, args=(1,), timeout=0)
+        # Reset for clean shutdown
+        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
+
+    def test_dist_init_decorator(self):
+        @dist_init(setup_rpc=False)
+        def test_func(self):
+            return "expected result"
+
+        self.assertEqual(test_func(self), "expected result")
+
+        @dist_init
+        def test_func(self):
+            return "expected result"
+
+        self.assertEqual(test_func(self), "expected result")
+
+    def test_use_rpc_pickler(self):
+        class TestPickler:
+            pass
+
+        test_pickler = TestPickler()
+        with _use_rpc_pickler(test_pickler):
+            self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler)
+        self.assertTrue(
+            torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler
+        )
+
+    @dist_init
+    def test_wait_all(self):
+        with _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
+            self.assertTrue(len(_thread_local_var.future_list) == 1)
+            self.assertTrue(
+                isinstance(_thread_local_var.future_list[0], torch._C.Future)
+            )
+        self.assertTrue(fut.done())
+        self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_multiple_call(self):
+        with _wait_all():
+            self.assertTrue(_thread_local_var.future_list == [])
+            dst = worker_name((self.rank + 1) % self.world_size)
+            for i in range(20):
+                fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1))
+                res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1))
+                self.assertEqual(res, torch.ones(i, i) + 1)
+                self.assertEqual(fut.wait(), torch.ones(i, i) + 1)
+            self.assertTrue(len(_thread_local_var.future_list) == 20)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_timeout(self):
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            with _wait_all():
+                self.assertTrue(_thread_local_var.future_list == [])
+                dst = worker_name((self.rank + 1) % self.world_size)
+                timeout = 0.1  # 100 ms
+                rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_raise_in_user_func(self):
+        with self.assertRaises(ValueError):
+            with _wait_all():
+                self.assertTrue(_thread_local_var.future_list == [])
+                dst = worker_name((self.rank + 1) % self.world_size)
+                rpc.rpc_async(dst, raise_func)
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_wait_all_raise_in_body(self):
+        with self.assertRaises(ValueError):
+            with _wait_all():
+                raise_func()
+        self.assertFalse(hasattr(_thread_local_var, "future_list"))
+
+    @dist_init
+    def test_custom_exception_throw_during_reconstruction(self):
+        """
+        Test that we still throw info about the remote side exception even when
+        we cannot recreate it on client side.
+        """
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        if self.rank != 0:
+            exc_caught = False
+            dst = worker_name(0)
+            try:
+                rpc.rpc_sync(dst, custom_raise_func, args=())
+            except RuntimeError as e:
+                exc_caught = True
+                msg = str(e)
+                print(f"Got msg {msg}")
+                self.assertTrue("Original exception on remote side was" in msg)
+                self.assertTrue("CustomException" in msg)
+            except BaseException as e:
+                raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e
+            finally:
+                self.assertTrue(exc_caught)
+
+        dist.barrier()
+
+    timed_out_rpc_event = None
+
+    @staticmethod
+    def timed_out_rpc():
+        RpcTest.timed_out_rpc_event.wait()
+
+    @dist_init
+    def test_wait_all_exit_early_python(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, raise_func)
+        fut3 = rpc.rpc_async(dst, raise_func)
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(ValueError, expected_err):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_wait_all_exit_early_builtin(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5)))
+        fut3 = rpc.rpc_async(dst, torch.add, args=(torch.rand(10), torch.rand(5)))
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(RuntimeError, "size of tensor"):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_wait_all_exit_early_script_function(self):
+        # Initialize the event in the subprocess.
+        RpcTest.timed_out_rpc_event = Event()
+
+        # Wait for all processes to initialize event.
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+        dist.barrier()
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut1 = rpc.rpc_async(dst, RpcTest.timed_out_rpc)
+        fut2 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,))
+        fut3 = rpc.rpc_async(dst, raise_func_script, args=(expected_err,))
+
+        # We should receive the error from fut2
+        with self.assertRaisesRegex(RuntimeError, expected_err):
+            torch.futures.wait_all([fut1, fut2, fut3])
+
+        # Unblock RPC thread for fut1
+        RpcTest.timed_out_rpc_event.set()
+
+    @dist_init
+    def test_function_not_on_callee(self):
+        # test that if a function does not exist on a callee, we don't crash,
+        # instead we get an AttributeError indicating that the func does not exist.
+        this_module = sys.modules[__name__]
+        caller_worker = "worker0"
+        callee_worker = "worker1"
+
+        if self.rank == 1:
+            # Use delattr to remove the binding of a func on this nodes
+            delattr(this_module, "foo_add")
+            # notify remote end that we have removed it.
+            rpc.rpc_sync(caller_worker, set_value, args=(self.rank,))
+
+        if self.rank == 0:
+            # func exists on caller, but not callee.
+            # wait for remote end to remove the binding of foo_add func.
+            wait_for_value_future()
+            # Ensure that we have the attribute on this module. Otherwise, the test could fail due to a caller-side pickling error.
+            self.assertTrue(hasattr(this_module, "foo_add"))
+            with self.assertRaisesRegex(RuntimeError, "RPC pickler does not serialize"):
+                rpc.rpc_sync(callee_worker, foo_add, args=())
+
+    @dist_init
+    def test_non_garbage_collected_user_rref_due_to_local_circular_dependency(self):
+        dst_worker_name = worker_name((self.rank + 1) % self.world_size)
+
+        a = MyClass(1)
+        b = MyClass(2)
+
+        # This is to make Python not garbage collect a and b.
+        a.other = b
+        b.other = a
+
+        n = self.rank
+        a.rref = rpc.remote(dst_worker_name, torch.add, args=(torch.ones(n, n), 2))
+
+    @dist_init(setup_rpc=False)
+    def test_use_rref_after_shutdown(self):
+        rpc.init_rpc(
+            name=f"worker{self.rank:d}",
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        rref = rpc.remote(
+            worker_name(dst_rank),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+        # pass in graceful=True to ensure that local UserRRefs are deleted.
+        rpc.shutdown(graceful=True)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Cannot call to_here\\(\\) on it after deletion."
+        ):
+            rref.to_here()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Cannot call fork an UserRRef after deletion."
+        ):
+            import torch.distributed.rpc.internal as internal
+
+            internal.serialize(rref)
+
+    @staticmethod
+    def _return_gpu_tensor():
+        return torch.rand(3, 3).cuda(0)
+
+    @staticmethod
+    def _return_gpu_tensor_list():
+        return [torch.rand(3, 3).cuda(0), torch.rand(3, 3).cuda(1)]
+
+    @staticmethod
+    def _gpu_tensor_list_arg(tensor_list):
+        return torch.rand(3, 3)
+
+    def _create_rref(self):
+        owner_rank = (self.rank + 2) % self.world_size
+        return rpc.remote(
+            worker_name(owner_rank), torch.add, args=(torch.zeros(2, 2), 1)
+        )
+
+    @dist_init
+    def test_user_rrefs_confirmed(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret = rpc.rpc_sync(worker_name(dst_rank), check_rref_confirmed, args=(rref,))
+        self.assertEqual(ret, True)
+
+    @dist_init
+    def test_user_rrefs_confirmed_remote(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        rref = self._create_rref()
+        ret_rref = rpc.remote(worker_name(dst_rank), check_rref_confirmed, args=(rref,))
+        self.assertEqual(ret_rref.to_here(), True)
+
+    @dist_init
+    def test_rref_py_pickle_not_supported(self):
+        local_rref = RRef(35)
+        with TemporaryFileName() as fname:
+            with self.assertRaisesRegex(
+                RuntimeError, "Can not pickle rref in python pickler"
+            ):
+                torch.save(local_rref, fname)
+
+    @dist_init
+    def test_remote_throw(self):
+        rref = rpc.remote(
+            worker_name((self.rank + 1) % self.world_size),
+            raise_or_inc,
+            args=(torch.ones(2),),
+        )
+        with self.assertRaisesRegex(Exception, ".*Expected error.*"):
+            rref.to_here()
+
+    @dist_init
+    def test_non_cont_tensors(self):
+        if self.rank == 0:
+            # Create a non-contiguous tensor.
+            t = torch.rand(5, 5)
+            t_view = t.narrow(1, 2, 2)
+            self.assertFalse(t_view.is_contiguous())
+            t_cont = t_view.contiguous()
+            self.assertTrue(t_cont.is_contiguous())
+            self.assertEqual(t_view, t_cont)
+
+            # Send non-cont tensor over RPC.
+            next_rank = (self.rank + 1) % self.world_size
+            t_ret = rpc.rpc_sync(
+                worker_name(next_rank), non_cont_test, args=(t_view, t_cont)
+            )
+
+            # Verify the returned tensor.
+            self.assertEqual(t_view, t_ret)
+            self.assertFalse(t_ret.is_contiguous())
+
+    @dist_init
+    def test_callback_simple(self):
+        set_by_cb = concurrent.futures.Future()
+        n = self.rank + 1
+
+        def callback(fut):
+            ret = fut.wait()
+            self.assertEqual(ret, torch.ones(n, n) * 2)
+            set_by_cb.set_result(ret.clone() + 1)
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+        self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1)
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_callback_wrong_arg_num(self):
+        n = self.rank + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        cb_fut = fut.then(my_function)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "my\\_function\\(\\) missing 2 required positional arguments"
+        ):
+            cb_fut.wait()
+
+    @dist_init
+    def test_callback_wrong_arg_type(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1))
+        fut1 = fut0.then(lambda x: x + 1)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "unsupported operand type\\(s\\) for \\+"
+        ):
+            fut1.wait()
+
+    @dist_init
+    def test_callback_multi(self):
+        num_cbs = 10
+        n = self.rank + 1
+
+        def callback(idx, fut):
+            ret = fut.wait()
+            self.assertEqual(ret, torch.ones(n, n) * 2)
+            return ret + idx
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        cb_futs = [fut.then(partial(callback, idx)) for idx in range(num_cbs)]
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        for idx in range(num_cbs):
+            self.assertEqual(cb_futs[idx].wait(), torch.ones(n, n) * 2 + idx)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_callback_chain(self):
+        n = self.rank + 1
+
+        def callback(fut):
+            return fut.wait() + 1
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size), torch.add, args=(torch.ones(n, n), 1)
+        )
+
+        num_cbs = 20
+        for _ in range(num_cbs):
+            fut = fut.then(callback)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
+
+    @dist_init
+    def test_callback_in_rpc(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, add_use_future_cb, args=(dst2, torch.ones(2, 2), 1, 2))
+        self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
+
+    @dist_init
+    def test_callback_with_ret(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        def callback(fut0):
+            fut2 = rpc.rpc_async(dst, torch.add, args=(fut0.wait(), 1)).then(
+                lambda fut1: fut1.wait() + 1
+            )
+
+            return fut2.wait()
+
+        fut3 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)).then(callback)
+
+        self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_callback_with_error(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+
+        def callback(fut0):
+            with self.assertRaisesRegex(ValueError, "Expected error"):
+                fut0.wait()
+            raise RuntimeError("Another expected error")
+
+        fut1 = rpc.rpc_async(dst, raise_func).then(callback)
+        with self.assertRaisesRegex(RuntimeError, "Another expected error"):
+            fut1.wait()
+
+    @dist_init
+    def test_callback_none(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(TypeError, "incompatible function arguments."):
+            rpc.rpc_async(dst, raise_func).then(None)
+
+    @dist_init
+    def test_add_done_callback(self):
+        set_by_cb = False
+        n = self.rank + 1
+
+        def callback(fut):
+            nonlocal set_by_cb
+            fut.wait()
+            set_by_cb = True
+
+        fut = rpc.rpc_async(
+            worker_name(n % self.world_size),
+            torch.add,
+            args=(torch.ones(n, n), torch.ones(n, n)),
+        )
+
+        fut.add_done_callback(callback)
+        fut_then = fut.then(lambda _: True)
+
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+        # We have no guarantee that the add_done_callback fn will execute before the test finishes.
+        # Adding a 'then' callback that runs afterwards to guarantee we wait for the first callback
+        fut_then.wait()
+        self.assertTrue(set_by_cb)
+        self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
+
+    @dist_init
+    def test_mark_future_twice(self):
+        fut = rpc.rpc_async(
+            worker_name((self.rank + 1) % self.world_size),
+            torch.add,
+            args=(torch.zeros(2, 2), 1),
+        )
+        self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1)
+        with self.assertRaisesRegex(
+            RuntimeError, "Future can only be marked completed once"
+        ):
+            fut.set_result(1)
+
+    @dist_init
+    def test_pickle_future(self):
+        fut = torch.futures.Future()
+        errMsg = "Can not pickle torch.futures.Future"
+
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with TemporaryFileName():
+            with self.assertRaisesRegex(RuntimeError, errMsg):
+                rpc.rpc_sync(dst, fail_on_fut, args=(fut,))
+
+        with TemporaryFileName():
+            with self.assertRaisesRegex(RuntimeError, errMsg):
+                rpc.rpc_async(dst, fail_on_fut, args=(fut,))
+
+        with TemporaryFileName():
+            with self.assertRaisesRegex(RuntimeError, errMsg):
+                rpc.remote(dst, fail_on_fut, args=(fut,))
+
+    @dist_init
+    def test_future_done(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1))
+        fut.wait()
+        self.assertTrue(fut.done())
+
+    @dist_init
+    def test_future_done_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        fut = rpc.rpc_async(dst, raise_func)
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            fut.wait()
+        self.assertTrue(fut.done())
+
+    def _test_future_cb(self, func):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, func, args=(dst2, torch.ones(2, 2), 1, 2))
+        self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
+
+    @dist_init
+    def test_future_in_rpc(self):
+        self._test_future_cb(add_use_future_set_result)
+
+    @dist_init
+    def test_future_nested_callback(self):
+        self._test_future_cb(add_use_future_nested_cb)
+
+    def _test_async_function_raise(self, mode):
+        with self.assertRaisesRegex(RuntimeError, "Expected error"):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), async_raise_func, mode
+            )
+
+    @dist_init
+    def test_async_function_raise(self):
+        self._test_async_function_raise(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_async_function_raise_async(self):
+        self._test_async_function_raise(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_raise_remote(self):
+        self._test_async_function_raise(RPCExecMode.REMOTE)
+
+    def _test_async_function_wrong_return_type(self, mode):
+        errMsg = (
+            "Functions decorated with @rpc\\.async_function must return a "
+            "torch\\.futures\\.Future object,"
+        )
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), async_wrong_type, mode
+            )
+
+    @dist_init
+    def test_async_function_wrong_return_type(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_async_function_wrong_return_type_async(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_wrong_return_type_remote(self):
+        self._test_async_function_wrong_return_type(RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_simple(self):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        ret = rpc.rpc_sync(dst1, async_add, args=(dst2, torch.ones(2, 2), 1))
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+    def _test_async_function(self, fn, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        args = (dst2, torch.ones(2, 2), 1, 2)
+        ret = self._run_func_in_mode(dst1, fn, mode, args=args)
+        self.assertEqual(ret, torch.ones(2, 2) + 3)
+
+    @dist_init
+    def test_async_function_with_future_ctor(self):
+        self._test_async_function(async_add_with_future_ctor)
+
+    @dist_init
+    def test_async_function_with_future_ctor_remote(self):
+        self._test_async_function(async_add_with_future_ctor, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_chained(self):
+        self._test_async_function(async_add_chained)
+
+    @dist_init
+    def test_async_function_chained_remote(self):
+        self._test_async_function(async_add_chained, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_nested(self):
+        self._test_async_function(async_add_nested)
+
+    @dist_init
+    def test_async_function_nested_remote(self):
+        self._test_async_function(async_add_nested, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_static_method(self):
+        self._test_async_function(AsyncExecutionClass.static_async_add)
+
+    @dist_init
+    def test_async_static_method_remote(self):
+        self._test_async_function(
+            AsyncExecutionClass.static_async_add, RPCExecMode.REMOTE
+        )
+
+    @dist_init
+    def test_async_class_method(self):
+        self._test_async_function(AsyncExecutionClass.class_async_add)
+
+    @dist_init
+    def test_async_class_method_remote(self):
+        self._test_async_function(
+            AsyncExecutionClass.class_async_add, RPCExecMode.REMOTE
+        )
+
+    def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+        rref = rpc.remote(dst1, AsyncExecutionClass)
+
+        x = torch.ones(2, 2)
+        y = torch.ones(2, 2) + 1
+        if mode == RPCExecMode.SYNC:
+            ret = rref.rpc_sync().static_async_add(dst2, x, x, y)
+            ret += rref.rpc_sync().class_async_add(dst2, x, x, y)
+            ret += rref.rpc_sync().bound_async_add(dst2, x, x, y)
+        elif mode == RPCExecMode.ASYNC:
+            ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait()
+            ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait()
+            ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait()
+        elif mode == RPCExecMode.REMOTE:
+            ret = rref.remote().static_async_add(dst2, x, x, y).to_here()
+            ret += rref.remote().class_async_add(dst2, x, x, y).to_here()
+            ret += rref.remote().bound_async_add(dst2, x, x, y).to_here()
+
+        self.assertEqual(ret, 3 * 4 * x)
+
+    @dist_init
+    def test_async_class_rref_proxy(self):
+        self._test_test_async_class_rref_proxy()
+
+    @dist_init
+    def test_async_class_rref_proxy_async(self):
+        self._test_test_async_class_rref_proxy(mode=RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_class_rref_proxy_remote(self):
+        self._test_test_async_class_rref_proxy(mode=RPCExecMode.REMOTE)
+
+    def _test_async_function_multi(self, fn, mode=RPCExecMode.SYNC):
+        dst1 = worker_name((self.rank + 1) % self.world_size)
+        dst2 = worker_name((self.rank + 2) % self.world_size)
+
+        num = 20
+        step = 3
+        args = (dst2, torch.ones(2, 2), num, step)
+        ret = self._run_func_in_mode(dst1, fn, mode, args=args)
+        self.assertEqual(ret, torch.ones(2, 2) + num * step)
+
+    @dist_init
+    def test_async_function_multi_chained(self):
+        self._test_async_function_multi(async_add_chained_multi)
+
+    @dist_init
+    def test_async_function_multi_chained_async(self):
+        self._test_async_function_multi(async_add_chained_multi, RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_multi_chained_remote(self):
+        self._test_async_function_multi(async_add_chained_multi, RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_async_function_multi_fanout(self):
+        self._test_async_function_multi(async_add_multi_fanout)
+
+    @dist_init
+    def test_async_function_multi_fanout_async(self):
+        self._test_async_function_multi(async_add_multi_fanout, RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_async_function_multi_fanout_remote(self):
+        self._test_async_function_multi(async_add_multi_fanout, RPCExecMode.REMOTE)
+
+    def _test_return_future(self, mode):
+        with self.assertRaisesRegex(
+            RuntimeError, "Can not pickle torch.futures.Future"
+        ):
+            self._run_func_in_mode(
+                worker_name((self.rank + 1) % self.world_size), return_future, mode
+            )
+
+    @dist_init
+    def test_return_future(self):
+        self._test_return_future(RPCExecMode.SYNC)
+
+    @dist_init
+    def test_return_future_async(self):
+        self._test_return_future(RPCExecMode.ASYNC)
+
+    @dist_init
+    def test_return_future_remote(self):
+        self._test_return_future(RPCExecMode.REMOTE)
+
+    @dist_init
+    def test_rref_timeout(self):
+        # This test is similar to ones in FaultyProcessGroupTest, but is meant to be
+        # run with other backends besides ProcessGroup.
+        if self.rank != 0:
+            return
+
+        dst_rank = (self.rank + 1) % self.world_size
+        dst_worker = f"worker{dst_rank}"
+        # 10 ms timeout
+        rref = rpc.remote(dst_worker, my_sleep_func, args=(2,), timeout=0.01)
+        # Future corresponding to the remote creation should time out.
+        expected_error = self.get_timeout_error_regex()
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            rref._get_future().wait()
+        # Call to ensure pending callbacks are run.
+        wait_until_pending_futures_and_users_flushed()
+        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
+            rref.to_here()
+
+        wait_until_owners_and_forks_on_rank(1, 1, rank=1)
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "init_pg_then_rpc does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614.",
+    )
+    def test_init_pg_then_rpc(self):
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        # Test RPC.
+        next_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test PG
+        dist.barrier()
+
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "init_rpc_then_pg does not work with TCP init, see https://github.com/pytorch/pytorch/issues/41614.",
+    )
+    def test_init_rpc_then_pg(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        # Test RPC.
+        next_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(
+            worker_name(next_rank), torch.add, args=(torch.ones(2, 2), 1)
+        )
+        self.assertEqual(ret, torch.ones(2, 2) + 1)
+
+        # Test PG
+        dist.barrier()
+
+        rpc.shutdown()
+
+    @dist_init
+    def test_wait_all_with_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [rpc.rpc_async(dst, raise_func) for _ in range(10)]
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+    @dist_init
+    def test_wait_all_with_partial_exception(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        futs = [
+            rpc.rpc_async(dst, torch.add, args=(torch.ones(2), 1)) for _ in range(10)
+        ]
+
+        futs.append(rpc.rpc_async(dst, raise_func))
+
+        with self.assertRaisesRegex(ValueError, "Expected error"):
+            torch.futures.wait_all(futs)
+
+    @dist_init(setup_rpc=False)
+    @skip_but_pass_in_sandcastle_if(
+        os.environ.get("RPC_INIT_WITH_TCP", None) == "1",
+        "Test does not work with TCP init, see https://github.com/pytorch/pytorch/issues/46491",
+    )
+    def test_init_rpc_twice(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc.shutdown()
+
+        # Wait for all init to complete.
+        dist.barrier()
+
+        # Use a different file name for the next initialization
+        new_backend_options = self.rpc_backend_options
+        new_backend_options.init_method += "init_2"
+
+        # Ensure rpc initialization works again.
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=new_backend_options,
+        )
+
+        # Verify RPCs work after re-init.
+        dst = worker_name((self.rank + 1) % self.world_size)
+        rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
+        rpc.rpc_sync(dst, foo_add, args=())
+
+        rpc.shutdown()
+
+    def test_wrong_types(self):
+        with self.assertRaisesRegex(
+            TypeError,
+            "Argument backend must be a member of BackendType",
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend="TENSORPIPE",
+            )
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "Argument rpc_backend_options must be an instance of RpcBackendOptions",
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend=self.rpc_backend,
+                rpc_backend_options={"init_method": self.init_method},
+            )
+
+    def test_cannot_infer_backend_from_options(self):
+        # An exception should be raised if the backend isn't specified but
+        # options are given which are not an instance of any of the known
+        # agents' option classes.
+        rpc_backend_options = FooBackendOptions(self.init_method)
+
+        with self.assertRaisesRegex(TypeError, "Could not infer backend for options"):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                # Do _not_ pass backend.
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    @dist_init
+    def test_owner_rref_backward(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        t1 = torch.rand(10, 10, requires_grad=True)
+        rref = rpc.RRef(t1.sum() + t1.sum())
+        rref.backward()
+        expected_grad = torch.ones_like(t1) * 2
+        self.assertEqual(expected_grad, t1.grad)
+
+        with dist_autograd.context() as context_id:
+            t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1))
+            rref = rpc.RRef(t2.sum())
+            rref.backward(context_id)
+            self.assertEqual(expected_grad, dist_autograd.get_gradients(context_id)[t1])
+
+        # Double backward.
+        with dist_autograd.context() as context_id:
+            t2 = rpc.rpc_sync(dst, torch.add, args=(t1, t1))
+            rref = rpc.RRef(t2.sum())
+            rref.backward(context_id, retain_graph=True)
+            rref.backward(context_id)
+            self.assertEqual(
+                expected_grad * 2, dist_autograd.get_gradients(context_id)[t1]
+            )
+
+        # Test errors.
+        with self.assertRaisesRegex(
+            RuntimeError, "tensors does not require grad and does not have a grad_fn"
+        ):
+            rpc.RRef(torch.rand(10)).backward()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "grad can be implicitly created only for scalar outputs"
+        ):
+            rpc.RRef(torch.rand(10, requires_grad=True)).backward()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Could not find autograd context with id: 100"
+        ):
+            rpc.RRef(torch.rand(10, requires_grad=True).sum()).backward(100)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "RRef should contain a tensor for .backward()"
+        ):
+            rpc.RRef("foo").backward()
+
+    @staticmethod
+    def _sum(x):
+        return x.sum()
+
+    @staticmethod
+    def _identity(x):
+        return x
+
+    @dist_init
+    def test_user_rref_backward(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        t = torch.rand(10, requires_grad=True)
+        with dist_autograd.context() as context_id:
+            rref = rpc.remote(dst, RpcTest._sum, args=(t,))
+            rref.backward(context_id, retain_graph=True)
+            rref.backward(context_id)
+            self.assertEqual(
+                torch.ones_like(t) * 2, dist_autograd.get_gradients(context_id)[t]
+            )
+
+        with dist_autograd.context() as context_id:
+            rref = rpc.remote(dst, RpcTest._identity, args=("foo",))
+            with self.assertRaisesRegex(
+                RuntimeError, "RRef should contain a tensor for .backward()"
+            ):
+                rref.backward(context_id)
+
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "User RRefs require 'dist_autograd_ctx_id' to be specified",
+            ):
+                rref.backward()
+
+    @dist_init(setup_rpc=False)
+    def test_shutdown_errors(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        if self.rank != 0:
+            og_func = rpc.api._broadcast_to_followers
+            og_rref_func = rpc.api._delete_all_user_and_unforked_owner_rrefs
+
+            # Monkey-patch _broadcast_to_followers to fail, which would ensure
+            # _all_gather on leader raises an exception.
+            def raise_error(sequence_id, objects_map):
+                og_func(sequence_id, objects_map)
+                raise RuntimeError("simulation")
+
+            # Monkey-patch _delete_all_user_and_unforked_owner_rrefs to fail,
+            # which would ensure barrier is not called on followers.
+            def rref_error():
+                raise RuntimeError("simulation rref")
+
+            try:
+                rpc.api._broadcast_to_followers = raise_error
+                rpc.api._delete_all_user_and_unforked_owner_rrefs = rref_error
+                with self.assertRaisesRegex(RuntimeError, "simulation rref"):
+                    rpc.shutdown()
+            finally:
+                rpc.api._broadcast_to_followers = og_func
+                rpc.api._delete_all_user_and_unforked_owner_rrefs = og_rref_func
+        else:
+            with self.assertRaisesRegex(RuntimeError, "timed out in _all_gather"):
+                rpc.shutdown()
+
+        dist.barrier()
+
+    @dist_init
+    def test_my_parameter_server(self):
+        self._my_parameter_server(False)
+
+
+class CudaRpcTest(RpcAgentTestFixture):
+    @skip_if_lt_x_gpu(2)
+    @dist_init
+    def test_profiler_remote_cuda(self):
+        if self.rank != 1:
+            return
+
+        dst_cuda_0 = (self.rank + 1) % self.world_size
+        dst_cuda_1 = (self.rank + 2) % self.world_size
+        dst_worker_cuda_0 = worker_name(dst_cuda_0)
+        dst_worker_cuda_1 = worker_name(dst_cuda_1)
+
+        with _profile(use_cuda=True) as p:
+            fut1 = rpc.rpc_async(dst_worker_cuda_0, udf_with_torch_ops, args=(0,))
+            fut2 = rpc.rpc_async(dst_worker_cuda_1, udf_with_torch_ops, args=(1,))
+            fut1.wait()
+            fut2.wait()
+
+        def get_name(event):
+            return event.name[event.name.find(REMOTE_OP_STR) + len(REMOTE_OP_STR) :]
+
+        function_events = p.function_events
+        for event in function_events:
+            if event.is_async:
+                self.assertEqual(0, event.device_time_total)
+                self.assertEqual([], event.kernels)
+                self.assertEqual(0, event.device_time)
+            else:
+                if event.node_id == 1:
+                    continue
+                self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1])
+                if get_name(event) in EXPECTED_REMOTE_EVENTS:
+                    self.assertGreater(event.device_time_total, 0)
+                    self.assertEqual(1, len(event.kernels))
+                    kernel = event.kernels[0]
+                    if event.node_id == dst_cuda_0:
+                        self.assertEqual(kernel.device, 0)
+                    if event.node_id == dst_cuda_1:
+                        self.assertEqual(kernel.device, 1)
+                    self.assertGreater(event.device_time, 0)
+
+        # Validate that EXPECTED_REMOTE_EVENTS is a subset of remotely profiled
+        # events.
+        remote_events = [event for event in function_events if event.is_remote]
+        remote_event_names = [
+            get_name(event)
+            for event in remote_events
+            if get_name(event) in EXPECTED_REMOTE_EVENTS
+        ]
+        self.assertEqual(set(remote_event_names), set(EXPECTED_REMOTE_EVENTS))
+
+
+class TensorPipeAgentRpcTest(RpcAgentTestFixture, RpcTestCommon):
+    def test_mismatched_type_for_options(self):
+        # An exception should be raised if the options are not an instance of
+        # TensorPipeRpcBackendOptions.
+        rpc_backend_options = FooBackendOptions(self.init_method)
+
+        with self.assertRaisesRegex(
+            TypeError, "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`"
+        ):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                rank=self.rank,
+                world_size=self.world_size,
+                backend=rpc.BackendType.TENSORPIPE,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    def test_infer_backend_from_options(self):
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.init_method, _transports=tp_transports()
+        )
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            rank=self.rank,
+            world_size=self.world_size,
+            # Do _not_ pass backend.
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        self.assertIsInstance(rpc.api._get_current_rpc_agent(), rpc.TensorPipeAgent)
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_set_and_get_num_worker_threads(self):
+        NUM_THREADS = 27
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.rpc_backend_options.init_method,
+            num_worker_threads=NUM_THREADS,
+            _transports=tp_transports(),
+        )
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        info = rpc.api._get_current_rpc_agent().get_debug_info()
+        self.assertEqual(int(info["agent.thread_pool_size"]), NUM_THREADS)
+        rpc.shutdown()
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_tensorpipe_set_default_timeout(self):
+        # Set a high timeout since it doesn't affect test runtime and ensures
+        # the test doesn't erroneously timeout due to slow machines.
+        timeout = 100
+        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+            init_method=self.rpc_backend_options.init_method,
+            num_worker_threads=self.rpc_backend_options.num_worker_threads,
+            rpc_timeout=timeout,
+            _transports=tp_transports(),
+        )
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc_backend_options,
+        )
+
+        default_timeout = rpc.get_rpc_timeout()
+        self.assertEqual(default_timeout, timeout)
+        rpc.shutdown()
+
+    # FIXME Merge this test with the corresponding one in RpcTest.
+    @dist_init(setup_rpc=False)
+    def test_tensorpipe_options_throw_on_timedelta_timeout(self):
+        from datetime import timedelta
+
+        timeout = timedelta()
+        # Ensure that constructing TensorPipeRpcBackendOptions with timedelta fails
+        with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"):
+            rpc.TensorPipeRpcBackendOptions(
+                init_method=self.rpc_backend_options.init_method,
+                num_worker_threads=self.rpc_backend_options.num_worker_threads,
+                rpc_timeout=timeout,
+            )
+
+    @dist_init
+    def _test_rref_get_type_timeout(self, blocking):
+        # Test where we try to get the type of a RRef from an owner, but RRef
+        # creation is slower than timeout passed into _get_type.
+        dst_rank = (self.rank + 1) % self.world_size
+        dst = worker_name(dst_rank)
+        slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
+        timeout = 0.5
+        expected_err = self.get_timeout_error_regex()
+        # Blocking: blocks on inline call
+        if blocking:
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                slow_rref._get_type(timeout=timeout, blocking=blocking)
+        # Non-blocking: blocks on wait
+        else:
+            fut = slow_rref._get_type(timeout=timeout, blocking=blocking)
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                fut.wait()
+
+        # FIXME We wait until the remote completed creating the OwnerRRef
+        # because there's currently a race if we shut down RPC before that.
+        slow_rref.to_here()
+
+    def test_rref_get_type_timeout_blocking(self):
+        self._test_rref_get_type_timeout(blocking=True)
+
+    def test_rref_get_type_timeout_non_blocking(self):
+        self._test_rref_get_type_timeout(blocking=False)
+
+    @dist_init
+    def test_op_with_invalid_args(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Overloaded torch operator invoked from Python failed to match any schema",
+        ):
+            rpc.rpc_sync(dst, torch.add, args=())
+
+    def _test_rref_proxy_timeout(self, rref_proxy_api):
+        dst_rank = (self.rank + 1) % self.world_size
+        dst = worker_name(dst_rank)
+        rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2),))
+        # Ensure RRef is created on remote node.
+        rref.to_here()
+        rref_api = getattr(rref, rref_proxy_api)
+        self.assertTrue(
+            rref_api is not None, f"Failed to get RRef proxy api: {rref_proxy_api}"
+        )
+        expected_error = self.get_timeout_error_regex()
+        timeout = 2
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            result = rref_api(timeout=timeout).my_slow_method(torch.ones(2, 2))
+            if rref_api == rref.rpc_async:
+                result.wait()
+            elif rref_api == rref.remote:
+                result._get_future().wait()
+
+        # Case where rpc.remote() is stuck and exceeds timeout
+        slow_rref = rpc.remote(dst, MyClass, args=(torch.ones(2, 2), True))
+        timeout = 0.01
+        rref_api = getattr(slow_rref, rref_proxy_api)
+        # Note that even when we call rref.rpc_async() in this case, we
+        # time out in future creation, not waiting for future. This is because
+        # rref proxy function calls rref._get_type before returning future,
+        # which blocks on the RRef being created on owner node, until the
+        # specified timeout.
+        with self.assertRaisesRegex(RuntimeError, expected_error):
+            result = rref_api(timeout=timeout).my_instance_method(torch.ones(2, 2))
+            # rpc_async returns immediately and surface a timeout through wait()
+            if rref_api == slow_rref.rpc_async:
+                result.wait()
+
+        # FIXME We wait until the remote completed creating the OwnerRRef
+        # because there's currently a race if we shut down RPC before that.
+        slow_rref.to_here()
+
+    @dist_init
+    def test_rref_proxy_timeout(self):
+        for rpc_api in ["rpc_sync", "rpc_async", "remote"]:
+            self._test_rref_proxy_timeout(rpc_api)
+
+    @dist_init
+    def test_send_to_rank_sparse(self):
+        dst_rank = (self.rank + 1) % self.world_size
+
+        # Test sparse tensor
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            x = build_sparse_tensor()
+            y = build_sparse_tensor()
+            expected_tensor = x + y
+            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+            self.assertEqual(expected_tensor, ret)
+
+        for exec_mode in [RPCExecMode.SYNC, RPCExecMode.ASYNC, RPCExecMode.REMOTE]:
+            x = build_sparse_tensor(coalesce=True)
+            y = build_sparse_tensor(coalesce=True)
+            expected_tensor = x + y
+            ret = self._run_func_in_mode(dst_rank, torch.add, exec_mode, args=(x, y))
+            self.assertEqual(expected_tensor, ret)
+
+    @dist_init
+    def test_self_py_udf_remote_sparse(self):
+        self._self_py_udf_remote(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_rpc_arg_sparse(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_rpc_arg(
+            dst, build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_self_rpc_arg_sparse(self):
+        self._self_remote_rref_as_rpc_arg(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_remote_arg_sparse(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._self_remote_rref_as_remote_arg(
+            dst, build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_self_remote_rref_as_self_remote_arg_sparse(self):
+        self._self_remote_rref_as_remote_arg(
+            rpc.get_worker_info(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+        )
+
+    def test_world_size_one_sparse(self):
+        self._world_size_one(build_sparse_tensor(), build_sparse_tensor())
+
+    @dist_init
+    def test_multi_rpc_sparse(self):
+        self._multi_rpc(True)
+
+    def test_wait_all_workers_sparse(self):
+        self._wait_all_workers(heavy_rpc_sparse, build_sparse_tensor())
+
+    def test_wait_all_workers_twice_sparse(self):
+        self._wait_all_workers_twice(heavy_rpc_sparse, build_sparse_tensor())
+
+    @dist_init
+    def test_py_sparse_tensors_in_container(self):
+        n = self.rank + 1
+        dst_rank = n % self.world_size
+        a = [build_sparse_tensor(), build_sparse_tensor()]
+        ret = rpc.rpc_sync(worker_name(dst_rank), my_container_sum, args=(a,))
+        self.assertEqual(ret, my_container_sum(a))
+
+    @dist_init
+    def test_nested_rpc_sparse(self):
+        self._nested_rpc(nested_rpc_sparse, build_sparse_tensor() * 2)
+
+    @dist_init
+    def test_stress_heavy_rpc_sparse(self):
+        self._stress_test_rpc(
+            heavy_rpc_sparse, repeat=20, args=(build_sparse_tensor(),)
+        )
+
+    @dist_init
+    def test_builtin_remote_ret_sparse(self):
+        self._builtin_remote_ret(
+            build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_builtin_remote_self_sparse(self):
+        self._builtin_remote_self(
+            build_sparse_tensor(), build_sparse_tensor(), build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_multi_builtin_remote_ret_sparse(self):
+        self._test_multi_remote_call(torch.add, True, args_fn=RpcTest._multi_args_fn)
+
+    @dist_init
+    def test_multi_py_udf_remote_sparse(self):
+        self._test_multi_remote_call(
+            my_function, True, kwargs_fn=RpcTest._multi_kwargs_fn
+        )
+
+    @dist_init
+    def test_py_rref_args_sparse(self):
+        self._py_rref_args(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 4,
+        )
+
+    @dist_init
+    def test_py_rref_args_user_share_sparse(self):
+        self._py_rref_args_user_share(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 6,
+        )
+
+    @dist_init
+    def test_py_rpc_rref_args_sparse(self):
+        self._py_rpc_rref_args(
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor(),
+            build_sparse_tensor() * 6,
+        )
+
+    @dist_init
+    def test_nested_remote_sparse(self):
+        self._nested_remote(
+            nested_remote_sparse, build_sparse_tensor() + build_sparse_tensor()
+        )
+
+    @dist_init
+    def test_nested_rref_sparse(self):
+        self._nested_rref(
+            nested_rref_sparse, build_sparse_tensor() * 2, build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_nested_rref_stress_sparse(self):
+        self._nested_rref_stress(
+            nested_rref_sparse, build_sparse_tensor() * 2, build_sparse_tensor() * 2
+        )
+
+    @dist_init
+    def test_my_parameter_server_sparse(self):
+        self._my_parameter_server(True)
+
+    # Test init_rpc without world_size argument
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_init_rpc(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+        rpc.shutdown()
+
+    # Dynamic RPC new ranks communicate with existing ranks
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_new_rank_can_communicated_with_existing_rank(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+            result = rpc.rpc_sync(
+                worker_name(0), torch.add, args=(torch.tensor(1), torch.tensor(1))
+            )
+            self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    # Dynamic RPC existing ranks can communicate with new ranks
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        # Rest of ranks join after barrier
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        dist.barrier()
+        if self.rank == 0:
+            for i in range(1, self.world_size):
+                result = rpc.rpc_sync(
+                    worker_name(i), torch.add, args=(torch.tensor(1), torch.tensor(1))
+                )
+                self.assertEqual(torch.add(torch.tensor(1), torch.tensor(1)), result)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    # Dynamic RPC existing ranks can communicate with new ranks using CUDA rpc
+    @skip_if_lt_x_gpu(2)
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_existing_rank_can_communicate_with_new_rank_cuda(self):
+        initialize_pg(self.file_init_method, self.rank, self.world_size)
+
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            for i in range(1, self.world_size):
+                dst = worker_name(i)
+                options.set_device_map(dst, {1: 0})
+                options.set_device_map(dst, {0: 1})
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=options,
+            )
+
+        # Rank 0 will be initialized with RPC after this barrier
+        dist.barrier()
+
+        # Rest of ranks join after barrier
+        if self.rank != 0:
+            # Newly joined ranks will be able to communicate with rank 0, since that was created first
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # TODO: Cuda RPC is failing due to:
+        # terminate called after throwing an instance of 'c10::Error'
+        # what():  0 <= device && static_cast(device) < device_allocator.size()
+        # INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":1937,
+        # please report a bug to PyTorch. Allocator not initialized for device 1: did you call init?
+        # dist.barrier()
+        # if self.rank == 0:
+        #     for i in range(1, self.world_size):
+        #         x = torch.ones(2)
+        #         result_on_device_0 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(0), 1))
+        #         result_on_device_1 = rpc.rpc_sync(worker_name(i), torch.add, args=(x.to(1), 1))
+        #         self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_0)
+        #         self.assertEqual(torch.device('cuda:0'), result_on_device_0.device)
+        #         self.assertEqual(torch.add(torch.ones(2), 1), result_on_device_1)
+        #         self.assertEqual(torch.device('cuda:1'), result_on_device_1.device)
+
+        # Barrier to ensure that all rpc_sync calls are finished
+        dist.barrier()
+        rpc.shutdown()
+
+    @dist_init(setup_rpc=False)
+    def test_dynamic_rpc_init_rpc_without_rank(self):
+        # default initialization uses file init
+        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        # env init
+        with self.assertRaisesRegex(ValueError, "environment variable RANK expected"):
+            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(init_method="env://")
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+        # tcp init
+        with self.assertRaisesRegex(ValueError, "rank parameter missing"):
+            rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
+                init_method="tcp://127.0.0.1:23456"
+            )
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rpc_backend_options=rpc_backend_options,
+            )
+
+    @dist_init(setup_rpc=False)
+    def test_dynamic_and_static_init_rpc_together(self):
+        # Initialize a static rpc group with size = self.world_size - 1
+        dist.init_process_group(
+            backend="gloo",
+            init_method=self.file_init_method,
+            rank=self.rank,
+            world_size=self.world_size,
+        )
+
+        world_size_minus_one = self.world_size - 1
+        if self.rank < world_size_minus_one:
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=world_size_minus_one,
+                rpc_backend_options=self.rpc_backend_options,
+            )
+
+        dist.barrier()
+
+        # Attempt to add an additional dynamic group member
+        if self.rank == world_size_minus_one:
+            # Expect error message to be thrown
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "RPC group mixes statically and dynamically\
+ initialized members which is not supported.",
+            ):
+                rpc.init_rpc(
+                    name=worker_name(self.rank),
+                    backend=self.rpc_backend,
+                    rank=self.rank,
+                    rpc_backend_options=self.rpc_backend_options,
+                )
+
+
+class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture, RpcTestCommon):
+    def _test_device_maps(self, options, errMsg):
+        with self.assertRaisesRegex(ValueError, errMsg):
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+        self.assertFalse(rpc.api._is_current_rpc_agent_set())
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_wrong_worker_name(self):
+        options = self.rpc_backend_options
+        options.set_device_map("none_exist", {0: 1})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has invalid target node names in its device maps",
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_max_local_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {torch.cuda.device_count(): 0})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has source devices with invalid indices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_max_remote_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {0: torch.cuda.device_count()})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has target devices with invalid indices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_many_to_one(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {1: 0})
+        options.set_device_map(dst, {0: 0})
+
+        self._test_device_maps(
+            options,
+            errMsg="Node worker0 has duplicated target devices in its device map for worker1",
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_one_to_many(self):
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            dst = worker_name((self.rank + 1) % self.world_size)
+            options.set_device_map(dst, {0: 1})
+            with self.assertRaisesRegex(
+                ValueError, "`set_device_map` only supports 1-to-1 mapping"
+            ):
+                options.set_device_map(dst, {0: 0})
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_invalid_min_device(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        with self.assertRaisesRegex(RuntimeError, "Device index must not be negative"):
+            options.set_device_map(dst, {-1: 0})
+
+        with self.assertRaisesRegex(RuntimeError, "Device index must not be negative"):
+            options.set_device_map(dst, {0: -1})
+
+    @staticmethod
+    def _gpu_add(x, y):
+        if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 1]):
+            return (x + y).to(0)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_gpu(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {0: 1, 1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        ret = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add,
+            args=(torch.zeros(2).to(0), torch.ones(2).to(0)),
+        )
+        self.assertEqual(ret.device, torch.device(1))
+        self.assertEqual(ret, (torch.zeros(2) + torch.ones(2)).to(1))
+        rpc.shutdown()
+
+    @staticmethod
+    def _gpu_add_given_devices(x, y, x_to, y_to, z_to):
+        x_device = "cpu" if x.device.type == "cpu" else x.device.index
+        y_device = "cpu" if y.device.type == "cpu" else y.device.index
+        if x_device == x_to and y_device == y_to:
+            return x.to(z_to) + y.to(z_to)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    def _test_device_maps_gpu(
+        self, x_from, y_from, z_to, device_map, dst=None, fn=None
+    ):
+        fn = TensorPipeAgentCudaRpcTest._gpu_add_given_devices if fn is None else fn
+        x_to = device_map[x_from]
+        y_to = device_map[y_from]
+
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size) if dst is None else dst
+        options.set_device_map(dst, device_map)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(x_from)
+        y = torch.ones(2).to(y_from)
+
+        ret = rpc.rpc_sync(dst, fn, args=(x, y, x_to, y_to, z_to))
+
+        reverse_device_map = {device_map[k]: k for k in device_map}
+        z_from = reverse_device_map[z_to]
+
+        ret_device = "cpu" if ret.device.type == "cpu" else ret.device.index
+        self.assertEqual(ret_device, z_from)
+        self.assertEqual(ret, torch.ones(2).to(z_from))
+
+        rpc.shutdown()
+
+    def test_device_map_cpu(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to="cpu",
+            device_map={"cpu": "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_map_cpu_to_gpu_default(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to=0,
+            device_map={"cpu": 0},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_cpu_to_gpu_non_default(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to=1,
+            device_map={"cpu": 1},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_map_gpu_to_cpu_default(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=0,
+            z_to="cpu",
+            device_map={0: "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_to_cpu_non_default(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=1,
+            z_to="cpu",
+            device_map={1: "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_default(self):
+        self._test_device_maps_gpu(x_from=0, y_from=0, z_to=0, device_map={0: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_non_default(self):
+        self._test_device_maps_gpu(x_from=1, y_from=1, z_to=1, device_map={1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_default_to_non_default(self):
+        self._test_device_maps_gpu(x_from=0, y_from=0, z_to=1, device_map={0: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_non_default_to_default(self):
+        self._test_device_maps_gpu(x_from=1, y_from=1, z_to=0, device_map={1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_1(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=0, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_2(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=1, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_3(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=0, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_4(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=1, device_map={0: 0, 1: 1})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_5(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=0, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_6(self):
+        self._test_device_maps_gpu(x_from=0, y_from=1, z_to=1, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_7(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=0, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_8(self):
+        self._test_device_maps_gpu(x_from=1, y_from=0, z_to=1, device_map={0: 1, 1: 0})
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_1(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=0,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_2(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=1,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_3(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=0,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_4(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=1,
+            device_map={0: 0, 1: 1},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_5(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=0,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_6(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=1,
+            z_to=1,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_7(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=0,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_mixed_self_8(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=0,
+            z_to=1,
+            device_map={0: 1, 1: 0},
+            dst=worker_name(self.rank),
+        )
+
+    @staticmethod
+    def _gpu_add_multi_gpu(x, y):
+        if all([x.is_cuda, x.device.index == 1, y.is_cuda, y.device.index == 0]):
+            return x.to(0) + y, x - y.to(1)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    def _test_device_maps_multi_gpu(self, dst):
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 1})
+        options.set_device_map(dst, {1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(0)
+        y = torch.ones(2).to(1)
+        rets = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu, args=(x, y)
+        )
+
+        self.assertEqual(rets[0].device, torch.device(1))
+        self.assertEqual(rets[1].device, torch.device(0))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_multi_gpu(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._test_device_maps_multi_gpu(dst)
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_multi_gpu_self(self):
+        dst = worker_name(self.rank)
+        self._test_device_maps_multi_gpu(dst)
+
+    @staticmethod
+    def _gpu_add_return_to_gpu(x, y):
+        if x.device.type == "cpu" and y.device.type == "cpu":
+            return (x + y).to(0), (x - y).to(1), (x * y).to(2), (x / y).to(3)
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_in_options(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
+                init_method=options.init_method,
+                num_worker_threads=options.num_worker_threads,
+                device_maps={dst: {0: 1, 1: 0}},
+                _transports=tp_transports(),
+            ),
+        )
+
+        rets = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add_multi_gpu,
+            args=(torch.zeros(2).to(0), torch.ones(2).to(1)),
+        )
+        self.assertEqual(rets[0].device, torch.device(1))
+        self.assertEqual(rets[1].device, torch.device(0))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(1))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        rpc.shutdown()
+
+    def _test_device_maps_return_to_gpu(self, dst):
+        options = self.rpc_backend_options
+
+        options.set_device_map(dst, {0: 1})
+        options.set_device_map(dst, {1: 2})
+        options.set_device_map(dst, {2: 3})
+        options.set_device_map(dst, {3: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rets = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._gpu_add_return_to_gpu,
+            args=(torch.zeros(2), torch.ones(2)),
+        )
+        for i in range(len(rets)):
+            self.assertEqual(rets[i].device, torch.device((3 + i) % 4))
+        self.assertEqual(rets[0], (torch.zeros(2) + torch.ones(2)).to(3))
+        self.assertEqual(rets[1], (torch.zeros(2) - torch.ones(2)).to(0))
+        self.assertEqual(rets[2], (torch.zeros(2) * torch.ones(2)).to(1))
+        self.assertEqual(rets[3], (torch.zeros(2) / torch.ones(2)).to(2))
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_return_to_gpu(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        self._test_device_maps_return_to_gpu(dst)
+
+    @skip_if_lt_x_gpu(4)
+    def test_device_maps_return_to_gpu_self(self):
+        dst = worker_name(self.rank)
+        self._test_device_maps_return_to_gpu(dst)
+
+    @staticmethod
+    def _add_to_gpu(x, y):
+        return (x + y).to(0)
+
+    def _test_device_maps_missing_config(self, mode):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        errMsg = (
+            "TensorPipe RPC backend only supports CPU tensors by default.*"
+            "`set_device_map` on `TensorPipeRpcBackendOptions`"
+        )
+
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            if mode == RPCExecMode.SYNC:
+                rpc.rpc_sync(dst, torch.add, args=(torch.zeros(2).to(0), 1))
+            elif mode == RPCExecMode.REMOTE:
+                rpc.remote(dst, torch.add, args=(torch.zeros(2).to(0), 1)).to_here()
+            else:
+                raise ValueError(f"unexpected mode {mode}")
+
+        # make sure RPC is still functioning
+        ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
+        self.assertEqual(ret, torch.ones(2) + 1)
+
+    def _test_device_maps_missing_config_response(self, mode):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        errMsg = "Response device mapping is not available"
+
+        with self.assertRaisesRegex(RuntimeError, errMsg):
+            if mode == RPCExecMode.SYNC:
+                rpc.rpc_sync(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._add_to_gpu,
+                    args=(torch.zeros(2), 1),
+                )
+            elif mode == RPCExecMode.REMOTE:
+                rpc.remote(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._add_to_gpu,
+                    args=(torch.zeros(2), 1),
+                ).to_here()
+            else:
+                raise ValueError(f"unexpected mode {mode}")
+
+        # make sure RPC is still functioning
+        ret = rpc.rpc_sync(dst, torch.add, args=(torch.ones(2), 1))
+        self.assertEqual(ret, torch.ones(2) + 1)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config(self):
+        self._test_device_maps_missing_config(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_maps_missing_config_not_timeout(self):
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=self.rpc_backend_options,
+        )
+
+        timeout = rpc.get_rpc_timeout()
+
+        tik = time.time()
+        self._test_device_maps_missing_config(RPCExecMode.SYNC)
+        rpc.shutdown()
+        tok = time.time()
+
+        self.assertTrue(tok - tik < timeout)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_loop(self):
+        for _ in range(self.rpc_backend_options.num_worker_threads + 5):
+            self._test_device_maps_missing_config(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_response(self):
+        self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_response_loop(self):
+        for _ in range(self.rpc_backend_options.num_worker_threads + 5):
+            self._test_device_maps_missing_config_response(RPCExecMode.SYNC)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_remote(self):
+        self._test_device_maps_missing_config(RPCExecMode.REMOTE)
+
+    @skip_if_lt_x_gpu(1)
+    @dist_init
+    def test_device_maps_missing_config_remote_response(self):
+        self._test_device_maps_missing_config_response(RPCExecMode.REMOTE)
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_maps_remote(self):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, {1: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rref = rpc.remote(
+            dst, TensorPipeAgentCudaRpcTest._add_to_gpu, args=(torch.zeros(2), 1)
+        )
+
+        self.assertEqual(rref.to_here().device.index, 1)
+        self.assertEqual(rref.to_here(), torch.ones(2).to(1))
+
+        rpc.shutdown()
+
+    @staticmethod
+    def _slow_add_on_user_stream(x, y):
+        s0 = torch.cuda.current_stream(x.device)
+        s1 = torch.cuda.Stream(device=x.device)
+        s1.wait_stream(s0)
+        x.record_stream(s1)
+        y.record_stream(s1)
+        with torch.cuda.stream(s1):
+            torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+            z = x + y
+        s0.wait_stream(s1)
+        z.record_stream(s0)
+        return z
+
+    def _test_custom_stream(self, fn, device_map):
+        options = self.rpc_backend_options
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options.set_device_map(dst, device_map)
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        fn(dst)
+
+        rpc.shutdown()
+
+    def _test_stream_sync(self, dst):
+        x = torch.ones(2, 2).to(0)
+        ret = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, args=(x, x)
+        )
+        self.assertEqual(ret, 2 * x)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream(self):
+        self._test_custom_stream(self._test_stream_sync, {"cuda:0": "cuda:1"})
+
+    def _test_stream_multi_async(self, dst):
+        futs = []
+        for i in range(20):
+            x = torch.ones(2, 2).to(0) * i
+            futs.append(
+                rpc.rpc_async(
+                    dst,
+                    TensorPipeAgentCudaRpcTest._slow_add_on_user_stream,
+                    args=(x, x),
+                )
+            )
+
+        for i in range(20):
+            self.assertEqual(futs[i].wait(), 2 * torch.ones(2, 2).to(0) * i)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_multi(self):
+        self._test_custom_stream(self._test_stream_multi_async, {"cuda:0": "cuda:1"})
+
+    @staticmethod
+    def _nested_slow_add_on_user_stream(dst, x, y, z):
+        ret = rpc.rpc_sync(
+            dst, TensorPipeAgentCudaRpcTest._slow_add_on_user_stream, args=(x, y)
+        )
+
+        return TensorPipeAgentCudaRpcTest._slow_add_on_user_stream(ret, z)
+
+    def _test_stream_nested_sync(self, dst):
+        x = torch.ones(2, 2).to(0)
+        y = torch.ones(2, 2).to(0) * 2
+        z = torch.ones(2, 2).to(0) * 3
+        nested_dst = worker_name((self.rank + 2) % self.world_size)
+        ret = rpc.rpc_sync(
+            dst,
+            TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream,
+            args=(nested_dst, x, y, z),
+        )
+        self.assertEqual(ret, 6 * x)
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_nested(self):
+        self._test_custom_stream(
+            self._test_stream_nested_sync, {"cuda:0": "cuda:1", "cuda:1": "cuda:0"}
+        )
+
+    def _test_stream_nested_multi_async(self, dst):
+        if self.rank == 0:
+            futs = []
+            n = 5
+            xs, ys, zs = [], [], []
+            for i in range(n):
+                x = torch.ones(2, 2).to(0) * (i - 1)
+                y = torch.ones(2, 2).to(0) * i
+                z = torch.ones(2, 2).to(0) * (i + 1)
+                xs.append(x)
+                ys.append(y)
+                zs.append(z)
+                nested_dst = worker_name((self.rank + 2) % self.world_size)
+                futs.append(
+                    rpc.rpc_async(
+                        dst,
+                        TensorPipeAgentCudaRpcTest._nested_slow_add_on_user_stream,
+                        args=(nested_dst, x, y, z),
+                    )
+                )
+
+            for i in range(n):
+                self.assertEqual(futs[i].wait(), xs[i] + ys[i] + zs[i])
+
+    @skip_if_lt_x_gpu(2)
+    def test_custom_stream_nested_multi(self):
+        self._test_custom_stream(
+            self._test_stream_nested_multi_async,
+            {"cuda:0": "cuda:1", "cuda:1": "cuda:0"},
+        )
+
+    @staticmethod
+    def _gpu_add_wrong_gpus(x, y):
+        if x.is_cuda and y.is_cuda:
+            return x.cpu() + y.cuda()
+        else:
+            raise ValueError("Wrong device affinity")
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_mismatch(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        x = torch.zeros(2).to(0)
+        y = torch.ones(2).to(0)
+
+        with self.assertRaisesRegex(
+            RuntimeError,
+            "Expected all tensors to be on the same device, but found at least two devices",
+        ):
+            rpc.rpc_sync(
+                dst, TensorPipeAgentCudaRpcTest._gpu_add_wrong_gpus, args=(x, y)
+            )
+
+        rpc.shutdown()
+
+    def _test_rref_synchronization(self, local_device, remote_device):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {local_device: remote_device})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 1:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                x = torch.randn(200, 1, 28, 28).to(local_device)
+                actual = rref.remote().forward(x).to_here()
+                expected = rref.rpc_sync().forward(x)
+                self.assertEqual(actual, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_to_here_synchronization1(self):
+        self._test_rref_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization2(self):
+        self._test_rref_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization3(self):
+        self._test_rref_synchronization("cuda:1", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_to_here_synchronization4(self):
+        self._test_rref_synchronization("cuda:0", "cuda:1")
+
+    def _test_rref_as_arg_synchronization(
+        self, local_device, remote_device, devicesOptions=None
+    ):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {local_device: remote_device})
+
+        input_src = worker_name((self.rank - 1 + self.world_size) % self.world_size)
+        options.set_device_map(input_src, {remote_device: local_device})
+
+        if devicesOptions is not None:
+            options.set_devices(devicesOptions[self.rank])
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 1:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                rref_x = RRef(torch.randn(200, 1, 28, 28).to(local_device))
+                actual = rref.remote().forward(rref_x, True).to_here()
+                expected = rref.rpc_sync().forward(rref_x, True)
+                self.assertEqual(actual, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_as_arg_synchronization1(self):
+        self._test_rref_as_arg_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization2(self):
+        self._test_rref_as_arg_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization3(self):
+        self._test_rref_as_arg_synchronization("cuda:1", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_as_arg_synchronization4(self):
+        self._test_rref_as_arg_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_as_arg_synchronization5(self):
+        self._test_rref_as_arg_synchronization(
+            "cuda:0",
+            "cuda:0",
+            [["cuda:0"] for _ in range(4)],  # devicesOptions
+        )
+
+    @staticmethod
+    def _rref_relay(rref):
+        return rref.to_here()
+
+    def _test_rref_forward_synchronization(self, local_device, remote_device):
+        options = self.rpc_backend_options
+
+        input_src = worker_name(0)
+        model_dst = worker_name(1)
+        out_relay = worker_name(2)
+
+        if self.rank == 0:
+            # for 1) model construction 2) forward execution
+            options.set_device_map(model_dst, {local_device: remote_device})
+
+            # Forward output will be first copied to the relay node before
+            # returning to the worker. This is intentional, to test RRef
+            # forward CUDA stream synchronizations.
+            options.set_device_map(out_relay, {local_device: local_device})
+        elif self.rank == 1:
+            # worker1 hosts the model and runs forward. The forward functions
+            # calls RRef.to_here(), hence needs to configure the device map
+            options.set_device_map(input_src, {remote_device: local_device})
+        elif self.rank == 2:
+            # worker2 will get the out RRef and call to_here() and hence, needs
+            # to configure device map.
+            options.set_device_map(model_dst, {local_device: remote_device})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        if self.rank == 0:
+            # This test compares rref.rpc_sync().forward(x) vs rref.remote().forward(x).to_here()
+            # If to_here() is properly synchronized with forward(x) the results must be identical
+            # This test needs multiple iterations and significant batch size to simulate real
+            # training of a CNN of MNIST-like data.
+            # see https://github.com/pytorch/pytorch/issues/54771
+            rref = rpc.remote(model_dst, MyConvNetForMNIST, args=(remote_device,))
+            for _ in range(10):
+                rref_input = RRef(torch.randn(200, 1, 28, 28).to(local_device))
+                rref_out = rref.remote().forward(rref_input, True)
+                out = rpc.remote(
+                    out_relay, TensorPipeAgentCudaRpcTest._rref_relay, args=(rref_out,)
+                ).to_here()
+                expected = rref.rpc_sync().forward(rref_input, True)
+                self.assertEqual(out, expected)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_forward_synchronization1(self):
+        self._test_rref_forward_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization2(self):
+        self._test_rref_forward_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization3(self):
+        self._test_rref_forward_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_rref_forward_synchronization4(self):
+        self._test_rref_forward_synchronization("cuda:1", "cuda:1")
+
+    def _test_owner_rref_forward_synchronization(self, local_device, remote_device):
+        if self.rank == 0:
+            options = self.rpc_backend_options
+            options.set_device_map("w0", {local_device: remote_device})
+            rpc.init_rpc("w0", rank=0, world_size=1, rpc_backend_options=options)
+
+            model = (
+                rpc.remote("w0", torch.nn.Linear, (2048, 20000))
+                .remote()
+                .to(remote_device)
+            )
+            for _ in range(30):
+                data = torch.rand(2048, 2048).to(local_device)
+                output = model.rpc_sync().forward(data)
+                # to_here() internally calls localValue as the caller is
+                # the owner of the RRef.
+                v0 = rpc.RRef(output).remote().sum().to_here().item()
+                v1 = output.sum().item()
+                self.assertEqual(v0, v1)
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_owner_rref_forward_synchronization1(self):
+        self._test_owner_rref_forward_synchronization("cuda:0", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization2(self):
+        self._test_owner_rref_forward_synchronization("cuda:0", "cuda:1")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization3(self):
+        self._test_owner_rref_forward_synchronization("cuda:1", "cuda:0")
+
+    @skip_if_lt_x_gpu(2)
+    def test_owner_rref_forward_synchronization4(self):
+        self._test_owner_rref_forward_synchronization("cuda:1", "cuda:1")
+
+    @staticmethod
+    def _return_tensor_view(i):
+        x = torch.ones(1000, 200).cuda(0) * i
+        torch.cuda._sleep(10 * FIFTY_MIL_CYCLES)
+        # serialization of the return value will create a new tensor from the
+        # view, which is done outside of the user function.
+        return x.split(100)[0]
+
+    @skip_if_lt_x_gpu(1)
+    def test_tensor_view_as_return_value(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {0: 0})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        futs = [
+            rpc.rpc_async(
+                dst, TensorPipeAgentCudaRpcTest._return_tensor_view, args=(i,)
+            )
+            for i in range(5)
+        ]
+
+        for i in range(5):
+            self.assertEqual(torch.ones(100, 200) * i, futs[i].wait())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_devices_option_mismatch(self):
+        with self.assertRaisesRegex(
+            ValueError,
+            "Node worker0 has unexpected source devices in its device map for worker1",
+        ):
+            dst = worker_name((self.rank + 1) % self.world_size)
+            options = self.rpc_backend_options
+            options.set_device_map(dst, {0: 0})
+            options.set_devices([1])
+
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(2)
+    def test_devices_option_mismatch_reverse(self):
+        with self.assertRaisesRegex(
+            ValueError,
+            "Node worker0 has unexpected target devices in its device map for worker1",
+        ):
+            dst = worker_name((self.rank + 1) % self.world_size)
+
+            options = rpc.TensorPipeRpcBackendOptions(
+                init_method=self.rpc_backend_options.init_method,
+                num_worker_threads=self.rpc_backend_options.num_worker_threads,
+                device_maps={dst: {0: 1}},
+                devices=[0],
+            )
+
+            rpc.init_rpc(
+                name=worker_name(self.rank),
+                backend=self.rpc_backend,
+                rank=self.rank,
+                world_size=self.world_size,
+                rpc_backend_options=options,
+            )
+
+            rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_int(self):
+        Future(devices=[0])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_str(self):
+        Future(devices=["cuda:0"])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_as_device(self):
+        Future(devices=[torch.device("cuda", 0)])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_device_not_cuda(self):
+        with self.assertRaisesRegex(
+            ValueError, "Expected devices to have indices, got cpu"
+        ):
+            Future(devices=["cpu"])
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_list_with_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=False
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_cuda_future_callback_changes_devices(self):
+        # We check proper CUDA stream synchronization by filling the tensor with
+        # the expected value in one stream, and reading it from another stream.
+        tensor0 = torch.zeros((100,), device="cuda:0")
+        tensor1 = torch.zeros((100,), device="cuda:1")
+        parent_future = Future(devices=["cuda:0", "cuda:1"])
+
+        def cb(fut):
+            t0 = fut.value()
+            tensor1.copy_(t0, non_blocking=True)
+            return tensor1
+
+        child_future = parent_future.then(cb)
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor0.fill_(1)
+                parent_future.set_result(tensor0)
+        with torch.cuda.device("cuda:1"):
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(another_stream):
+                self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
+
+    @skip_if_lt_x_gpu(2)
+    def test_cuda_future_value_on_bad_device(self):
+        tensor0 = torch.zeros((100,), device="cuda:0")
+        tensor1 = torch.zeros((100,), device="cuda:1")
+        parent_future = Future(devices=["cuda:1"])
+
+        # As a plus, we test that futures still invoke callbacks even in case of
+        # error, and that the child futures are successful if those callbacks
+        # don't access the parent future.
+        def cb(fut):
+            with torch.cuda.device("cuda:1"):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor1.fill_(1)
+                return tensor1
+
+        child_future = parent_future.then(cb)
+        with torch.cuda.device("cuda:0"):
+            stream = torch.cuda.Stream()
+            with torch.cuda.stream(stream):
+                torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
+                tensor0.fill_(1)
+                parent_future.set_result(tensor0)
+        with self.assertRaisesRegex(
+            ValueError,
+            r"The result contained tensors residing on device\(s\) cuda:0 "
+            r"which are not among the expected device\(s\) cuda:1",
+        ):
+            parent_future.wait()
+        with torch.cuda.device("cuda:1"):
+            another_stream = torch.cuda.Stream()
+            with torch.cuda.stream(another_stream):
+                self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
+
+    @skip_if_lt_x_gpu(1)
+    def test_async_execution_with_cuda_future(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        t = torch.zeros((100,), device="cuda:0")
+        fut = rpc.rpc_async(dst, async_cuda_sleep_and_set_to_one, args=(t,))
+        another_stream = torch.cuda.Stream("cuda:0")
+        with torch.cuda.stream(another_stream):
+            self.assertTrue(torch.eq(fut.wait(), 1).all().item())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_async_execution_nested_with_cuda_future(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        nested_dst = worker_name((self.rank + 2) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        a = torch.ones((100,), device="cuda:0")
+        b = torch.ones((100,), device="cuda:0")
+        c = torch.ones((100,), device="cuda:0")
+        fut = rpc.rpc_async(dst, async_cuda_nested_add, args=(nested_dst, a, b, c))
+        another_stream = torch.cuda.Stream("cuda:0")
+        with torch.cuda.stream(another_stream):
+            self.assertTrue(torch.eq(fut.wait(), 3).all().item())
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_modify_tensor_inplace(self):
+        tensor = torch.zeros((100,), device="cuda:0")
+        future = Future(devices=["cuda:0"])
+        future.set_result(tensor)
+        # It's weird to modify the value of a future once it's complete, but
+        # technically possible. Currently this is considered undefined behavior
+        # (in practice the future will ignore the modification and still
+        # synchronize with the original value). We could one day add logic to
+        # detect and warn or throw in such cases, but for now we just check that
+        # this doesn't crash.
+        tensor.fill_(1)
+        future.wait()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_replace_tensor(self):
+        tensor_list = [torch.zeros((100,), device="cuda:0")]
+        future = Future(devices=["cuda:0"])
+        future.set_result(tensor_list)
+        # It's weird to modify the value of a future once it's complete, but
+        # technically possible. Currently this is considered undefined behavior
+        # (in practice the future will ignore the modification and still
+        # synchronize with the original value). We could one day add logic to
+        # detect and warn or throw in such cases, but for now we just check that
+        # this doesn't crash.
+        # We set things up so that the original tensor contained in the list
+        # gets deleted once we replace it with the other one. This will
+        # invalidate any cached information held by the future.
+        tensor_list[0] = torch.ones((100,), device="cuda:0")
+        future.wait()
+
+    @skip_if_lt_x_gpu(1)
+    def test_rref_with_unpickleable_attributes(self):
+        dst = worker_name((self.rank + 1) % self.world_size)
+        options = self.rpc_backend_options
+        options.set_device_map(dst, {"cuda:0": "cuda:0"})
+
+        rpc.init_rpc(
+            name=worker_name(self.rank),
+            backend=self.rpc_backend,
+            rank=self.rank,
+            world_size=self.world_size,
+            rpc_backend_options=options,
+        )
+
+        rref = rpc.remote(dst, TensorWrapper, args=(torch.zeros(42, device="cuda:0"),))
+        rref.rpc_sync().increase(1)
+        ret = rref.rpc_sync().sum()
+        self.assertEqual(ret, 42)
+
+        rpc.shutdown()
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: t, unwrapper=lambda v: v, sparse_tensor=True
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_list_with_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=lambda t: [t], unwrapper=operator.itemgetter(0), sparse_tensor=True
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_cuda_future_can_extract_custom_class_with_cuda_sparse_tensor(self):
+        self._test_cuda_future_extraction(
+            wrapper=TensorWrapper, unwrapper=lambda v: v.tensor, sparse_tensor=True
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..021ae60468009d2fd4fa947c90455d99c1c6d54e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py
@@ -0,0 +1,28 @@
+# mypy: allow-untyped-defs
+
+import torch.distributed.rpc as rpc
+from torch.testing._internal.common_distributed import tp_transports
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+
+
+class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture):
+    @property
+    def rpc_backend(self):
+        return rpc.backend_registry.BackendType["TENSORPIPE"]
+
+    @property
+    def rpc_backend_options(self):
+        return rpc.backend_registry.construct_rpc_backend_options(
+            self.rpc_backend, init_method=self.init_method, _transports=tp_transports()
+        )
+
+    def get_shutdown_error_regex(self):
+        # FIXME Once we consolidate the error messages returned by the
+        # TensorPipe agent put some more specific regex here.
+        error_regexes = [".*"]
+        return "|".join([f"({error_str})" for error_str in error_regexes])
+
+    def get_timeout_error_regex(self):
+        return "RPC ran for more than"
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a24e4f97f05df22396dc08e3e6bc381085477882
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/distributed/rpc_utils.py
@@ -0,0 +1,188 @@
+# mypy: allow-untyped-defs
+
+import os
+import sys
+import unittest
+
+from torch.testing._internal.common_distributed import MultiProcessTestCase
+from torch.testing._internal.common_utils import (
+    find_free_port,
+    IS_SANDCASTLE,
+    TEST_WITH_DEV_DBG_ASAN,
+)
+from torch.testing._internal.distributed.ddp_under_dist_autograd_test import (
+    CudaDdpComparisonTest,
+    DdpComparisonTest,
+    DdpUnderDistAutogradTest,
+)
+from torch.testing._internal.distributed.nn.api.remote_module_test import (
+    CudaRemoteModuleTest,
+    RemoteModuleTest,
+    ThreeWorkersRemoteModuleTest,
+)
+from torch.testing._internal.distributed.rpc.dist_autograd_test import (
+    CudaDistAutogradTest,
+    DistAutogradTest,
+    FaultyAgentDistAutogradTest,
+    TensorPipeAgentDistAutogradTest,
+    TensorPipeCudaDistAutogradTest,
+)
+from torch.testing._internal.distributed.rpc.dist_optimizer_test import (
+    DistOptimizerTest,
+)
+from torch.testing._internal.distributed.rpc.examples.parameter_server_test import (
+    ParameterServerTest,
+)
+from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import (
+    ReinforcementLearningRpcTest,
+)
+from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import (
+    FaultyAgentRpcTest,
+)
+from torch.testing._internal.distributed.rpc.jit.dist_autograd_test import (
+    JitDistAutogradTest,
+)
+from torch.testing._internal.distributed.rpc.jit.rpc_test import JitRpcTest
+from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import (
+    JitFaultyAgentRpcTest,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+    RpcAgentTestFixture,
+)
+from torch.testing._internal.distributed.rpc.rpc_test import (
+    CudaRpcTest,
+    RpcTest,
+    TensorPipeAgentCudaRpcTest,
+    TensorPipeAgentRpcTest,
+)
+
+
+def _check_and_set_tcp_init():
+    # if we are running with TCP init, set main address and port
+    # before spawning subprocesses, since different processes could find
+    # different ports.
+    use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+    if use_tcp_init == "1":
+        os.environ["MASTER_ADDR"] = "127.0.0.1"
+        os.environ["MASTER_PORT"] = str(find_free_port())
+
+
+def _check_and_unset_tcp_init():
+    use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
+    if use_tcp_init == "1":
+        del os.environ["MASTER_ADDR"]
+        del os.environ["MASTER_PORT"]
+
+
+# The tests for the RPC module need to cover multiple possible combinations:
+# - different aspects of the API, each one having its own suite of tests;
+# - different agents (ProcessGroup, TensorPipe, ...);
+# To avoid a combinatorial explosion in code size, and to prevent forgetting to
+# add a combination, these are generated automatically by the code in this file.
+# Here, we collect all the test suites that we need to cover.
+# We then have one separate file for each agent, from which
+# we call the generate_tests function of this file, passing to it a fixture for
+# the agent, which then gets mixed-in with each test suite.
+
+
+@unittest.skipIf(
+    TEST_WITH_DEV_DBG_ASAN,
+    "Skip ASAN as torch + multiprocessing spawn have known issues",
+)
+class SpawnHelper(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        _check_and_set_tcp_init()
+        self._spawn_processes()
+
+    def tearDown(self):
+        _check_and_unset_tcp_init()
+        super().tearDown()
+
+
+# This list contains test suites that are agent-agnostic and that only verify
+# compliance with the generic RPC interface specification. These tests should
+# *not* make use of implementation details of a specific agent (options,
+# attributes, ...). These test suites will be instantiated multiple times, once
+# for each agent (except the faulty agent, which is special).
+GENERIC_TESTS = [
+    RpcTest,
+    ParameterServerTest,
+    DistAutogradTest,
+    DistOptimizerTest,
+    JitRpcTest,
+    JitDistAutogradTest,
+    RemoteModuleTest,
+    ThreeWorkersRemoteModuleTest,
+    DdpUnderDistAutogradTest,
+    DdpComparisonTest,
+    ReinforcementLearningRpcTest,
+]
+GENERIC_CUDA_TESTS = [
+    CudaRpcTest,
+    CudaDistAutogradTest,
+    CudaRemoteModuleTest,
+    CudaDdpComparisonTest,
+]
+
+
+# This list contains test suites that will only be run on the TensorPipeAgent.
+# These suites should be standalone, and separate from the ones in the generic
+# list (not subclasses of those!).
+TENSORPIPE_TESTS = [
+    TensorPipeAgentRpcTest,
+    TensorPipeAgentDistAutogradTest,
+]
+TENSORPIPE_CUDA_TESTS = [
+    TensorPipeAgentCudaRpcTest,
+    TensorPipeCudaDistAutogradTest,
+]
+
+
+# This list contains test suites that will only be run on the faulty RPC agent.
+# That agent is special as it's only used to perform fault injection in order to
+# verify the error handling behavior. Thus the faulty agent will only run the
+# suites in this list, which were designed to test such behaviors, and not the
+# ones in the generic list.
+FAULTY_AGENT_TESTS = [
+    FaultyAgentRpcTest,
+    FaultyAgentDistAutogradTest,
+    JitFaultyAgentRpcTest,
+]
+
+
+def generate_tests(
+    prefix: str,
+    mixin: type[RpcAgentTestFixture],
+    tests: list[type[RpcAgentTestFixture]],
+    module_name: str,
+) -> dict[str, type[RpcAgentTestFixture]]:
+    """Mix in the classes needed to autogenerate the tests based on the params.
+
+    Takes a series of test suites, each written against a "generic" agent (i.e.,
+    derived from the abstract RpcAgentTestFixture class), as the `tests` args.
+    Takes a concrete subclass of RpcAgentTestFixture, which specializes it for a
+    certain agent, as the `mixin` arg. Produces all combinations of them.
+    Returns a dictionary of class names to class type
+    objects which can be inserted into the global namespace of the calling
+    module. The name of each test will be a concatenation of the `prefix` arg
+    and the original name of the test suite.
+    The `module_name` should be the name of the calling module so
+    that the classes can be fixed to make it look like they belong to it, which
+    is necessary for pickling to work on them.
+    """
+    ret: dict[str, type[RpcAgentTestFixture]] = {}
+    for test_class in tests:
+        if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN:
+            print(
+                f"Skipping test {test_class} on sandcastle for the following reason: "
+                "Skip dev-asan as torch + multiprocessing spawn have known issues",
+                file=sys.stderr,
+            )
+            continue
+
+        name = f"{prefix}{test_class.__name__}"
+        class_ = type(name, (test_class, mixin, SpawnHelper), {})
+        class_.__module__ = module_name
+        ret[name] = class_
+    return ret
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py
new file mode 100644
index 0000000000000000000000000000000000000000..08246cc65132b785233b32c263790c17e5a447de
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/dynamo_test_failures.py
@@ -0,0 +1,144 @@
+# mypy: allow-untyped-defs
+import logging
+import os
+import sys
+
+
+# NOTE: [dynamo_test_failures.py]
+#
+# We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures`
+# We generate skipIfTorchDynamo* for all tests in `dynamo_skips`
+# We generate runWithoutCompiledAutograd for all tests in `compiled_autograd_skips`
+#
+# For an easier-than-manual way of generating and updating these lists,
+# see scripts/compile_tests/update_failures.py
+#
+# If you're adding a new test, and it's failing PYTORCH_TEST_WITH_DYNAMO=1,
+# either add the appropriate decorators to your test or add skips for them
+# via test/dynamo_skips and test/dynamo_expected_failures.
+#
+# *These are not exactly unittest.expectedFailure and unittest.skip. We'll
+# always execute the test and then suppress the signal, if necessary.
+# If your tests crashes, or is slow, please use @skipIfTorchDynamo instead.
+#
+# The expected failure and skip files are located in test/dynamo_skips and
+# test/dynamo_expected_failures. They're individual files rather than a list so
+# git will merge changes easier.
+
+
+def find_test_dir():
+    # Find the path to the dynamo expected failure and skip files.
+    from os.path import abspath, basename, dirname, exists, join, normpath
+
+    if sys.platform == "win32":
+        return None
+
+    # Check relative to this file (local build):
+    test_dir = normpath(join(dirname(abspath(__file__)), "../../../test"))
+    if exists(join(test_dir, "dynamo_expected_failures")):
+        return test_dir
+
+    # Check relative to __main__ (installed builds relative to test file):
+    main = sys.modules["__main__"]
+    file = getattr(main, "__file__", None)
+    if file is None:
+        # Generated files do not have a module.__file__
+        return None
+    test_dir = dirname(abspath(file))
+    while dirname(test_dir) != test_dir:
+        if basename(test_dir) == "test" and exists(
+            join(test_dir, "dynamo_expected_failures")
+        ):
+            return test_dir
+        test_dir = dirname(test_dir)
+
+    # Not found
+    return None
+
+
+test_dir = find_test_dir()
+if not test_dir:
+    logger = logging.getLogger(__name__)
+    logger.warning(
+        "test/dynamo_expected_failures directory not found - known dynamo errors won't be skipped."
+    )
+
+# Tests that run without strict mode in PYTORCH_TEST_WITH_INDUCTOR=1.
+# Please don't add anything to this list.
+FIXME_inductor_non_strict = {
+    "test_modules",
+    "test_ops",
+    "test_ops_gradients",
+    "test_torch",
+}
+
+# We generate unittest.expectedFailure for all of the following tests
+# when run under PYTORCH_TEST_WITH_DYNAMO=1.
+# see NOTE [dynamo_test_failures.py] for more details
+#
+# This lists exists so we can more easily add large numbers of failing tests,
+if test_dir is None:
+    dynamo_expected_failures = set()
+    dynamo_skips = set()
+
+    inductor_expected_failures = set()
+    inductor_skips = set()
+
+    compiled_autograd_skips = set()
+else:
+    dynamo_failures_directory = os.path.join(test_dir, "dynamo_expected_failures")
+    dynamo_skips_directory = os.path.join(test_dir, "dynamo_skips")
+
+    dynamo_expected_failures = set(os.listdir(dynamo_failures_directory))
+    dynamo_skips = set(os.listdir(dynamo_skips_directory))
+
+    inductor_failures_directory = os.path.join(test_dir, "inductor_expected_failures")
+    inductor_skips_directory = os.path.join(test_dir, "inductor_skips")
+
+    inductor_expected_failures = set(os.listdir(inductor_failures_directory))
+    inductor_skips = set(os.listdir(inductor_skips_directory))
+
+    compiled_autograd_skips_directory = os.path.join(
+        test_dir, "compiled_autograd_skips"
+    )
+    compiled_autograd_skips = set(os.listdir(compiled_autograd_skips_directory))
+
+# TODO: due to case sensitivity problems, for now list these files by hand
+extra_dynamo_skips = {
+    "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_exhaustive_t_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_fake_exhaustive_t_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_t_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_inplace_t_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_T_cpu_float32",
+    "TestProxyTensorOpInfoCPU.test_make_fx_symbolic_exhaustive_out_t_cpu_float32",
+}
+dynamo_skips = dynamo_skips.union(extra_dynamo_skips)
+
+
+# verify some invariants
+for test in (
+    dynamo_expected_failures
+    | dynamo_skips
+    | inductor_expected_failures
+    | inductor_skips
+):
+    if len(test.split(".")) != 2:
+        raise AssertionError(f'Invalid test name: "{test}"')
+
+dynamo_intersection = dynamo_expected_failures.intersection(dynamo_skips)
+if len(dynamo_intersection) > 0:
+    raise AssertionError(
+        "there should be no overlap between dynamo_expected_failures "
+        "and dynamo_skips, got " + str(dynamo_intersection)
+    )
+
+inductor_intersection = inductor_expected_failures.intersection(inductor_skips)
+if len(inductor_intersection) > 0:
+    raise AssertionError(
+        "there should be no overlap between inductor_expected_failures "
+        "and inductor_skips, got " + str(inductor_intersection)
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e93c41de72a765415a8dce5d3c98c8cd0cf2c41
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module.py
@@ -0,0 +1,44 @@
+import sys
+from typing import Optional
+
+from torch.utils._config_module import Config, install_config_module
+
+
+e_bool = True
+e_int = 1
+e_float = 1.0
+e_string = "string"
+e_list = [1]
+e_set = {1}
+e_tuple = (1,)
+e_dict = {1: 2}
+e_none: Optional[bool] = None
+e_optional: Optional[bool] = True
+e_ignored = True
+_e_ignored = True
+magic_cache_config_ignored = True
+# [@compile_ignored: debug]
+e_compile_ignored = True
+e_config: bool = Config(default=True)
+e_jk: bool = Config(justknob="does_not_exist", default=True)
+e_jk_false: bool = Config(justknob="does_not_exist", default=False)
+e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False)
+e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True)
+e_env_default_str: bool = Config(env_name_default="ENV_STR", default="default")
+e_env_default_str_empty: bool = Config(
+    env_name_default="ENV_STR_EMPTY", default="default"
+)
+e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False)
+e_aliased_bool: bool = Config(
+    alias="torch.testing._internal.fake_config_module2.e_aliasing_bool"
+)
+
+
+class nested:
+    e_bool = True
+
+
+_cache_config_ignore_prefix = ["magic_cache_config"]
+_save_config_ignore = ["e_ignored"]
+
+install_config_module(sys.modules[__name__])
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c2e2baa4ddca7685adf734809488979c21ab63
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module2.py
@@ -0,0 +1,13 @@
+import sys
+
+from torch.utils._config_module import Config, install_config_module
+
+
+e_aliasing_bool = False
+
+e_env_default_multi: bool = Config(
+    env_name_default=["ENV_TRUE", "ENV_FALSE"], default=False
+)
+e_env_force_multi: bool = Config(env_name_force=["ENV_FAKE", "ENV_TRUE"], default=False)
+
+install_config_module(sys.modules[__name__])
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d3d7f15d901aa45fa46aa86422d50a69ced84be
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/fake_config_module3.py
@@ -0,0 +1,11 @@
+import sys
+from typing import Callable, Optional
+
+from torch.utils._config_module import install_config_module
+
+
+e_list = [1]
+e_set = {1}
+e_func: Optional[Callable] = None
+
+install_config_module(sys.modules[__name__])
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7c11addb9199b63209872053002b5946c58af1b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/generated/annotated_fn_args.py
@@ -0,0 +1,2891 @@
+"""
+This file is needed for generating procedural tests required for
+testing __torch_function__. See tests/test_overrides.py.
+"""
+
+# flake8: noqa
+import torch
+
+annotated_args = {
+    torch._C._VariableFunctions._cast_Byte: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cast_Char: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cast_Double: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cast_Float: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cast_Int: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cast_Long: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cast_Short: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cast_Half: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._make_dual: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._unpack_dual: [{'is_kwarg_only': 'False', 'name': 'dual', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.align_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions._assert_scalar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions._functional_assert_scalar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._functional_assert_async: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'assert_msg', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._assert_tensor_metadata: [{'is_kwarg_only': 'False', 'name': 'a', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._print: [{'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.sym_constrain_range: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.sym_constrain_range_for_size: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._functional_sym_constrain_range: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._functional_sym_constrain_range_for_size: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'int64_t?'}, {'is_kwarg_only': 'False', 'name': 'dep_token', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._make_dep_token: [],
+    torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._use_cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'zero_infinity', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._cudnn_ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blank', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'zero_infinity', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._use_cudnn_rnn_flatten_weight: [],
+    torch._C._VariableFunctions._cudnn_rnn_flatten_weight: [{'is_kwarg_only': 'False', 'name': 'weight_arr', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'proj_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._cudnn_rnn: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'weight_buf', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'proj_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dropout_state', 'simple_type': 'Tensor?'}],
+    torch._C._VariableFunctions._cudnn_init_dropout_state: [{'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout_seed', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._debug_has_internal_overlap: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._fused_dropout: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}],
+    torch._C._VariableFunctions._masked_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.native_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool?'}],
+    torch._C._VariableFunctions._sobol_engine_draw: [{'is_kwarg_only': 'False', 'name': 'quasi', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sobolstate', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_generated', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType?'}],
+    torch._C._VariableFunctions._sobol_engine_ff_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sobolstate', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_generated', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._sobol_engine_scramble_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ltm', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._sobol_engine_initialize_state_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._reshape_from_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._shape_as_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.feature_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.feature_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.alpha_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.alpha_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.feature_alpha_dropout: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.feature_alpha_dropout_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.view_as_real: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.view_as_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.real: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.imag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conj_physical_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.resolve_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.resolve_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._neg_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arccos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.avg_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.adaptive_avg_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.adaptive_max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._add_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._add_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addmv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.affine_grid_generator: [{'is_kwarg_only': 'False', 'name': 'theta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._is_all_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._is_any_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._test_check_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._test_functorch_fallback: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.allclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.arange: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._dim_arange: [{'is_kwarg_only': 'False', 'name': 'like', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.acosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arccosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.asinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arcsinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arctanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.as_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.as_strided_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arcsin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arctan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atleast_1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atleast_1d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.atleast_2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atleast_2d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.atleast_3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atleast_3d: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.bartlett_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.bartlett_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.quantized_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'output_scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'output_zero_point', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._batch_norm_impl_index: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.bilinear: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.binary_cross_entropy_with_logits: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bincount: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._lazy_clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.blackman_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.blackman_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.broadcast_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions._sparse_broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.concat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.concatenate: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.block_diag: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.chain_matmul: [{'is_kwarg_only': 'False', 'name': 'matrices', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.chain_matmul: [{'is_kwarg_only': 'False', 'name': 'matrices', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.unsafe_chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor_indices_or_sections', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cudnn_is_acceptable: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.complex: [{'is_kwarg_only': 'False', 'name': 'real', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'imag', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.complex: [{'is_kwarg_only': 'False', 'name': 'real', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'imag', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.polar: [{'is_kwarg_only': 'False', 'name': 'abs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'angle', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.polar: [{'is_kwarg_only': 'False', 'name': 'abs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'angle', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.constant_pad_nd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions._convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'transposed', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._convolution_mode: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.conv1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv_tbc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv_transpose1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._copy_from: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dst', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._copy_from_and_resize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dst', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cosine_embedding_loss: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cov: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.corrcoef: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cudnn_affine_grid_generator: [{'is_kwarg_only': 'False', 'name': 'theta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'C', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'H', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'W', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cudnn_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.cudnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.cudnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.cudnn_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'allow_tf32', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._mps_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.cudnn_convolution_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.cudnn_convolution_add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'z', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.cudnn_grid_sampler: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions._cummax_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions._cummin_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.cumulative_trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cumulative_trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._ctc_loss: [{'is_kwarg_only': 'False', 'name': 'log_probs', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'targets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target_lengths', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diag_embed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diagflat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'ScalarList'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.gradient: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'spacing', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch._C._VariableFunctions.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch._C._VariableFunctions.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.einsum: [{'is_kwarg_only': 'False', 'name': 'equation', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.embedding: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.embedding_renorm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max_norm', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'norm_type', 'simple_type': 'double'}],
+    torch._C._VariableFunctions._embedding_bag_forward_only: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._rowwise_prune: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'compressed_indices_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.row_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.row_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_grad_by_freq', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sparse', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'per_sample_weights', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'include_last_offset', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'padding_idx', 'simple_type': 'int64_t?'}],
+    torch._C._VariableFunctions._embedding_bag: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.empty: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.empty_permuted: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'physical_layout', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._empty_affine_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions._empty_per_channel_affine_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'axis', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._resize_output_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'device', 'simple_type': 'Device'}],
+    torch._C._VariableFunctions.empty_quantized: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'qtensor', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.empty_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.empty_strided: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.exp2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'm', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.eye: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'm', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}],
+    torch._C._VariableFunctions.fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.full: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.full_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.from_file: [{'is_kwarg_only': 'False', 'name': 'filename', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gcd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lcm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.grid_sampler: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.grid_sampler_2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._grid_sampler_2d_cpu_fallback: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.grid_sampler_3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'grid', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'interpolation_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'padding_mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.hann_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.hann_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.hamming_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'beta', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.kaiser_window: [{'is_kwarg_only': 'False', 'name': 'window_length', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'periodic', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'beta', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.hinge_embedding_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.group_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_groups', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.native_group_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'C', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'HxW', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'group', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions._fft_r2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'onesided', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fft_r2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'onesided', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fft_c2r: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'last_dim_size', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions._fft_c2r: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'last_dim_size', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions._fft_c2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'forward', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fft_c2c: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'normalization', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'forward', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._validate_compressed_sparse_indices: [{'is_kwarg_only': 'False', 'name': 'is_crow', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'compressed_idx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'plain_idx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cdim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'nnz', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._cufft_get_plan_cache_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}],
+    torch._C._VariableFunctions._cufft_get_plan_cache_max_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}],
+    torch._C._VariableFunctions._cufft_set_plan_cache_max_size: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}, {'is_kwarg_only': 'False', 'name': 'max_size', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._cufft_clear_plan_cache: [{'is_kwarg_only': 'False', 'name': 'device_index', 'simple_type': 'DeviceIndex'}],
+    torch._C._VariableFunctions._unsafe_index: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}],
+    torch._C._VariableFunctions._unsafe_masked_index: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'fill', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._unsafe_masked_index_put_accumulate: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._unsafe_index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._index_put_impl_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.instance_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'use_input_stats', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'cudnn_enabled', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.isclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_element', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'elements', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'test_element', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'element', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isin: [{'is_kwarg_only': 'False', 'name': 'element', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'test_elements', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isnan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_distributed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_floating_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._is_zerotensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isreal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_same_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_signed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.is_inference: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.kl_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.layer_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.native_layer_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.rms_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions._fused_rms_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'normalized_shape_ndim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nan_to_num_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mkldnn_linear_backward_weights: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias_defined', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._cslt_compress: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cslt_sparse_mm: [{'is_kwarg_only': 'False', 'name': 'compressed_A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dense_B', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._cslt_sparse_mm_search: [{'is_kwarg_only': 'False', 'name': 'compressed_A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dense_B', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_semi_structured_tile: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_semi_structured_apply: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'thread_masks', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_semi_structured_apply_dense: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'thread_masks', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_semi_structured_linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'meta', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_semi_structured_mm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1_meta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_semi_structured_addmm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1_meta', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._mixed_dtypes_linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fbgemm_linear_int8_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fbgemm_linear_int8_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fbgemm_linear_quantize_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fbgemm_pack_gemm_matrix_fp16: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._wrapped_linear_prepack: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._wrapped_quantized_linear_prepacked: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_channel', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.fbgemm_linear_fp16_weight_fp32_activation: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fbgemm_linear_fp16_weight: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fbgemm_pack_quantized_matrix: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'K', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'N', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ldexp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.linspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logspace: [{'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'steps', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions._log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions._log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions._logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.margin_ranking_loss: [{'is_kwarg_only': 'False', 'name': 'input1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._compute_linear_combination: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coefficients', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._compute_linear_combination: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coefficients', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.max_pool1d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions.mkldnn_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions.mkldnn_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._VariableFunctions.quantized_max_pool1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.quantized_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions.quantized_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._VariableFunctions.max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._mps_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.mkldnn_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.mkldnn_rnn_layer: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight0', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight3', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx_', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx_', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reverse', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.miopen_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'exponential_average_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'epsilon', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.miopen_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.miopen_convolution_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'output_padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.miopen_depthwise_convolution: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'benchmark', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'deterministic', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.miopen_convolution_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.miopen_convolution_add_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'z', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.miopen_rnn: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight_stride0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cx', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'hidden_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dropout_state', 'simple_type': 'Tensor?'}],
+    torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions._int_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._int_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._convert_weight_to_int4pack: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'innerKTiles', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._weight_int4pack_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScaleAndZeros', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._weight_int4pack_mm_with_scales_and_zeros: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qZeros', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._convert_weight_to_int4pack_for_cpu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'innerKTiles', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._weight_int4pack_mm_for_cpu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qGroupSize', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qScaleAndZeros', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._dyn_quant_pack_4bit_weight: [{'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales_zeros', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'block_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'in_features', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_features', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._dyn_quant_matmul_4bit: [{'is_kwarg_only': 'False', 'name': 'inp', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_weights', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'block_size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'in_features', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_features', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._weight_int8pack_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_sparse_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.native_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.native_batch_norm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions._native_batch_norm_legit: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'training', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions._native_batch_norm_legit_no_training: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.batch_norm_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.batch_norm_elemt: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.batch_norm_elemt: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.batch_norm_gather_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.batch_norm_gather_stats_with_counts: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'counts', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.batch_norm_backward_reduce: [{'is_kwarg_only': 'False', 'name': 'grad_out', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'input_g', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'weight_g', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bias_g', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.batch_norm_backward_elemt: [{'is_kwarg_only': 'False', 'name': 'grad_out', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'invstd', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'sum_dy', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sum_dy_xmu', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.batch_norm_update_stats: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_mean', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'running_var', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'momentum', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.is_vulkan_available: [],
+    torch._C._VariableFunctions._nnpack_available: [],
+    torch._C._VariableFunctions._nnpack_spatial_convolution: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.ones: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.ones_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pairwise_distance: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cdist: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._euclidean_dist: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pdist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cosine_similarity: [{'is_kwarg_only': 'False', 'name': 'x1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.permute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.adjoint: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pixel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'upscale_factor', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.pixel_unshuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'downscale_factor', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.channel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.native_channel_shuffle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'groups', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions._pin_memory: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pinverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.poisson_nll_loss: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'log_input', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'full', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'reduction', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rad2deg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.deg2rad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.scalar_tensor: [{'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.rand: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.rand_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.randint: [{'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.randint_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'low', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'high', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.randn: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.randn_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.randperm: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'True', 'name': 'generator', 'simple_type': 'Generator?'}],
+    torch._C._VariableFunctions.ravel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.negative_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions._mkldnn_reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rrelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rrelu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.prelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._prelu_kernel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.selu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.selu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.celu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.celu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logit_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sinc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.detach: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.detach_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.slice_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.select_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.diagonal_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.as_strided_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.smm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions._softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'half_to_float', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions._softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'input_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.unsafe_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.unsafe_split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._chunk_cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_chunks', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._chunk_cat: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_chunks', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.hstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.hstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.vstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.vstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.dstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.dstack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.istft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.square_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.std_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.t: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tensordot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims_self', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dims_other', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.tensordot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims_self', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'dims_other', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.threshold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.threshold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.threshold_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'threshold', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.tile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions._mkldnn_transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._mkldnn_transpose_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.flip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.fliplr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.flipud: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.roll: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shifts', 'simple_type': 'SymIntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.rot90: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.trapezoid: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.trapz: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.trapz: [{'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._transform_bias_rescale_qkv: [{'is_kwarg_only': 'False', 'name': 'qkv', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_heads', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._nested_tensor_from_mask: [{'is_kwarg_only': 'False', 'name': 't', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_tensor_from_mask_left_aligned: [{'is_kwarg_only': 'False', 'name': 't', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_from_padded: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cpu_nested_shape_example', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_from_padded_and_nested_example: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nt_example', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_view_from_buffer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_view_from_buffer_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_view_from_buffer_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'nested_strides', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_view_from_jagged: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_view_from_jagged_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_view_from_jagged_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_offsets: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_lengths: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_ragged_idx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_min_seqlen: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_max_seqlen: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_get_jagged_dummy: [{'is_kwarg_only': 'False', 'name': 'any', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_compute_contiguous_strides_offsets: [{'is_kwarg_only': 'False', 'name': 'nested_size', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._trilinear: [{'is_kwarg_only': 'False', 'name': 'i1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'i2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'i3', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'expand1', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'expand2', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'expand3', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sumdim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.triplet_margin_loss: [{'is_kwarg_only': 'False', 'name': 'anchor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'positive', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'negative', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fix_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._has_compatible_shallow_copy_type: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'from', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._unique: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.unique_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.unique_consecutive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._unique2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.unsqueeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.vander: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.var_mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.norm_except_dim: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._weight_norm: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'g', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._weight_norm_interface: [{'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'g', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'True', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.zeros: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions._efficientzerotensor: [{'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.zeros_like: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._standard_gamma_grad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._standard_gamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._dirichlet_grad: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'alpha', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'total', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sample_dirichlet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.poisson: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.binomial: [{'is_kwarg_only': 'False', 'name': 'count', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'prob', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.native_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.native_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType?'}],
+    torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions._sparse_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions._sparse_csr_sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions._sparse_csr_prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions._sparse_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._sparse_log_softmax_backward_data: [{'is_kwarg_only': 'False', 'name': 'grad_output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch._C._VariableFunctions.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.frobenius_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.frobenius_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions.nuclear_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions.clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.positive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.resize_as_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.resize_as_sparse_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.rsub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rsub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'out_dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._scaled_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._scaled_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._scaled_grouped_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_a', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_b', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._grouped_mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._validate_sparse_coo_tensor_args: [{'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._validate_sparse_compressed_tensor_args: [{'is_kwarg_only': 'False', 'name': 'compressed_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'plain_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'layout', 'simple_type': 'Layout'}],
+    torch._C._VariableFunctions._validate_sparse_csr_tensor_args: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._validate_sparse_csc_tensor_args: [{'is_kwarg_only': 'False', 'name': 'ccol_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'row_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._validate_sparse_bsr_tensor_args: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._validate_sparse_bsc_tensor_args: [{'is_kwarg_only': 'False', 'name': 'ccol_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'row_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._to_cpu: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._coalesce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.hspmm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.hspmm: [{'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions._to_sparse_semi_structured: [{'is_kwarg_only': 'False', 'name': 'dense', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.quantize_per_tensor_dynamic: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'reduce_range', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.quantize_per_tensor: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.quantize_per_channel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scales', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_points', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.dequantize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.dequantize: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.q_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.q_zero_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.q_per_channel_scales: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.q_per_channel_zero_points: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.q_per_channel_axis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.int_repr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._make_per_tensor_quantized_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._make_per_channel_quantized_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.fake_quantize_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._fake_quantize_per_tensor_affine_cachemask_tensor_qparams: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_enabled', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._fake_quantize_learnable_per_tensor_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.fake_quantize_per_channel_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._fake_quantize_learnable_per_channel_affine: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.fused_moving_avg_obs_fake_quant: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'observer_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_min', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_max', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'averaging_const', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ch_axis', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._fused_moving_avg_obs_fq_helper: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'observer_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fake_quant_on', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_min', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'running_max', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'zero_point', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'averaging_const', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'quant_min', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'quant_max', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ch_axis', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._choose_qparams_per_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._saturate_weight_to_fp16: [{'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.choose_qparams_optimized: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'numel', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'n_bins', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'ratio', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'bit_width', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.meshgrid: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.meshgrid: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'indexing', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.cartesian_prod: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.combinations: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.result_type: [{'is_kwarg_only': 'False', 'name': 'scalar1', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scalar2', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.can_cast: [{'is_kwarg_only': 'False', 'name': 'from_', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.promote_types: [{'is_kwarg_only': 'False', 'name': 'type1', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'type2', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions._lstm_mps: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.lstm: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.lstm: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.gru: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.gru: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.rnn_tanh: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.rnn_tanh: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.rnn_relu: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.rnn_relu: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'params', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'has_biases', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'num_layers', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dropout', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'train', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'bidirectional', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.lstm_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gru_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rnn_tanh_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.rnn_relu_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.quantized_lstm_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.quantized_gru_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.quantized_rnn_relu_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.quantized_rnn_tanh_cell: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'hx', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'w_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'packed_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_ih', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_offsets_hh', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'scale_hh', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_ih', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'zero_point_hh', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._pack_padded_sequence: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'lengths', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._pad_packed_sequence: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_sizes', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'padding_value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'total_length', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.masked_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._masked_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.tril_indices: [{'is_kwarg_only': 'False', 'name': 'row', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'col', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.triu_indices: [{'is_kwarg_only': 'False', 'name': 'row', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'col', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.trace: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.argwhere: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_check_errors: [{'is_kwarg_only': 'False', 'name': 'info', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'api_name', 'simple_type': 'c10::string_view'}, {'is_kwarg_only': 'True', 'name': 'is_matrix', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.swapaxes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.swapdims: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._lu_with_info: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lu_unpack: [{'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lu_unpack: [{'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.i0_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.dist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._histogramdd_bin_edges: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._histogramdd_from_bin_cts: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._histogramdd_from_bin_tensors: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.histogramdd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}],
+    torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}],
+    torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}],
+    torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch._C._VariableFunctions.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.normal: [{'is_kwarg_only': 'False', 'name': 'mean', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'std', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions._amp_foreach_non_finite_check_and_unscale_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'found_inf', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'inv_scale', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._amp_update_scale_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'growth_tracker', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'found_inf', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'scale_growth_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'scale_backoff_factor', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'growth_interval', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_maximum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalar', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_minimum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'scalars', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foreach_abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weights', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'tensors1', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_lgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._foreach_pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'ScalarList'}],
+    torch._C._VariableFunctions._foreach_reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._foreach_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.bucketize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'boundaries', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions.searchsorted: [{'is_kwarg_only': 'False', 'name': 'sorted_sequence', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}],
+    torch._C._VariableFunctions._convert_indices_from_coo_to_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._convert_indices_from_coo_to_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._convert_indices_from_csr_to_coo: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._convert_indices_from_csr_to_coo: [{'is_kwarg_only': 'False', 'name': 'crow_indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'col_indices', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions.mkldnn_adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions._adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._VariableFunctions._adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._VariableFunctions.column_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.column_stack: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions.isfinite: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._add_batch_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._remove_batch_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'batch_size', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.det: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.logdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._test_serialization_subcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._test_parallel_materialize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_parallel', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._test_autograd_multiple_dispatch: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._test_autograd_multiple_dispatch: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._test_autograd_multiple_dispatch_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._test_autograd_multiple_dispatch_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._test_autograd_multiple_dispatch_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.segment_reduce: [{'is_kwarg_only': 'False', 'name': 'data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch._C._VariableFunctions._nested_tensor_from_tensor_list: [{'is_kwarg_only': 'False', 'name': 'list', 'simple_type': 'TensorList'}],
+    torch._C._VariableFunctions._fw_primal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._fw_primal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._make_dual_copy: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._make_dual_copy: [{'is_kwarg_only': 'False', 'name': 'primal', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tangent', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'level', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.view_as_real_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.view_as_real_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.view_as_complex_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.view_as_complex_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._conj_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._conj_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._neg_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._neg_view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.as_strided_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.as_strided_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions._sparse_broadcast_to_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._sparse_broadcast_to_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.diagonal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.diagonal_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.expand_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.expand_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.permute_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.permute_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions._reshape_alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions._reshape_alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.select_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.select_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.detach_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.detach_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.slice_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.slice_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.split_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.split_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}],
+    torch._C._VariableFunctions.split_with_sizes_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.split_with_sizes_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.squeeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch._C._VariableFunctions.t_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.t_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.transpose_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.transpose_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.unsqueeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.unsqueeze_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.values_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.crow_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.crow_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.col_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.col_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ccol_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.ccol_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.row_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.row_indices_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.unbind_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.unbind_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._VariableFunctions.view_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch._C._VariableFunctions.unfold_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.unfold_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions.alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions.alias_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_from_padded_tensor: [{'is_kwarg_only': 'False', 'name': 'padded', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'offsets', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._nested_tensor_softmax_with_shape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._safe_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._transformer_encoder_layer_fwd: [{'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_heads', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'use_gelu', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'norm_first', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'norm_weight_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_bias_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_weight_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'norm_bias_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_weight_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_bias_1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_weight_2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ffn_bias_2', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._native_multi_head_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_head', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._fused_sdp_choice: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._scaled_dot_product_attention_math: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._scaled_dot_product_attention_math_for_mps: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._scaled_dot_product_flash_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._scaled_dot_product_flash_attention_for_cpu: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._scaled_dot_product_efficient_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'attn_bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'compute_log_sumexp', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._scaled_dot_product_cudnn_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'attn_bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'compute_log_sumexp', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._triton_scaled_dot_attention: [{'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'v', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._fill_mem_eff_dropout_mask_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dropout_p', 'simple_type': 'double'}, {'is_kwarg_only': 'False', 'name': 'seed', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'offset', 'simple_type': 'int64_t'}],
+    torch._C._VariableFunctions._triton_multi_head_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'embed_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'num_head', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'qkv_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'qkv_bias', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'proj_bias', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._foobar: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._VariableFunctions._fused_adam_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fused_adam_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fused_adamw_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fused_adamw_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avgs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'max_exp_avg_sqs', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'beta1', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'beta2', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'amsgrad', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fused_sgd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'momentum_buffer_list', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'dampening', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'nesterov', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'is_first_step', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fused_sgd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'momentum_buffer_list', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'momentum', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'dampening', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'nesterov', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'is_first_step', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fused_adagrad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_sums', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'lr_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._fused_adagrad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'grads', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_sums', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'False', 'name': 'state_steps', 'simple_type': 'TensorList'}, {'is_kwarg_only': 'True', 'name': 'lr', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'lr_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'weight_decay', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'eps', 'simple_type': 'double'}, {'is_kwarg_only': 'True', 'name': 'maximize', 'simple_type': 'bool'}],
+    torch._C._VariableFunctions._propagate_xla_data: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output', 'simple_type': 'Tensor'}],
+    torch._C._nn.binary_cross_entropy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.binary_cross_entropy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._nn.linear: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._nn.mkldnn_linear: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch._C._nn.relu6: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.relu6_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.gelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.gelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.gelu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.silu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.silu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.silu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.mish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.mish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.mish_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.one_hot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.mkldnn_reorder_conv2d_weight: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.mkldnn_reorder_conv3d_weight: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.cross_entropy_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.mse_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.mse_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.multi_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.multi_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.multilabel_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.multilabel_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.nll_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.nll_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.nll_loss_nd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.nll_loss2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.nll_loss2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.smooth_l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.smooth_l1_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.huber_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.huber_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.soft_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.soft_margin_loss: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'target', 'simple_type': 'Tensor'}],
+    torch._C._nn.elu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.elu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.elu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.glu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.glu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardsigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardsigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardsigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardtanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardtanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardtanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardswish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardswish: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.hardswish_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.leaky_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.leaky_relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.leaky_relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.log_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.log_sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.rrelu_with_noise: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}],
+    torch._C._nn.rrelu_with_noise: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}],
+    torch._C._nn.rrelu_with_noise_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'noise', 'simple_type': 'Tensor'}],
+    torch._C._nn.softplus: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.softplus: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.softshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.softshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.adaptive_avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.adaptive_avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.adaptive_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.adaptive_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.adaptive_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._nn.adaptive_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._nn.avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.avg_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._nn.avg_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._nn.fractional_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}],
+    torch._C._nn.fractional_max_pool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}],
+    torch._C._nn.fractional_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}],
+    torch._C._nn.fractional_max_pool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'random_samples', 'simple_type': 'Tensor'}],
+    torch._C._nn.max_pool2d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.max_pool2d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.max_pool3d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._nn.max_pool3d_with_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._nn.max_unpool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.max_unpool2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.max_unpool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._nn.max_unpool3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 3}],
+    torch._C._nn.reflection_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.reflection_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.reflection_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}],
+    torch._C._nn.reflection_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}],
+    torch._C._nn.reflection_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}],
+    torch._C._nn.reflection_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}],
+    torch._C._nn.replication_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.replication_pad1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.replication_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}],
+    torch._C._nn.replication_pad2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 4}],
+    torch._C._nn.replication_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}],
+    torch._C._nn.replication_pad3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 6}],
+    torch._C._nn._pad_circular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._nn._pad_enum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'mode', 'simple_type': 'int64_t'}],
+    torch._C._nn.pad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pad', 'simple_type': 'SymIntArrayRef'}],
+    torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn.upsample_linear1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn.upsample_bilinear2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn._upsample_bilinear2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn.upsample_trilinear3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn.upsample_bicubic2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn._upsample_bicubic2d_aa: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'align_corners', 'simple_type': 'bool'}],
+    torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}],
+    torch._C._nn.upsample_nearest1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}],
+    torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}],
+    torch._C._nn._upsample_nearest_exact1d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 1}],
+    torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.upsample_nearest2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn._upsample_nearest_exact2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.upsample_nearest3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef?'}, {'is_kwarg_only': 'False', 'name': 'scale_factors', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn._upsample_nearest_exact3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.slow_conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.slow_conv_transpose2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.slow_conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.slow_conv_transpose3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.thnn_conv2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.thnn_conv2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn._conv_depthwise2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn._conv_depthwise2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.conv_depthwise3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'bias', 'simple_type': 'Tensor?'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'SymIntArrayRef', 'size': 3}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.slow_conv3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.slow_conv3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.slow_conv_dilated2d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 2}],
+    torch._C._nn.slow_conv_dilated3d: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'SymIntArrayRef', 'size': 3}],
+    torch._C._nn.col2im: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.col2im: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'output_size', 'simple_type': 'SymIntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.im2col: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn.im2col: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'kernel_size', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'dilation', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'IntArrayRef', 'size': 2}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch._C._nn._test_optional_intlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'IntArrayRef?'}],
+    torch._C._nn._test_optional_filled_intlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'IntArrayRef?', 'size': 2}],
+    torch._C._nn._test_optional_floatlist: [{'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'addends', 'simple_type': 'ArrayRef?'}],
+    torch._C._nn._test_string_default: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}],
+    torch._C._nn._test_ambiguous_defaults: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}],
+    torch._C._nn._test_ambiguous_defaults: [{'is_kwarg_only': 'False', 'name': 'dummy', 'simple_type': 'Tensor'}],
+    torch._C._nn._test_warn_in_autograd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._nn.pad_sequence: [{'is_kwarg_only': 'False', 'name': 'sequences', 'simple_type': 'TensorList'}],
+    torch._C._nn.flatten_dense_tensors: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._nn.unflatten_dense_tensors: [{'is_kwarg_only': 'False', 'name': 'flat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._nn.scaled_dot_product_attention: [{'is_kwarg_only': 'False', 'name': 'query', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'key', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_diagonal: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_solve_triangular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'upper', 'simple_type': 'bool'}],
+    torch._C._linalg.linalg_solve_triangular: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'upper', 'simple_type': 'bool'}],
+    torch._C._linalg.linalg_vander: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cholesky_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cholesky_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lu_factor: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lu_factor: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lu_factor_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lu_factor_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lu: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lu: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lu_solve: [{'is_kwarg_only': 'False', 'name': 'LU', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lu_solve: [{'is_kwarg_only': 'False', 'name': 'LU', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_det: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_ldl_factor_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_ldl_factor_ex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_ldl_factor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_ldl_factor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_ldl_solve: [{'is_kwarg_only': 'False', 'name': 'LD', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_ldl_solve: [{'is_kwarg_only': 'False', 'name': 'LD', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'pivots', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lstsq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_lstsq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'b', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_vecdot: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_vecdot: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'y', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_slogdet: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_eig: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_eig: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg._linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_eigvals: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_eigh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_eigvalsh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_eigvalsh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_householder_product: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tau', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_householder_product: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tau', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_inv_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_inv_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_inv: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_inv: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'c10::string_view'}],
+    torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'c10::string_view'}],
+    torch._C._linalg.linalg_vector_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_vector_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'Scalar'}],
+    torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'ord', 'simple_type': 'Scalar'}],
+    torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_svd: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_svdvals: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_svdvals: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'c10::string_view'}],
+    torch._C._linalg.linalg_cond: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'c10::string_view'}],
+    torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'double'}],
+    torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'double'}],
+    torch._C._linalg.linalg_pinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'rcond', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_solve_ex: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_solve: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_solve: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'B', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_tensorinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_tensorinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_tensorsolve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_tensorsolve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_qr: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_qr: [{'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch._C._linalg.linalg_matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'double'}],
+    torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'double'}],
+    torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_matrix_rank: [{'is_kwarg_only': 'False', 'name': 'input', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tol', 'simple_type': 'Tensor'}],
+    torch._C._linalg.linalg_multi_dot: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._linalg.linalg_multi_dot: [{'is_kwarg_only': 'False', 'name': 'tensors', 'simple_type': 'TensorList'}],
+    torch._C._special.special_entr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_entr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_ndtri: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_ndtri: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_log_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_log_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_psi: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_psi: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_gammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_gammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_erfcx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_erfcx: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_ndtr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlog1py: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_zeta: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch._C._special.special_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_i0e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_i0e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_i1e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_i1e: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._special.special_logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch._C._special.special_expit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_expit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._special.special_gammainc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_gammainc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_gammaincc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_gammaincc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch._C._special.special_multigammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}],
+    torch._C._special.special_multigammaln: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}],
+    torch._C._special.special_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch._C._special.special_airy_ai: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._special.special_airy_ai: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._special.special_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_bessel_j1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_bessel_j1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_bessel_y0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_bessel_y0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_bessel_y1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_bessel_y1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_hermite_polynomial_h: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_hermite_polynomial_he: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_laguerre_polynomial_l: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_legendre_polynomial_p: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_modified_bessel_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_modified_bessel_i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_modified_bessel_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_modified_bessel_i1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._special.special_scaled_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._special.special_scaled_modified_bessel_k0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._special.special_scaled_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._special.special_scaled_modified_bessel_k1: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_t: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_u: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_v: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Tensor'}],
+    torch._C._special.special_shifted_chebyshev_polynomial_w: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'Scalar'}],
+    torch._C._special.special_spherical_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._special.special_spherical_bessel_j0: [{'is_kwarg_only': 'False', 'name': 'x', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_fft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_fft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ifft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ifft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_rfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_rfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_irfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_irfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_hfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_hfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ihfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ihfft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_fft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_fft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ifft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ifft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_rfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_rfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_irfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_irfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_hfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_hfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ihfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ihfft2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_fftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_fftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ifftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ifftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_rfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_rfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_irfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_irfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_hfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_hfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ihfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ihfftn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_fftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch._C._fft.fft_fftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch._C._fft.fft_rfftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch._C._fft.fft_rfftfreq: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch._C._fft.fft_fftshift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch._C._fft.fft_ifftshift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.retain_grad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.rename_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch.Tensor.rename: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList?'}],
+    torch.Tensor.align_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}],
+    torch.Tensor.align_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'order', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'ellipsis_idx', 'simple_type': 'int64_t'}],
+    torch.Tensor.align_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.refine_names: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}],
+    torch.Tensor.abs: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.abs_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.absolute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.absolute_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.angle: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sgn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sgn_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.chalf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.conj_physical: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.conj_physical_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.resolve_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.resolve_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._neg_view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.acos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.acos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arccos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arccos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.addmv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}],
+    torch.Tensor.addmv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}],
+    torch.Tensor.addr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch.Tensor.addr_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch.Tensor._is_all_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._is_any_true: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.all: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.allclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.any: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.argmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.argmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.acosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.acosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arccosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arccosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.asinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.asinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arcsinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arcsinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.atanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.atanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arctanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arctanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.as_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.as_strided_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.asin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.asin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arcsin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arcsin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.atan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.atan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arctan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.arctan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.baddbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}],
+    torch.Tensor.baddbmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}],
+    torch.Tensor.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.bernoulli: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}],
+    torch.Tensor.bernoulli_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Tensor'}],
+    torch.Tensor.bernoulli_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.bincount: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_not_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.copysign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.copysign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.copysign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor._lazy_clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.logical_not: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.logical_not_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.logical_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.logical_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.logical_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.logical_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.logical_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.logical_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch.Tensor.broadcast_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.ceil: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.ceil_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.unsafe_chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}],
+    torch.Tensor.chunk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'chunks', 'simple_type': 'int64_t'}],
+    torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'SymInt'}],
+    torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.tensor_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor_indices_or_sections', 'simple_type': 'Tensor'}],
+    torch.Tensor.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.clamp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.clamp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}],
+    torch.Tensor.clamp_max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}],
+    torch.Tensor.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Scalar'}],
+    torch.Tensor.clamp_max_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'max', 'simple_type': 'Tensor'}],
+    torch.Tensor.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}],
+    torch.Tensor.clamp_min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}],
+    torch.Tensor.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Scalar'}],
+    torch.Tensor.clamp_min_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'min', 'simple_type': 'Tensor'}],
+    torch.Tensor.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.clip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.clip_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cos: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cos_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cosh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cosh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.count_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cov: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.corrcoef: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.cummax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.cummin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.cumprod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.cumprod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.cumprod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.cumsum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.cumsum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.cumsum_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.diag_embed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.diagflat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.diagonal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.fill_diagonal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}],
+    torch.Tensor.diff: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch.Tensor.div: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch.Tensor.div_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch.Tensor.divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch.Tensor.divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'rounding_mode', 'simple_type': 'c10::string_view?'}],
+    torch.Tensor.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.true_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.true_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.true_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.dot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}],
+    torch.Tensor.vdot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.new_empty: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.new_empty_strided: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.new_full: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'fill_value', 'simple_type': 'Scalar'}],
+    torch.Tensor.new_zeros: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.new_ones: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.resize_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.erf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.erf_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.erfc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.erfc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.exp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.exp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.exp2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.expm1: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.expm1_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.expand: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.expand_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'start_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'end_dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.flatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'DimnameList'}, {'is_kwarg_only': 'False', 'name': 'out_dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.unflatten: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'sizes', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'names', 'simple_type': 'DimnameList'}],
+    torch.Tensor.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch.Tensor.floor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.floor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.floor_divide: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.floor_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.floor_divide_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.frac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.frac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.gcd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.gcd_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.lcm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.lcm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_copy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'c10::List<::std::optional>'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch.Tensor.isclose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.isnan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_distributed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_floating_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_complex: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_conj: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._is_zerotensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.isreal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_nonzero: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_same_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_signed: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_inference: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.kron: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}],
+    torch.Tensor.kthvalue: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.nan_to_num: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.nan_to_num_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.ldexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.ldexp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.log: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.log_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.log10: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.log10_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.log1p: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.log1p_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.log2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.log2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.logaddexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.logaddexp2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.xlogy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.xlogy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.log_softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.logcumsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch.Tensor.logsumexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch.Tensor.matmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.matrix_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch.Tensor.matrix_exp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.aminmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.max: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.amax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch.Tensor.mean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch.Tensor.nanmean: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.median: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.nanmedian: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.min: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.amin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.mm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch.Tensor.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.mode: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.mul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.mul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.multiply: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.multiply_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.multiply_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.mv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec', 'simple_type': 'Tensor'}],
+    torch.Tensor.mvlgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}],
+    torch.Tensor.mvlgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'int64_t'}],
+    torch.Tensor.narrow_copy: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}],
+    torch.Tensor.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'SymInt'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}],
+    torch.Tensor.narrow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'start', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'length', 'simple_type': 'SymInt'}],
+    torch.Tensor.permute: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.movedim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}],
+    torch.Tensor.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.moveaxis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'destination', 'simple_type': 'int64_t'}],
+    torch.Tensor.adjoint: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_pinned: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.pin_memory: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.pinverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.rad2deg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.rad2deg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.deg2rad: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.deg2rad_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.ravel: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.reciprocal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.reciprocal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.neg: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.neg_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.negative: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.negative_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.repeat: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'Tensor'}],
+    torch.Tensor.repeat_interleave: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'repeats', 'simple_type': 'SymInt'}],
+    torch.Tensor.reshape: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shape', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.reshape_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.round: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.round_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.relu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.relu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.prelu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch.Tensor.hardshrink: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.rsqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.rsqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'int64_t'}],
+    torch.Tensor.select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}],
+    torch.Tensor.sigmoid: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sigmoid_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.logit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.logit_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sin_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sinc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sinc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sinh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sinh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.detach: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.detach_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.slice_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.slice_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.select_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'SymInt'}],
+    torch.Tensor.diagonal_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.as_strided_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'stride', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.smm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch.Tensor.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.softmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.unsafe_split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}],
+    torch.Tensor.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymInt'}],
+    torch.Tensor.split: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.unsafe_split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.split_with_sizes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'split_sizes', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}],
+    torch.Tensor.hsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}],
+    torch.Tensor.vsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sections', 'simple_type': 'int64_t'}],
+    torch.Tensor.dsplit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.squeeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.squeeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.sspaddmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch.Tensor.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}],
+    torch.Tensor.stft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}],
+    torch.Tensor.istft: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n_fft', 'simple_type': 'int64_t'}],
+    torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch.Tensor.sum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch.Tensor.nansum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sum_to_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.sqrt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sqrt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.square: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.square_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch.Tensor.std: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.prod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.t: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.t_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.tan: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.tan_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.tanh: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.tanh_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.tile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch.Tensor.transpose: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'Dimname'}],
+    torch.Tensor.transpose_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch.Tensor.flip: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dims', 'simple_type': 'IntArrayRef'}],
+    torch.Tensor.fliplr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.flipud: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.roll: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'shifts', 'simple_type': 'SymIntArrayRef', 'size': 1}],
+    torch.Tensor.rot90: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._nested_tensor_size: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._nested_tensor_strides: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._nested_tensor_storage_offsets: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.trunc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.trunc_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.fix: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.fix_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.type_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.unsqueeze: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.unsqueeze_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef?', 'size': 1}],
+    torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch.Tensor.var: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch.Tensor.view_as: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.where: [{'is_kwarg_only': 'False', 'name': 'condition', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'IntArrayRef', 'size': 1}],
+    torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}, {'is_kwarg_only': 'False', 'name': 'keepdim', 'simple_type': 'bool'}, {'is_kwarg_only': 'True', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch.Tensor.norm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar?'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'DimnameList', 'size': 1}],
+    torch.Tensor.frexp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.clone: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.positive: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.resize_as_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}],
+    torch.Tensor.resize_as_sparse_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'the_template', 'simple_type': 'Tensor'}],
+    torch.Tensor.zero_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sub: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.sub_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.subtract: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.subtract_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.subtract_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.heaviside: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch.Tensor.heaviside_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'values', 'simple_type': 'Tensor'}],
+    torch.Tensor.addmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch.Tensor.addmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch.Tensor._addmm_activation: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mat2', 'simple_type': 'Tensor'}],
+    torch.Tensor.sparse_resize_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dense_dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.sparse_resize_and_clear_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'IntArrayRef'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dense_dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.sparse_mask: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}],
+    torch.Tensor._sparse_mask_projection: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}],
+    torch.Tensor.to_dense: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._to_dense: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sparse_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._dimI: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.dense_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._dimV: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._nnz: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.coalesce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.is_coalesced: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._coalesced_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'coalesced', 'simple_type': 'bool'}],
+    torch.Tensor.indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.values: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.crow_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.col_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.ccol_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.row_indices: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.unbind: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}],
+    torch.Tensor.to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'sparse_dim', 'simple_type': 'int64_t'}],
+    torch.Tensor._to_sparse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.to_sparse_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._to_sparse_csr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.to_sparse_csc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._to_sparse_csc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.to_sparse_bsr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch.Tensor._to_sparse_bsr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch.Tensor.to_sparse_bsc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch.Tensor._to_sparse_bsc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'blocksize', 'simple_type': 'IntArrayRef', 'size': 2}],
+    torch.Tensor.to_mkldnn: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.dequantize: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.q_scale: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.q_zero_point: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.q_per_channel_scales: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.q_per_channel_zero_points: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.q_per_channel_axis: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.int_repr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.qscheme: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor._autocast_to_reduced_precision: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cuda_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cpu_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cuda_dtype', 'simple_type': 'ScalarType'}, {'is_kwarg_only': 'False', 'name': 'cpu_dtype', 'simple_type': 'ScalarType'}],
+    torch.Tensor._autocast_to_full_precision: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'cuda_enabled', 'simple_type': 'bool'}, {'is_kwarg_only': 'False', 'name': 'cpu_enabled', 'simple_type': 'bool'}],
+    torch.Tensor.is_set_to: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor', 'simple_type': 'Tensor'}],
+    torch.Tensor.masked_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.masked_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch.Tensor.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.masked_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch.Tensor.masked_scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.masked_scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'SymIntArrayRef'}],
+    torch.Tensor.view: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dtype', 'simple_type': 'ScalarType'}],
+    torch.Tensor.put_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.put: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_reduce_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch.Tensor.index_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'source', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.index_fill_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.index_fill: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Tensor'}],
+    torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'True', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.scatter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.scatter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'value', 'simple_type': 'Scalar'}],
+    torch.Tensor.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.scatter_add: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.scatter_add_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}],
+    torch.Tensor.scatter_reduce: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch.Tensor.scatter_reduce_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'src', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'reduce', 'simple_type': 'c10::string_view'}],
+    torch.Tensor.eq_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.eq_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.bitwise_and: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.bitwise_and_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__and__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__iand__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__iand__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.bitwise_or: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.bitwise_or_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__or__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__ior__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__ior__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.bitwise_xor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.bitwise_xor_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__xor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__ixor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__ixor__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__lshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__ilshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__ilshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_left_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.bitwise_left_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_left_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__rshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.__irshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.__irshift__: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_right_shift: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.bitwise_right_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.bitwise_right_shift_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.tril_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.triu_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.digamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}],
+    torch.Tensor.lerp_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch.Tensor.addbmm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}],
+    torch.Tensor.addbmm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'batch2', 'simple_type': 'Tensor'}],
+    torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'from', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'int64_t?'}],
+    torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'to', 'simple_type': 'int64_t'}],
+    torch.Tensor.random_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.uniform_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cauchy_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.log_normal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.exponential_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.geometric_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'double'}],
+    torch.Tensor.diag: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cross: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.triu: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.tril: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.trace: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.ne: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.ne_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.ne_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.not_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.not_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.not_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.eq: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.ge: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.ge_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.ge_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.greater_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.greater_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.greater_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.le: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.le_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.le_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.less_equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.less_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.less_equal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.gt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.gt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.gt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.greater: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.greater_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.greater_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.lt: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.lt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.lt_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.less: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.less_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.less_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.take: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch.Tensor.take_along_dim: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'indices', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch.Tensor.index_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch.Tensor.masked_select: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'mask', 'simple_type': 'Tensor'}],
+    torch.Tensor.nonzero_static: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'size', 'simple_type': 'SymInt'}],
+    torch.Tensor.argwhere: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch.Tensor.gather: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}, {'is_kwarg_only': 'False', 'name': 'index', 'simple_type': 'Tensor'}],
+    torch.Tensor.addcmul: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}],
+    torch.Tensor.addcmul_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}],
+    torch.Tensor.addcdiv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}],
+    torch.Tensor.addcdiv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor1', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'tensor2', 'simple_type': 'Tensor'}],
+    torch.Tensor.triangular_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'A', 'simple_type': 'Tensor'}],
+    torch.Tensor.svd: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.swapaxes: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}],
+    torch.Tensor.swapaxes_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'axis0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'axis1', 'simple_type': 'int64_t'}],
+    torch.Tensor.swapdims: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch.Tensor.swapdims_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim0', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'dim1', 'simple_type': 'int64_t'}],
+    torch.Tensor.cholesky: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.cholesky_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}],
+    torch.Tensor.cholesky_inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.qr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.geqrf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.orgqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}],
+    torch.Tensor.ormqr: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input2', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'input3', 'simple_type': 'Tensor'}],
+    torch.Tensor.lu_solve: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_data', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'LU_pivots', 'simple_type': 'Tensor'}],
+    torch.Tensor.multinomial: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'num_samples', 'simple_type': 'SymInt'}],
+    torch.Tensor.lgamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.lgamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.digamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.polygamma: [{'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.polygamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'n', 'simple_type': 'int64_t'}],
+    torch.Tensor.erfinv: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.erfinv_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.i0: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.i0_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sign: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sign_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.signbit: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.dist: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.atan2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.atan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.arctan2: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.arctan2_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Scalar'}],
+    torch.Tensor.lerp: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'end', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'weight', 'simple_type': 'Tensor'}],
+    torch.Tensor.histc: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'bins', 'simple_type': 'Tensor'}],
+    torch.Tensor.histogram: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.fmod: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.fmod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.fmod_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.hypot: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.hypot_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.igamma: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.igamma_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.igammac: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.igammac_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.nextafter: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.nextafter_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.remainder: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.remainder_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Scalar'}],
+    torch.Tensor.remainder_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.fmin: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.fmax: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.maximum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.minimum: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}],
+    torch.Tensor.quantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}],
+    torch.Tensor.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'Tensor'}],
+    torch.Tensor.nanquantile: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'q', 'simple_type': 'double'}],
+    torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}],
+    torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.sort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool?'}, {'is_kwarg_only': 'True', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.msort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'True', 'name': 'stable', 'simple_type': 'bool'}],
+    torch.Tensor.argsort: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'Dimname'}],
+    torch.Tensor.topk: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'k', 'simple_type': 'SymInt'}],
+    torch.Tensor.renorm: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}],
+    torch.Tensor.renorm_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'p', 'simple_type': 'Scalar'}, {'is_kwarg_only': 'False', 'name': 'dim', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'maxnorm', 'simple_type': 'Scalar'}],
+    torch.Tensor.unfold: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'dimension', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'size', 'simple_type': 'int64_t'}, {'is_kwarg_only': 'False', 'name': 'step', 'simple_type': 'int64_t'}],
+    torch.Tensor.equal: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch.Tensor.pow: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch.Tensor.pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch.Tensor.pow_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch.Tensor.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch.Tensor.float_power: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch.Tensor.float_power_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Scalar'}],
+    torch.Tensor.float_power_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'exponent', 'simple_type': 'Tensor'}],
+    torch.Tensor.normal_: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.isfinite: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.isinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.record_stream: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 's', 'simple_type': 'Stream'}],
+    torch.Tensor.isposinf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.isneginf: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.det: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.slogdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.logdet: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.inverse: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
+    torch.Tensor.inner: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'other', 'simple_type': 'Tensor'}],
+    torch.Tensor.outer: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch.Tensor.ger: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'vec2', 'simple_type': 'Tensor'}],
+    torch.Tensor.to_padded_tensor: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}, {'is_kwarg_only': 'False', 'name': 'padding', 'simple_type': 'double'}],
+}
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..50462859019c4e6812a6f6b12a3625bf8b2b3834
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/hop_db.py
@@ -0,0 +1,432 @@
+# mypy: ignore-errors
+
+import functools
+import unittest
+
+import torch
+from functorch.experimental.control_flow import map
+from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import onlyCUDA
+from torch.testing._internal.common_dtype import all_types_and, custom_types
+from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput
+from torch._higher_order_ops.invoke_subgraph import mark_compile_region
+from torch._higher_order_ops import InvokeQuant, invoke_quant_packed
+
+
+def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(
+        [make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
+        args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)),
+    )
+
+
+def inner_f(x, y0, y1):
+    return [x[0].cos().add_(1.0) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
+
+
+def simple_map(xs, y0, y1):
+    def f(x, y0, y1):
+        return inner_f(x, y0, y1)
+
+    return map(f, xs, y0, y1)
+
+
+def nested_map(xs, y0, y1):
+    def f1(xx, y0, y1):
+        def f2(x, y0, y1):
+            return inner_f(x, y0, y1)
+
+        return map(f2, xx, y0, y1)
+
+    return map(f1, xs, y0, y1)
+
+
+def triple_nested_map(xs, y0, y1):
+    def f0(xs, y0, y1):
+        def f1(xx, y0, y1):
+            def f2(x, y0, y1):
+                return inner_f(x, y0, y1)
+
+            return map(f2, xx, y0, y1)
+
+        return map(f1, xs, y0, y1)
+
+    return map(f0, xs, y0, y1)
+
+
+# PLEASE DON'T ADD ANYTHING NEW TO THIS LIST,
+# and do add an OpInfo for your HOP.
+# The OpInfo lets us do automated testing for the HOP to check that
+# your HOP will work correctly with PyTorch!
+#
+# Your new HOP may fail some automated testing. That's OK. If you don't
+# care about certain features (like torch.export), it's fine to xfail those
+# failing tests. It is less fine to xfail a more critical check (like checking
+# if torch.compile works with your HOP, or if your HOP has a docstring).
+# If you don't know if a test is fine to xfail, please ask.
+#
+# There are legitimate reasons why something cannot be added to this list
+# (e.g. it uses executorch which is not in PyTorch). If that's the case then
+# please leave a comment.
+FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
+    "custom_function_call",
+    "autograd_function_apply",
+    "run_and_save_rng_state",
+    "run_with_rng_state",
+    "graphsafe_run_with_rng_state",
+    "out_dtype",
+    "trace_wrapped",
+    'tag_activation_checkpoint',
+    'executorch_call_delegate',
+    'wrap',
+    'wrap_with_set_grad_enabled',
+    'auto_functionalized_v2',
+    'associative_scan',
+    'flat_apply',  # is WIP, doesn't pass any of the tests yet
+    'wrap_with_autocast',
+    'wrap_activation_checkpoint',
+    'run_const_graph',
+    'auto_functionalized',
+    "map",  # T183144629
+    "map_impl",
+    "with_effects",
+    "strict_mode",
+    "_export_tracepoint",
+    "call_torchbind",
+    "triton_kernel_wrapper_mutation",
+    "triton_kernel_wrapper_functional",
+    "hints_wrapper",
+    "dynamo_bypassing_wrapper",  # TODO(soulitzer)
+    "foreach_map",
+    "aoti_call_delegate",
+]
+
+torch.library.define(
+    "testlib::mutating_custom_op",
+    "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
+    tags=torch.Tag.pt2_compliant_tag,
+)
+
+
+@torch.library.impl("testlib::mutating_custom_op", "cpu")
+def foo_impl_cpu(x, z):
+    x.add_(5)
+    z.add_(5)
+    return x, z, x + z
+
+
+@torch.library.impl("testlib::mutating_custom_op", "cuda")
+def foo_impl_cuda(x, z):
+    x.add_(5)
+    z.add_(5)
+    return x, z, x + z
+
+
+@torch.library.register_fake("testlib::mutating_custom_op")
+def foo_impl_abstract(x, z):
+    return x, z, x + z
+
+
+def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
+
+
+def simple_cond(x):
+    return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])
+
+
+def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
+
+
+@mark_compile_region
+def fn_for_invoke_subgraph(x):
+    return torch.sin(x)
+
+def simple_invoke_subgraph(x):
+    return fn_for_invoke_subgraph(x)
+
+
+def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=False
+    )
+    yield SampleInput(
+        make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)
+    )
+
+
+def simple_auto_functionalize(x, z):
+    return torch.ops.testlib.mutating_custom_op(x, z)
+
+
+def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    def score_mod(score, b, h, m, n):
+        return score + h
+
+    q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3))
+    block_mask = _create_empty_block_mask(q, k)
+    yield SampleInput(q, k, v, score_mod, block_mask)
+
+
+def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=False
+    )
+    yield SampleInput(
+        torch.tensor(3),
+        make_arg(2, 3, 4, low=0.1, high=2),
+    )
+
+
+def simple_while_loop(iter_t, x):
+    def cond_fn(iter_t, x):
+        return iter_t > 0
+
+    def body_fn(iter_t, x):
+        return iter_t - 1, x.cos()
+
+    return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
+
+
+def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
+    make_arg = functools.partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    yield SampleInput(
+        make_arg(2, 2, low=0.1, high=2),
+        make_arg(2, 2, 2, low=0.1, high=2),
+    )
+
+
+def simple_scan(init, xs):
+
+    def combine_fn(carry, x):
+        result = carry @ x + x
+        return result, carry.clone()
+
+    return torch._higher_order_ops.scan(combine_fn, init, xs)
+
+
+quant_tracer = InvokeQuant()
+
+
+def simple_invoke_quant(x):
+    def fn(x, y):
+        return (torch.sin(x) * y,)
+
+    return quant_tracer(fn, x, x)[0] * 2.
+
+
+def simple_invoke_quant_packed(x):
+    def fn(x):
+        return (torch.sin(x),)
+
+    return invoke_quant_packed(fn, x)[0] * 2.
+
+
+
+hop_db = [
+    OpInfo(
+        name="scan",
+        variant_test_name="simple",
+        op=simple_scan,
+        sample_inputs_func=sample_inputs_scan,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_subgraph",
+        variant_test_name="simple",
+        op=simple_invoke_subgraph,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="simple",
+        op=simple_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="nested",
+        op=nested_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="map",
+        variant_test_name="triple_nested",
+        op=triple_nested_map,
+        sample_inputs_func=sample_inputs_map,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+    ),
+    OpInfo(
+        name="cond",
+        variant_test_name="simple",
+        op=simple_cond,
+        sample_inputs_func=sample_inputs_cond,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_quant",
+        variant_test_name="simple",
+        op=simple_invoke_quant,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="invoke_quant_packed",
+        variant_test_name="simple",
+        op=simple_invoke_quant_packed,
+        sample_inputs_func=sample_inputs_invoke_subgraph,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=True,
+        # "torch.compile with aot_autograd does not currently support double backward."
+        supports_gradgrad=False,
+    ),
+    OpInfo(
+        name="while_loop",
+        variant_test_name="simple",
+        op=simple_while_loop,
+        sample_inputs_func=sample_inputs_while_loop,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+    ),
+    OpInfo(
+        name="auto_functionalize",
+        variant_test_name="simple",
+        op=simple_auto_functionalize,
+        sample_inputs_func=sample_inputs_auto_functionalize,
+        dtypes=all_types_and(torch.bool, torch.half),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        supports_autograd=False,
+    ),
+    OpInfo(
+        name="flex_attention",
+        variant_test_name="simple",
+        op=flex_attention,
+        sample_inputs_func=sample_inputs_flex_attention,
+        dtypes=custom_types(torch.float16, torch.float32),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        decorators=[onlyCUDA],
+    ),
+    OpInfo(
+        name="flex_attention_backward",
+        variant_test_name="simple",
+        op=flex_attention,
+        sample_inputs_func=sample_inputs_flex_attention,
+        dtypes=custom_types(torch.float16, torch.float32),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        check_inplace_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
+            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
+        ),
+        decorators=[onlyCUDA],
+    ),
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02ef4c9e04b0026dec01c54b1d616d0764d854a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/hypothesis_utils.py
@@ -0,0 +1,367 @@
+# mypy: ignore-errors
+
+from collections import defaultdict
+from collections.abc import Iterable
+import numpy as np
+import torch
+
+import hypothesis
+from functools import reduce
+from hypothesis import assume
+from hypothesis import settings
+from hypothesis import strategies as st
+from hypothesis.extra import numpy as stnp
+from hypothesis.strategies import SearchStrategy
+
+from torch.testing._internal.common_quantized import _calculate_dynamic_qparams, _calculate_dynamic_per_channel_qparams
+
+# Setup for the hypothesis tests.
+# The tuples are (torch_quantized_dtype, zero_point_enforce), where the last
+# element is enforced zero_point. If None, any zero_point point within the
+# range of the data type is OK.
+
+# Tuple with all quantized data types.
+_ALL_QINT_TYPES = (
+    torch.quint8,
+    torch.qint8,
+    torch.qint32,
+)
+
+# Enforced zero point for every quantized data type.
+# If None, any zero_point point within the range of the data type is OK.
+_ENFORCED_ZERO_POINT = defaultdict(lambda: None, {
+    torch.quint8: None,
+    torch.qint8: None,
+    torch.qint32: 0
+})
+
+def _get_valid_min_max(qparams):
+    scale, zero_point, _quantized_type = qparams
+    adjustment = 1 + torch.finfo(torch.float).eps
+    _long_type_info = torch.iinfo(torch.long)
+    long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment
+    # make sure intermediate results are within the range of long
+    min_value = max((long_min - zero_point) * scale, (long_min / scale + zero_point))
+    max_value = min((long_max - zero_point) * scale, (long_max / scale + zero_point))
+    return np.float32(min_value), np.float32(max_value)
+
+# This wrapper wraps around `st.floats` and checks the version of `hypothesis`, if
+# it is too old, removes the `width` parameter (which was introduced)
+# in 3.67.0
+def _floats_wrapper(*args, **kwargs):
+    if 'width' in kwargs and hypothesis.version.__version_info__ < (3, 67, 0):
+        # As long as nan, inf, min, max are not specified, reimplement the width
+        # parameter for older versions of hypothesis.
+        no_nan_and_inf = (
+            (('allow_nan' in kwargs and not kwargs['allow_nan']) or
+             'allow_nan' not in kwargs) and
+            (('allow_infinity' in kwargs and not kwargs['allow_infinity']) or
+             'allow_infinity' not in kwargs))
+        min_and_max_not_specified = (
+            len(args) == 0 and
+            'min_value' not in kwargs and
+            'max_value' not in kwargs
+        )
+        if no_nan_and_inf and min_and_max_not_specified:
+            if kwargs['width'] == 16:
+                kwargs['min_value'] = torch.finfo(torch.float16).min
+                kwargs['max_value'] = torch.finfo(torch.float16).max
+            elif kwargs['width'] == 32:
+                kwargs['min_value'] = torch.finfo(torch.float32).min
+                kwargs['max_value'] = torch.finfo(torch.float32).max
+            elif kwargs['width'] == 64:
+                kwargs['min_value'] = torch.finfo(torch.float64).min
+                kwargs['max_value'] = torch.finfo(torch.float64).max
+        kwargs.pop('width')
+    return st.floats(*args, **kwargs)
+
+def floats(*args, **kwargs):
+    if 'width' not in kwargs:
+        kwargs['width'] = 32
+    return _floats_wrapper(*args, **kwargs)
+
+"""Hypothesis filter to avoid overflows with quantized tensors.
+
+Args:
+    tensor: Tensor of floats to filter
+    qparams: Quantization parameters as returned by the `qparams`.
+
+Returns:
+    True
+
+Raises:
+    hypothesis.UnsatisfiedAssumption
+
+Note: This filter is slow. Use it only when filtering of the test cases is
+      absolutely necessary!
+"""
+def assume_not_overflowing(tensor, qparams):
+    min_value, max_value = _get_valid_min_max(qparams)
+    assume(tensor.min() >= min_value)
+    assume(tensor.max() <= max_value)
+    return True
+
+"""Strategy for generating the quantization parameters.
+
+Args:
+    dtypes: quantized data types to sample from.
+    scale_min / scale_max: Min and max scales. If None, set to 1e-3 / 1e3.
+    zero_point_min / zero_point_max: Min and max for the zero point. If None,
+        set to the minimum and maximum of the quantized data type.
+        Note: The min and max are only valid if the zero_point is not enforced
+              by the data type itself.
+
+Generates:
+    scale: Sampled scale.
+    zero_point: Sampled zero point.
+    quantized_type: Sampled quantized type.
+"""
+@st.composite
+def qparams(draw, dtypes=None, scale_min=None, scale_max=None,
+            zero_point_min=None, zero_point_max=None):
+    if dtypes is None:
+        dtypes = _ALL_QINT_TYPES
+    if not isinstance(dtypes, (list, tuple)):
+        dtypes = (dtypes,)
+    quantized_type = draw(st.sampled_from(dtypes))
+
+    _type_info = torch.iinfo(quantized_type)
+    qmin, qmax = _type_info.min, _type_info.max
+
+    # TODO: Maybe embed the enforced zero_point in the `torch.iinfo`.
+    _zp_enforced = _ENFORCED_ZERO_POINT[quantized_type]
+    if _zp_enforced is not None:
+        zero_point = _zp_enforced
+    else:
+        _zp_min = qmin if zero_point_min is None else zero_point_min
+        _zp_max = qmax if zero_point_max is None else zero_point_max
+        zero_point = draw(st.integers(min_value=_zp_min, max_value=_zp_max))
+
+    if scale_min is None:
+        scale_min = torch.finfo(torch.float).eps
+    if scale_max is None:
+        scale_max = torch.finfo(torch.float).max
+    scale = draw(floats(min_value=scale_min, max_value=scale_max, width=32))
+
+    return scale, zero_point, quantized_type
+
+"""Strategy to create different shapes.
+Args:
+    min_dims / max_dims: minimum and maximum rank.
+    min_side / max_side: minimum and maximum dimensions per rank.
+
+Generates:
+    Possible shapes for a tensor, constrained to the rank and dimensionality.
+
+Example:
+    # Generates 3D and 4D tensors.
+    @given(Q = qtensor(shapes=array_shapes(min_dims=3, max_dims=4))
+    some_test(self, Q):...
+"""
+@st.composite
+def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
+    """Return a strategy for array shapes (tuples of int >= 1)."""
+    assert min_dims < 32
+    if max_dims is None:
+        max_dims = min(min_dims + 2, 32)
+    assert max_dims < 32
+    if max_side is None:
+        max_side = min_side + 5
+    candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
+    if max_numel is not None:
+        candidate = candidate.filter(lambda x: reduce(int.__mul__, x, 1) <= max_numel)
+    return draw(candidate.map(tuple))
+
+
+"""Strategy for generating test cases for tensors.
+The resulting tensor is in float32 format.
+
+Args:
+    shapes: Shapes under test for the tensor. Could be either a hypothesis
+            strategy, or an iterable of different shapes to sample from.
+    elements: Elements to generate from for the returned data type.
+              If None, the strategy resolves to float within range [-1e6, 1e6].
+    qparams: Instance of the qparams strategy. This is used to filter the tensor
+             such that the overflow would not happen.
+
+Generates:
+    X: Tensor of type float32. Note that NaN and +/-inf is not included.
+    qparams: (If `qparams` arg is set) Quantization parameters for X.
+        The returned parameters are `(scale, zero_point, quantization_type)`.
+        (If `qparams` arg is None), returns None.
+"""
+@st.composite
+def tensor(draw, shapes=None, elements=None, qparams=None, dtype=np.float32):
+    if isinstance(shapes, SearchStrategy):
+        _shape = draw(shapes)
+    else:
+        _shape = draw(st.sampled_from(shapes))
+    if qparams is None:
+        if elements is None:
+            elements = floats(-1e6, 1e6, allow_nan=False, width=32)
+        X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
+        assume(not (np.isnan(X).any() or np.isinf(X).any()))
+        return X, None
+    qparams = draw(qparams)
+    if elements is None:
+        min_value, max_value = _get_valid_min_max(qparams)
+        elements = floats(min_value, max_value, allow_infinity=False,
+                          allow_nan=False, width=32)
+    X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
+    # Recompute the scale and zero_points according to the X statistics.
+    scale, zp = _calculate_dynamic_qparams(X, qparams[2])
+    enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
+    if enforced_zp is not None:
+        zp = enforced_zp
+    return X, (scale, zp, qparams[2])
+
+@st.composite
+def per_channel_tensor(draw, shapes=None, elements=None, qparams=None):
+    if isinstance(shapes, SearchStrategy):
+        _shape = draw(shapes)
+    else:
+        _shape = draw(st.sampled_from(shapes))
+    if qparams is None:
+        if elements is None:
+            elements = floats(-1e6, 1e6, allow_nan=False, width=32)
+        X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
+        assume(not (np.isnan(X).any() or np.isinf(X).any()))
+        return X, None
+    qparams = draw(qparams)
+    if elements is None:
+        min_value, max_value = _get_valid_min_max(qparams)
+        elements = floats(min_value, max_value, allow_infinity=False,
+                          allow_nan=False, width=32)
+    X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
+    # Recompute the scale and zero_points according to the X statistics.
+    scale, zp = _calculate_dynamic_per_channel_qparams(X, qparams[2])
+    enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
+    if enforced_zp is not None:
+        zp = enforced_zp
+    # Permute to model quantization along an axis
+    axis = int(np.random.randint(0, X.ndim, 1))
+    permute_axes = np.arange(X.ndim)
+    permute_axes[0] = axis
+    permute_axes[axis] = 0
+    X = np.transpose(X, permute_axes)
+
+    return X, (scale, zp, axis, qparams[2])
+
+"""Strategy for generating test cases for tensors used in Conv.
+The resulting tensors is in float32 format.
+
+Args:
+    spatial_dim: Spatial Dim for feature maps. If given as an iterable, randomly
+                 picks one from the pool to make it the spatial dimension
+    batch_size_range: Range to generate `batch_size`.
+                      Must be tuple of `(min, max)`.
+    input_channels_per_group_range:
+        Range to generate `input_channels_per_group`.
+        Must be tuple of `(min, max)`.
+    output_channels_per_group_range:
+        Range to generate `output_channels_per_group`.
+        Must be tuple of `(min, max)`.
+    feature_map_range: Range to generate feature map size for each spatial_dim.
+                       Must be tuple of `(min, max)`.
+    kernel_range: Range to generate kernel size for each spatial_dim. Must be
+                  tuple of `(min, max)`.
+    max_groups: Maximum number of groups to generate.
+    elements: Elements to generate from for the returned data type.
+              If None, the strategy resolves to float within range [-1e6, 1e6].
+    qparams: Strategy for quantization parameters. for X, w, and b.
+             Could be either a single strategy (used for all) or a list of
+             three strategies for X, w, b.
+Generates:
+    (X, W, b, g): Tensors of type `float32` of the following drawen shapes:
+        X: (`batch_size, input_channels, H, W`)
+        W: (`output_channels, input_channels_per_group) + kernel_shape
+        b: `(output_channels,)`
+        groups: Number of groups the input is divided into
+Note: X, W, b are tuples of (Tensor, qparams), where qparams could be either
+      None or (scale, zero_point, quantized_type)
+
+
+Example:
+    @given(tensor_conv(
+        spatial_dim=2,
+        batch_size_range=(1, 3),
+        input_channels_per_group_range=(1, 7),
+        output_channels_per_group_range=(1, 7),
+        feature_map_range=(6, 12),
+        kernel_range=(3, 5),
+        max_groups=4,
+        elements=st.floats(-1.0, 1.0),
+        qparams=qparams()
+    ))
+"""
+@st.composite
+def tensor_conv(
+    draw, spatial_dim=2, batch_size_range=(1, 4),
+    input_channels_per_group_range=(3, 7),
+    output_channels_per_group_range=(3, 7), feature_map_range=(6, 12),
+    kernel_range=(3, 7), max_groups=1, can_be_transposed=False,
+    elements=None, qparams=None
+):
+
+    # Resolve the minibatch, in_channels, out_channels, iH/iW, iK/iW
+    batch_size = draw(st.integers(*batch_size_range))
+    input_channels_per_group = draw(
+        st.integers(*input_channels_per_group_range))
+    output_channels_per_group = draw(
+        st.integers(*output_channels_per_group_range))
+    groups = draw(st.integers(1, max_groups))
+    input_channels = input_channels_per_group * groups
+    output_channels = output_channels_per_group * groups
+
+    if isinstance(spatial_dim, Iterable):
+        spatial_dim = draw(st.sampled_from(spatial_dim))
+
+    feature_map_shape = [draw(st.integers(*feature_map_range)) for _ in range(spatial_dim)]
+
+    kernels = [draw(st.integers(*kernel_range)) for _ in range(spatial_dim)]
+
+    tr = False
+    weight_shape = (output_channels, input_channels_per_group) + tuple(kernels)
+    bias_shape = output_channels
+    if can_be_transposed:
+        tr = draw(st.booleans())
+        if tr:
+            weight_shape = (input_channels, output_channels_per_group) + tuple(kernels)
+            bias_shape = output_channels
+
+    # Resolve the tensors
+    if qparams is not None:
+        if isinstance(qparams, (list, tuple)):
+            assert len(qparams) == 3, "Need 3 qparams for X, w, b"
+        else:
+            qparams = [qparams] * 3
+
+    X = draw(tensor(shapes=(
+        (batch_size, input_channels) + tuple(feature_map_shape),),
+        elements=elements, qparams=qparams[0]))
+    W = draw(tensor(shapes=(weight_shape,), elements=elements,
+                    qparams=qparams[1]))
+    b = draw(tensor(shapes=(bias_shape,), elements=elements,
+                    qparams=qparams[2]))
+
+    return X, W, b, groups, tr
+
+# We set the deadline in the currently loaded profile.
+# Creating (and loading) a separate profile overrides any settings the user
+# already specified.
+hypothesis_version = hypothesis.version.__version_info__
+current_settings = settings._profiles[settings._current_profile].__dict__
+current_settings['deadline'] = None
+if hypothesis_version >= (3, 16, 0) and hypothesis_version < (5, 0, 0):
+    current_settings['timeout'] = hypothesis.unlimited
+def assert_deadline_disabled():
+    if hypothesis_version < (3, 27, 0):
+        import warnings
+        warning_message = (
+            "Your version of hypothesis is outdated. "
+            "To avoid `DeadlineExceeded` errors, please update. "
+            f"Current hypothesis version: {hypothesis.__version__}"
+        )
+        warnings.warn(warning_message)
+    else:
+        assert settings().deadline is None
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..91a4aaa5728a824e93e37f6bb58a99597a302cdd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/inductor_utils.py
@@ -0,0 +1,343 @@
+# mypy: ignore-errors
+
+import logging
+import torch
+import re
+import unittest
+import functools
+import contextlib
+import os
+from subprocess import CalledProcessError
+import sys
+import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch._inductor.graph import GraphLowering
+from torch._inductor.compile_fx import shape_env_from_inputs
+from torch._inductor.codecache import CppCodeCache
+from torch._inductor.custom_graph_pass import CustomGraphModulePass
+from torch._inductor.codegen.common import (
+    get_custom_backend_pass_for_device,
+    get_scheduling_for_device,
+    get_wrapper_codegen_for_device,
+    init_backend_registration,
+    register_backend_for_device
+)
+from torch._inductor.codegen.wrapper import PythonWrapperCodegen
+from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
+from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu
+from torch.utils._helion import has_helion
+from torch.utils._triton import has_triton
+from torch.testing._internal.common_device_type import (
+    get_desired_device_type_test_bases,
+)
+from torch.testing._internal.common_utils import (
+    LazyVal,
+    IS_FBCODE,
+)
+from torch.testing._internal.common_utils import (
+    TestCase,
+    IS_CI,
+    IS_WINDOWS,
+)
+
+log: logging.Logger = logging.getLogger(__name__)
+
+def test_cpu():
+    try:
+        CppCodeCache.load("")
+        return not IS_FBCODE
+    except (
+        CalledProcessError,
+        OSError,
+        torch._inductor.exc.InvalidCxxCompiler,
+        torch._inductor.exc.CppCompileError,
+    ):
+        return False
+
+HAS_CPU = LazyVal(test_cpu)
+
+HAS_TRITON = has_triton()
+
+HAS_HELION = has_helion()
+
+if HAS_TRITON:
+    import triton
+    TRITON_HAS_CPU = "cpu" in triton.backends.backends
+else:
+    TRITON_HAS_CPU = False
+
+
+HAS_CUDA = torch.cuda.is_available() and HAS_TRITON
+
+HAS_XPU = torch.xpu.is_available() and HAS_TRITON
+
+HAS_MPS = torch.mps.is_available()
+
+HAS_GPU = HAS_CUDA or HAS_XPU
+
+GPU_TYPE = get_gpu_type()
+
+HAS_MULTIGPU = any(
+    getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2
+    for gpu in GPU_TYPES
+)
+
+_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True)
+RUN_GPU = (
+    HAS_GPU
+    and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases)
+)
+
+RUN_CPU = (
+    HAS_CPU
+    and any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
+)
+
+def _check_has_dynamic_shape(
+    self: TestCase,
+    code,
+):
+    for_loop_found = False
+    has_dynamic = False
+    lines = code.split("\n")
+    for line in lines:
+        if "for(" in line:
+            for_loop_found = True
+            if re.search(r";.*ks.*;", line) is not None:
+                has_dynamic = True
+                break
+    self.assertTrue(
+        has_dynamic, msg=f"Failed to find dynamic for loop variable\n{code}"
+    )
+    self.assertTrue(for_loop_found, f"Failed to find for loop\n{code}")
+
+
+def skipDeviceIf(cond, msg, *, device):
+    if cond:
+        def decorate_fn(fn):
+            @functools.wraps(fn)
+            def inner(self, *args, **kwargs):
+                if not hasattr(self, "device"):
+                    warn_msg = "Expect the test class to have attribute device but not found. "
+                    if hasattr(self, "device_type"):
+                        warn_msg += "Consider using the skip device decorators in common_device_type.py"
+                    log.warning(warn_msg)
+                if self.device == device:
+                    raise unittest.SkipTest(msg)
+                return fn(self, *args, **kwargs)
+            return inner
+    else:
+        def decorate_fn(fn):
+            return fn
+
+    return decorate_fn
+
+def skip_windows_ci(name: str, file: str) -> None:
+    if IS_WINDOWS and IS_CI:
+        module = os.path.basename(file).strip(".py")
+        sys.stderr.write(
+            f"Windows CI does not have necessary dependencies for {module} tests yet\n"
+        )
+        if name == "__main__":
+            sys.exit(0)
+        raise unittest.SkipTest("requires sympy/functorch/filelock")
+
+# TODO: Remove HAS_MPS condition  when `HAS_GPU` includes HAS_MPS
+requires_gpu = functools.partial(unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu")
+requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
+requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion")
+
+def requires_cuda_with_enough_memory(min_mem_required):
+    def inner(fn):
+        if not torch.cuda.is_available() or torch.cuda.get_device_properties().total_memory < min_mem_required:
+            return unittest.skip(f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe")(fn)
+        else:
+            return fn
+
+    return inner
+
+skipCUDAIf = functools.partial(skipDeviceIf, device="cuda")
+skipXPUIf = functools.partial(skipDeviceIf, device="xpu")
+skipCPUIf = functools.partial(skipDeviceIf, device="cpu")
+
+IS_A100 = LazyVal(
+    lambda: HAS_CUDA
+    and get_gpu_shared_memory() == 166912
+)
+
+IS_H100 = LazyVal(
+    lambda: HAS_CUDA
+    and get_gpu_shared_memory() == 232448
+)
+
+IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())
+
+def dummy_graph() -> GraphLowering:
+    """
+    Create a graph. This is useful for unit testing code which accesses
+    V.graph.sizevars.
+    """
+    example_inputs = [torch.randn(10) for _ in range(2)]
+    gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs)
+    shape_env = shape_env_from_inputs(example_inputs)
+    graph = GraphLowering(
+        gm,
+        shape_env=shape_env,
+    )
+
+    return graph
+
+def maybe_skip_size_asserts(op):
+    """
+    For certain ops, there meta and eager implementation returns different
+    strides. This cause size/strides assert fail. Skip adding those
+    asserts for now.
+    """
+    if (
+        op.aten_name
+        in (
+            "fft_hfftn",
+            "fft_hfft",
+            "fft_hfft2",
+            "fft_ihfftn",
+            "fft_fft",
+            "fft_fft2",
+            "fft_fftn",
+            "fft_ifft",
+            "fft_ifft2",
+            "fft_ifftn",
+            "fft_irfft",
+            "fft_irfft2",
+            "fft_irfftn",
+            "fft_ihfft",
+            "fft_ihfft2",
+            "fft_rfft",
+            "fft_rfft2",
+            "fft_rfftn",
+            "linalg_eig",
+            "linalg_eigvals",
+        )
+        and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ
+    ):
+        return torch._inductor.config.patch(size_asserts=False)
+    else:
+        return contextlib.nullcontext()
+
+def get_func_call() -> str:
+    return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call("
+
+def get_kernel_launch() -> str:
+    return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run("
+
+def clone_preserve_strides_offset(x, device=None):
+    if not isinstance(x, torch.Tensor):
+        return x
+    buffer = torch.as_strided(
+        x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
+    )
+    if not device:
+        buffer = buffer.clone()
+    else:
+        buffer = buffer.to(device, copy=True)
+    out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
+    return out
+
+# define the e4m3/e5m2 constants
+E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
+E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
+E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
+E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
+
+FP16_MAX_POS: float = torch.finfo(torch.float16).max
+EPS: float = 1e-12
+
+Tensor = torch.Tensor
+
+def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
+    # The default behavior in PyTorch for casting to `float8_e4m3fn`
+    # and `e5m2` is to not saturate. In this context, we should saturate.
+    # A common case where we want to saturate is when the history of a
+    # tensor has a maximum value of `amax1`, and the current amax value
+    # is `amax2`, where `amax1 < amax2`. This is common when using delayed
+    # scaling.
+    if float8_dtype == torch.float8_e4m3fn:
+        x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
+    elif float8_dtype == torch.float8_e5m2:
+        x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
+    elif float8_dtype == torch.float8_e4m3fnuz:
+        x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
+    elif float8_dtype == torch.float8_e5m2fnuz:
+        x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
+    else:
+        raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
+    return x.to(float8_dtype)
+
+@torch.no_grad()
+def _amax_to_scale(
+    amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
+) -> torch.Tensor:
+    # To make scale dtype to be fp32 for accuracy
+    amax = amax.float()
+    if float8_dtype == torch.float8_e4m3fn:
+        res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
+    else:  # e5m2
+        res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
+
+    # Ensure that the scale is representable in float16,
+    # this helps when amax is small. We are assuming that we don't need
+    # to care about this for float32/bfloat16.
+    if orig_dtype is torch.float16:
+        res = torch.clamp(res, max=FP16_MAX_POS)
+    return res
+
+def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
+    amax = torch.max(torch.abs(x))
+    scale = _amax_to_scale(amax, float8_dtype, x.dtype)
+    x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
+    inverse_scale = scale.reciprocal()
+    return x_fp8, inverse_scale
+
+def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
+    amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
+    scale = _amax_to_scale(amax, float8_dtype, x.dtype)
+    x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
+    inverse_scale = scale.reciprocal()
+    return x_fp8, inverse_scale
+
+@contextlib.contextmanager
+def patch_inductor_backend(
+    device: str,
+    python_wrapper_codegen: PythonWrapperCodegen = None,
+    custom_pass: CustomGraphModulePass = None
+):
+    """
+    Patch the inductor backend for a specific device.
+    """
+    # Make sure the backend is already registered
+    init_backend_registration()
+
+    # Get the original registration parameters
+    original_scheduling = get_scheduling_for_device(device)
+    original_python_wrapper = get_wrapper_codegen_for_device(device, False)
+    original_cpp_wrapper = get_wrapper_codegen_for_device(device, True)
+    original_custom_pass = get_custom_backend_pass_for_device(device)
+
+    try:
+        # Register modified backend for the device
+        register_backend_for_device(
+            device,
+            original_scheduling,
+            python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
+            original_cpp_wrapper,
+            custom_pass if custom_pass is not None else original_custom_pass
+        )
+        yield
+    finally:
+        # Restore the original backend
+        register_backend_for_device(
+            device,
+            original_scheduling,
+            original_python_wrapper,
+            original_cpp_wrapper,
+            original_custom_pass
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3dbb95f4ba9c3c430e27a677fc3850aee2b3549
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_metaprogramming_utils.py
@@ -0,0 +1,725 @@
+# mypy: ignore-errors
+
+# Torch
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401
+import torch.nn.functional as F
+import torch
+import torch.cuda
+import torch.jit
+import torch.jit._logging
+import torch.jit.frontend
+from torch.testing._internal.common_nn import module_tests, get_new_module_tests
+from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
+
+import collections
+from copy import deepcopy
+from typing import Any, Union
+import math  # noqa: F401
+
+# Testing utils
+from torch import inf
+
+assert torch.get_default_dtype() == torch.float32
+
+L = 20
+M = 10
+S = 5
+
+
+def unpack_variables(args):
+    if isinstance(args, tuple):
+        return tuple(unpack_variables(elem) for elem in args)
+    else:
+        return args
+
+class dont_convert(tuple):
+    __slots__ = ()
+
+non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
+
+def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.float, device=None):
+    if not isinstance(call_args, tuple):
+        call_args = (call_args,)
+
+    def map_arg(arg):
+        def maybe_non_contig(tensor):
+            if not non_contiguous or tensor.numel() < 2:
+                return tensor.clone()
+
+            return noncontiguous_like(tensor)
+
+        def conjugate(tensor):
+            return tensor.conj()
+
+        if isinstance(arg, (torch.Size, dont_convert)):
+            return arg
+        elif isinstance(arg, tuple) and len(arg) == 0:
+            var = conjugate(torch.randn((), dtype=dtype, device=device))
+            var.requires_grad = requires_grad
+            return var
+        elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
+            return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
+        # double check casting
+        elif isinstance(arg, non_differentiable):
+            if isinstance(arg.tensor, torch.Tensor):
+                return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
+            return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
+        elif isinstance(arg, torch.Tensor):
+            if arg.is_complex() != dtype.is_complex:
+                raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
+                                   "which is not supported for now")
+            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
+            v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
+            v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
+            return v
+        elif callable(arg):
+            return map_arg(arg(dtype=dtype, device=device))
+        else:
+            return arg
+    args_out = tuple(map_arg(arg) for arg in call_args)
+    kwargs_out = {k: map_arg(v) for k, v in call_kwargs.items()} if call_kwargs else {}
+    return args_out, kwargs_out
+
+# NB: JIT script tests for all nn functional interfaces, script mode does
+# not support in_place operations yet, so no inplace operation tests added.
+# removed all the deprecated functions
+#
+# (
+#   method name,
+#   input size/constructing fn,
+#   args (tuple represents shape of a tensor arg),
+#   test variant name(will be used at test name suffix,
+#       'inplace' skips grad tests),                         // optional
+#   (True, nonfusible_nodes, fusible_nodes) for autodiff     // optional
+#   fn to determine if test should be skipped,               // optional
+#   fn mapping output to part that should be gradcheck'ed,   // optional
+#   kwargs for function,                                     // optional
+# )
+def get_nn_functional_tests():
+    nn_functional_tests = [
+        ('conv1d', (S, S, S), ((S, S, S),)),
+        ('conv2d', (S, S, S, S), ((S, S, S, S),)),
+        ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
+        ('conv_transpose1d', (S, S, S), ((S, S, S),)),
+        ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
+        ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
+        ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
+        ('avg_pool1d', (S, S, S), (3,)),
+        ('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
+        ('avg_pool3d', (S, S, S, S, S), (3,)),
+        ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
+        ('max_pool1d', (S, S, S), (2, 1)),
+        ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
+        ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
+        ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
+        ('max_pool3d', (S, S, S, S, S), (2, 1)),
+        ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
+        ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
+        ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
+        ('lp_pool1d', (S, S, S), (2., 3, 2,)),
+        ('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
+        ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
+        ('adaptive_max_pool1d', (S, S, S), (5,)),
+        ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
+        ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
+        ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
+        ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
+        ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
+        ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
+        ('alpha_dropout', (S, S, S), (0.5,)),
+        ('dropout2d', (S, S, S), (0.5,)),
+        ('dropout2d', (S, S, S, S), (0.5,), 'batched'),
+        ('dropout3d', (S, S, S, S), (0.5,)),
+        ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
+        ('feature_alpha_dropout', (S, S, S), (0.5,)),
+        ('threshold', (S, S, S), (0.1, 2.), '', (True,)),
+        ('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
+        ('relu', (S, S, S), (), '', (True,)),
+        ('relu', (S, S, S), (), 'inplace'),
+        ('glu', (S - 1, S - 1, S - 1), (),),
+        ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
+        ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
+        ('relu6', (S, S, S), (), '', (True,)),
+        ('relu6', (S, S, S), (True), 'inplace'),
+        ('elu', (S, S, S), (0.9,),),
+        ('elu', (S, S, S), (0.9, True), 'inplace'),
+        ('selu', (S, S, S), (),),
+        ('selu', (S, S, S), (True), 'inplace'),
+        ('celu', (S, S, S), (0.9,),),
+        ('celu', (S, S, S), (0.9, True), 'inplace'),
+        ('leaky_relu', (S, S, S), (0.02,), '', (True,)),
+        ('leaky_relu', (S, S, S), (0.02,), 'inplace'),
+        ('rrelu', (S, S), (0.1, 0.3, False),),
+        ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
+        ('hardshrink', (S, S, S), (0.4,), '', (True,)),
+        ('tanhshrink', (S, S, S), (),),
+        ('softsign', (S, S, S), (),),
+        ('softplus', (S, S, S), (), '', (True,)),
+        ('softmin', (S, S, S), (0,),),
+        ('softmax', (S, S, S), (0,), '', (True,)),
+        ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
+        ('tanh', (S, S, S), (), '', (True,)),
+        ('sigmoid', (S, S, S), (), '', (True,)),
+        ('silu', (S, S, S), (), '', (True,)),
+        ('log_softmax', (S, S, S), (0,), '', (True,)),
+        ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
+        ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
+        ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
+        ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
+        ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
+        ('batch_norm', (S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
+            'training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (0, S, S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'size_zero', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (0, S, S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S),
+            (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+             non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+            'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, non_differentiable(torch.ones(S)), True, ),
+            'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), None, True, ),
+            'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, None, False, ),
+            'inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
+            'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                None, non_differentiable(torch.ones(S)), False, ),
+            'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+                                non_differentiable(torch.randn(S)), None, False, ),
+            'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
+        ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
+        ('layer_norm', (S, S, S, S), ([5],), '',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+        ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
+                                      non_differentiable(torch.rand(S))), 'with_weight_and_bias',
+         (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
+        ('group_norm', (S, S, S), (1, torch.rand(5),),),
+        ('local_response_norm', (S, S, S), (2, ),),
+        ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
+        ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
+        ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
+        ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
+        ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
+        ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
+        ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
+        ('margin_ranking_loss', (S,), ((S,), (S,)),),
+        ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
+        ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
+        ('pixel_shuffle', (1, 9, 4, 4), (3,),),
+        ('pixel_unshuffle', (1, 1, 12, 12), (3,),),
+        ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
+        ('pad', (3, 3, 4, 2), ([1, 1],),),
+        ('pairwise_distance', (S, S), ((S, S),),),
+        ('pdist', (S, S), (),),
+        ('cosine_similarity', (S, S), ((S, S),),),
+        ('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
+        ('normalize', (S, S, S), (),),
+        ('unfold', (S, S, S, S), ([2, 3]),),
+        ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
+        ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
+        ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
+        ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
+        ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
+        ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
+                                       1, 1., non_differentiable(torch.randn(S))),),
+        ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
+                                                               non_differentiable(torch.randn(3, 2))),),
+        ('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
+            (non_differentiable(torch.rand(3, 2)),
+             non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
+        ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
+         (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
+          torch.randint(1, S, (S,), dtype=torch.long))),
+        ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
+        ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
+        ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
+        ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
+        ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
+        ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
+         'nearest_4d_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
+         'nearest_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
+         'bilinear_4d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
+         'bilinear_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
+         'bicubic_4d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
+         'bicubic_4d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
+         'nearest_3d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
+         'nearest_3d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
+         'linear_3d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
+         'linear_3d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
+         'nearest_5d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
+         'nearest_5d_with_size_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
+         'trilinear_5d_with_scale_not_recompute_scale_factor'),
+        ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
+         'trilinear_5d_with_size_not_recompute_scale_factor'),
+    ]
+    return nn_functional_tests
+
+script_template = '''
+def the_method({}):
+    return {}
+'''
+
+def value_to_literal(value):
+    if isinstance(value, str):
+        # Quotes string and escapes special characters
+        return ascii(value)
+    if isinstance(value, torch.Tensor):
+        return 'torch.' + str(value)
+    else:
+        return str(value)
+
+def get_call(method_name, func_type, args, kwargs):
+    kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
+    self_arg = args[0]
+    if func_type == 'method':
+        args = args[1:]
+
+    argument_str = ', '.join(args)
+    argument_str += ', ' if len(args) and len(kwargs) else ''
+    argument_str += kwargs_str
+
+    if func_type == 'functional' or func_type == 'function':
+        call = f'torch.{method_name}({argument_str})'
+    elif func_type == 'method':
+        call = f'{self_arg}.{method_name}({argument_str})'
+    elif func_type == 'nn_functional':
+        call = f'torch.nn.functional.{method_name}({argument_str})'
+    else:
+        raise TypeError('Unsupported function type')
+
+    return call
+
+def get_constant(x):
+    if x == inf:
+        return 'math.inf'
+    if x == -inf:
+        return '-math.inf'
+    return x
+
+def get_script_args(args):
+    formals: list[str] = []
+    tensors: list[Union[torch.Tensor, list[torch.Tensor]]] = []
+    actuals: list[str] = []
+    for arg in args:
+        if isinstance(arg, torch.Tensor):
+            name = f'i{len(formals)}'
+            formals.append(name)
+            actuals.append(name)
+            tensors.append(arg)
+        elif is_iterable_of_tensors(arg):
+            name = f'i{len(formals)}'
+            formals.append(name + ': List[torch.Tensor]')
+            actuals.append(name)
+            tensors.append(list(arg))
+        elif isinstance(arg, str):
+            actuals.append(f"'{arg}'")
+        else:
+            actuals.append(str(get_constant(arg)))
+    return (formals, tensors, actuals)
+
+# create a script function from (name, func_type, output_process_fn),
+# and returns the compiled function and example inputs
+def gen_script_fn_and_args(method_name, func_type, *args, **kwargs):
+    formals, tensors, actuals = get_script_args(args)
+    call = get_call(method_name, func_type, actuals, kwargs)
+    script = script_template.format(', '.join(formals), call)
+    CU = torch.jit.CompilationUnit(script)
+    return CU.the_method, tensors
+
+# create a script function from (name, func_type),
+# returns a function takes in (args, kwargs) and runs the compiled function
+def create_script_fn(self, method_name, func_type):
+    # function returns tuple containing original output and
+    # filtered output to be used in checking gradients
+    def script_fn(*args, **kwargs):
+        fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs)
+        self.assertExportImport(fn.graph, tensors)
+        output = fn(*tensors)
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        script_fn.last_graph = fn.graph_for(*tensors)  # type: ignore[attr-defined]
+        return output
+    return script_fn
+
+class SplitInputs:
+    all_tensors: list[Any]
+    tensor_args: list[Any]
+    nontensor_args: list[Any]
+    arg_types: list[str]
+    tensor_kwargs: dict[str, Any]
+    kwarg_order: list[str]
+    nontensor_kwargs: dict[str, Any]
+    kwarg_types: dict[str, Any]
+
+    @staticmethod
+    def _is_tensor_input(arg):
+        return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)
+
+    def __init__(self, args, kwargs):
+        self.arg_types = ['t' if self._is_tensor_input(arg) else 's' for arg in args]
+        self.kwarg_types = {k: 't' if self._is_tensor_input(v) else 's' for k, v in kwargs.items()}
+        self.tensor_args = [arg for arg in args if self._is_tensor_input(arg)]
+        self.nontensor_args = [arg for arg in args if not self._is_tensor_input(arg)]
+        self.tensor_kwargs = {k: v for k, v in kwargs.items() if self._is_tensor_input(v)}
+        self.nontensor_kwargs = {k: v for k, v in kwargs.items() if not self._is_tensor_input(v)}
+        self.all_tensors = [*self.tensor_args, *[v for k, v in self.tensor_kwargs.items()]]
+        self.kwarg_order = [k for k, v in kwargs.items()]
+
+    def nontensors_match(self, other: 'SplitInputs'):
+        if self.arg_types != other.arg_types:
+            return False
+        if self.kwarg_types != other.kwarg_types:
+            return False
+        if self.kwarg_order != other.kwarg_order:
+            return False
+        if self.nontensor_args != other.nontensor_args:
+            return False
+        if self.nontensor_kwargs != other.nontensor_kwargs:
+            return False
+        return True
+
+# make a new function where all non-tensor arguments in 'args' have been partially
+# applied, and all tensor arguments remain.
+# used to trace functions when some arguments are not tensors
+def partial_apply_nontensors(fn, args, kwargs):
+    inputs = SplitInputs(args, kwargs)
+
+    def new_fn(*tensors_):
+        tensors = iter(tensors_)
+        full_args = [args[i] if s == 's' else next(tensors) for i, s in enumerate(inputs.arg_types)]
+        full_kwargs = {k: kwargs[k] if s == 's' else next(tensors) for k, s in inputs.kwarg_types.items()}
+        return fn(*full_args, **full_kwargs)
+
+    return new_fn, inputs
+
+# create a trace function from input fn
+def create_traced_fn(self, fn, cache_traced_fn=False):
+    def traced_fn(*inputs, **kwargs):
+        # `check_trace` is set to False because check_trace is run with @no_grad
+        # Also, `check_against_reference` already does all the checks
+        # against python function
+        fn_tensors, split_inputs = partial_apply_nontensors(fn, inputs, kwargs)
+        if not cache_traced_fn or not hasattr(traced_fn, 'traced'):
+            traced = torch.jit.trace(fn_tensors, split_inputs.all_tensors, check_trace=False)
+            self.assertExportImport(traced.graph, split_inputs.all_tensors)
+            output = traced(*split_inputs.all_tensors)
+            if cache_traced_fn:
+                traced_fn.traced = traced
+                traced_fn.split_inputs = split_inputs
+        else:
+            # Guard to check that nontensor inputs are the same as during tracing
+            self.assertTrue(traced_fn.split_inputs.nontensors_match(split_inputs))
+            output = traced_fn.traced(*split_inputs.all_tensors)
+            traced = traced_fn.traced
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        traced_fn.last_graph = traced.graph_for(*split_inputs.all_tensors)  # type: ignore[attr-defined]
+        traced_fn.graph = traced.graph  # type: ignore[attr-defined]
+        return output
+    return traced_fn
+
+# known to be failing in script
+EXCLUDE_SCRIPT = {
+    'test_norm_fro_default',
+    'test_norm_fro_cpu',
+    'test_norm_nuc',
+    'test_norm_fro',
+    'test_norm_nuc_batched',
+
+    # aten op has additional cudnn argument
+    'test_nn_unfold',
+
+    # flaky test - TODO fix
+    'test_nn_ctc_loss',
+
+    # unknown builtin op
+    'test_nn_fold',
+
+    # jit doesn't support sparse tensors.
+    'test_to_sparse',
+    'test_to_sparse_dim',
+}
+
+# generates a script function and set of example inputs
+# from a specified test in the format of nn_functional_tests
+def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name='', *extra_args):
+    test_name = 'test_nn_' + name
+
+    if variant_name != '':
+        test_name = test_name + '_' + variant_name
+
+    self_variable = create_input((self_size,))[0][0]
+
+    # need to record this because methods can change the size (e.g. unsqueeze)
+    args_variable, _kwargs_variable = create_input(args)
+
+    self_tensor = deepcopy(self_variable.data)
+    args_tensor = deepcopy(unpack_variables(args_variable))
+
+    f_args_variable = (self_variable,) + args_variable
+    f_args_tensor = (self_tensor,) + args_tensor  # noqa: F841
+    with torch._jit_internal._disable_emit_hooks():
+        script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
+    return script_fn, inputs
+
+
+
+EXCLUDE_SCRIPT_MODULES = {
+    'test_nn_AdaptiveAvgPool2d_tuple_none',
+    'test_nn_AdaptiveAvgPool3d_tuple_none',
+    'test_nn_AdaptiveMaxPool2d_tuple_none',
+    'test_nn_AdaptiveMaxPool3d_tuple_none',
+
+    # Doesn't use future division, so this is not supported
+    'test_nn_CrossMapLRN2d',
+    # Derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented
+    'test_nn_TransformerDecoderLayer_gelu_activation',
+    'test_nn_TransformerDecoderLayer_relu_activation',
+    'test_nn_TransformerEncoderLayer_gelu_activation',
+    'test_nn_TransformerEncoderLayer_relu_activation',
+    'test_nn_Transformer_multilayer_coder',
+}
+
+script_method_template = '''
+def forward({}):
+    return {}
+'''
+
+def create_script_module(self, nn_module, constructor_args, *args, **kwargs):
+    def script_module(*args, **kwargs):
+        _formals, tensors, actuals = get_script_args(args)
+
+        method_args = ', '.join(['self'] + actuals)
+        call_args_str = ', '.join(actuals)
+        call = f"self.submodule({call_args_str})"
+        script = script_method_template.format(method_args, call)
+
+        submodule_constants = []
+        if kwargs.get('is_constant'):
+            submodule_constants = ['submodule']
+
+        # Create module to use the script method
+        class TheModule(torch.jit.ScriptModule):
+            __constants__ = submodule_constants
+
+            def __init__(self) -> None:
+                super().__init__()
+                self.submodule = nn_module(*constructor_args)
+
+        def make_module(script):
+            module = TheModule()
+            # check __repr__
+            str(module)
+            module.define(script)
+            return module
+
+        module = make_module(script)
+        if self:
+            self.assertExportImportModule(module, tensors)
+            module(*args)
+        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
+        create_script_module.last_graph = module.graph  # type: ignore[attr-defined]
+        return module
+    return script_module
+
+def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'):
+    formals, tensors, actuals = get_script_args(args)
+    call = get_call(method_name, func_type, actuals, kwargs)
+    script = script_template.format(', '.join(formals), call)
+    CU = torch.jit.CompilationUnit(script)
+    # to clean up IR
+    torch._C._jit_pass_inline(CU.the_method.graph)
+    torch._C._jit_pass_constant_propagation(CU.the_method.graph)
+    torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name)
+
+def get_nn_module_name_from_kwargs(**kwargs):
+    if 'module_name' in kwargs:
+        return kwargs['module_name']
+    elif 'fullname' in kwargs:
+        return kwargs['fullname']
+    elif 'constructor' in kwargs:
+        return kwargs['constructor'].__name__
+
+def get_nn_mod_test_name(**kwargs):
+    if 'fullname' in kwargs:
+        test_name = kwargs['fullname']
+    else:
+        test_name = get_nn_module_name_from_kwargs(**kwargs)
+        if 'desc' in kwargs:
+            test_name = f"{test_name}_{kwargs['desc']}"
+    return f'test_nn_{test_name}'
+
+def get_nn_module_class_from_kwargs(**kwargs):
+    name = get_nn_module_name_from_kwargs(**kwargs)
+    index = name.find("_")
+    if index == -1:
+        return name
+    else:
+        return name[0:name.find("_")]
+
+def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
+    name = get_nn_module_name_from_kwargs(**kwargs)
+
+    if 'desc' in kwargs and 'eval' in kwargs['desc']:
+        # eval() is not supported, so skip these tests
+        return
+
+    test_name = name
+    if 'desc' in kwargs:
+        test_name = f"{test_name}_{kwargs['desc']}"
+    test_name = get_nn_mod_test_name(**kwargs)
+
+    if test_name in EXCLUDE_SCRIPT_MODULES:
+        return
+    if 'constructor' in kwargs:
+        nn_module = kwargs['constructor']
+    else:
+        nn_module = getattr(torch.nn, name)
+
+    if "FunctionalModule" in str(nn_module):
+        return
+
+    if 'constructor_args_fn' in kwargs:
+        constructor_args = kwargs['constructor_args_fn']()
+    else:
+        constructor_args = kwargs.get('constructor_args', ())
+
+    # Set up inputs from tuple of sizes or constructor fn
+    input_dtype = torch.double
+    if 'input_fn' in kwargs:
+        input = kwargs['input_fn']()
+        if isinstance(input, torch.Tensor):
+            input = (input,)
+
+        if all(tensor.is_complex() for tensor in input):
+            input_dtype = torch.cdouble
+    else:
+        input = (kwargs['input_size'],)
+
+    # Extra parameters to forward()
+    if 'extra_args' in kwargs:
+        input = input + kwargs['extra_args']
+
+    if 'target_size' in kwargs:
+        input = input + (kwargs['target_size'],)
+    elif 'target_fn' in kwargs:
+        if torch.is_tensor(input):
+            input = (input,)
+        input = input + (kwargs['target_fn'](),)
+
+    args_variable, _kwargs_variable = create_input(input, dtype=input_dtype)
+    f_args_variable = deepcopy(unpack_variables(args_variable))
+    out_var = deepcopy(f_args_variable)
+
+
+    _args, mod = f_args_variable, create_script_module(
+        None, nn_module, constructor_args, *f_args_variable
+    )(*f_args_variable)
+
+    return mod, out_var
+
+
+def get_all_nn_module_tests():
+    # additional modules test
+    # TODO: delete this list once we make all nn_tests work
+    additional_module_tests = [
+        {
+            'module_name': 'Bilinear',
+            'constructor_args': (S, S, M),
+            'input_size': (S, S),
+            'extra_args': ((S, S),)
+        },
+        {
+            'module_name': 'RNNCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'LSTMCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'GRUCell',
+            'constructor_args': (S, S),
+            'input_size': (S, S),
+        },
+        {
+            'module_name': 'MultiheadAttention',
+            'constructor_args': (128, 8),
+            'input_size': (10, 8, 128),
+            'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
+            'slowTest': True
+        },
+        {
+            'module_name': 'Transformer',
+            'constructor_args': (1, 1, 1, 1, 2),
+            'input_size': (3, 1, 1),
+            'extra_args': (torch.randn(1, 1, 1),),
+            'slowTest': True
+        }
+    ]
+
+    return module_tests + get_new_module_tests() + additional_module_tests
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bc0738ec2f3845a51a1b84eb59e583679860def
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/jit_utils.py
@@ -0,0 +1,893 @@
+# mypy: ignore-errors
+
+# Torch
+from torch.autograd import Variable
+from torch.autograd.function import _nested_map
+from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401
+
+from torch.onnx import OperatorExportTypes
+import torch
+import torch.cuda
+import torch.jit
+import torch.jit._logging
+import torch.jit.frontend
+import torch.jit.quantized
+import zipfile
+import functools
+
+# Testing utils
+from torch.testing import FileCheck
+from torch.testing._internal.common_utils import IS_WINDOWS, \
+    freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS, \
+    is_iterable_of_tensors
+from torch.testing._internal.common_jit import JitCommonTestCase
+from torch.testing._internal.common_utils import enable_profiling_mode  # noqa: F401
+
+# Standard library
+from contextlib import contextmanager
+from functools import reduce
+from io import StringIO
+from collections import defaultdict
+
+import importlib.util
+import inspect
+import io
+import math
+import os
+import pickle
+import sys
+import tempfile
+import textwrap
+from importlib.abc import Loader
+from typing import Any, Union
+
+RUN_CUDA = torch.cuda.is_available()
+RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
+RUN_CUDA_HALF = RUN_CUDA
+# HIP supports half, no version check necessary
+if torch.cuda.is_available() and not torch.version.hip:
+    CUDA_VERSION = torch._C._cuda_getCompiledVersion()
+    for d in range(torch.cuda.device_count()):
+        major = torch.cuda.get_device_capability(d)[0]
+        if (major < 6):
+            RUN_CUDA_HALF = False
+
+def execWrapper(code, glob, loc):
+    exec(code, glob, loc)
+
+def do_input_map(fn, input):
+    return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
+
+def clear_class_registry():
+    torch._C._jit_clear_class_registry()
+    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
+    torch.jit._state._clear_class_state()
+
+def get_execution_plan(graph_executor_state):
+    execution_plans = list(graph_executor_state.execution_plans.values())
+    num_plans = len(execution_plans)
+    if num_plans != 1:
+        raise RuntimeError('This test assumes this GraphExecutor should '
+                           f'only have one execution plan, got: {num_plans}')
+    return execution_plans[0]
+
+class _AssertRaisesRegexWithHighlightContext:
+    """
+    A context manager that is useful for checking that error messages highlight
+    the correct part of the source code.
+    """
+
+    def __init__(self, test_case, exception, regex, highlight):
+        self.test_case = test_case
+        self.exception_type = exception
+        self.regex = regex
+        self.highlight = highlight
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
+        with self.test_case.assertRaisesRegex(self.exception_type, self.regex):
+            if type:
+                raise value
+
+        if self.highlight:
+            FileCheck().check_source_highlighted(self.highlight).run(str(value))
+
+        return True
+
+FUSION_GROUP = "prim::TensorExprGroup"
+
+class JitTestCase(JitCommonTestCase):
+    _do_cuda_memory_leak_check = True
+    _restored_warnings = False
+
+    class capture_stdout(list):
+        """
+        Replace sys.stdout with a temporary StringIO
+        """
+        def __enter__(self):
+            self.sys_stdout = sys.stdout
+            self.stringio = StringIO()
+            sys.stdout = self.stringio
+            return self
+
+        def __exit__(self, *args):
+            self.append(str(self.stringio.getvalue()))
+            del self.stringio
+            sys.stdout = self.sys_stdout
+
+    class capture_stderr(list):
+        """
+        Replace sys.stderr with a temporary StringIO
+        """
+        def __enter__(self):
+            self.sys_stderr = sys.stderr
+            self.stringio = StringIO()
+            sys.stderr = self.stringio
+            return self
+
+        def __exit__(self, *args):
+            self.append(str(self.stringio.getvalue()))
+            del self.stringio
+            sys.stderr = self.sys_stderr
+
+    def setHooks(self):
+        torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook)
+
+    def clearHooks(self):
+        torch._C._jit_set_emit_hooks(None, None)
+
+    def setUp(self):
+        super().setUp()
+        # unittest overrides all warning filters and forces all of them to show up
+        # after we install our own to silence those coming from inside PyTorch.
+        # This will ensure that our filter still takes precedence.
+        if not JitTestCase._restored_warnings:
+            torch.jit.TracerWarning.ignore_lib_warnings()
+            JitTestCase._restored_warnings = True
+        self.setHooks()
+
+    def tearDown(self):
+        super().tearDown()
+        # needs to be cleared because python might be unloaded before
+        # the callback gets destructed
+        self.clearHooks()
+        clear_class_registry()
+
+    def assertAllFused(self, graph, except_for=()):
+
+        # note this helper collects nodes on 'fast path' only
+        # i.e. the true blocks of specialized checks
+        def get_nodes_and_parents_recursively(block, kind, acc):
+            for node in block.nodes():
+                if node.kind() == kind:
+                    acc[block].append(node)
+                elif node.kind() == 'prim::DifferentiableGraph':
+                    get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc)
+                elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or
+                                                    node.inputs().__next__().node().kind() == 'prim::TypeCheck' or
+                                                    node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'):
+                    get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc)
+                else:
+                    for inner_block in node.blocks():
+                        get_nodes_and_parents_recursively(inner_block, kind, acc)
+
+        allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
+                         'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for)
+
+        fusion_groups : dict[torch._C.Block, list[torch._C.Node]] = defaultdict(list)
+        get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups)
+        self.assertTrue(len(fusion_groups) == 1, f'got {graph}')
+        (graph, fusion_nodes) = next(iter(fusion_groups.items()))
+        # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes`
+        self.assertTrue(len(fusion_nodes) == 1, f'got {graph}')
+        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
+                        f'got {graph}')
+
+    def _isHookExceptionOk(self, e):
+        se = str(e)
+        allowed = ("Could not export Python function",
+                   "closures are not exportable")
+        for a in allowed:
+            if a in se:
+                return True
+        return False
+
+    def _compared_saved_loaded(self, m):
+        def extract_files(buffer):
+            # crack open the zip format to get at the main module code
+            archive = zipfile.ZipFile(buffer)
+            # check that we have no duplicate names
+            self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
+            files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
+            # unwrap all the code files into strings
+            code_files_str = filter(lambda x: x.endswith('.py'), files)
+            code_files_stream = (archive.open(f) for f in code_files_str)
+            code_files = ("".join([line.decode() for line in file]) for file in code_files_stream)
+
+            # unpickled all the debug files
+            debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
+            debug_files_stream = (archive.open(f) for f in debug_files_str)
+            debug_files = (pickle.load(f) for f in debug_files_stream)
+            return code_files, debug_files
+
+        # disable the hook while we parse code, otherwise we will re-enter the hook
+        with torch._jit_internal._disable_emit_hooks():
+            try:
+                # short-circuit if this is an empty function or module
+                if len(m.code) == 0:
+                    return
+                if isinstance(m, torch._C.ScriptModule):
+                    if len(m._method_names()) == 0:
+                        return
+
+                # save the module to a buffer
+                buffer = io.BytesIO()
+                torch.jit.save(m, buffer)
+                # copy the data in the buffer so we can restore it later. This
+                # is because py2 and py3 have different semantics with zipfile
+                # and it's easier to just work with a fresh copy each time.
+                buffer_copy = buffer.getvalue()
+
+                code_files, _debug_files = extract_files(buffer)
+
+            except RuntimeError as e:
+                if not self._isHookExceptionOk(e):
+                    raise
+                else:
+                    return
+
+            # import the model again (from a the copy we made of the original)
+            buffer2 = io.BytesIO(buffer_copy)
+            imported = torch.jit.load(buffer2)
+
+            # save it again
+            saved_module_buffer_2 = io.BytesIO()
+            torch.jit.save(imported, saved_module_buffer_2)
+
+            saved_module_buffer_2.seek(0)
+            code_files_2, _debug_files_2 = extract_files(saved_module_buffer_2)
+
+            for a, b in zip(code_files, code_files_2):
+                self.assertMultiLineEqual(a, b)
+
+            if isinstance(m, torch._C.ScriptModule):
+                self.assertTrue(torch._C._ivalue_tags_match(m, imported._c))
+
+
+    def emitFunctionHook(self, func):
+        # func has invalid names for export, skip the jitter check
+        if func.name == "" or "aten::" in func.name:
+            return
+        self._compared_saved_loaded(func)
+
+    def emitModuleHook(self, module):
+        self._compared_saved_loaded(module)
+
+
+    def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
+        buffer = io.BytesIO()
+        m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None)
+        torch.jit.save(m, buffer)
+        m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+        buffer.seek(0)
+        imported = torch.jit.load(buffer, map_location=map_location)
+        imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+
+        if not also_test_file:
+            return imported
+
+        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+        # close the file after creation and try to remove it manually
+        f = tempfile.NamedTemporaryFile(delete=False)
+        try:
+            f.close()
+            imported.save(f.name)
+            result = torch.jit.load(f.name, map_location=map_location)
+        finally:
+            os.unlink(f.name)
+
+        result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
+        return result
+
+    def assertGraphContains(self, graph, kind, consider_subgraphs=False):
+
+        if consider_subgraphs:
+            strgraph = str(graph)
+            count = strgraph.count(kind) - strgraph.count(f'with {kind}')
+            self.assertTrue(count > 0)
+            return
+
+        def nodes(block):
+            out = []
+            for node in block.nodes():
+                if node.kind() == kind:
+                    out.append(node)
+                for block in node.blocks():
+                    out += nodes(block)
+            return out
+
+        out_nodes = nodes(graph)
+        self.assertTrue(len(out_nodes) > 0)
+
+    def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
+        def perform_assert(graph, kind, actual, expected, consider_subgraphs):
+            if actual == expected:
+                return
+            subgraph = 'including' if consider_subgraphs else 'excluding'
+            raise AssertionError(
+                f'{graph}\nError: graph contains {actual} {kind} nodes ({subgraph} subgraphs) but expected {expected}')
+
+        if consider_subgraphs:
+            strgraph = str(graph)
+            count = strgraph.count(kind) - strgraph.count(f'with {kind}')
+            perform_assert(graph, kind, count, num_kind_nodes,
+                           consider_subgraphs)
+            return
+
+        def nodes(block):
+            out = []
+            for node in block.nodes():
+                if node.kind() == kind:
+                    out.append(node)
+                for block in node.blocks():
+                    out += nodes(block)
+            return out
+
+        out_nodes = nodes(graph)
+        perform_assert(graph, kind, len(out_nodes), num_kind_nodes,
+                       consider_subgraphs)
+
+    def assertExpectedONNXGraph(self, g, *args, **kwargs):
+        g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
+        self.assertExpectedGraph(g, *args, **kwargs)
+
+    def assertExpectedGraph(self, trace, *args, **kwargs):
+        if isinstance(trace, torch._C.Graph):
+            graph = trace
+        else:
+            graph = trace.graph()
+
+        torch._C._jit_pass_lint(graph)
+        torch._C._jit_pass_dce(graph)
+        torch._C._jit_pass_lint(graph)
+        graph = torch._C._jit_pass_canonicalize(graph)
+        torch._C._jit_pass_lint(graph)
+        self.assertExpected(str(graph), *args, **kwargs)
+
+    def run_pass(self, name, trace):
+        if isinstance(trace, torch._C.Graph):
+            graph = trace
+            set_graph = False
+        else:
+            set_graph = True
+            graph = trace.graph()
+
+        torch._C._jit_pass_lint(graph)
+        result = getattr(torch._C, '_jit_pass_' + name)(graph)
+        if result is not None and not isinstance(result, bool):
+            graph = result
+        torch._C._jit_pass_lint(graph)
+
+        if set_graph:
+            trace.set_graph(graph)
+        return graph
+
+    def get_frame_vars(self, frames_up):
+        frame = inspect.currentframe()
+        if not frame:
+            raise RuntimeError("failed to inspect frame")
+        i = 0
+        while i < frames_up + 1:
+            frame = frame.f_back
+            if not frame:
+                raise RuntimeError("failed to get frame")
+            i += 1
+        defined_vars: dict[str, Any] = {}
+        defined_vars.update(frame.f_locals)
+        defined_vars.update(frame.f_globals)
+        return defined_vars
+
+    def assertRaisesRegexWithHighlight(self, exception, regex, highlight):
+        return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight)
+
+    def checkScriptRaisesRegex(self, script, inputs, exception, regex,
+                               name=None, outputs=None, capture_output=False,
+                               frames_up=1, profiling=ProfilingMode.PROFILING):
+        """
+        Checks that a given function will throw the correct exception,
+        when executed with normal python, the string frontend, and the
+        AST frontend. Logic taken from `checkScript` (see comments there
+        for details)
+        """
+        with enable_profiling_mode_for_profiling_tests():
+            # Normal Python
+            with self.assertRaisesRegex(exception, regex):
+                if isinstance(script, str):
+                    frame = self.get_frame_vars(frames_up)
+                    the_locals: dict[str, Any] = {}
+                    execWrapper(script, glob=frame, loc=the_locals)
+                    frame.update(the_locals)
+
+                    python_fn = frame[name]
+                else:
+                    python_fn = script
+
+                python_fn(*inputs)
+
+            # String frontend
+            with self.assertRaisesRegex(exception, regex):
+                if isinstance(script, str):
+                    cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
+                    string_frontend = getattr(cu, name)
+                else:
+                    source = textwrap.dedent(inspect.getsource(script))
+                    cu = torch.jit.CompilationUnit(source, _frames_up=frames_up)
+                    string_frontend = getattr(cu, script.__name__)
+
+                string_frontend(*inputs)
+
+            # Python AST frontend
+            if not isinstance(script, str):
+                with self.assertRaisesRegex(exception, regex):
+                    ge = torch.jit.script(python_fn)
+                    ge(*inputs)
+
+    def checkBailouts(self, model, inputs, expected):
+        state = model.get_debug_state()
+        plan = get_execution_plan(state)
+        num_bailouts = plan.code.num_bailouts()
+        for i in range(0, num_bailouts):
+            plan.code.request_bailout(i)
+            bailout_outputs = model(*inputs)
+            self.assertEqual(bailout_outputs, expected)
+
+    def checkScript(self,
+                    script,
+                    inputs,
+                    name='func',
+                    optimize=True,
+                    inputs_requires_grad=False,
+                    capture_output=False,
+                    frames_up=1,
+                    profiling=ProfilingMode.PROFILING,
+                    atol=None,
+                    rtol=None):
+        """
+        Checks that a given script generates the same output as the Python
+        version using the given inputs.
+        """
+        with torch.jit.optimized_execution(optimize):
+            with enable_profiling_mode_for_profiling_tests():
+                extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs)
+                if isinstance(script, str):
+                    # Compile the string to a Script function
+                    # with enable_profiling_mode():
+                    cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
+
+                    # Execute the Python function so we can run it later and get its
+                    # outputs
+
+                    frame = self.get_frame_vars(frames_up)
+                    the_locals: dict[str, Any] = {}
+                    execWrapper(script, glob=frame, loc=the_locals)
+                    frame.update(the_locals)
+
+                    python_fn = frame[name]
+                    scripted_fn = getattr(cu, name)
+                else:
+
+                    # Check the string frontend first
+                    source = textwrap.dedent(inspect.getsource(script))
+                    self.checkScript(
+                        source,
+                        inputs,
+                        script.__name__,
+                        optimize=optimize,
+                        inputs_requires_grad=inputs_requires_grad,
+                        capture_output=capture_output,
+                        profiling=profiling,
+                        frames_up=2)
+
+                    # Continue checking the Python frontend
+                    scripted_fn = torch.jit.script(script, _frames_up=1)
+                    python_fn = script
+
+                if inputs_requires_grad:
+                    recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
+                else:
+                    recording_inputs = inputs
+
+                if capture_output:
+                    with self.capture_stdout() as script_stdout:
+                        script_outputs = scripted_fn(*recording_inputs)
+                    with self.capture_stdout():
+                        opt_script_outputs = scripted_fn(*recording_inputs)
+                    with self.capture_stdout():
+                        python_outputs = python_fn(*inputs)
+                    if not IS_WINDOWS:
+                        self.assertExpected(script_stdout[0], subname='stdout')
+                    self.assertEqual(python_outputs, opt_script_outputs, atol=atol, rtol=rtol)
+                else:
+                    # profiling run
+                    script_outputs = scripted_fn(*recording_inputs)
+                    if inputs_requires_grad or extra_profile_runs:
+                        opt_script_outputs = scripted_fn(*recording_inputs)
+                    # optimized run
+                    opt_script_outputs = scripted_fn(*recording_inputs)
+                    if TEST_BAILOUTS:
+                        self.checkBailouts(scripted_fn, inputs, opt_script_outputs)
+                    python_outputs = python_fn(*inputs)
+                self.assertEqual(python_outputs, script_outputs, atol=atol, rtol=rtol)
+                self.assertEqual(script_outputs, opt_script_outputs, atol=atol, rtol=rtol)
+                return scripted_fn
+
+    def checkTrace(self, func, reference_tensors, input_tensors=None,
+                   drop=None, allow_unused=False, verbose=False,
+                   inputs_require_grads=True, check_tolerance=1e-5, export_import=True,
+                   _force_outplace=False, grad_atol=None, grad_rtol=None):
+
+        # TODO: check gradients for parameters, not just inputs
+        def allSum(vs):
+            # drop allows us to remove some values from ever being used
+            # to test unused outputs
+            if drop is not None:
+                vs = vs[:-drop]
+            # we don't want all the grad for all the outputs to be the same
+            # so we multiply each by a constant
+            return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None)
+        if input_tensors is None:
+            input_tensors = reference_tensors
+
+        def flatten_inputs(inputs):
+            def input_reduce(input, fn, acc):
+                if isinstance(input, torch.Tensor):
+                    fn(input, acc)
+                elif isinstance(input, dict):
+                    reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
+                else:
+                    reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
+                return acc
+            return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
+
+        nograd_inputs = reference_tensors
+        if inputs_require_grads:
+            recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
+            flattened_recording_inputs = flatten_inputs(recording_inputs)
+        else:
+            recording_inputs = reference_tensors
+
+        # `check_trace` is set to False because check_trace is run with @no_grad
+        # Also, `checkTrace` already does all the checks
+        # against python function
+        ge = torch.jit.trace(func, input_tensors, check_tolerance=check_tolerance,
+                             _force_outplace=_force_outplace, check_trace=False)
+
+        if export_import:
+            ge = self.getExportImportCopy(ge)
+
+        if verbose:
+            print(ge.graph)
+
+        # test no gradients case
+        outputs = func(*nograd_inputs)
+        outputs_ge = ge(*nograd_inputs)
+        self.assertEqual(outputs, outputs_ge)
+
+        # test gradients case
+        outputs = func(*recording_inputs)
+        if inputs_require_grads:
+            grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
+                                        allow_unused=allow_unused)
+
+        outputs_ge = ge(*recording_inputs)
+        if inputs_require_grads:
+            grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
+                                           allow_unused=allow_unused)
+        self.assertEqual(outputs, outputs_ge)
+        if inputs_require_grads:
+            self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol)
+
+        # test the grad grad case
+        outputs = func(*recording_inputs)
+        l1 = allSum(outputs)
+        if inputs_require_grads:
+            grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
+                                        allow_unused=allow_unused)
+        if inputs_require_grads:
+            l2 = (allSum(grads) * l1)
+            grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
+
+        if inputs_require_grads:
+            recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
+            flattened_recording_inputs = flatten_inputs(recording_inputs)
+
+        outputs_ge = ge(*recording_inputs)
+        l1_ge = allSum(outputs_ge)
+        if inputs_require_grads:
+            grads_ge = torch.autograd.grad(
+                l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
+
+        if inputs_require_grads:
+            l2_ge = (allSum(grads_ge) * l1_ge)
+            grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
+
+        self.assertEqual(outputs, outputs_ge)
+        if inputs_require_grads:
+            self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol)
+            for g2, g2_ge in zip(grads2, grads2_ge):
+                if g2 is None and g2_ge is None:
+                    continue
+                self.assertEqual(g2, g2_ge, atol=8e-4, rtol=8e-4)
+
+        return ge
+
+    def checkModule(self, nn_module, args):
+        """
+        Check that a nn.Module's results in Script mode match eager and that it
+        can be exported
+        """
+        sm = torch.jit.script(nn_module)
+
+        with freeze_rng_state():
+            eager_out = nn_module(*args)
+
+        with freeze_rng_state():
+            script_out = sm(*args)
+
+        self.assertEqual(eager_out, script_out)
+        self.assertExportImportModule(sm, args)
+
+        return sm
+
+class NoTracerWarnContextManager:
+    def __enter__(self):
+        self.prev = torch._C._jit_get_tracer_state_warn()
+        torch._C._jit_set_tracer_state_warn(False)
+
+    def __exit__(self, *args):
+        torch._C._jit_set_tracer_state_warn(self.prev)
+
+@contextmanager
+def inline_everything_mode(should_inline):
+    old = torch._C._jit_get_inline_everything_mode()
+    torch._C._jit_set_inline_everything_mode(should_inline)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_inline_everything_mode(old)
+
+@contextmanager
+def set_fusion_group_inlining(inlining):
+    old = torch._C._debug_get_fusion_group_inlining()
+    torch._C._debug_set_fusion_group_inlining(inlining)
+    try:
+        yield
+    finally:
+        torch._C._debug_set_fusion_group_inlining(old)
+
+# note: not re-entrant, use unnested only
+@contextmanager
+def disable_autodiff_subgraph_inlining(enabled=True):
+    torch._C._debug_set_autodiff_subgraph_inlining(not enabled)
+    try:
+        yield
+    finally:
+        torch._C._debug_set_autodiff_subgraph_inlining(True)
+
+def _inline_everything(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        with inline_everything_mode(True):
+            fn(*args, **kwargs)
+    return wrapper
+
+# this exists for forward compatibility reasons temporarily.
+# TODO(suo) remove
+def _tmp_donotuse_dont_inline_everything(fn):
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        with inline_everything_mode(False):
+            fn(*args, **kwargs)
+    return wrapper
+
+# make it easy to quickly define/trace a function for these tests
+def _trace(*args, **kwargs):
+    def wrapper(func):
+        return torch.jit.trace(func, args, **kwargs)
+    return wrapper
+
+
+def enable_cpu_fuser(fn):
+    def wrapper(*args, **kwargs):
+        torch._C._jit_override_can_fuse_on_cpu_legacy(True)
+        torch._C._jit_override_can_fuse_on_cpu(True)
+        torch._C._jit_set_te_must_use_llvm_cpu(False)
+        try:
+            fn(*args, **kwargs)
+        finally:
+            torch._C._jit_override_can_fuse_on_cpu_legacy(False)
+            torch._C._jit_override_can_fuse_on_cpu(False)
+            torch._C._jit_set_te_must_use_llvm_cpu(True)
+    return wrapper
+
+
+def enable_cpu_fuser_if(cond):
+    if cond:
+        return enable_cpu_fuser
+    else:
+        def noop_fuser(fn):
+            def wrapper(*args, **kwargs):
+                return fn(*args, **kwargs)
+            return wrapper
+        return noop_fuser
+
+def get_forward(c):
+    return c._get_method('forward')
+
+def get_forward_graph(c):
+    return c._get_method('forward').graph
+
+def get_module_method(m, module, method):
+    return m._c.getattr(module)._get_method(method)
+
+def attrs_with_prefix(module, prefix):
+    return [x for x, _ in module._modules._c.items()
+            if x.startswith(prefix)]
+
+def warmup_backward(f, *args):
+    profiling_count = 3
+    results = []
+    for _ in range(profiling_count):
+        if len(args) > 0:
+            r = torch.autograd.grad(f, *args)
+            results.append(r)
+        else:
+            f.backward(retain_graph=True)
+
+    return results
+
+# TODO: Remove me once https://bugs.python.org/issue42666 is resolved
+def make_global(*args):
+    for arg in args:
+        setattr(sys.modules[arg.__module__], arg.__name__, arg)
+
+# Helper function to eval Python3 code without causing a syntax error for
+# this file under py2
+def _get_py3_code(code, fn_name):
+    with tempfile.TemporaryDirectory() as tmp_dir:
+        script_path = os.path.join(tmp_dir, 'script.py')
+        with open(script_path, 'w') as f:
+            f.write(code)
+        spec = importlib.util.spec_from_file_location(fn_name, script_path)
+        module = importlib.util.module_from_spec(spec)
+        loader = spec.loader
+        assert isinstance(loader, Loader)  # Assert type to meet MyPy requirement
+        loader.exec_module(module)
+        fn = getattr(module, fn_name)
+        return fn
+
+class TensorExprTestOptions:
+    def __init__(self) -> None:
+        self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
+        self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
+
+        self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
+        self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
+        torch._C._jit_override_can_fuse_on_cpu(True)
+        torch._C._jit_override_can_fuse_on_gpu(True)
+        self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
+        torch._C._jit_set_texpr_fuser_enabled(True)
+        self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
+        torch._C._debug_set_fusion_group_inlining(False)
+        self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
+        torch._C._jit_set_te_must_use_llvm_cpu(False)
+
+    def restore(self):
+        torch._C._jit_set_profiling_executor(self.old_profiling_executor)
+        torch._C._get_graph_executor_optimize(self.old_profiling_mode)
+
+        torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
+        torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
+        torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
+        torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
+        torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
+
+def clone_inputs(args):
+    inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
+
+    for arg in args:
+        if isinstance(arg, torch.Tensor):
+            inputs.append(arg.detach().clone())
+        elif is_iterable_of_tensors(arg):
+            inputs.append([t.detach().clone() for t in arg])
+        else:
+            inputs.append(arg)
+
+    return inputs
+
+def get_traced_sample_variant_pairs(device, dtype, op):
+    # tuples of (variant, sample)
+    outputs: list[tuple[Any, Any]] = []
+
+    samples = op.sample_inputs(device, dtype)
+
+    # Acquires variants to test
+    func = op.get_op()
+    method = op.get_method()
+    variants = {
+        # TODO: inplace tests currently fail, fix and add inplace variant
+        'function': func, 'method': method,
+    }
+
+    # TODO: find better way to standardize on op registration itself..
+    has_fake_function = op.name in ["resize_", 'resize_as_']
+
+    if has_fake_function:
+        variants = {'method': getattr(torch.Tensor, op.name)}
+
+    # In eager mode, these ops can take (Tensor, bool) args; but in
+    # JIT they can only take (Tensor, Scalar), and bool is not a
+    # scalar in the JIT type system. So to test these in JIT, the bool
+    # is converted to an int for the test.
+    ops_with_unsupported_bool_args = [
+        {
+            "name": "div_floor_rounding",
+            "arg_idx": [0],
+        },
+        {
+            "name": "div_no_rounding_mode",
+            "arg_idx": [0],
+        },
+        {
+            "name": "div_trunc_rounding",
+            "arg_idx": [0],
+        },
+        {
+            "name": "index_fill",
+            "arg_idx": [2],
+        },
+        {
+            "name": "full_like",
+            "arg_idx": [0],
+        },
+        {
+            "name": "mul",
+            "arg_idx": [0],
+        },
+        {
+            "name": "new_full",
+            "arg_idx": [1],
+        },
+    ]
+
+    # doesn't support tracing
+    if has_fake_function:
+        return outputs
+
+    for sample in samples:
+        for variant in variants.values():
+            if variant is None:
+                continue
+
+            if is_lambda(variant):
+                continue
+
+            matching_ops = filter(lambda x: op.formatted_name == x["name"], ops_with_unsupported_bool_args)
+            for op_data in matching_ops:
+                for idx in op_data["arg_idx"]:
+                    args = list(sample.args)
+                    if len(sample.args) > idx and isinstance(sample.args[idx], bool):
+                        args[idx] = int(args[idx])
+                    sample.args = tuple(args)
+
+            outputs.append((variant, sample))
+
+    return outputs
+
+# types.LambdaType gave false positives
+def is_lambda(lamb):
+    LAMBDA = lambda: 0  # noqa: E731
+    return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71f0f46854756a4b4251df6a53a03a288183172
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_tensor.py
@@ -0,0 +1,168 @@
+# mypy: ignore-errors
+
+import torch
+from torch.utils._pytree import tree_map
+from typing import Optional
+from collections.abc import Iterator
+import logging
+import contextlib
+import itertools
+from torch.utils._dtype_abbrs import dtype_abbrs as _dtype_abbrs
+from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils.weak import WeakTensorKeyDictionary
+import functools
+from torch._C._profiler import gather_traceback, symbolize_tracebacks
+
+logger = logging.getLogger("LoggingTensor")
+
+# How the chain of calls works for LoggingTensor:
+# 1. Call torch.sin
+# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
+# 3. Enter dispatcher, wind your way through Autograd
+# 4. Hit Python dispatch key, call __torch_dispatch__
+
+# This Tensor can work with autograd in two ways:
+#  - The wrapped Tensor does not require gradients. In that case, the LoggingTensor
+#    can require gradients if the user asks for it as a constructor kwarg.
+#  - The wrapped Tensor can require gradients. In that case autograd will be tracked
+#    for the wrapped Tensor and the LoggingTensor itself cannot require gradients.
+# WARNING: We allow these two possibilities for testing purposes. You should NEVER use both in a single
+# test or you might get surprising behavior.
+
+# TODO: TensorBase should work
+class LoggingTensor(torch.Tensor):
+    elem: torch.Tensor
+
+    __slots__ = ['elem']
+
+    context = contextlib.nullcontext
+
+    @staticmethod
+    def __new__(cls, elem, *args, **kwargs):
+        # The wrapping tensor (LoggingTensor) shouldn't hold any
+        # memory for the class in question, but it should still
+        # advertise the same device as before
+        r = torch.Tensor._make_wrapper_subclass(
+            cls, elem.size(),
+            strides=elem.stride(), storage_offset=elem.storage_offset(),
+            # TODO: clone storage aliasing
+            dtype=elem.dtype, layout=elem.layout,
+            device=elem.device, requires_grad=kwargs.get("requires_grad", False)
+        )
+        # ...the real tensor is held as an element on the tensor.
+        r.elem = elem.detach() if r.requires_grad else elem
+        return r
+
+    def __repr__(self):
+        return super().__repr__(tensor_contents=f"{self.elem}")
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        def unwrap(e):
+            return e.elem if isinstance(e, cls) else e
+
+        def wrap(e):
+            return cls(e) if isinstance(e, torch.Tensor) else e
+
+        with cls.context():
+            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
+        logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)  # noqa: G004
+        return rs
+
+class LoggingTensorMode(TorchDispatchMode):
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        if kwargs is None:
+            kwargs = {}
+        rs = func(*args, **kwargs)
+        logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)  # noqa: G004
+        return rs
+
+class LoggingTensorReentrant(LoggingTensor):
+    context = torch.overrides.enable_reentrant_dispatch
+
+# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
+class LoggingTensorHandler(logging.Handler):
+    def __init__(
+            self, log_list: list[str], use_shortid_for_all_tensors: bool,
+            with_type: bool, tracebacks_list: Optional[list]) -> None:
+        logging.Handler.__init__(self)
+        self.log_list = log_list
+        self.use_shortid_for_all_tensors = use_shortid_for_all_tensors
+        self.tracebacks_list = tracebacks_list
+        self.memo = WeakTensorKeyDictionary()
+        self.next_id = 0
+        self.with_type = with_type
+
+    def _shortid(self, t: torch.Tensor) -> int:
+        if t not in self.memo:
+            self.memo[t] = self.next_id
+            self.next_id += 1
+        return self.memo[t]
+
+    def _fmt(self, a: object, with_type: bool = False) -> str:
+        cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor
+        if isinstance(a, cond_cls):
+            maybe_type = ""
+            if with_type and self.with_type:
+                maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]"
+            x = f"${self._shortid(a)}{maybe_type}"
+            return x
+        else:
+            return repr(a)
+
+    def emit(self, record):
+        fmt_args = ", ".join(
+            itertools.chain(
+                (str(tree_map(self._fmt, a)) for a in record.args[0]),
+                (f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()),
+            )
+        )
+        fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2])
+        self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})')
+        if self.tracebacks_list is not None:
+            self.tracebacks_list.append(record.traceback)
+
+def log_input(name: str, var: object) -> None:
+    logger.info("input", (name,), {}, var)  # noqa: PLE1205
+
+class GatherTraceback(logging.Filter):
+    def __init__(self, python=True, script=True, cpp=False):
+        self.python = python
+        self.script = script
+        self.cpp = cpp
+
+    def filter(self, record):
+        record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp)
+        return True
+
+@contextlib.contextmanager
+def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[list[str]]:
+    collect_traceback = python_tb or script_tb or cpp_tb
+    log_list: list[str] = []
+    tracebacks_list: list[str] = []
+    handler = LoggingTensorHandler(
+        log_list,
+        with_type=True,
+        use_shortid_for_all_tensors=is_mode,
+        tracebacks_list=tracebacks_list if collect_traceback else None
+    )
+    logger.addHandler(handler)
+    logger.setLevel(logging.INFO)
+    logger.propagate = False
+    if collect_traceback:
+        logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb))
+    try:
+        if collect_traceback:
+            yield log_list, tracebacks_list
+        else:
+            yield log_list
+    finally:
+        symbolized_tracebacks = symbolize_tracebacks(tracebacks_list)
+        tracebacks_list.clear()
+        tracebacks_list.extend(symbolized_tracebacks)
+        logger.removeHandler(handler)
+
+@contextlib.contextmanager
+def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False):
+    with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs:
+        yield logs
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d017ffe14ff40dca3fd19779ee20097cb3e885b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/logging_utils.py
@@ -0,0 +1,243 @@
+# mypy: ignore-errors
+
+import torch._dynamo.test_case
+import unittest.mock
+import os
+import contextlib
+import torch._logging
+import torch._logging._internal
+from contextlib import AbstractContextManager
+from typing import Callable
+from torch._dynamo.utils import LazyString
+from torch._inductor import config as inductor_config
+import logging
+import io
+
+@contextlib.contextmanager
+def preserve_log_state():
+    prev_state = torch._logging._internal._get_log_state()
+    torch._logging._internal._set_log_state(torch._logging._internal.LogState())
+    try:
+        yield
+    finally:
+        torch._logging._internal._set_log_state(prev_state)
+        torch._logging._internal._init_logs()
+
+def log_settings(settings):
+    exit_stack = contextlib.ExitStack()
+    settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
+    exit_stack.enter_context(preserve_log_state())
+    exit_stack.enter_context(settings_patch)
+    torch._logging._internal._init_logs()
+    return exit_stack
+
+def log_api(**kwargs):
+    exit_stack = contextlib.ExitStack()
+    exit_stack.enter_context(preserve_log_state())
+    torch._logging.set_logs(**kwargs)
+    return exit_stack
+
+
+def kwargs_to_settings(**kwargs):
+    INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
+
+    settings = []
+
+    def append_setting(name, level):
+        if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY:
+            settings.append(INT_TO_VERBOSITY[level] + name)
+            return
+        else:
+            raise ValueError("Invalid value for setting")
+
+    for name, val in kwargs.items():
+        if isinstance(val, bool):
+            settings.append(name)
+        elif isinstance(val, int):
+            append_setting(name, val)
+        elif isinstance(val, dict) and name == "modules":
+            for module_qname, level in val.items():
+                append_setting(module_qname, level)
+        else:
+            raise ValueError("Invalid value for setting")
+
+    return ",".join(settings)
+
+
+# Note on testing strategy:
+# This class does two things:
+# 1. Runs two versions of a test:
+#    1a. patches the env var log settings to some specific value
+#    1b. calls torch._logging.set_logs(..)
+# 2. patches the emit method of each setup handler to gather records
+# that are emitted to each console stream
+# 3. passes a ref to the gathered records to each test case for checking
+#
+# The goal of this testing in general is to ensure that given some settings env var
+# that the logs are setup correctly and capturing the correct records.
+def make_logging_test(**kwargs):
+    def wrapper(fn):
+        @inductor_config.patch({"fx_graph_cache": False})
+        def test_fn(self):
+
+            torch._dynamo.reset()
+            records = []
+            # run with env var
+            if len(kwargs) == 0:
+                with self._handler_watcher(records):
+                    fn(self, records)
+            else:
+                with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
+                    fn(self, records)
+
+            # run with API
+            torch._dynamo.reset()
+            records.clear()
+            with log_api(**kwargs), self._handler_watcher(records):
+                fn(self, records)
+
+
+        return test_fn
+
+    return wrapper
+
+def make_settings_test(settings):
+    def wrapper(fn):
+        def test_fn(self):
+            torch._dynamo.reset()
+            records = []
+            # run with env var
+            with log_settings(settings), self._handler_watcher(records):
+                fn(self, records)
+
+        return test_fn
+
+    return wrapper
+
+class LoggingTestCase(torch._dynamo.test_case.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        cls._exit_stack.enter_context(
+            unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
+        )
+        cls._exit_stack.enter_context(
+            unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
+        )
+        cls._exit_stack.enter_context(
+            unittest.mock.patch("torch._dynamo.config.verbose", False)
+        )
+
+    @classmethod
+    def tearDownClass(cls):
+        cls._exit_stack.close()
+        torch._logging._internal.log_state.clear()
+        torch._logging._init_logs()
+
+    def hasRecord(self, records, m):
+        return any(m in r.getMessage() for r in records)
+
+    def getRecord(self, records, m):
+        record = None
+        for r in records:
+            # NB: not r.msg because it looks like 3.11 changed how they
+            # structure log records
+            if m in r.getMessage():
+                self.assertIsNone(
+                    record,
+                    msg=LazyString(
+                        lambda: f"multiple matching records: {record} and {r} among {records}"
+                    ),
+                )
+                record = r
+        if record is None:
+            self.fail(f"did not find record with {m} among {records}")
+        return record
+
+    # This patches the emit method of each handler to gather records
+    # as they are emitted
+    def _handler_watcher(self, record_list):
+        exit_stack = contextlib.ExitStack()
+
+        def emit_post_hook(record):
+            nonlocal record_list
+            record_list.append(record)
+
+        # registered logs are the only ones with handlers, so patch those
+        for log_qname in torch._logging._internal.log_registry.get_log_qnames():
+            logger = logging.getLogger(log_qname)
+            num_handlers = len(logger.handlers)
+            self.assertLessEqual(
+                num_handlers,
+                2,
+                "All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
+            )
+
+            self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
+
+            for handler in logger.handlers:
+                old_emit = handler.emit
+
+                def new_emit(record):
+                    old_emit(record)
+                    emit_post_hook(record)
+
+                exit_stack.enter_context(
+                    unittest.mock.patch.object(handler, "emit", new_emit)
+                )
+
+        return exit_stack
+
+
+def logs_to_string(module, log_option):
+    """Example:
+    logs_to_string("torch._inductor.compile_fx", "post_grad_graphs")
+    returns the output of TORCH_LOGS="post_grad_graphs" from the
+    torch._inductor.compile_fx module.
+    """
+    log_stream = io.StringIO()
+    handler = logging.StreamHandler(stream=log_stream)
+
+    @contextlib.contextmanager
+    def tmp_redirect_logs():
+        try:
+            logger = torch._logging.getArtifactLogger(module, log_option)
+            logger.addHandler(handler)
+            yield
+        finally:
+            logger.removeHandler(handler)
+
+    def ctx_manager():
+        exit_stack = log_settings(log_option)
+        exit_stack.enter_context(tmp_redirect_logs())
+        return exit_stack
+
+    return log_stream, ctx_manager
+
+
+def multiple_logs_to_string(module: str, *log_options: str) -> tuple[list[io.StringIO], Callable[[], AbstractContextManager[None]]]:
+    """Example:
+    multiple_logs_to_string("torch._inductor.compile_fx", "pre_grad_graphs", "post_grad_graphs")
+    returns the output of TORCH_LOGS="pre_graph_graphs, post_grad_graphs" from the
+    torch._inductor.compile_fx module.
+    """
+    log_streams = [io.StringIO() for _ in range(len(log_options))]
+    handlers = [logging.StreamHandler(stream=log_stream) for log_stream in log_streams]
+
+    @contextlib.contextmanager
+    def tmp_redirect_logs():
+        loggers = [torch._logging.getArtifactLogger(module, option) for option in log_options]
+        try:
+            for logger, handler in zip(loggers, handlers):
+                logger.addHandler(handler)
+            yield
+        finally:
+            for logger, handler in zip(loggers, handlers):
+                logger.removeHandler(handler)
+
+    def ctx_manager() -> AbstractContextManager[None]:
+        exit_stack = log_settings(", ".join(log_options))
+        exit_stack.enter_context(tmp_redirect_logs())
+        return exit_stack  # type: ignore[return-value]
+
+    return log_streams, ctx_manager
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..97c38f3560625213fbd59d09a9cfd22bad26ba04
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/__init__.py
@@ -0,0 +1,4 @@
+# mypy: ignore-errors
+
+import torch.testing._internal.opinfo.core
+import torch.testing._internal.opinfo.definitions
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cd248792dcb11321c2993231d552a81f7ded373
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/core.py
@@ -0,0 +1,3218 @@
+# mypy: ignore-errors
+
+import collections
+import collections.abc
+import contextlib
+import logging
+import math
+import operator
+import unittest
+from abc import ABC, abstractmethod
+from collections.abc import Iterable
+from dataclasses import asdict, dataclass, field
+from enum import Enum
+from functools import partial
+from itertools import product
+from typing import Any, Callable, Optional, TypeVar, Union
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import (
+    skipCPUIfNoFFT,
+    tol,
+    toleranceOverride,
+)
+from torch.testing._internal.common_dtype import (
+    _dispatch_dtypes,
+    floating_and_complex_types,
+    floating_and_complex_types_and,
+    floating_types,
+    get_all_dtypes,
+)
+from torch.testing._internal.common_utils import (
+    extract_test_fn,
+    IS_FBCODE,
+    is_iterable_of_tensors,
+    noncontiguous_like,
+    OPINFO_SAMPLE_INPUT_INDEX,
+    TEST_WITH_ROCM,
+    torch_to_numpy_dtype_dict,
+    TrackedInputIter,
+    USE_PYTEST,
+)
+from torch.testing._internal.opinfo import utils
+from torchgen.utils import dataclass_repr
+
+
+# setup logging
+log = logging.getLogger(__name__)
+
+# Reasonable testing sizes for dimensions
+L = 20
+M = 10
+S = 5
+XS = 3
+
+# Unique value to distinguish default from anything else
+_NOTHING = object()
+
+
+# Extension of getattr to support qualified names
+# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm
+def _getattr_qual(obj, name, default=_NOTHING):
+    try:
+        for path in name.split("."):
+            obj = getattr(obj, path)
+        return obj
+    except AttributeError:
+        if default is not _NOTHING:
+            return default
+        else:
+            raise
+
+
+class DecorateInfo:
+    """Describes which test, or type of tests, should be wrapped in the given
+    decorators when testing an operator. Any test that matches all provided
+    arguments will be decorated. The decorators will only be applied if the
+    active_if argument is True."""
+
+    __slots__ = [
+        "decorators",
+        "cls_name",
+        "test_name",
+        "device_type",
+        "dtypes",
+        "active_if",
+    ]
+
+    def __init__(
+        self,
+        decorators,
+        cls_name=None,
+        test_name=None,
+        *,
+        device_type=None,
+        dtypes=None,
+        active_if=True,
+    ):
+        self.decorators = (
+            list(decorators)
+            if isinstance(decorators, collections.abc.Sequence)
+            else [decorators]
+        )
+        self.cls_name = cls_name
+        self.test_name = test_name
+        self.device_type = device_type
+        self.dtypes = dtypes
+        self.active_if = active_if
+
+        # Validate dtypes
+        if self.dtypes is not None:
+            for dtype in self.dtypes:
+                assert isinstance(dtype, torch.dtype)
+
+    def is_active(self, cls_name, test_name, device_type, dtype, param_kwargs):
+        return (
+            self.active_if
+            and (self.cls_name is None or self.cls_name == cls_name)
+            and (self.test_name is None or self.test_name == test_name)
+            and (self.device_type is None or self.device_type == device_type)
+            and (self.dtypes is None or dtype in self.dtypes)
+            # Support callables over kwargs to determine if the decorator is active.
+            and (
+                self.active_if(param_kwargs)
+                if isinstance(self.active_if, Callable)
+                else self.active_if
+            )
+        )
+
+
+# FIXME
+# Note: historically the 'input' kwarg had to be a Tensor or TensorList, but we are trying
+#   to support scalar inputs, too. Some tests still depend on 'input' being a Tensor
+#   or TensorList, however.
+class SampleInput:
+    """Represents sample inputs to a function."""
+
+    __slots__ = [
+        "input",
+        "args",
+        "kwargs",
+        "output_process_fn_grad",
+        "broadcasts_input",
+        "name",
+    ]
+
+    def __init__(
+        self,
+        input,
+        *var_args,
+        args=None,
+        kwargs=None,
+        output_process_fn_grad=None,
+        broadcasts_input=None,
+        name=None,
+        **var_kwargs,
+    ):
+        # input is the first input to the op and is typically either a Tensor or TensorList (Sequence[Tensor]).
+        # This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...).
+        self.input = input
+
+        # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as
+        # SampleInput(input, *args, **kwargs) but not to mix the two forms
+        if args is not None or kwargs is not None:
+            assert (
+                not var_args and not var_kwargs
+            ), """
+A SampleInput can be constructed "naturally" with *args and **kwargs or by
+explicitly setting the "args" and "kwargs" parameters, but the two
+methods of construction cannot be mixed!"""
+        elif len(var_args) or len(var_kwargs):
+            assert (
+                output_process_fn_grad is None
+                and broadcasts_input is None
+                and name is None
+            ), """
+A SampleInput constructed "naturally" with *args and **kwargs
+cannot specify additional metadata in keyword arguments"""
+
+        self.args = args if args is not None else var_args
+        assert isinstance(self.args, tuple)
+        self.kwargs = kwargs if kwargs is not None else var_kwargs
+        assert isinstance(self.kwargs, dict)
+
+        self.output_process_fn_grad = (
+            output_process_fn_grad
+            if output_process_fn_grad is not None
+            else lambda x: x
+        )
+        self.name = name if name is not None else ""
+
+        # Specifies if `self.input` is broadcasted or not,
+        # given that the operator supports broadcasting.
+        # This field is used to verify the behavior for inplace variant.
+        #
+        # If a SampleInput is marked with `broadcasts_input=True`,
+        # it is verified that we get a `RuntimeError` with this sample,
+        # and inplace variant. Also inplace grad{grad} tests are skipped,
+        # for such inputs (as they will error out otherwise).
+        self.broadcasts_input = (
+            broadcasts_input if broadcasts_input is not None else False
+        )
+
+    def with_metadata(
+        self, *, output_process_fn_grad=None, broadcasts_input=None, name=None
+    ):
+        if output_process_fn_grad is not None:
+            self.output_process_fn_grad = output_process_fn_grad
+        if broadcasts_input is not None:
+            self.broadcasts_input = broadcasts_input
+        if name is not None:
+            self.name = name
+        return self
+
+    def _repr_helper(self, formatter):
+        # Helper function to return the details of the SampleInput as `str`
+        # It consolidates all the fields of SampleInput and allows,
+        # formatting the fields like `input`, `args`, etc with `formatter`
+        # callable to customize the representation.
+        # Look at `summary` method for example.
+        arguments = [
+            f"input={formatter(self.input)}",
+            f"args={formatter(self.args)}",
+            f"kwargs={formatter(self.kwargs)}",
+            f"broadcasts_input={self.broadcasts_input}",
+            f"name={repr(self.name)}",
+        ]
+
+        return f'SampleInput({", ".join(a for a in arguments if a is not None)})'
+
+    def __repr__(self):
+        return self._repr_helper(lambda x: x)
+
+    def summary(self):
+        # Returns the SampleInput details in a more
+        # friendly format.
+        # It formats `Tensor` and `TensorList`
+        # in a more condensed representation.
+        def formatter(arg):
+            # Format any instance of `Tensor` (standalone, in list, or in dict)
+            # by Tensor[TensorShape]
+            # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4]
+            if isinstance(arg, torch.Tensor):
+                shape = str(tuple(arg.shape))
+                dtype = str(arg.dtype)
+                device = str(arg.device)
+                contiguity_suffix = ""
+                # NB: sparse CSR tensors annoyingly return is_sparse=False
+                is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr
+                if not is_sparse and not arg.is_contiguous():
+                    contiguity_suffix = ", contiguous=False"
+                return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]'
+            elif isinstance(arg, dict):
+                return {k: formatter(v) for k, v in arg.items()}
+            elif is_iterable_of_tensors(arg):
+                return "TensorList[" + ", ".join(map(formatter, arg)) + "]"
+            elif isinstance(arg, (list, tuple)):  # Handle list, tuple
+                return "(" + ",".join(map(formatter, arg)) + ")"
+
+            return repr(arg)
+
+        return self._repr_helper(formatter)
+
+    # Applies the transform f(t) -> t to each tensor and dtype in the SampleInput
+    def transform(self, f):
+        def tt(t):
+            def _tt(t):
+                with torch.no_grad():
+                    return f(t)
+
+            if isinstance(t, torch.Tensor):
+                return _tt(t)
+            elif isinstance(t, torch.dtype):
+                return _tt(t)
+            elif isinstance(t, list):
+                return list(map(tt, t))
+            elif isinstance(t, tuple):
+                return tuple(map(tt, t))
+            elif isinstance(t, dict):
+                return {k: tt(v) for k, v in t.items()}
+            else:
+                return t
+
+        sample_tt_input, tt_args, tt_kwargs = (
+            tt(self.input),
+            tt(self.args),
+            tt(self.kwargs),
+        )
+
+        # Note the transformed SampleInput assumes metadata like output_process_fn_grad is still valid!
+        return SampleInput(
+            sample_tt_input,
+            args=tt_args,
+            kwargs=tt_kwargs,
+            output_process_fn_grad=self.output_process_fn_grad,
+            broadcasts_input=self.broadcasts_input,
+            name=self.name + "_transformed",
+        )
+
+    # Returns the NumPy version of the sample input object in the form of a tuple: (input, args, kwargs)
+    # Converts tensors to ndarrays by calling .detach().cpu().numpy() on them
+    # Converts dtypes by remapping them using torch_to_numpy_dtype_dict
+    def numpy(self):
+        def to_numpy(t):
+            if isinstance(t, torch.Tensor):
+                if t.dtype is torch.bfloat16:
+                    return t.detach().cpu().to(torch.float32).numpy()
+                if t.dtype is torch.chalf:
+                    return t.detach().cpu().to(torch.cfloat).numpy()
+                return t.detach().cpu().numpy()
+            elif isinstance(t, torch.dtype):
+                return torch_to_numpy_dtype_dict[t]
+
+            return t
+
+        return self.transform(to_numpy)
+
+    def noncontiguous(self):
+        def to_noncontiguous(t):
+            if isinstance(t, torch.Tensor):
+                return noncontiguous_like(t)
+            elif isinstance(t, torch.dtype):
+                return t
+
+            return t
+
+        return self.transform(to_noncontiguous)
+
+
+NumericsFilter = collections.namedtuple("NumericsFilter", ["condition", "safe_val"])
+
+
+class ErrorInput:
+    """
+    A SampleInput that will cause the operation to throw an error plus information
+    about the resulting error.
+    """
+
+    __slots__ = ["sample_input", "error_type", "error_regex"]
+
+    def __init__(self, sample_input, *, error_type=RuntimeError, error_regex):
+        self.sample_input = sample_input
+        self.error_type = error_type
+        self.error_regex = error_regex
+
+
+class AliasInfo:
+    """Class holds alias information. For example, torch.abs ->
+    torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_
+    """
+
+    def __init__(self, alias_name):
+        self.name = alias_name
+        self.op = _getattr_qual(torch, alias_name)
+        self.method_variant = getattr(torch.Tensor, alias_name, None)
+        self.inplace_variant = getattr(torch.Tensor, alias_name + "_", None)
+
+    def __call__(self, *args, **kwargs):
+        return self.op(*args, **kwargs)
+
+
+# Note [OpInfos]
+# ~~~~~~~~~~~~~~
+#
+# The majority of this note was written shortly after the PyTorch 1.9 release.
+# If you notice it's out-of-date or think it could be improved then please
+# file an issue.
+#
+# See also: the OpInfo tracker (https://github.com/pytorch/pytorch/issues/54261)
+# See also: "Writing Test Templates" in common_device_type.py to learn how to
+#   parametrize a test template using OpInfos.
+# See also: PyTorch's GitHub wiki on running and writing tests
+#   https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
+# See also: ModuleInfos, OpInfo's sister class, defined in common_modules.py
+#
+# An OpInfo is a collection of metadata related to a PyTorch operator. This
+#   metadata is used to generate tests that validate properties of the operator,
+#   like if it implements the correct gradient formula.
+#
+# WHY OPINFOS?
+# ~~~~~~~~~~~~
+#
+# OpInfos are principally intended to do three things:
+#
+#   1) to allow systematic testing over all PyTorch's operators
+#   2) to simplify operating testing by autogenerating many tests
+#   3) to allow systems (like autograd, torchscript, fx, nnc...) to test
+#        against every PyTorch operator
+#
+# All these goals are still a work in progress. Not every operator has an
+#   OpInfo, and some operator tests that could be automatically generated
+#   still have to be written manually.
+#
+# It's helpful to understand that OpInfos are both about test simplification and
+#   modularity. PyTorch is a complicated framework with many interrelated systems,
+#   too many for any one person to keep track of. An OpInfo can be thought of as the
+#   interface between an operator implementer and those other systems. Instead of
+#   requiring the implementer of torch.foo understand how to test its forward
+#   mode AD or NNC support that's typically handled automatically just by
+#   defining an OpInfo.
+#
+# It's often surprising to OpInfo writers that just implementing an OpInfo
+#   typically can't verify an operator is actually implemented correctly:
+#
+# "If an OpInfo doesn't validate my op works as expected, what's the point
+#     of it?"
+#
+# But the point of is the above. OpInfos are intended to let you focus on testing
+#   the operator logic you're familiar with instead of having to write tests for
+#   how the operator interacts with each of PyTorch's many systems.
+#
+# And, OK, it turns out that SOMETIMES just writing an OpInfo DOES
+#   validate your op works as expected, but that's only in special
+#   cases. See below for details.
+#
+# WHAT'S AN OPINFO?
+# ~~~~~~~~~~~~~~~~~
+#
+# So what is an OpInfo? It's a Python class that describes an operator's properties,
+#   like which dtypes it supports on the CPU and whether it has any aliases.
+#   These properties can be divided into three categories:
+#
+#   1) Metadata describing the operator, like the operator's name and if it
+#     "supports" the out kwarg.
+#   2) Test directives, like "skips" that tell the test suite to skip some
+#     tests.
+#   3) A "sample inputs" function that generates valid inputs for the operator.
+#
+# OpInfo attributes are described in more detail below.
+#
+# THE SAMPLE INPUTS FUNCTION
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# The "sample inputs" function merits special elaboration. This function is
+#   crucial to testing with OpInfos. A typical OpInfo test has to treat the operator
+#   as a black box. There's no structure for the test to understand or exploit.
+#   Without "sample inputs" it wouldn't even know how to call the OpInfo's
+#   operator. The sample input function saves the day by providing different
+#   "SampleInputs" that can be used to call the operator. A sample input
+#   function should have the following signature:
+#
+#   def sample_inputs_foo(op_info, device, dtype, requires_grad, **kwargs):
+#
+#   And should return an iterable of SampleInputs (see the class description
+#   above). Each SampleInput defines an "input", "args", "kwargs", an
+#   "output_process_fn_grad" function, the "broadcasts_input" bool and a
+#   "name".
+#
+#   All the "sample_inputs" functions are invoked within a `torch.no_grad()`
+#   environment for efficiency and correctness. As such remember to set the
+#   "requires_grad" flag on the inputs **after** performing any transformations
+#   on them.
+#
+# The "input" is the first argument to the operator, or the tensor that
+#   the method or inplace variants of the operator should be called on, and
+#   should be on the requested device, of the requested dtype, and its
+#   requires_grad attribute should be set to the requires_grad argument.
+#
+# "args" should contain positional arguments, and "kwargs" keyword arguments.
+#
+# "output_process_fn_grad" has an interesting name. It's a function that maps
+#   the operator's output (when given the input, args, and kwargs) to the
+#   portion of the output to gradcheck. For example, consider an operator
+#   like torch.linalg.slogdet
+#   (https://pytorch.org/docs/main/generated/torch.linalg.slogdet.html).
+#   This operator returns a tuple of two tensors, but the first tensor
+#   cannot be backwarded through. Its "output_process_fn_grad" filters
+#   this output tuple to just the second argument, which we can call backward
+#   on. Functions that produce a single tensor can ignore this argument.
+#
+# "broadcasts_input" is a bool indicated if the SampleInput causes the operator
+#   to broadcast the "input" argument. This is important for tests to understand
+#   because inplace variants of operations throw a runtime error if they
+#   would broadcast their input arguments, so tests that work with inplace
+#   variants filter SampleInputs that broadcast their input.
+#
+# "name" is a string that's just used for debugging. It appears when printing
+#   the SampleInput.
+#
+# Sample inputs are designed to be used with many tests, some
+#   that are very time consuming, so they should be a small
+#   set with small tensors. An elaborated set of sample inputs
+#   can be specified using the "reference_inputs_func" attribute.
+#   The "reference inputs" for an operation are an extended
+#   set of sample inputs that can more exhaustively test an
+#   operator. They are used by only a few tests that are careful
+#   not to take too long to run. Adding reference inputs
+#   is highly encouraged!
+#
+# THE (OPTIONAL) ERROR INPUTS FUNCTION
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# OpInfos may optionally specify "error inputs" through an error function. If
+#   specified test_errors in test_ops.py will call the op with these inputs
+#   and validate that the desired error is thrown.
+#
+# Error inputs automate a common testing pattern where multiple inputs are
+#   passed to an operation and the errors they thrown are reviewed. Tests
+#   written in this style should be ported to the new OpInfo pattern.
+#
+# Error inputs are specified using the ErrorInputs class, which contains
+#   a SampleInput (see above) and data about the expected error.
+#
+# OPINFO FILE ORGANIZATION
+# ~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# All OpInfos are currently defined in this file. Most OpInfo tests are defined
+#   in test_ops.py, but some system-specific tests are defined in those
+#   systems' test files, and subclass-specific tests are defined in the test
+#   file that corresponds to that subclass (see the below).
+#   Expect a reorganization in the future.
+#
+# WHAT'S TESTED?
+# ~~~~~~~~~~~~~~
+#
+# Every OpInfo in the op_db sequence has the following properties validated in
+# test_ops.py:
+#
+#   - that its supported dtypes are specified correctly
+#   - that the operation produces the same results when called with noncontiguous inputs
+#   - that it supports the out= argument properly (if it allows out=),
+#       see https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
+#   - that it works with the conjugate view bit properly
+#   - that its function, method, and inplace variants perform the same operation
+#       (that is, that torch.add, torch.Tensor.add, and torch.Tensor.add_ all
+#       do the same thing).
+#   - that its inplace variant preserves the input's storage
+#   - that its gradient formula is implemented correctly, and that it supports
+#       gradgrad and complex grad and gradgrad and forward mode AD properly for
+#       the op's function and inplace variants (method variants are skipped
+#       to reduce test time).
+#   - that the operation performs the same operation when traced or scripted
+#       using the jit
+#   - that the operation is autodifferentiated by the jit as expected
+#   - that the operator's aliases, if any, perform the same operation and that
+#       the jit understands the alias
+#   - that the operator throws the correct errors (if error_inputs is defined)
+#   - that the operator produces the same results as a NumPy reference (if ref is defined)
+#   - that the operator produces the same results as a NumPy reference on an extended
+#       set of "reference inputs" (if both ref and reference_inputs_func are defined)
+#       (NOTE: elementwise unary and elementwise binary OpInfos do this even if only
+#         ref is defined, because they effectively autogenerate reference inputs)
+#   - that the operator works on different CUDA devices
+#
+# Additional OpInfo tests are in test_jit_fuser_te.py, test_fx_experimental.py,
+#   and test_fx.py. These tests validate that operators work with NNC and FX
+#   as expected.
+#
+# For performance, some of the above tests may only run on the first
+#   SampleInput returned by an OpInfo's sample input function.
+#
+# In addition to these tests, some subclasses (discussed in the next section)
+#   define additional tests.
+#
+# Critically, as mentioned above, what's not necessarily tested is that the operator
+#   works as expected. When implementing an OpInfo an engineer must still
+#   typically write one or more tests validating the operator's behavior.
+#   The exception to this is if reference testing is sufficient, or if
+#   the operation belongs to an OpInfo subclass that has more exhaustive
+#   operator testing. Elementwise unary and elementwise binary operators,
+#   in particular, usually don't require additional testing beyond
+#   writing an Opinfo.
+#
+#
+# OPINFO (SUB)CLASSES
+# ~~~~~~~~~~~~~~~~~~~
+#
+# In addition to the OpInfo base class there are several specialized OpInfo
+#   subclasses. For example, the UnaryUfuncInfo subclass is used for
+#   unary elementwise operations. These operations have a common structure
+#   that test_unary_ufuncs.py exploits with additional automated testing.
+#   The automated testing in test_unary_ufuncs.py is so thorough, comparing
+#   the operator to a NumPy reference function on a plethora of values, that
+#   just implementing an OpInfo for a unary elementwise operation is often
+#   sufficient testing.
+#
+# The ForeachFuncInfo is another OpInfo subclass that is hyper-specialized to a
+#   very unique class of operations. These OpInfos aren't included in the
+#   op_db sequence and have their own tests.
+#
+# Other OpInfo subclasses, like SpectralFuncInfo, are just for convenience
+# when writing OpInfos.
+#
+# TESTING A NEW OPERATOR
+# ~~~~~~~~~~~~~~~~~~~~~~
+#
+# If you're adding a new operator to any of the following namespaces:
+#   - torch
+#   - torch.fft
+#   - torch.linalg,
+#   - torch.special
+#   - torch.nn.functional
+# then you should typically add an OpInfo for it.
+#
+# As mentioned a couple times above, implementing an OpInfo is not
+#   usually sufficient testing (unless the operator is a unary or binary elementwise
+#   operator). The OpInfo will only test the properties described in the
+#   "WHAT'S TESTED" section. It DOES NOT necessarily verify that the operator is
+#   implemented correctly.
+#
+# TIPS FOR WRITING AN OPINFO AND OPINFO TESTS
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Writing an OpInfo can be a little daunting. Since the point of an OpInfo is to
+#   be consumed by a variety of systems it can be hard to understand how to
+#   deal with test failures or how to set the OpInfo metadata properly.
+#
+# Before adding an OpInfo it helps to look at other OpInfos. A sample inputs
+#   function must be defined, and the operator's dtypes must be specified.
+#   Once that's done you should run the operator's tests in test_ops.py
+#   (these can be filtered using the "-k" argument in pytest). Tests that
+#   fail should provide an error message that describes what to change about
+#   your OpInfo. You don't need to worry about changing an OpInfo's default
+#   values unless a test yells at you.
+#
+# Similarly, if you're writing a test that consumes OpInfos then it's critical
+#   your test provides a clear error message describing what to do when it
+#   fails. You should not assume the OpInfo implementer is familiar with your
+#   system.
+#
+# If you see a confusing error message while developing an OpInfo then please
+#   file an issue describing what happened.
+#
+# This trial-and-error approach to writing an OpInfo can be frustrating,
+#   but it's probably necessary as long as OpInfos don't require
+#   learning about all the systems that consume them. One thing that can help
+#   is the get_supported_dtypes() function defined in utils.py. This
+#   function can be used to programmatically specify the dtypes an operator
+#   supports, and is especially useful if writing an OpInfo on a machine
+#   without a CUDA device. See its documentation for more details.
+#
+# THE FUTURE OF OPINFOS AND OPINFO TESTING
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# In the future we expect OpInfo coverage to improve and cover
+#   the great majority of PyTorch's (public) operators.
+#
+
+
+# Classes and methods for the operator database
+@dataclass
+class OpInfo:
+    """Operator information and helper functions for acquiring it."""
+
+    # the string name of the function
+    name: str
+
+    # An optional reference function that accepts ndarrays (AKA "NumPy arrays").
+    # If given, the op will be compared with its reference on each of its sample inputs.
+    ref: Optional[Callable] = None
+
+    # the following metadata describes the operator, its variants, and its aliases, if any
+
+    # iterable of aliases, e.g. ("absolute",) for torch.abs
+    aliases: Iterable = None
+
+    # additional string to include in the test name
+    # this is useful when an op needs multiple OpInfos,
+    # like divide does, often because it's really several
+    # different ops behind the scenes
+    variant_test_name: str = ""
+
+    # the function variant of the operation, populated as torch. if None
+    op: Callable = None
+
+    # allows the method variant of this operation to be specified as follows:
+    # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name
+    # - if None, then the OpInfo explicitly specifies is has no associated method
+    # - if a Callable, then that callable should be the method associated with this operation
+    method_variant: Callable = _NOTHING
+
+    # allows the inplace variant of this operation to be specified as follows:
+    # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name
+    # - if None, then the OpInfo explicitly specifies is has no associated inplace variant
+    # - if a Callable, then that callable should be the inplace variant associated with this operation
+    inplace_variant: Callable = _NOTHING
+
+    # allows the operator variant of this operation to be specified as follows:
+    # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name
+    # - if None, then the OpInfo explicitly specifies is has no associated operator
+    # - if a Callable, then that callable should be the operator associated with this operation
+    operator_variant: Callable = _NOTHING
+
+    # allows the inplace operator variant of this operation to be specified as follows:
+    # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name
+    # - if None, then the OpInfo explicitly specifies is has no associated inplace operator
+    # - if a Callable, then that callable should be the inplace operator associated with this operation
+    inplace_operator_variant: Callable = _NOTHING
+
+    # the following metadata are test directives for skipping or modifying tests
+
+    # information about which tests to skip
+    skips: tuple = ()
+
+    # decorators to apply to generated tests
+    decorators: tuple = ()
+
+    # the following are pointers to functions to generate certain classes of inputs
+
+    # function to generate sample inputs with strided layouts
+    sample_inputs_func: Callable = None
+
+    # function to generate a more thorough set of samples inputs with strided layouts
+    reference_inputs_func: Callable = None
+
+    # function to generate inputs that will throw errors
+    error_inputs_func: Callable = None
+
+    # function to generate sparse (coo, csr, csc, bsr, bsc) inputs that will throw errors
+    error_inputs_sparse_func: Callable = None
+
+    # function to generate sample inputs with sparse coo layouts
+    sample_inputs_sparse_coo_func: Callable = None
+
+    # function to generate sample inputs with sparse csr layouts
+    sample_inputs_sparse_csr_func: Callable = None
+
+    # function to generate sample inputs with sparse csc layouts
+    sample_inputs_sparse_csc_func: Callable = None
+
+    # function to generate sample inputs with sparse bsr layouts
+    sample_inputs_sparse_bsr_func: Callable = None
+
+    # function to generate sample inputs with sparse bsc layouts
+    sample_inputs_sparse_bsc_func: Callable = None
+
+    # the following metadata relates to dtype support and is tested for correctness in test_ops.py
+
+    # dtypes this function works with on the CPU,
+    # inherited by other device types that don't specify their own dtypes
+    dtypes: _dispatch_dtypes = None
+
+    # the following dtypesIf... options override the dtypes value on their respective device types
+    # I.e. instead of writing multiple `dtypesIfCUDA`, `dtypesIfROCM`, etc one can simply define a dict
+    # dtypesIf = { 'cuda': (torch.float, torch.double), 'rocm': (torch.half, torch.bfloat16) }
+    dtypesIf: dict[str, _dispatch_dtypes] = field(default_factory=dict)
+
+    def __getattribute__(self, name: str) -> Any:
+        if name.startswith("dtypesIf") and name != "dtypesIf":
+            # TODO: Warn if used
+            dev_name = name.removeprefix("dtypesIf").lower()
+            return self.dtypesIf.get(dev_name)
+        return super().__getattribute__(name)
+
+    def __setattr__(self, name: str, value: Any) -> None:
+        # TODO: After migration, start adding warnings here
+        if name.startswith("dtypesIf") and name != "dtypesIf":
+            assert isinstance(value, (_dispatch_dtypes, type(None)))
+            dev_name = name.removeprefix("dtypesIf").lower()
+            self.dtypesIf[dev_name] = value
+            return
+        super().__setattr__(name, value)
+
+    # dtypes this function is expected to work with on CUDA
+    dtypesIfCUDA: _dispatch_dtypes = None
+
+    # dtypes this function is expected to work with on ROCM
+    dtypesIfROCM: _dispatch_dtypes = None
+
+    dtypesIfHpu: _dispatch_dtypes = None
+
+    # dtypes this function is expected to work with on XPU
+    dtypesIfXPU: _dispatch_dtypes = None
+
+    # backward dtypes this function is expected to work with
+    backward_dtypes: _dispatch_dtypes = None
+
+    # backward dtypes this function is expected to work with on CUDA
+    backward_dtypesIfCUDA: _dispatch_dtypes = None
+
+    # backward dtypes this function is expected to work with on ROCM
+    backward_dtypesIfROCM: _dispatch_dtypes = None
+
+    backward_dtypesIfHpu: _dispatch_dtypes = None
+
+    # the following metadata describes the operators out= support
+
+    # whether the op supports the out kwarg
+    # defaults to True, if the op does not allow the out kwarg or
+    # supports it incorrectly then test_out in test_ops.py should fail
+    supports_out: bool = True
+
+    # the following metadata relates to autograd support
+    # whether the operation supports backward mode AD
+    # if true, gradient correctness is tested in test_ops.py
+    # using the op's sample inputs
+    supports_autograd: bool = True
+
+    # whether the op supports second order gradients
+    # if true, gradgrad correctness is tested in test_ops.py
+    # defaults to support_autograd's value
+    # TODO: rename this to supports_bwgrad_bwgrad to be consistent with below
+    supports_gradgrad: bool = None
+
+    # whether the ops supports second order gradients via
+    # forward-over-reverse. If True, forward-over-reverse gradgrad correctness
+    # is tested. If False, test that forward grad is not implemented.
+    # Defaults to False.
+    supports_fwgrad_bwgrad: bool = False
+
+    # whether the operation supports inplace autograd
+    # if true, tested in test_ops.py
+    # defaults to supports_autograd's value
+    supports_inplace_autograd: bool = None
+
+    # Whether the operation support forward mode AD
+    # If the value is True, we check that the gradients are correct
+    # If the value is False, we test that forward grad is not implemented
+    supports_forward_ad: bool = False
+
+    # Whether the operation has a varargs variant
+    # (e.g. functions like ones, zeros, methods like view, permute)
+    supports_varargs: bool = False
+
+    # Whether the forward operation avoids materializing COW tensor inputs
+    supports_cow_input_no_materialize_forward: bool = True
+
+    # Whether the backward operation avoids materializing COW tensor inputs
+    supports_cow_input_no_materialize_backward: bool = True
+
+    # Whether to skip the backward part of the COW tensor input test
+    skip_cow_input_backward: bool = False
+
+    # If `supports_cow_input_no_materialize_forward == True`, this list contains
+    # the arg indices or kwarg names of inputs that are expected to materialize
+    allow_cow_input_materialize_forward: list[Union[int, str]] = None
+
+    # If `supports_cow_input_no_materialize_backward == True`, this list contains
+    # the arg indices or kwarg names of inputs that are expected to materialize
+    allow_cow_input_materialize_backward: list[Union[int, str]] = None
+
+    # wrapper function for gradcheck
+    gradcheck_wrapper: Callable = lambda op, *args, **kwargs: op(*args, **kwargs)
+
+    # whether to check batched grad when doing gradcheck
+    # defaults to support_autograd's value
+    check_batched_grad: bool = None
+
+    # whether to check batched grad grad when doing gradgradcheck
+    # default's to support_gradgrad's value
+    check_batched_gradgrad: bool = None
+
+    # whether to check batched forward grad when doing gradcheck
+    # defaults to the value of `supports_forward_ad`
+    check_batched_forward_grad: bool = None
+
+    # whether to check batched forward grad when doing gradcheck
+    # defaults to the value of `check_batched_forward_grad`
+    check_inplace_batched_forward_grad: bool = None
+
+    # tolerance for nondeterminism while performing gradcheck
+    gradcheck_nondet_tol: float = 0.0
+
+    # Whether to use the fast implementation for gradcheck/gradgradcheck.
+    # When set to None, defers to the default value provided by the wrapper
+    # function around gradcheck (testing._internal.common_utils.gradcheck)
+    gradcheck_fast_mode: bool = None
+
+    # the following metadata relates to JIT support and is tested for correctness in test_ops.py
+
+    # name of the corresponding aten:: operator
+    aten_name: str = None
+
+    # if this is a composite implicit autograd op, the decomposed op
+    decomp_aten_name: Optional[str] = None
+
+    # name of the corresponding aten:: operator for backwards
+    aten_backward_name: Optional[str] = None
+
+    # if a op's aten::node is expected to be symbolically autodiffed
+    assert_autodiffed: bool = False
+
+    # a list of strings with node names that are expected to be in a
+    # DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'],
+    # default is populated to be ['aten::(name of Python operator)']
+    autodiff_nonfusible_nodes: list[str] = None
+
+    # a list of strings with node names that are expected to be in FusionGroups
+    # inside of DifferentiableGraphs when this operation is autodiffed.
+    # Ex: ['aten::add', 'aten::mm'], defaults to an empty list
+    # Note: currently no ops use fusible nodes
+    autodiff_fusible_nodes: list[str] = None
+
+    # the following metadata relates to sparse support and is used in test_sparse.py
+
+    # whether the op supports sparse coo inputs, defaults to False
+    # TODO: rename supports_sparse to supports_sparse_coo
+    supports_sparse: bool = None
+
+    # only run tracing tests
+    supports_scripting: bool = True
+
+    # if the operator can be traced
+    supports_tracing: bool = True
+
+    # the following metadata relates to sparse compressed support and
+    # is used in test_sparse_csr.py and test_sparse.py
+
+    # whether the op supports sparse csr inputs, defaults to False
+    supports_sparse_csr: bool = None
+    # whether the op supports sparse csc inputs, defaults to False
+    supports_sparse_csc: bool = None
+    # whether the op supports sparse bsr inputs, defaults to False
+    supports_sparse_bsr: bool = None
+    # whether the op supports sparse bsc inputs, defaults to False
+    supports_sparse_bsc: bool = None
+    # whether the op supports nested jagged inputs, defaults to False
+    supports_njt: bool = None
+
+    # whether the op promotes integer inputs to float
+    promotes_int_to_float: bool = False
+
+    # the following metadata relates to complex support and is checked in test_ops.py
+
+    test_conjugated_samples: bool = True
+
+    test_neg_view: bool = True
+
+    # assert that jit shape analysis fully propagates shape
+    assert_jit_shape_analysis: bool = False
+
+    # the following metadata relates to ExpandedWeights support and is checked in test_expanded_weights.py
+
+    supports_expanded_weight: bool = False
+
+    is_factory_function: bool = False
+
+    skip_correctness_check_compile_vs_eager: bool = False
+
+    def __post_init__(self):
+        self._original_opinfo_args = asdict(self).copy()
+
+        assert self.dtypes is not None, f"OpInfo for {self.name} has no dtypes!"
+
+        # Validates the dtypes are generated from the dispatch-related functions
+        for name, val in self.dtypesIf.items():
+            if val is not None:
+                assert isinstance(val, _dispatch_dtypes)
+                self.dtypesIf[name] = set(val)
+
+        if self.aten_name is None:
+            self.aten_name = self.name
+
+        # Attribute to verify dynamic_dtypes are used.
+        self.dynamic_dtypes = any(
+            isinstance(dtypes, utils._dynamic_dispatch_dtypes)
+            for dtypes in self.dtypesIf.values()
+        )
+
+        if self.dynamic_dtypes:
+            # Make sure `dtyesIfCUDA` is dynamic, if dynamic dispatch is used for CPU
+            # This is because, below we set dtypesIfCUDA to dtypes if they are None.
+            assert isinstance(self.dtypesIfCUDA, utils._dynamic_dispatch_dtypes), (
+                f"To use dynamic dtypes for operator {self.name}, "
+                "acquire the dtypes dynamically for argument `dtypesIfCUDA`."
+                "This is to ensure that CUDA dtypes are acquired correctly as they"
+                "differ from CPU dtypes occasionally"
+            )
+
+        self.dtypes = set(self.dtypes)
+
+        # NOTE: backward dtypes must be acquired before forward dtypes
+        #   since they fallback to explicit (not implicit!) specifications of
+        #   forward dtypes
+        self.backward_dtypesIfROCM = (
+            set(self.backward_dtypesIfROCM)
+            if self.backward_dtypesIfROCM is not None
+            else (
+                self.backward_dtypesIfCUDA
+                if self.backward_dtypesIfCUDA is not None
+                else self.backward_dtypes
+                if self.backward_dtypes is not None
+                else self.dtypesIfROCM
+                if self.dtypesIfROCM is not None
+                else self.dtypesIfCUDA
+                if self.dtypesIfCUDA is not None
+                else self.dtypes
+            )
+        )
+        self.backward_dtypesIfCUDA = (
+            set(self.backward_dtypesIfCUDA)
+            if self.backward_dtypesIfCUDA is not None
+            else (
+                self.backward_dtypes
+                if self.backward_dtypes is not None
+                else self.dtypesIfCUDA
+                if self.dtypesIfCUDA is not None
+                else self.dtypes
+            )
+        )
+        self.backward_dtypesIfHpu = (
+            set(self.backward_dtypesIfHpu)
+            if self.backward_dtypesIfHpu is not None
+            else (
+                self.backward_dtypes
+                if self.backward_dtypes is not None
+                else self.dtypes
+            )
+        )
+
+        self.backward_dtypes = (
+            set(self.backward_dtypes)
+            if self.backward_dtypes is not None
+            else self.dtypes
+        )
+
+        # Inherit from cpu
+        for dev_type in ["cuda", "hpu"]:
+            if self.dtypesIf.get(dev_type) is None:
+                self.dtypesIf[dev_type] = self.dtypes
+
+        # Inherit from CUDA
+        for dev_type in ["rocm", "xpu"]:
+            if self.dtypesIf.get(dev_type) is None:
+                self.dtypesIf[dev_type] = self.dtypesIf["cuda"]
+
+        # NOTE: if the op is unspecified it is assumed to be under the torch namespace
+        if not self.op:
+            self.op = _getattr_qual(torch, self.name)
+
+        if self.method_variant is _NOTHING:
+            self.method_variant = getattr(torch.Tensor, self.name, None)
+
+        # attributes like real, imag are not callable
+        if not callable(self.method_variant):
+            self.method_variant = None
+
+        if self.inplace_variant is _NOTHING:
+            inplace_name = self.name + "_"
+            self.inplace_variant = getattr(torch.Tensor, inplace_name, None)
+
+        if self.operator_variant is _NOTHING:
+            self.operator_variant = getattr(operator, self.name, None)
+
+        if self.inplace_operator_variant is _NOTHING:
+            # Note: operator.i will use operator. and assign the result to the lhs when no
+            # __i__ method is found. This results in the appearance of an inplace operator variant which
+            # does not have the correct inplace behavior. To avoid this, we guard automatic detection of the inplace
+            # operator with a check that an inplace variant exists.
+            if self.inplace_variant is not None:
+                inplace_operator_name = "i" + self.name
+                self.inplace_operator_variant = getattr(
+                    operator, inplace_operator_name, None
+                )
+            else:
+                self.inplace_operator_variant = None
+
+        self.decorators = (*self.decorators, *self.skips)
+
+        # Specifying sample inputs function without specifying the
+        # corresponding layout support implies the layout support:
+        if self.supports_sparse is None:
+            self.supports_sparse = self.sample_inputs_sparse_coo_func is not None
+        if self.sample_inputs_sparse_coo_func is None:
+            self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified
+
+        if self.supports_sparse_csr is None:
+            self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None
+        if self.sample_inputs_sparse_csr_func is None:
+            self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified
+
+        if self.supports_sparse_csc is None:
+            self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None
+        if self.sample_inputs_sparse_csc_func is None:
+            self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified
+
+        if self.supports_sparse_bsr is None:
+            self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None
+        if self.sample_inputs_sparse_bsr_func is None:
+            self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified
+
+        if self.supports_sparse_bsc is None:
+            self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None
+        if self.sample_inputs_sparse_bsc_func is None:
+            self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified
+
+        if self.supports_njt is None:
+            self.supports_njt = False
+
+        # We run the sampling functions without tracking the gradiends of the creation of inputs
+        self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func)
+        self.sample_inputs_sparse_coo_func = torch.no_grad()(
+            self.sample_inputs_sparse_coo_func
+        )
+        self.sample_inputs_sparse_csr_func = torch.no_grad()(
+            self.sample_inputs_sparse_csr_func
+        )
+        self.sample_inputs_sparse_csc_func = torch.no_grad()(
+            self.sample_inputs_sparse_csc_func
+        )
+        self.sample_inputs_sparse_bsr_func = torch.no_grad()(
+            self.sample_inputs_sparse_bsr_func
+        )
+        self.sample_inputs_sparse_bsc_func = torch.no_grad()(
+            self.sample_inputs_sparse_bsc_func
+        )
+        if self.reference_inputs_func is not None:
+            self.reference_inputs_func = torch.no_grad()(self.reference_inputs_func)
+
+        if not self.autodiff_fusible_nodes:
+            self.autodiff_fusible_nodes = []
+
+        if self.autodiff_nonfusible_nodes is None:
+            self.autodiff_nonfusible_nodes = ["aten::" + self.name]
+
+        # Autograd support
+
+        # Autograd flags that depend on backward AD only
+        # - If setting has been explicitly set, raise error if inconsistent
+        if self.supports_gradgrad is None:
+            self.supports_gradgrad = self.supports_autograd
+        else:
+            assert not (self.supports_gradgrad and not self.supports_autograd), (
+                "supports_gradgrad refines the part of autograd is supported, so it should "
+                "not be set if supports_autograd is False"
+            )
+        if self.check_batched_grad is None:
+            self.check_batched_grad = self.supports_autograd or self.supports_forward_ad
+        else:
+            assert not (
+                self.check_batched_grad
+                and not (self.supports_autograd or self.supports_forward_ad)
+            ), (
+                "check_batched_grad refines the part of autograd that will be checked (by gradcheck), so "
+                "it should not be set if supports_autograd is False"
+            )
+        if self.check_batched_gradgrad is None:
+            self.check_batched_gradgrad = self.supports_gradgrad
+        else:
+            assert not (self.check_batched_gradgrad and not self.supports_gradgrad), (
+                "check_batched_gradgrad refines the part of autograd that will be checked (by "
+                "gradgradcheck), so it should not be set if either supports_gradgrad or supports_autograd "
+                "is False."
+            )
+        if self.check_batched_forward_grad is None:
+            self.check_batched_forward_grad = self.supports_forward_ad
+        else:
+            assert not (
+                self.check_batched_forward_grad and not self.supports_forward_ad
+            ), (
+                "check_batched_forward_grad should only be used when supports_forward_ad "
+                "is True. It is used to disable the test in the specific cases "
+                "where the op supports forward ad but fails to compute "
+                "batched forward grad."
+            )
+
+        if self.check_inplace_batched_forward_grad is None:
+            self.check_inplace_batched_forward_grad = self.check_batched_forward_grad
+        else:
+            assert not (
+                self.check_inplace_batched_forward_grad
+                and not self.check_batched_forward_grad
+            ), (
+                "check_batched_forward_grad should only be used when check_batched_forward_grad "
+                "is True. It is used to disable the test in the specific cases "
+                "where the op supports batched forward grad but fails to compute batched forward "
+                "grad for the inplace variant of the op."
+            )
+
+        assert not (self.supports_fwgrad_bwgrad and not self.supports_autograd), (
+            "supports_fwgrad_bwgrad enables forward-over-backward gradgrad checks and should only be "
+            "True if backward ad is also checked, i.e., supports_forward_ad should be True.",
+            self.name,
+        )
+
+        # Autograd flags that depend on both forward AD and backward AD
+        if self.supports_inplace_autograd is None:
+            self.supports_inplace_autograd = (
+                self.supports_autograd or self.supports_forward_ad
+            )
+        else:
+            assert not (
+                self.supports_inplace_autograd
+                and not self.supports_autograd
+                and not self.supports_forward_ad
+            ), (
+                "supports_inplace_autograd refines the part of autograd that is supported, so "
+                "it should not be set if both supports_autograd and supports_forward_ad are False"
+            )
+
+        if self.aliases is not None:
+            self.aliases = tuple(AliasInfo(a) for a in self.aliases)  # type: ignore[assignment]
+        else:
+            self.aliases = ()
+
+    def __call__(self, *args, **kwargs):
+        """Calls the function variant of the operator."""
+        return self.op(*args, **kwargs)
+
+    def __str__(self):
+        return dataclass_repr(self)
+
+    def get_op(self):
+        """Returns the function variant of the operator, torch.."""
+        return self.op
+
+    def get_method(self):
+        """Returns the method variant of the operator, torch.Tensor..
+        Returns None if the operator has no method variant.
+        """
+        return self.method_variant
+
+    def get_inplace(self):
+        """Returns the inplace variant of the operator, torch.Tensor._.
+        Returns None if the operator has no inplace variant.
+        """
+        return self.inplace_variant
+
+    def get_operator(self):
+        """Returns operator variant of the operator, e.g. operator.neg
+        Returns None if the operator has no operator variant.
+        """
+        return self.operator_variant
+
+    def get_inplace_operator(self):
+        """Returns the inplace operator variant of the operator, e.g operator.iadd
+        Returns None if the operator has no inplace operator variant"""
+        return self.inplace_operator_variant
+
+    # Returns a tuple of callables:
+    # (TestCase -> subtest context, TestCase -> skip / xfail context)
+    # I'd love to combine these into one but I haven't figured out how to do it
+    # in a way that works like it should, and I tried a LOT of things.
+    def _maybe_skip_or_xfail(self, rules, device, sample, idx):
+        def _subtest_fn(test_case, sample=sample.name, idx=idx):
+            return test_case.subTest(sample=sample, idx=idx)
+
+        if rules is None or len(rules) == 0:
+            return (_subtest_fn, lambda _: contextlib.nullcontext())
+
+        # NB: match first rule only (order matters!)
+        for rule in rules:
+            if rule.sample_match_fn(device, sample):
+                log.debug(
+                    "matched %s rule '%s': %s %s %s",
+                    rule.type,
+                    rule.name,
+                    self.full_name,
+                    device,
+                    sample,
+                )
+
+                # Provide a context for the test case to run the sample input
+                # through as a subtest AND handle skip / xfail for it as needed.
+                return (
+                    _subtest_fn,
+                    lambda test_case, rule=rule: rule.get_context(test_case),
+                )
+
+        log.debug("matched no rules: %s %s %s", self.full_name, device, sample)
+        return (_subtest_fn, lambda _: contextlib.nullcontext())
+
+    def _sample_callback_fn(self, use_subtests, device):
+        # Get sample-specific skips / xfails.
+        sample_skips_and_xfails = getattr(
+            extract_test_fn(), "sample_skips_and_xfails", None
+        )
+
+        if sample_skips_and_xfails is not None and not use_subtests:
+            raise RuntimeError(
+                """Sample-specific skips / xfails require use_subtests=True.
+Please pass this to the sample generation function and run the test logic within the
+returned contexts (NB: order matters!). For example:
+
+def test_foo(self, device, dtype, op):
+    for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(..., use_subtests=True):
+        # these contexts handle running within subtests and skips / xfails
+        with subtest_ctx(self), skip_xfail_ctx(self):
+            # test logic here
+            ..."""
+            )
+
+        if not use_subtests:
+            # use the default callback that returns the sample without a subtest context
+            return None
+
+        if USE_PYTEST:
+            try:
+                import pytest_subtests  # noqa: F401
+            except ModuleNotFoundError:
+                raise RuntimeError(
+                    "Encountered an OpInfo test with use_subtests=True and pytest-subtests is "
+                    "not installed. The feature will not work correctly within pytest without "
+                    "this package; please install it."
+                ) from None
+
+        def _f(
+            sample,
+            idx,
+            self=self,
+            device=device,
+            sample_skips_and_xfails=sample_skips_and_xfails,
+            use_subtests=use_subtests,
+        ):
+            # When subtests are enabled, also return a subtest context. This is required
+            # for xfails / skips to work properly.
+            return (
+                sample,
+                *self._maybe_skip_or_xfail(
+                    sample_skips_and_xfails, device, sample, idx
+                ),
+            )
+
+        return _f
+
+    def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
+        """Returns an iterable of SampleInputs but with the tensor input or first
+        tensor in a sequence input conjugated.
+        """
+
+        set_seed = kwargs.pop("set_seed", True)
+        use_subtests = kwargs.pop("use_subtests", False)
+        samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
+        conj_samples = list(samples)
+
+        def conjugate(tensor):
+            _requires_grad = tensor.requires_grad
+            tensor = tensor.conj()
+            return tensor.requires_grad_(_requires_grad)
+
+        for i, sample in enumerate(samples):
+            sample = conj_samples[i]
+            # Note: it is assumed that the input here is either a tensor or tensorlist
+            if isinstance(sample.input, torch.Tensor):
+                sample.input = conjugate(sample.input)
+            else:
+                sample.input[0] = conjugate(sample.input[0])
+
+        return TrackedInputIter(
+            iter(conj_samples),
+            "conjugate sample input",
+            item_callback=self._sample_callback_fn(use_subtests, device),
+            set_seed=set_seed,
+            restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
+        )
+
+    def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
+        """
+        Returns an iterable of SampleInputs.
+
+        These samples should be sufficient to test the function works correctly
+        with autograd, TorchScript, etc.
+        """
+        set_seed = kwargs.pop("set_seed", True)
+        use_subtests = kwargs.pop("use_subtests", False)
+        samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
+
+        if kwargs.get("include_conjugated_inputs", False):
+            conj_samples = self.conjugate_sample_inputs(
+                device, dtype, requires_grad, **kwargs
+            )
+            samples_list = list(samples)
+            samples_list.extend(conj_samples)
+            samples = tuple(samples_list)
+
+        return TrackedInputIter(
+            iter(samples),
+            "sample input",
+            item_callback=self._sample_callback_fn(use_subtests, device),
+            set_seed=set_seed,
+            restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
+        )
+
+    def reference_inputs(self, device, dtype, requires_grad=False, **kwargs):
+        """
+        Returns an iterable of SampleInputs.
+
+        Distinct from sample_inputs() above because this returns an expanded set
+        of inputs when reference_inputs_func is defined. If undefined this returns
+        the sample inputs.
+        """
+        set_seed = kwargs.pop("set_seed", True)
+        use_subtests = kwargs.pop("use_subtests", False)
+        if self.reference_inputs_func is None:
+            samples = self.sample_inputs_func(
+                self, device, dtype, requires_grad, **kwargs
+            )
+            return TrackedInputIter(
+                iter(samples),
+                "reference input",
+                item_callback=self._sample_callback_fn(use_subtests, device),
+                set_seed=set_seed,
+                restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
+            )
+
+        if kwargs.get("include_conjugated_inputs", False):
+            raise NotImplementedError
+
+        references = self.reference_inputs_func(
+            self, device, dtype, requires_grad, **kwargs
+        )
+        return TrackedInputIter(
+            iter(references),
+            "reference input",
+            item_callback=self._sample_callback_fn(use_subtests, device),
+            set_seed=set_seed,
+            restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
+        )
+
+    def error_inputs(self, device, **kwargs):
+        """
+        Returns an iterable of ErrorInputs.
+        """
+        set_seed = kwargs.pop("set_seed", True)
+        use_subtests = kwargs.pop("use_subtests", False)
+        errs = self.error_inputs_func(self, device, **kwargs)
+
+        def _error_item_callback(e, i, use_subtests=use_subtests, device=device):
+            cb = self._sample_callback_fn(use_subtests, device)
+            # no rules to apply; just return the sample
+            if cb is None:
+                return e
+
+            # adapt the callback call since ErrorInputs contain SampleInputs
+            _, subtest_ctx = cb(e.sample_input, i)
+            return (e, subtest_ctx)
+
+        return TrackedInputIter(
+            iter(errs),
+            "error input",
+            track_callback=lambda e: e.sample_input,
+            item_callback=_error_item_callback,
+            set_seed=set_seed,
+            restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
+        )
+
+    def error_inputs_sparse(self, device, layout, **kwargs):
+        """
+        Returns an iterable of ErrorInputs that contain sparse sample
+        inputs with a specified layout.
+        """
+        if not self.supports_sparse_layout(layout):
+            raise unittest.SkipTest("unsupported sparse layout")
+        return self.error_inputs_sparse_func(self, device, layout, **kwargs)
+
+    def supports_sparse_layout(self, layout):
+        """Return True if OpInfo supports the specified sparse layout."""
+        layout_name = str(layout).split(".")[-1]
+        # map torch.sparse_coo to OpInfo.supports_sparse:
+        layout_name = layout_name.replace("_coo", "")
+        return getattr(self, f"supports_{layout_name}")
+
+    def sample_inputs_sparse(
+        self, layout, device, dtype, requires_grad=False, **kwargs
+    ):
+        """Returns an iterable of SampleInputs that contain inputs with a
+        specified sparse layout.
+        """
+        layout_name = str(layout).split(".")[-1]
+        sample_inputs_mth = getattr(self, "sample_inputs_" + layout_name)
+
+        def non_empty_sampler(op, generator):
+            found_sample = False
+            for sample in generator:
+                found_sample = True
+                yield sample
+            if not found_sample:
+                raise unittest.SkipTest("NO SAMPLES!")
+
+        return non_empty_sampler(
+            self,
+            sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs),
+        )
+
+    def _sample_inputs_unspecified(self, *args, **kwargs):
+        """Raises an NotImplemented exception in a OpInfo instance creation
+        that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True
+        without specifying the corresponding sample function as
+        sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func.
+
+        To avoid this, either define the corresponding sample function,
+        or re-map unsupported samples to error inputs in an appropriate
+
+          opinfo/definitions/sparse.py:_validate_sample_input_sparse_
+
+        function.
+        """
+        raise NotImplementedError("no sample function specified")
+
+    def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs):
+        """Returns an iterable of SampleInputs that contain inputs with sparse
+        coo layout.
+        """
+        return self.sample_inputs_sparse_coo_func(
+            self, device, dtype, requires_grad, **kwargs
+        )
+
+    def sample_inputs_sparse_csr(self, device, dtype, requires_grad=False, **kwargs):
+        """Returns an iterable of SampleInputs that contain inputs with sparse
+        csr layout.
+        """
+        return self.sample_inputs_sparse_csr_func(
+            self, device, dtype, requires_grad, **kwargs
+        )
+
+    def sample_inputs_sparse_csc(self, device, dtype, requires_grad=False, **kwargs):
+        """Returns an iterable of SampleInputs that contain inputs with sparse
+        csc layout.
+        """
+        return self.sample_inputs_sparse_csc_func(
+            self, device, dtype, requires_grad, **kwargs
+        )
+
+    def sample_inputs_sparse_bsr(self, device, dtype, requires_grad=False, **kwargs):
+        """Returns an iterable of SampleInputs that contain inputs with sparse
+        bsr layout.
+        """
+        return self.sample_inputs_sparse_bsr_func(
+            self, device, dtype, requires_grad, **kwargs
+        )
+
+    def sample_inputs_sparse_bsc(self, device, dtype, requires_grad=False, **kwargs):
+        """Returns an iterable of SampleInputs that contain inputs with sparse
+        bsc layout.
+        """
+        return self.sample_inputs_sparse_bsc_func(
+            self, device, dtype, requires_grad, **kwargs
+        )
+
+    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
+        """Returns the decorators targeting the given test."""
+        result = []
+        for decorator in self.decorators:
+            if isinstance(decorator, DecorateInfo):
+                if decorator.is_active(
+                    test_class, test_name, device, dtype, param_kwargs
+                ):
+                    result.extend(decorator.decorators)
+            else:
+                result.append(decorator)
+        return result
+
+    def supported_dtypes(self, device_type):
+        if device_type == "privateuse1":
+            device_type = torch._C._get_privateuse1_backend_name()
+        device_type = torch.device(device_type).type
+        if device_type == "cuda" and TEST_WITH_ROCM:
+            device_type = "rocm"
+        return self.dtypesIf.get(device_type, self.dtypes)
+
+    def supported_backward_dtypes(self, device_type):
+        if not self.supports_autograd:
+            return set()
+
+        if device_type == "privateuse1":
+            device_type = torch._C._get_privateuse1_backend_name()
+        device_type = torch.device(device_type).type
+        backward_dtypes = None
+        if device_type == "cuda":
+            backward_dtypes = (
+                self.backward_dtypesIfROCM
+                if TEST_WITH_ROCM
+                else self.backward_dtypesIfCUDA
+            )
+        elif device_type == "hpu":
+            backward_dtypes = self.backward_dtypesIfHpu
+        else:
+            backward_dtypes = self.backward_dtypes
+
+        allowed_backward_dtypes = floating_and_complex_types_and(
+            torch.bfloat16, torch.float16, torch.complex32
+        )
+        return set(allowed_backward_dtypes).intersection(backward_dtypes)
+
+    def supports_dtype(self, dtype, device_type) -> bool:
+        return dtype in self.supported_dtypes(device_type)
+
+    @property
+    def full_name(self):
+        """Returns a full name that helps to uniquely identify this OpInfo."""
+        variant = "." + self.variant_test_name if self.variant_test_name else ""
+        # example: "normal.in_place" where "normal" is the name and "in_place" is the variant
+        return f"{self.name}{variant}"
+
+    @property
+    def formatted_name(self):
+        """Returns a formatted full name for this OpInfo that can be used in test names."""
+        return self.full_name.replace(".", "_")
+
+
+# Represents a skip / xfail rule matching a particular set of tests. It allows granularity
+# at the device, dtype, op, and individual sample levels. This flexibility allows entire
+# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
+# test cases across multiple ops.
+@dataclass
+class SampleRule(ABC):
+    # function to indicate whether the rule applies to this op; return True if so
+    # NB: str arg of callable is device_type
+    op_match_fn: Callable[[str, OpInfo], bool] = None
+    # function to indicate whether the rule applies to this sample; return True if so
+    sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
+    # optional name for identifying the rule
+    name: str = ""
+
+    def __post_init__(self):
+        if self.op_match_fn is None:
+            raise ValueError("must have op_match_fn set to be useful")
+        if self.sample_match_fn is None:
+            # by default, match for all samples
+            self.sample_match_fn = lambda device, sample: True
+
+    # returns a string identifier of the rule type
+    @abstractmethod
+    def type(self) -> str:
+        ...
+
+    # returns an appropriate context that handles the xfail, skips, etc.
+    @abstractmethod
+    def get_context(self, test_case):
+        ...
+
+
+# useful for specifying xfails
+@dataclass
+class XFailRule(SampleRule):
+    # expected error type
+    error_type: TypeVar = Exception
+    # expected error message
+    error_msg: str = ".*"
+
+    @property
+    def type(self) -> str:
+        return "xfail"
+
+    def get_context(self, test_case):
+        return test_case.assertRaisesRegex(
+            # failing within torch.compile wraps within a BackendCompilerFailed
+            (self.error_type, torch._dynamo.exc.BackendCompilerFailed),
+            self.error_msg,
+        )
+
+
+# useful for specifying skips
+@dataclass
+class SkipRule(SampleRule):
+    @property
+    def type(self):
+        return "skip"
+
+    def get_context(self, test_case):
+        @contextlib.contextmanager
+        def skipcontext(test_case=test_case):
+            test_case.skipTest("Skipped!")
+            yield
+
+        return skipcontext()
+
+
+# Decorator that defines skip / xfail rules for a given test function. If these are
+# present, the @ops decorator will apply these for each op and place them onto the
+# parametrized test functions for use by e.g. OpInfo.sample_inputs().
+class sample_skips_and_xfails:
+    def __init__(self, rules):
+        self.rules = rules
+
+    def __call__(self, fn):
+        rules = getattr(fn, "sample_skips_and_xfails", None)
+        if rules is not None:
+            raise RuntimeError("Multiple sets of sample_skips_and_xfails defined")
+
+        fn.sample_skips_and_xfails = self.rules
+        return fn
+
+
+def _generate_reduction_inputs(device, dtype, requires_grad, **kwargs):
+    """Generates input tensors for testing reduction operators"""
+    yield make_tensor([], dtype=dtype, device=device, requires_grad=requires_grad)
+    yield make_tensor([2], dtype=dtype, device=device, requires_grad=requires_grad)
+    yield make_tensor([3, 5], dtype=dtype, device=device, requires_grad=requires_grad)
+    yield make_tensor(
+        [3, 2, 1, 2], dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+
+def _generate_reduction_kwargs(ndim, supports_multiple_dims=True):
+    """Generates a subset of all valid dim and keepdim kwargs given ndim that
+    is appropriate for testing reduction operators.
+    """
+
+    # Test default dim and keepdim
+    yield {}
+
+    # Test reducing inner and outer most dimensions
+    yield {"dim": 0, "keepdim": True}
+    yield {"dim": -1, "keepdim": False}
+
+    # Test reducing middle dimension
+    if ndim > 2:
+        yield {"dim": ndim // 2, "keepdim": True}
+
+    if supports_multiple_dims:
+        # Test reducing all dimensions
+        yield {"dim": tuple(range(ndim)), "keepdim": False}
+
+        # Test reducing both first and last dimensions
+        if ndim > 1:
+            yield {"dim": (0, -1), "keepdim": True}
+
+        # Test reducing every other dimension starting with the second
+        if ndim > 3:
+            yield {"dim": tuple(range(1, ndim, 2)), "keepdim": False}
+
+
+def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for reduction operators."""
+
+    # TODO(@heitorschueroff) Once all reduction operators are using
+    # ReductionOpInfo use op_info.supports_multiple_dims directly.
+    supports_multiple_dims: bool = kwargs.get("supports_multiple_dims", True)
+
+    # TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo
+    # use op_info.generate_args_kwargs directly.
+    generate_args_kwargs = kwargs.get(
+        "generate_args_kwargs", lambda *args, **kwargs: (yield (), {})
+    )
+
+    for t in _generate_reduction_inputs(device, dtype, requires_grad):
+        for reduction_kwargs in _generate_reduction_kwargs(
+            t.ndim, supports_multiple_dims
+        ):
+            for args, kwargs in generate_args_kwargs(t, **reduction_kwargs):
+                kwargs.update(reduction_kwargs)
+                yield SampleInput(
+                    t.detach().requires_grad_(requires_grad), args=args, kwargs=kwargs
+                )
+
+
+# NOTE [Reductions]:
+#
+# For testing purposes, we relax the definition of a reduction operator
+# as defined in the docstring below. We do this to capture operators with
+# a similar API so they can be tested automatically. However...
+#
+# Strictly speaking a reduction operator is an operator that can reduce an
+# array to a single scalar value and that can be computed from the partial
+# result of reducing subarrays. This usually means that the reduction operation
+# should be commutative and associative. This definition is important when it
+# comes to implementation as it determines how a reduction can be parallelized.
+#
+# For example, many summary statistics such as median, mode and quantile cannot
+# be computed from partial results because these are sorting and counting based
+# algorithms that need information that would be lost in the reduced value.
+class ReductionOpInfo(OpInfo):
+    """Reduction operator information.
+
+    An operator is a reduction operator if it reduces one or more dimensions of
+    the input tensor to a single value. Reduction operators must implement the
+    following signature:
+
+    - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor`
+
+    ReductionOpInfo tests that reduction operators implement a consistent API.
+    Optional features such as reducing over multiple dimensions are captured in
+    the optional keyword parameters of the ReductionOpInfo constructor.
+
+    If a reduction operator does not yet implement the full required API of
+    reduction operators, this should be documented by xfailing the failing
+    tests rather than adding optional parameters to ReductionOpInfo.
+
+    NOTE
+    The API for reduction operators has not yet been finalized and some
+    requirements may change.
+
+    See tests in test/test_reductions.py
+    """
+
+    def __init__(
+        self,
+        name,
+        *,
+        # The identity value for the operator if it has one.
+        identity: Optional[Any] = None,
+        # The nan policy for the operator if it implements one.
+        # - propagate: NaN values are propagated to the output
+        # - omit: NaN values are discarded during the reduction
+        nan_policy: Optional[str] = None,
+        # Whether the operator supports reducing multiple dimensions.
+        supports_multiple_dims: bool = True,
+        # Whether the operator promotes integral to floating point dtypes.
+        promotes_int_to_float: bool = False,
+        # Whether the operator promotes all integral dtypes to int64.
+        promotes_int_to_int64: bool = False,
+        # If a specific dtype is given, then the operator always returns that
+        # dtype irrespective of the input dtype. If None, the operator returns
+        # the dtype according to the type promotion rules above.
+        result_dtype: Optional[torch.dtype] = None,
+        # Casts complex results to real (e.g. linalg.norm or torch.var)
+        complex_to_real: bool = False,
+        # ReductionOpInfo tests generate their own input, dim and keepdim
+        # arguments and call this function to generate tuples of extra args and
+        # kwargs to use when calling the op. This is required for operators that
+        # have other required parameters besides the input tensor.
+        generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (
+            yield (),
+            {},
+        ),
+        # Options from the OpInfo base class
+        **kwargs,
+    ):
+        self._original_reduction_args = locals().copy()
+        assert nan_policy in (None, "propagate", "omit")
+
+        # These are mutually exclusive options
+        assert not (result_dtype and promotes_int_to_float)
+        assert not (result_dtype and promotes_int_to_int64)
+        assert not (result_dtype and complex_to_real)
+        assert not (promotes_int_to_float and promotes_int_to_int64)
+
+        # Default sample_inputs_func for ReductionOpInfo which augments sample
+        # inputs from sample_inputs_reduction with the args and kwargs from
+        # generate_args_kwargs. This is only used if sample_inputs_func is None.
+        def sample_inputs_func(*args, **kwargs):
+            kwargs["supports_multiple_dims"] = supports_multiple_dims
+            kwargs["generate_args_kwargs"] = generate_args_kwargs
+            yield from sample_inputs_reduction(*args, **kwargs)
+
+        # Override OpInfo defaults and call base class __init__
+        kwargs.setdefault("inplace_variant", None)
+        kwargs.setdefault("sample_inputs_func", sample_inputs_func)
+        super().__init__(name, promotes_int_to_float=promotes_int_to_float, **kwargs)
+
+        self.identity = identity
+        self.nan_policy = nan_policy
+        self.supports_multiple_dims = supports_multiple_dims
+        self.promotes_int_to_int64 = promotes_int_to_int64
+        self.complex_to_real = complex_to_real
+        self.result_dtype = result_dtype
+        self.generate_args_kwargs = generate_args_kwargs
+
+
+# The base reference input generation for elementwise binary operations
+def _reference_inputs_elementwise_binary(
+    op, device, dtype, requires_grad, exclude_zero, **kwargs
+):
+    yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
+    yield from generate_elementwise_binary_tensors(
+        op,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+    if dtype is not torch.bool:
+        yield from generate_elementwise_binary_small_value_tensors(
+            op, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+    if dtype not in (torch.bool, torch.uint8, torch.int8):
+        yield from generate_elementwise_binary_large_value_tensors(
+            op, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+    yield from generate_elementwise_binary_broadcasting_tensors(
+        op,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+    yield from generate_elementwise_binary_with_scalar_samples(
+        op, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    yield from generate_elementwise_binary_with_scalar_and_type_promotion_samples(
+        op, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    if dtype.is_floating_point or dtype.is_complex:
+        yield from generate_elementwise_binary_extremal_value_tensors(
+            op, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+
+
+# Note that these references inputs use scalars for the SampleInput.input value,
+#   and many tests require SampleInput.input be a tensor or a list of tensors
+def reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs):
+    if hasattr(op, "rhs_make_tensor_kwargs"):
+        exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False)
+
+    gen = partial(
+        _reference_inputs_elementwise_binary,
+        op,
+        device,
+        dtype,
+        requires_grad,
+        exclude_zero,
+        **kwargs,
+    )
+
+    # yields "normal" samples
+    yield from gen()
+
+    # yields noncontiguous samples
+    for sample in gen():
+        yield sample.noncontiguous()
+
+    yield from generate_elementwise_binary_noncontiguous_tensors(
+        op,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+
+    yield from generate_elementwise_binary_arbitrarily_strided_tensors(
+        op,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+
+
+# A functional that extends an elementwise binary operator's bespoke error inputs
+#   with generic error inputs for the class of elementwise binary operations
+def make_error_inputs_elementwise_binary(error_inputs_func):
+    def error_inputs_func_wrapper(op, device, **kwargs):
+        if error_inputs_func is not None:
+            yield from error_inputs_func(op, device, **kwargs)
+
+        if not op.supports_rhs_python_scalar:
+            si = SampleInput(torch.tensor((1, 2, 3), device=device), args=(2,))
+            yield ErrorInput(si, error_type=Exception, error_regex="")
+
+        if not op.supports_one_python_scalar:
+            si = SampleInput(2, args=(torch.tensor((1, 2, 3), device=device),))
+            yield ErrorInput(si, error_type=Exception, error_regex="")
+
+        if (
+            not kwargs.get("skip_two_python_scalars", False)
+            and not op.supports_two_python_scalars
+        ):
+            si = SampleInput(2, args=(3,))
+            yield ErrorInput(si, error_type=Exception, error_regex="")
+
+    return error_inputs_func_wrapper
+
+
+# The following functions and classes are for testing elementwise binary operators.
+
+
+# Returns a generator of pairs of contiguous tensors on the requested device
+#   and with the requested dtype.
+#
+# This function is intended to test the non-vectorized and vectorized code
+#   paths of elementwise binary functions, as well as their handling of odd tensor
+#   sizes (like zero-dim tensors and tensors with zero elements).
+#
+# Each iterable will include an a tensor with no elements,
+#   zero dim (scalar) tensors, small 1D tensors, a medium 1D tensor, and
+#   a large 2D tensor.
+def generate_elementwise_binary_tensors(
+    op, *, device, dtype, requires_grad=False, exclude_zero=False
+):
+    shapes = (
+        # tensors with no elements
+        (0,),
+        (1, 0, 3),
+        # zero dim (scalar) tensor
+        (),
+        # small 1D tensor
+        (20,),
+        # medium 1D tensor
+        (812,),
+        # large 2D tensor
+        (1029, 917),
+    )
+
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+    for shape in shapes:
+        lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
+        rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
+        yield SampleInput(
+            lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
+        )
+
+
+def generate_elementwise_binary_arbitrarily_strided_tensors(
+    op, *, device, dtype, requires_grad=False, exclude_zero=False
+):
+    # shape, strides, offset
+    strided_cases = (
+        ((5, 6, 2), (1, 1, 7), 2),
+        ((5, 5, 4), (1, 1, 7), 2),
+        ((5, 5, 2), (4, 5, 7), 3),
+        ((5, 5, 2), (5, 5, 7), 3),
+        ((5, 5, 2), (5, 5, 5), 3),
+        ((9, 5, 2), (0, 1, 7), 3),
+    )
+
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+    for shape, strides, offset in strided_cases:
+        a = make_arg(
+            500,
+        ).as_strided(shape, strides, offset)
+        b = make_arg(shape)
+        yield SampleInput(a, args=(b,), kwargs=op.sample_kwargs(device, dtype, a)[0])
+
+
+# Returns a generator of pairs of contiguous tensors on the requested device and with
+#   the requested dtype.
+#
+# Unlike the previous function, the values in these tensors are specified manually.
+def generate_elementwise_binary_small_value_tensors(
+    op, *, device, dtype, requires_grad=False, exclude_zero=None
+):
+    if exclude_zero is None:
+        if hasattr(op, "rhs_make_tensor_kwargs"):
+            exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False)
+
+    # defines interesting values
+    _unsigned_int_vals = (0, 1, 55, 127, 128, 190, 210, 220, 254)
+    _int_vals = (0, -1, 1, -55, 55, -127, 127, -128)
+    _float_vals = (
+        0.0,
+        -0.0,
+        -0.001,
+        0.001,
+        -0.25,
+        0.25,
+        -1.0,
+        1.0,
+        -math.pi / 2,
+        math.pi / 2,
+        -math.pi + 0.00001,
+        math.pi - 0.00001,
+        -math.pi,
+        math.pi,
+        -math.pi - 0.00001,
+        math.pi + 0.00001,
+    )
+
+    l_vals = []
+    r_vals = []
+
+    if dtype.is_floating_point:
+        prod = product(_float_vals, _float_vals)
+    elif dtype.is_complex:
+        complex_vals = product(_float_vals, _float_vals)
+        # Note the use of list is required here or the map generator will be
+        #  emptied by the following product and it won't produce the desired cross-product
+        complex_vals = [complex(*x) for x in complex_vals]
+        prod = product(complex_vals, complex_vals)
+    elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64):
+        prod = product(_int_vals, _int_vals)
+    elif dtype is torch.uint8:
+        prod = product(_unsigned_int_vals, _unsigned_int_vals)
+    else:
+        raise ValueError("Unsupported dtype!")
+
+    for l, r in prod:
+        l_vals.append(l)
+        if r == 0 and exclude_zero:
+            r_vals.append(1)
+        else:
+            r_vals.append(r)
+
+    lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
+    rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
+
+
+def generate_elementwise_binary_large_value_tensors(
+    op, *, device, dtype, requires_grad=False
+):
+    _large_int_vals = (-1113, 1113, -10701, 10701)
+    _large_float16_vals = (-501, 501, -1001.2, 1001.2, -13437.7, 13437.7)
+    _large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20)
+
+    l_vals = []
+    r_vals = []
+
+    if dtype == torch.float16:
+        prod = product(_large_float16_vals, _large_float16_vals)
+    elif dtype.is_floating_point:
+        prod = product(_large_float_vals, _large_float_vals)
+    elif dtype.is_complex:
+        complex_vals = product(_large_float_vals, _large_float_vals)
+        # Note the use of list is required here or the map generator will be
+        #  emptied by the following product and it won't produce the desired cross-product
+        complex_vals = [complex(*x) for x in complex_vals]
+        prod = product(complex_vals, complex_vals)
+    elif dtype in (torch.int16, torch.int32, torch.int64):
+        prod = product(_large_int_vals, _large_int_vals)
+    else:
+        raise ValueError("Unsupported dtype!")
+
+    for l, r in prod:
+        l_vals.append(l)
+        r_vals.append(r)
+
+    lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
+    rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
+
+
+def generate_elementwise_binary_extremal_value_tensors(
+    op, *, device, dtype, requires_grad=False
+):
+    _float_extremals = (float("inf"), float("-inf"), float("nan"))
+
+    l_vals = []
+    r_vals = []
+
+    if dtype.is_floating_point:
+        prod = product(_float_extremals, _float_extremals)
+    elif dtype.is_complex:
+        complex_vals = product(_float_extremals, _float_extremals)
+        # Note the use of list is required here or the map generator will be
+        #  emptied by the following product and it won't produce the desired cross-product
+        complex_vals = [complex(*x) for x in complex_vals]
+        prod = product(complex_vals, complex_vals)
+    else:
+        raise ValueError("Unsupported dtype!")
+
+    for l, r in prod:
+        l_vals.append(l)
+        r_vals.append(r)
+
+    lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
+    rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
+
+    # Test case for NaN propagation
+    nan = (
+        float("nan") if dtype.is_floating_point else complex(float("nan"), float("nan"))
+    )
+    lhs = make_tensor(
+        (128, 128), device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    lhs.view(-1)[::3] = nan
+    rhs = make_tensor(
+        (128, 128), device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    rhs.view(-1)[::3] = nan
+
+    yield SampleInput(lhs, args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0])
+
+
+# Returns a generator of pairs of contiguous and noncontiguous tensors that
+#   require broadcasting
+def generate_elementwise_binary_broadcasting_tensors(
+    op, *, device, dtype, requires_grad=False, exclude_zero=False
+):
+    shapes = (
+        ((1,), ()),
+        ((2,), ()),
+        ((1,), (2,)),
+        ((2, 1), (2,)),
+        ((1, 2), (2,)),
+        ((3, 2), (2,)),
+        ((1, 3, 2), (2,)),
+        ((1, 3, 2), (3, 2)),
+        ((3, 1, 2), (3, 2)),
+        ((2, 3, 2), ()),
+        ((3, 1, 2), (1, 3, 2)),
+    )
+
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+    for shape, noncontiguous in product(shapes, [True, False]):
+        shape_lhs, shape_rhs = shape
+        lhs = make_arg(
+            shape_lhs, noncontiguous=noncontiguous, **op.lhs_make_tensor_kwargs
+        )
+        rhs = make_arg(
+            shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs
+        )
+
+        yield SampleInput(
+            lhs,
+            args=(rhs,),
+            broadcasts_input=True,
+            kwargs=op.sample_kwargs(device, dtype, lhs)[0],
+        )
+
+
+# Returns a generator of pairs of contiguous tensors and scalars
+def generate_elementwise_binary_with_scalar_samples(
+    op, *, device, dtype, requires_grad=False
+):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    shapes = ((), (3,), (5, 3), (0, 1, 3), (1, 5))
+    if op.supports_rhs_python_scalar:
+        for shape in shapes:
+            lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
+            rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
+            lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item()
+            rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item()
+
+            yield SampleInput(
+                lhs, args=(rhs_scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
+            )
+
+        # Extends with scalar lhs
+        if op.supports_one_python_scalar:
+            yield SampleInput(
+                lhs_scalar,
+                args=(rhs,),
+                kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0],
+            )
+
+    if op.supports_two_python_scalars:
+        lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item()
+        rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item()
+
+        yield SampleInput(
+            lhs_scalar,
+            args=(rhs_scalar,),
+            kwargs=op.sample_kwargs(device, dtype, lhs_scalar)[0],
+        )
+
+
+# Returns a generator of pairs of contiguous tensors and 0d tensors and scalars and type promotion
+def generate_elementwise_binary_with_scalar_and_type_promotion_samples(
+    op, *, device, dtype, requires_grad=False
+):
+    # add these samples only for logical and comparison ops, arithmetic ops are not happy about extremal scalars
+    if op.name in (
+        "eq",
+        "ne",
+        "gt",
+        "ge",
+        "lt",
+        "le",
+        "logical_and",
+        "logical_or",
+        "logical_xor",
+    ):
+        make_arg = partial(
+            make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+        shape = (
+            23,
+        )  # this shape is big enough to trigger vectorization, and has non-vectorized tail
+        values = (float("nan"), float("inf"), -float("inf"))
+        scalar_tensors = tuple(torch.tensor(val) for val in values)
+        if op.supports_rhs_python_scalar:
+            lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
+            rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
+            for scalar in values + scalar_tensors:
+                yield SampleInput(
+                    lhs, args=(scalar,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
+                )
+                # Extends with scalar lhs
+                if op.supports_one_python_scalar:
+                    yield SampleInput(
+                        scalar,
+                        args=(rhs,),
+                        kwargs=op.sample_kwargs(device, dtype, scalar)[0],
+                    )
+
+
+# Returns a generator of pairs of noncontiguous tensors
+def generate_elementwise_binary_noncontiguous_tensors(
+    op, *, device, dtype, requires_grad=False, exclude_zero=False
+):
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+
+    # Generic noncontiguity
+    lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs)
+    rhs = make_arg((1026,), noncontiguous=True, **op.rhs_make_tensor_kwargs)
+
+    yield SampleInput(
+        lhs.clone(), args=(rhs.clone(),), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
+    )
+    yield SampleInput(
+        lhs.contiguous(), args=(rhs,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
+    )
+
+    # Transposed
+    lhs = make_arg((789, 357), **op.lhs_make_tensor_kwargs)
+    rhs = make_arg((789, 357), **op.rhs_make_tensor_kwargs)
+
+    yield SampleInput(
+        lhs.T, args=(rhs.T,), kwargs=op.sample_kwargs(device, dtype, lhs)[0]
+    )
+
+    # More noncontiguity
+    shapes = ((5, 7), (1024,))
+
+    for shape in shapes:
+        lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
+        rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
+
+        lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
+        lhs_non_contig.copy_(lhs)
+
+        rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
+        rhs_non_contig.copy_(rhs)
+
+        yield SampleInput(
+            lhs_non_contig.clone(),
+            args=(rhs_non_contig.clone(),),
+            kwargs=op.sample_kwargs(device, dtype, lhs)[0],
+        )
+        yield SampleInput(
+            lhs_non_contig.contiguous(),
+            args=(rhs_non_contig,),
+            kwargs=op.sample_kwargs(device, dtype, lhs)[0],
+        )
+
+    # Noncontiguous indices
+    shape = (2, 2, 1, 2)
+    lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
+    rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
+
+    lhs_non_contig = lhs[:, 1, ...]
+    rhs_non_contig = rhs[:, 1, ...]
+
+    yield SampleInput(
+        lhs_non_contig.clone(),
+        args=(rhs_non_contig.clone(),),
+        kwargs=op.sample_kwargs(device, dtype, lhs)[0],
+    )
+    yield SampleInput(
+        lhs_non_contig.contiguous(),
+        args=(rhs_non_contig,),
+        kwargs=op.sample_kwargs(device, dtype, lhs)[0],
+    )
+
+    # Expanded tensors
+    shapes = ((1, 3), (1, 7), (5, 7))
+
+    for shape in shapes:
+        lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
+        rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
+
+        lhs_non_contig = lhs.expand(3, -1, -1)
+        rhs_non_contig = rhs.expand(3, -1, -1)
+
+        yield SampleInput(
+            lhs_non_contig,
+            args=(rhs_non_contig,),
+            kwargs=op.sample_kwargs(device, dtype, lhs)[0],
+        )
+
+
+# Sample inputs for elementwise binary operators, like add
+def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs):
+    _M = S if kwargs.get("small_inputs_only", False) else M
+    _S = XS if kwargs.get("small_inputs_only", False) else S
+
+    if hasattr(op, "rhs_make_tensor_kwargs"):
+        exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False)
+
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+
+    shapes = (
+        ((), ()),
+        ((_S,), ()),
+        ((_S, 1), (_S,)),
+        ((_M, _S), ()),
+        ((_S, _M, _S), (_M, _S)),
+        ((_S, _M, _S), (_S, _M, _S)),
+        ((_M, 1, _S), (_M, _S)),
+        ((_M, 1, _S), (1, _M, _S)),
+        ((0, 1, XS), (0, _M, XS)),
+    )
+
+    for shape_lhs, shape_rhs in shapes:
+        lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs)
+        rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs)
+        broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
+
+        yield SampleInput(
+            lhs,
+            args=(rhs,),
+            kwargs=op.sample_kwargs(device, dtype, lhs)[0],
+            broadcasts_input=broadcasts_input,
+        )
+
+
+# Metadata class for binary "universal functions (ufuncs)" that accept two
+# tensor and have common properties
+class BinaryUfuncInfo(OpInfo):
+    """Operator information for 'universal binary functions (binary ufuncs).'
+    These are functions of two tensors with common properties like:
+      - they are elementwise functions
+      - the output shape is determined by the input shape
+      - they typically have method and inplace variants
+      - they typically support the out kwarg
+      - they typically have NumPy or SciPy references
+    See NumPy's universal function documentation
+    (https://numpy.org/doc/stable/reference/ufuncs.html) for more details
+    about the concept of ufuncs.
+    """
+
+    def __init__(
+        self,
+        name,
+        *,
+        sample_inputs_func=sample_inputs_elementwise_binary,
+        reference_inputs_func=reference_inputs_elementwise_binary,
+        sample_kwargs=lambda device, dtype, input: ({}, {}),
+        error_inputs_func=None,
+        lhs_make_tensor_kwargs=None,
+        rhs_make_tensor_kwargs=None,
+        always_returns_bool=False,  # Set to true if the op always returns bool tensors
+        supports_rhs_python_scalar=True,  # Whether the operator allows Tensor x scalar inputs
+        supports_one_python_scalar=False,  # Whether the operator allows scalar x tensor and tensor x scalar inputs
+        supports_two_python_scalars=False,  # Whether the operator allows scalar x scalar inputs
+        **kwargs,
+    ):
+        self._original_binary_ufunc_args = locals().copy()
+
+        # Elementwise binary operations perform the equivalent of test_numpy_refs
+        #   in test_binary_ufuncs, but with additional test granularity. So the
+        #   generic test_ops.py test is skipped because it's redundant.
+        common_skips = (
+            DecorateInfo(
+                unittest.skip("Skipping redundant test."),
+                "TestCommon",
+                "test_numpy_refs",
+            ),
+        )
+        kwargs["skips"] = kwargs.get("skips", ()) + common_skips
+        super().__init__(
+            name,
+            sample_inputs_func=sample_inputs_func,
+            reference_inputs_func=reference_inputs_func,
+            error_inputs_func=make_error_inputs_elementwise_binary(error_inputs_func),
+            **kwargs,
+        )
+
+        self.sample_kwargs = sample_kwargs
+
+        # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on.
+        if lhs_make_tensor_kwargs is None:
+            lhs_make_tensor_kwargs = {}
+        self.lhs_make_tensor_kwargs = lhs_make_tensor_kwargs
+
+        if rhs_make_tensor_kwargs is None:
+            rhs_make_tensor_kwargs = {}
+        self.rhs_make_tensor_kwargs = rhs_make_tensor_kwargs
+
+        self.always_returns_bool = always_returns_bool
+        self.supports_rhs_python_scalar = supports_rhs_python_scalar
+        self.supports_one_python_scalar = supports_one_python_scalar
+        self.supports_two_python_scalars = supports_two_python_scalars
+
+        if self.supports_two_python_scalars:
+            self.supports_one_python_scalar = True
+
+        if self.supports_one_python_scalar:
+            assert (
+                supports_rhs_python_scalar
+            ), "Can't support lhs and rhs Python scalars but not rhs scalars!"
+
+
+# The following functions and classes are for testing elementwise unary operators.
+def sample_inputs_elementwise_unary(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    _L = S if kwargs.get("small_inputs_only", False) else L
+
+    low, high = op_info.domain
+    is_floating = dtype.is_floating_point or dtype.is_complex
+    low = low if low is None or not is_floating else low + op_info._domain_eps
+    high = high if high is None or not is_floating else high - op_info._domain_eps
+    if (
+        op_info.supports_sparse_csr
+        or op_info.supports_sparse_csc
+        or op_info.supports_sparse_bsr
+        or op_info.supports_sparse_bsc
+    ):
+        # Tensors with dim=2 for sparse compressed testing
+        yield SampleInput(
+            make_tensor(
+                (_L, _L),
+                device=device,
+                dtype=dtype,
+                low=low,
+                high=high,
+                requires_grad=requires_grad,
+            ),
+            kwargs=op_kwargs,
+        )
+    else:
+        # Creates a 1D, empty, and scalar tensor
+        for shape in ((_L,), (1, 0, 3), ()):
+            yield SampleInput(
+                make_tensor(
+                    shape,
+                    device=device,
+                    dtype=dtype,
+                    low=low,
+                    high=high,
+                    requires_grad=requires_grad,
+                ),
+                kwargs=op_kwargs,
+            )
+
+
+# Replace values satisfying condition with a safe value. This is used to block
+# out values the could cause singularity like tan(pi/2)
+def _replace_values_in_tensor(tensor, condition, safe_value):
+    mask = condition(tensor)
+    tensor.masked_fill_(mask, safe_value)
+
+
+# Helper to create a unary elementwise tensor with valid inputs
+def _make_unary_elementwise_tensor(shape, *, op, dtype, **kwargs):
+    low, high = op.domain
+    is_floating = dtype.is_floating_point or dtype.is_complex
+    low = low if low is None or not is_floating else low + op._domain_eps
+    high = high if high is None or not is_floating else high - op._domain_eps
+
+    a = make_tensor(shape, low=low, high=high, dtype=dtype, **kwargs)
+
+    if op.reference_numerics_filter is not None and dtype is not torch.bool:
+        condition, safe_value = op.reference_numerics_filter
+        _replace_values_in_tensor(a, condition, safe_value)
+
+    return a
+
+
+# Restricts the values in the tensor to the domain of the
+# given elementwise unary operator
+def _filter_unary_elementwise_tensor(a, *, op):
+    # short-circuits for boolean tensors
+    if a.dtype is torch.bool:
+        return a
+
+    low, high = op.domain
+    is_floating = a.dtype.is_floating_point or a.dtype.is_complex
+    low = low if low is None or not is_floating else low + op._domain_eps
+    high = high if high is None or not is_floating else high - op._domain_eps
+
+    if a.dtype is torch.uint8 and low is not None:
+        low = max(low, 0)
+
+    if not a.dtype.is_floating_point and not a.dtype.is_complex:
+        low = math.ceil(low) if low is not None else None
+        high = math.floor(high) if high is not None else None
+
+    if op.reference_numerics_filter is not None:
+        condition, safe_value = op.reference_numerics_filter
+        _replace_values_in_tensor(a, condition, safe_value)
+
+    if low is not None or high is not None:
+        if a.dtype.is_complex:
+            a.real.clamp_(low, high)
+            a.imag.clamp_(low, high)
+        else:
+            a.clamp_(min=low, max=high)
+
+    return a
+
+
+def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs):
+    # Special-cases bool
+    if dtype is torch.bool:
+        tensors = (
+            torch.empty(0, device=device, dtype=torch.bool),
+            torch.tensor(True, device=device),
+            torch.tensor(False, device=device),
+            torch.tensor((True, False), device=device),
+            make_tensor((812,), device=device, dtype=dtype),
+            make_tensor((1029, 917), device=device, dtype=dtype),
+        )
+        for a in tensors:
+            yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0])
+
+    shapes = (
+        (1029, 917),
+        (812,),
+        # Empty sizes
+        (0,),
+        (0, 3, 3),
+        (1, 0, 5),
+        (6, 0, 0, 0),
+        (3, 0, 1, 0),
+    )
+
+    make_arg = partial(
+        _make_unary_elementwise_tensor,
+        op=op,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+    )
+    for shape in shapes:
+        a = make_arg(shape)
+        yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0])
+
+
+def generate_elementwise_unary_small_value_tensors(
+    op, *, device, dtype, requires_grad=False
+):
+    for sample in generate_elementwise_binary_small_value_tensors(
+        op, device=device, dtype=dtype, requires_grad=requires_grad
+    ):
+        a = _filter_unary_elementwise_tensor(sample.input, op=op)
+        yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0])
+
+
+def generate_elementwise_unary_large_value_tensors(
+    op, *, device, dtype, requires_grad=False
+):
+    for sample in generate_elementwise_binary_large_value_tensors(
+        op, device=device, dtype=dtype, requires_grad=requires_grad
+    ):
+        a = _filter_unary_elementwise_tensor(sample.input, op=op)
+        yield SampleInput(sample.input, kwargs=op.sample_kwargs(device, dtype, a)[0])
+
+
+def generate_elementwise_unary_extremal_value_tensors(
+    op, *, device, dtype, requires_grad=False
+):
+    for sample in generate_elementwise_binary_extremal_value_tensors(
+        op, device=device, dtype=dtype, requires_grad=requires_grad
+    ):
+        yield SampleInput(
+            sample.input, kwargs=op.sample_kwargs(device, dtype, sample.input)[0]
+        )
+
+
+def generate_elementwise_unary_noncontiguous_tensors(
+    op, *, device, dtype, requires_grad=False
+):
+    make_arg = partial(
+        _make_unary_elementwise_tensor,
+        op=op,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+    )
+
+    # Generic noncontiguity
+    t = make_arg((1026,), noncontiguous=True)
+    yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0])
+
+    # Transposed
+    t = make_arg((1024, 1024)).T
+    yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0])
+
+    # Expanded tensors
+    shapes = ((1, 3), (1, 7), (5, 7))
+
+    for shape in shapes:
+        t = make_arg(shape)
+        t_non_contig = t.expand(3, -1, -1)
+        yield SampleInput(
+            t_non_contig, kwargs=op.sample_kwargs(device, dtype, t_non_contig)[0]
+        )
+
+
+def generate_elementwise_unary_arbitrarily_strided_tensors(
+    op, *, device, dtype, requires_grad=False
+):
+    # shape, strides, offset
+    strided_cases = (
+        ((5, 6, 2), (1, 1, 7), 2),
+        ((5, 5, 4), (1, 1, 7), 2),
+        ((5, 5, 2), (4, 5, 7), 3),
+        ((5, 5, 2), (5, 5, 7), 3),
+        ((5, 5, 2), (5, 5, 5), 3),
+        ((9, 5, 2), (0, 1, 7), 3),
+    )
+
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    for shape, strides, offset in strided_cases:
+        a = make_arg(
+            500,
+        ).as_strided(shape, strides, offset)
+        yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0])
+
+
+# Reuses the elementwise binary generators for consistency
+# TODO: in the future generalize the reference generators to handle n-ary elementwise operations
+def _reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs):
+    yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
+
+    yield from generate_elementwise_unary_tensors(
+        op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+    )
+
+    if dtype is not torch.bool:
+        yield from generate_elementwise_unary_small_value_tensors(
+            op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+        )
+    if dtype not in (torch.bool, torch.uint8, torch.int8) and (
+        op.handles_large_floats
+        or (not dtype.is_floating_point and not dtype.is_complex)
+    ):
+        yield from generate_elementwise_unary_large_value_tensors(
+            op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+        )
+
+    if dtype.is_floating_point or (
+        op.handles_complex_extremal_values and dtype.is_complex
+    ):
+        yield from generate_elementwise_unary_extremal_value_tensors(
+            op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+        )
+
+
+def reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs):
+    gen = partial(
+        _reference_inputs_elementwise_unary, op, device, dtype, requires_grad, **kwargs
+    )
+
+    # yields "normal" samples
+    yield from gen()
+
+    # yields noncontiguous samples
+    for sample in gen():
+        yield sample.noncontiguous()
+
+    yield from generate_elementwise_unary_noncontiguous_tensors(
+        op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+    )
+
+    yield from generate_elementwise_unary_arbitrarily_strided_tensors(
+        op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+    )
+
+
+# Metadata class for unary "universal functions (ufuncs)" that accept a single
+# tensor and have common properties like:
+class UnaryUfuncInfo(OpInfo):
+    """Operator information for 'universal unary functions (unary ufuncs).'
+    These are functions of a single tensor with common properties like:
+      - they are elementwise functions
+      - the input shape is the output shape
+      - they typically have method and inplace variants
+      - they typically support the out kwarg
+      - they typically have NumPy or SciPy references
+    See NumPy's universal function documentation
+    (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
+    about the concept of ufuncs.
+    """
+
+    def __init__(
+        self,
+        name,  # the string name of the function
+        *,
+        dtypes=floating_types(),
+        domain=(None, None),  # the [low, high) domain of the function
+        handles_complex_extremal_values=True,  # whether the op correctly handles extremal values (like nan/inf)
+        handles_large_floats=True,  # whether the op correctly handles large float values (like 1e20)
+        supports_complex_to_float=False,  # op supports casting from complex input to real output safely eg. angle
+        sample_inputs_func=sample_inputs_elementwise_unary,
+        reference_inputs_func=reference_inputs_elementwise_unary,
+        sample_kwargs=lambda device, dtype, input: ({}, {}),
+        reference_numerics_filter=None,  # Filters values in the range of the domain specified above but that should not be tested
+        **kwargs,
+    ):
+        self._original_unary_ufunc_args = locals().copy()
+
+        super().__init__(
+            name,
+            dtypes=dtypes,
+            sample_inputs_func=sample_inputs_func,
+            reference_inputs_func=reference_inputs_func,
+            **kwargs,
+        )
+        self.domain = domain
+        self.handles_complex_extremal_values = handles_complex_extremal_values
+        self.handles_large_floats = handles_large_floats
+        self.supports_complex_to_float = supports_complex_to_float
+        self.reference_numerics_filter = reference_numerics_filter
+
+        # test_unary_ufuncs.py generates its own inputs to test the consistency
+        # of the operator on sliced tensors, non-contig tensors, etc.
+        # `sample_kwargs` is a utility function to provide kwargs
+        # along with those inputs if required (eg. clamp).
+        # It should return two dictionaries, first holding kwarg for
+        # torch operator and second one for reference NumPy operator.
+        self.sample_kwargs = sample_kwargs
+
+        # Epsilon to ensure grad and gradgrad checks don't test values
+        #   outside a function's domain.
+        self._domain_eps = 1e-5
+
+
+def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
+    is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half
+    if not is_fp16_or_chalf:
+        nd_tensor = partial(
+            make_tensor,
+            (S, S + 1, S + 2),
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        oned_tensor = partial(
+            make_tensor, (31,), device=device, dtype=dtype, requires_grad=requires_grad
+        )
+    else:
+        # cuFFT supports powers of 2 for half and complex half precision
+        # NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args
+        # where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two
+        low = None
+        high = None
+        if self.name in ["fft.hfft", "fft.irfft", "_refs.fft.hfft", "_refs.fft.irfft"]:
+            shapes = ((2, 9, 9), (33,))
+        elif self.name in [
+            "fft.hfft2",
+            "fft.irfft2",
+            "_refs.fft.hfft2",
+            "_refs.fft.irfft2",
+        ]:
+            shapes = ((2, 8, 9), (33,))
+        elif self.name in [
+            "fft.hfftn",
+            "fft.irfftn",
+            "_refs.fft.hfftn",
+            "_refs.fft.irfftn",
+        ]:
+            shapes = ((2, 2, 33), (33,))
+            # Adjusting the limits because the test would be flaky due to over-saturation of float16
+            # See: https://github.com/pytorch/pytorch/pull/81416
+            low = -1.0
+            high = 1.0
+        else:
+            shapes = ((2, 8, 16), (32,))
+        nd_tensor = partial(
+            make_tensor,
+            shapes[0],
+            device=device,
+            low=low,
+            high=high,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        oned_tensor = partial(
+            make_tensor,
+            shapes[1],
+            device=device,
+            low=low,
+            high=high,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+
+    if self.ndimensional == SpectralFuncType.ND:
+        yield SampleInput(
+            nd_tensor(),
+            s=(3, 10) if not is_fp16_or_chalf else (4, 8),
+            dim=(1, 2),
+            norm="ortho",
+        )
+        yield SampleInput(nd_tensor(), norm="ortho")
+        yield SampleInput(nd_tensor(), s=(8,))
+        yield SampleInput(oned_tensor())
+        yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3, (0, -1)])
+    elif self.ndimensional == SpectralFuncType.TwoD:
+        yield SampleInput(
+            nd_tensor(),
+            s=(3, 10) if not is_fp16_or_chalf else (4, 8),
+            dim=(1, 2),
+            norm="ortho",
+        )
+        yield SampleInput(nd_tensor(), norm="ortho")
+        yield SampleInput(nd_tensor(), s=(6, 8) if not is_fp16_or_chalf else (4, 8))
+        yield SampleInput(nd_tensor(), dim=0)
+        yield SampleInput(nd_tensor(), dim=(0, -1))
+        yield SampleInput(nd_tensor(), dim=(-3, -2, -1))
+    else:
+        yield SampleInput(
+            nd_tensor(),
+            n=10 if not is_fp16_or_chalf else 8,
+            dim=1,
+            norm="ortho",
+        )
+        yield SampleInput(nd_tensor(), norm="ortho")
+        yield SampleInput(nd_tensor(), n=7 if not is_fp16_or_chalf else 8)
+        yield SampleInput(oned_tensor())
+        yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3])
+
+
+SpectralFuncType = Enum("SpectralFuncType", ("OneD", "TwoD", "ND"))
+
+
+# Metadata class for Fast Fourier Transforms in torch.fft.
+class SpectralFuncInfo(OpInfo):
+    """Operator information for torch.fft transforms."""
+
+    def __init__(
+        self,
+        name,  # the string name of the function
+        *,
+        ref=None,  # Reference implementation (probably in np.fft namespace)
+        dtypes=floating_and_complex_types(),
+        ndimensional: SpectralFuncType,
+        sample_inputs_func=sample_inputs_spectral_ops,
+        decorators=None,
+        **kwargs,
+    ):
+        self._original_spectral_func_args = dict(locals()).copy()
+        self._original_spectral_func_args.update(kwargs)
+
+        decorators = list(decorators) if decorators is not None else []
+        decorators += [
+            skipCPUIfNoFFT,
+            DecorateInfo(
+                toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}),
+                "TestCommon",
+                "test_complex_half_reference_testing",
+            ),
+        ]
+
+        super().__init__(
+            name=name,
+            dtypes=dtypes,
+            decorators=decorators,
+            sample_inputs_func=sample_inputs_func,
+            **kwargs,
+        )
+        self.ref = ref
+        self.ndimensional = ndimensional
+
+
+class ShapeFuncInfo(OpInfo):
+    """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll"""
+
+    def __init__(
+        self,
+        name,  # the string name of the function
+        *,
+        ref,  # a reference function
+        dtypes=floating_types(),
+        dtypesIfCUDA=None,
+        dtypesIfROCM=None,
+        dtypesIfXPU=None,
+        sample_inputs_func=None,
+        **kwargs,
+    ):
+        super().__init__(
+            name,
+            dtypes=dtypes,
+            dtypesIfCUDA=dtypesIfCUDA,
+            dtypesIfROCM=dtypesIfROCM,
+            dtypesIfXPU=dtypesIfXPU,
+            sample_inputs_func=sample_inputs_func,
+            **kwargs,
+        )
+        self.ref = ref
+
+
+def sample_inputs_foreach(
+    self,
+    device,
+    dtype,
+    N,
+    *,
+    noncontiguous=False,
+    same_size=False,
+    low=None,
+    high=None,
+    # zero_size means EVERY input is empty
+    zero_size: bool,
+    requires_grad: bool,
+    # mutually exclusive from same_size and zero_size, which are all or nothing
+    intersperse_empty_tensors: bool = False,
+):
+    if zero_size:
+        return [torch.empty(0, dtype=dtype, device=device) for _ in range(N)]
+    if same_size:
+        return [
+            make_tensor(
+                (N, N),
+                dtype=dtype,
+                device=device,
+                noncontiguous=noncontiguous,
+                low=low,
+                high=high,
+                requires_grad=requires_grad,
+            )
+            for _ in range(N)
+        ]
+    else:
+        # interweave some empty tensors + have the last 2 tensors be empty (see #100701)
+        return [
+            torch.empty(0, dtype=dtype, device=device, requires_grad=requires_grad)
+            if (i % 3 == 0 or i >= N - 2) and intersperse_empty_tensors
+            else make_tensor(
+                (N - i, N - i),
+                dtype=dtype,
+                device=device,
+                noncontiguous=noncontiguous,
+                low=low,
+                high=high,
+                requires_grad=requires_grad,
+            )
+            for i in range(N)
+        ]
+
+
+def get_foreach_method_names(name):
+    # get torch inplace reference function
+    op_name = "_foreach_" + name
+    inplace_op_name = op_name + "_"
+
+    op = getattr(torch, op_name, None)
+    inplace_op = getattr(torch, inplace_op_name, None)
+
+    ref = getattr(torch, name, None)
+    ref_inplace = getattr(torch.Tensor, name + "_", None)
+    return op, inplace_op, ref, ref_inplace
+
+
+@dataclass
+class ForeachFuncInfo(OpInfo):
+    """Early version of a specialized OpInfo for foreach functions
+
+    The main differences from the parent class are (a) `dtypes`, `dtypesIfCUDA`, and `dtypesIfROCM`
+    are set to `get_all_dtypes(include_qint=False)`, and (b) the following arguments.
+
+    ``supports_alpha_param=True`` means that the function supports a python scalar (``numbers.Number``)
+    as the last keyword argument such as `_foreach_add`.
+    ``supports_scalar_self_arg=True`` means that the function can take a python scalar as its first argument.
+    Currently only `_foreach_pow` supports this.
+    ``backward_requires_result=True``, which could sound self-explanatory, means that the function uses
+    the forward result for its backward computation.
+    """
+
+    supports_alpha_param: bool = False
+    supports_scalar_self_arg: bool = False
+    backward_requires_result: bool = False
+
+    def __post_init__(self):
+        (
+            foreach_method,
+            foreach_method_inplace,
+            torch_ref_method,
+            torch_ref_inplace,
+        ) = get_foreach_method_names(self.name)
+        if not self.supports_out:
+            # note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call
+            # `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero`
+            # is not defined at the moment. Thus to skip the qualification, set a similar torch
+            # function.
+            assert foreach_method is None
+            assert torch_ref_method is None
+            foreach_method = foreach_method_inplace
+            torch_ref_method = torch_ref_inplace
+
+        # We disable all complex128 tests internally for foreach due to reported flakiness
+        # tracked in #139648
+        supported_dtypes = get_all_dtypes(include_qint=False)
+        if IS_FBCODE:
+            supported_dtypes = [
+                x for x in supported_dtypes if x is not torch.complex128
+            ]
+        self.dtypes = _dispatch_dtypes(supported_dtypes)
+
+        self.op = foreach_method
+        self.method_variant = foreach_method
+        self.ref = torch_ref_method
+        self.inplace_variant = foreach_method_inplace
+        self.ref_inplace = torch_ref_inplace
+        self.has_no_in_place = self.inplace_variant is None
+
+        name = self.name
+        self.name = f"_foreach_{name}"
+        if name == "norm":
+            self.ref = torch.linalg.vector_norm
+        elif name == "minimum":
+            # because minimum ref does not support inplace or scalar
+            self.ref = torch.clamp_max
+            self.ref_inplace = torch.Tensor.clamp_max_
+        elif name == "maximum":
+            # because maximum ref does not support inplace or scalar
+            self.ref = torch.clamp_min
+            self.ref_inplace = torch.Tensor.clamp_min_
+
+        # The following sets `dtypesIfCUDA` and `dtypesIfROCM` accordingly.
+        super().__post_init__()
+
+    def sample_zero_size_inputs(self, device, dtype, requires_grad=False, **kwargs):
+        if not hasattr(self.sample_inputs_func, "sample_zero_size_tensor_inputs"):
+            return []
+        return self.sample_inputs_func.sample_zero_size_tensor_inputs(
+            self, device, dtype, requires_grad, **kwargs
+        )
+
+
+def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs):
+    """Gradcheck wrapper for functions that take Hermitian matrices as input.
+
+    They require a modified function because the finite-difference algorithm
+    for calculating derivatives does not preserve the Hermitian property of the input.
+    """
+    return op(input + input.mH, *args, **kwargs)
+
+
+def gradcheck_wrapper_ctc_loss(op, input, *args, **kwargs):
+    """Gradcheck wrapper for ctc loss to project onto log-simplex space."""
+    # See https://github.com/pytorch/pytorch/issues/52241
+    return op(input.log_softmax(dim=2), *args, **kwargs)
+
+
+def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs):
+    """Gradcheck wrapper for functions that take lower or upper triangular matrices as input.
+
+    They require a modified function because the finite-difference algorithm
+    for calculating derivatives does not preserve the triangular property of the input.
+    `idx` is used to specific which `args[idx]` is to be triangularized.
+    """
+    triangular_arg = args[idx].triu() if upper else args[idx].tril()
+    return op(*args[:idx], triangular_arg, *args[idx + 1 :], upper, **kwargs)
+
+
+def gradcheck_wrapper_triangular_input_real_positive_diagonal(
+    op, *args, upper=False, idx=0, **kwargs
+):
+    """Gradcheck wrapper for functions that take lower/upper triangular matrices
+    with real and positive diagonals, for example, cholesky-like operations.
+    """
+    arg = args[idx]
+    arg_diag = arg.diagonal(0, -2, -1)
+    arg_diag_embed = torch.diag_embed(arg_diag)
+    id_diag_tensor = torch.ones_like(arg_diag)
+    id_tensor = torch.diag_embed(id_diag_tensor)
+    # new_arg = arg - diag(arg) + I
+    new_arg = arg - arg_diag_embed + id_tensor
+    return gradcheck_wrapper_triangular_input(
+        op, *args[:idx], new_arg, *args[idx + 1 :], upper=upper, idx=idx, **kwargs
+    )
+
+
+def gradcheck_wrapper_masked_operation(op, input, *args, **kwargs):
+    """Gradcheck wrapper for masked operations.
+
+    When mask is specified, replaces masked-out elements with zeros.
+
+    Use for operations that produce non-finite masked-out elements,
+    for instance, for minimum and maximum reductions.
+    """
+    output = op(input, *args, **kwargs)
+    mask = kwargs.get("mask")
+    if mask is not None:
+        output_mask = torch.masked._output_mask(op, input, *args, **kwargs)
+        output = torch.where(output_mask, output, output.new_zeros([]))
+    return output
+
+
+def gradcheck_wrapper_masked_pointwise_operation(op, input, *args, **kwargs):
+    """Gradcheck wrapper for masked pointwise operations. Assumes that the result
+    will be masked iff both tensors are masked at a specific index
+
+    When mask is specified, replaces masked-out elements with zeros.
+
+    Use for operations that produce non-finite masked-out elements,
+    for instance, for minimum and maximum reductions.
+    """
+    output = op(input, *args, **kwargs)
+    input_mask = kwargs.get("input_mask")
+    other_mask = kwargs.get("other_mask")
+    if input_mask is not None and other_mask is not None:
+        combined_mask = torch.logical_and(input_mask, other_mask)
+        new_kwargs = dict(mask=combined_mask, **kwargs)
+        output_mask = torch.masked._input_mask(input, *args, **new_kwargs)
+        output = torch.where(output_mask, output, output.new_zeros([]))
+    return output
+
+
+def clone_sample(sample, **kwargs):
+    """
+    Given a SampleInput, this function analyzes its input, args and kwargs,
+    and produces a copy with each non-Tensor entry being copied by reference,
+    and with each Tensor entry cloned with `t.clone().requires_grad_(t.requires_grad)`
+    """
+
+    def clone_tensor(t):
+        if isinstance(t, torch.Tensor):
+            return t.detach().clone().requires_grad_(t.requires_grad)
+        else:
+            return t
+
+    sample_kwargs = kwargs if kwargs else sample.kwargs
+
+    return SampleInput(
+        clone_tensor(sample.input),
+        args=tuple(map(clone_tensor, sample.args)),
+        kwargs={k: clone_tensor(v) for k, v in sample_kwargs.items()},
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f26d3f402e741a54f21a5fca48beded5b0a58aec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/__init__.py
@@ -0,0 +1,26 @@
+# mypy: ignore-errors
+
+from torch.testing._internal.opinfo.core import OpInfo
+from torch.testing._internal.opinfo.definitions import (
+    _masked,
+    fft,
+    linalg,
+    signal,
+    special,
+)
+
+
+# Operator database
+op_db: list[OpInfo] = [
+    *fft.op_db,
+    *linalg.op_db,
+    *signal.op_db,
+    *special.op_db,
+    *_masked.op_db,
+]
+
+python_ref_db: list[OpInfo] = [
+    *fft.python_ref_db,
+    *linalg.python_ref_db,
+    *special.python_ref_db,
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py
new file mode 100644
index 0000000000000000000000000000000000000000..e05299632d04d51bb47ec6d35a51ea76b85b2fa6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -0,0 +1,1208 @@
+# mypy: ignore-errors
+
+import unittest
+from collections.abc import Sequence
+from functools import partial
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import tol, toleranceOverride
+from torch.testing._internal.common_dtype import (
+    all_types_and,
+    all_types_and_complex_and,
+    complex_types,
+    floating_and_complex_types_and,
+    floating_types_and,
+    integral_types,
+)
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    gradcheck_wrapper_masked_operation,
+    gradcheck_wrapper_masked_pointwise_operation,
+    M,
+    OpInfo,
+    ReductionOpInfo,
+    S,
+    sample_inputs_reduction,
+    SampleInput,
+)
+from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy
+
+
+# Used for log_softmax, softmax, softmin
+def sample_inputs_softmax_variant(
+    op_info,
+    device,
+    dtype,
+    requires_grad,
+    with_dtype=False,
+    use_zero_dimensions=True,
+    **kwargs,
+):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    cases = [
+        ((S,), (0,)),
+        ((S, S), (0,)),
+        ((S, S), (1,)),
+        ((S, S), (-1,)),
+        ((S, M, S), (2,)),
+        *([((S, 0, 0), (-1,))] if use_zero_dimensions else []),
+    ]
+    kwargs = dict(dtype=torch.float64) if with_dtype else None
+
+    # PyTorch on XLA throws an error when passed with dim argument for 0d tensor.
+    # See https://github.com/pytorch/xla/issues/3061 for more details.
+    if torch.device(device).type != "xla":
+        cases.append(((), (0,)))
+
+    return (
+        SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases
+    )
+
+
+def _generate_masked_op_mask(input_shape, device, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=torch.bool, device=device, requires_grad=False
+    )
+    yield None
+    yield make_arg(input_shape)
+    if len(input_shape) > 2:
+        # broadcast last mask dimension:
+        yield make_arg(input_shape[:-1] + (1,))
+        # broadcast middle mask dimension:
+        yield make_arg(input_shape[:1] + (1,) + input_shape[2:])
+        # broadcast first mask dimension:
+        yield make_arg((1,) + input_shape[1:])
+        # mask.ndim < input.ndim
+        yield make_arg(input_shape[1:])
+        # mask.ndim == 1
+        yield make_arg(input_shape[-1:])
+        # masks that require broadcasting of inputs (mask.ndim >
+        # input.ndim) will not be supported, however, we may
+        # reconsider this if there will be demand on this kind of
+        # degenerate cases.
+
+
+def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked reduction operators.
+
+    Masked reduction operator is a reduction operator with trailing
+    mask optional argument. A mask is a bool tensor with the same
+    shape as input or a shape that is broadcastable to input shape.
+    """
+    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
+
+    for sample_input in sample_inputs_reduction(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            sample_input_args, sample_input_kwargs = sample_input.args, dict(
+                mask=mask, **sample_input.kwargs
+            )
+            yield SampleInput(
+                sample_input.input.detach().requires_grad_(requires_grad),
+                args=sample_input_args,
+                kwargs=sample_input_kwargs,
+            )
+            if (
+                not requires_grad
+                and dtype.is_floating_point
+                and sample_input.input.ndim == 2
+                and mask is not None
+                and mask.shape == sample_input.input.shape
+            ):
+                for v in [torch.inf, -torch.inf, torch.nan]:
+                    t = sample_input.input.detach()
+                    t.diagonal(0, -2, -1).fill_(v)
+                    yield SampleInput(
+                        t.requires_grad_(requires_grad),
+                        args=sample_input_args,
+                        kwargs=sample_input_kwargs,
+                    )
+
+
+def sample_inputs_sparse_coo_masked_reduction(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    """Sample inputs for masked reduction operators that support inputs
+    with sparse coo layouts.
+    """
+    if op_info.supports_sparse:
+        op_name = op_info.name.replace("masked.", "")
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            mask = sample_input.kwargs.get("mask")
+            if mask is not None:
+                sample_input_kwargs = sample_input.kwargs.copy()
+                sample_input_kwargs.update(mask=mask.to_sparse())
+                yield SampleInput(
+                    sample_input.input.to_sparse(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+            else:
+                if op_name in {"prod", "amax", "amin"}:
+                    # FIXME: for now reductions with non-zero reduction identity and
+                    # unspecified mask are not supported for sparse COO
+                    # tensors, see torch.masked.prod implementation
+                    # for details.
+                    continue
+                yield SampleInput(
+                    sample_input.input.to_sparse(),
+                    args=sample_input.args,
+                    kwargs=sample_input.kwargs,
+                )
+
+
+def sample_inputs_sparse_csr_masked_reduction(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    """Sample inputs for masked reduction operators that support inputs
+    with sparse csr layouts.
+    """
+    if op_info.supports_sparse_csr:
+        op_name = op_info.name.replace("masked.", "")
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            if not (
+                sample_input.input.ndim == 2 and sample_input.kwargs.get("keepdim")
+            ):
+                # - sparse CSR tensors are always 2-D tensors
+                # - masked reduction on CSR tensors are defined only if keepdim is True.
+                continue
+            mask = sample_input.kwargs.get("mask")
+            if mask is not None:
+                sample_input_kwargs = sample_input.kwargs.copy()
+                sample_input_kwargs.update(mask=mask.to_sparse_csr())
+                new_sample = SampleInput(
+                    sample_input.input.to_sparse_csr(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+            else:
+                if op_name in ["prod", "amax", "amin", "mean"]:
+                    # reductions with non-zero reduction identity and
+                    # unspecified mask is not supported for sparse CSR
+                    # tensors, see torch.masked.prod implementation
+                    # for details.
+                    continue
+                new_sample = SampleInput(
+                    sample_input.input.to_sparse_csr(),
+                    args=sample_input.args,
+                    kwargs=sample_input.kwargs,
+                )
+            yield new_sample
+            if sample_input.kwargs["dim"] == 0:
+                # Reductions of CSR tensors use different implementations for
+                # inner and/or outer dimensions. So, as a minimum of testing CSR
+                # implementations the following kwargs must be generated:
+                #   dict(dim=0, keepdim=True)
+                #   dict(dim=1, keepdim=True)
+                #   dict(dim=(0, 1), keepdim=True)
+                # Here we generate the dim=1 case from the dim=0 case.
+                sample_input_kwargs = new_sample.kwargs.copy()
+                sample_input_kwargs.update(dim=1)
+                yield SampleInput(
+                    new_sample.input.clone(),
+                    args=sample_input.args,
+                    kwargs=sample_input_kwargs,
+                )
+
+
+def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked norm."""
+    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
+        for sample_input in sample_inputs_masked_reduction(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            sample_input_args, sample_input_kwargs = (
+                ord,
+            ) + sample_input.args, sample_input.kwargs.copy()
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                args=sample_input_args,
+                kwargs=sample_input_kwargs,
+            )
+
+
+def reference_masked_std_var(
+    numpy_fn,
+):
+    ref = reference_reduction_numpy(numpy_fn)
+
+    # Translate unbiased or correction arguments into ddof
+    def func(
+        input,
+        dim=None,
+        unbiased=None,
+        *,
+        correction=None,
+        **kwargs,
+    ):
+        ddof = 1
+        if unbiased is not None:
+            ddof = 1 if unbiased else 0
+        if correction is not None:
+            ddof = correction
+
+        if isinstance(dim, Sequence):
+            dim = tuple(dim)
+
+        return ref(input, dim, ddof=ddof, **kwargs)
+
+    return func
+
+
+def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked std/var."""
+    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
+    from torch.testing._internal.common_methods_invocations import sample_inputs_std_var
+
+    def masked_samples():
+        for sample_input in sample_inputs_std_var(
+            op_info, device, dtype, requires_grad, **kwargs
+        ):
+            if len(sample_input.args) and isinstance(sample_input.args[0], bool):
+                continue  # masked.{std, var} doesn't support `.var(unbiased)`
+
+            for mask in _generate_masked_op_mask(
+                sample_input.input.shape, device, **kwargs
+            ):
+                sample_input_args, sample_input_kwargs = sample_input.args, dict(
+                    mask=mask, **sample_input.kwargs
+                )
+                yield SampleInput(
+                    sample_input.input.detach().requires_grad_(requires_grad),
+                    args=sample_input_args,
+                    kwargs=sample_input_kwargs,
+                )
+                if (
+                    not requires_grad
+                    and dtype.is_floating_point
+                    and sample_input.input.ndim == 2
+                    and mask is not None
+                    and mask.shape == sample_input.input.shape
+                ):
+                    for v in [torch.inf, -torch.inf, torch.nan]:
+                        t = sample_input.input.detach()
+                        t.diagonal(0, -2, -1).fill_(v)
+                        yield SampleInput(
+                            t.requires_grad_(requires_grad),
+                            args=sample_input_args,
+                            kwargs=sample_input_kwargs,
+                        )
+
+    for sample_input in masked_samples():
+        correction = sample_input.kwargs.get("correction")
+        if correction is None:
+            correction = int(sample_input.kwargs.get("unbiased", True))
+
+        dim = sample_input.kwargs.get("dim", None)
+
+        if sample_input.kwargs.get("mask") is None:
+            orig_count = torch.masked.sum(
+                torch.ones(sample_input.input.shape, dtype=torch.int64),
+                dim,
+                keepdim=True,
+            )
+        else:
+            inmask = torch.masked._input_mask(
+                sample_input.input, *sample_input.args, **sample_input.kwargs
+            )
+            orig_count = torch.masked.sum(
+                inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
+                dim,
+                keepdim=True,
+                mask=inmask,
+            )
+        if orig_count.min() <= correction + 1:
+            # Skip samples that lead to nans in var computation
+            continue
+
+        yield sample_input
+
+
+def sample_inputs_masked_softmax(
+    op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
+):
+    """Sample inputs for masked softmax, log_softmax, and softmin.
+
+    Masked normalization operator is a reduction operator with
+    trailing mask optional argument. A mask is a bool tensor with the
+    same shape as input or a shape that is broadcastable to input
+    shape.
+    """
+    for sample_input in sample_inputs_softmax_variant(
+        op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                *sample_input.args,
+                mask=mask,
+                **sample_input.kwargs,
+            )
+
+
+def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked cumsum and cumprod."""
+    for sample_input in sample_inputs_softmax_variant(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        for mask in _generate_masked_op_mask(
+            sample_input.input.shape, device, **kwargs
+        ):
+            if type(mask) != torch.Tensor:
+                continue
+            sample_input_args, sample_input_kwargs = sample_input.args, dict(
+                mask=mask, **sample_input.kwargs
+            )
+            if "keepdim" in sample_input_kwargs:
+                sample_input_kwargs.pop("keepdim")
+            # dimension is required
+            if sample_input_args:
+                dim = sample_input.args[0]
+            else:
+                if "dim" not in sample_input_kwargs:
+                    continue
+                dim = sample_input_kwargs.pop("dim")
+                sample_input_args = (dim,)
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                *sample_input_args,
+                **sample_input_kwargs,
+            )
+
+
+def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked logaddexp."""
+    shapes = [(S,), (S, S), (S, M, S)]
+    input_mask_lists = [
+        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
+    ]
+    other_mask_lists = [
+        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
+    ]
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    for shape, input_masks, other_masks in zip(
+        shapes, input_mask_lists, other_mask_lists
+    ):
+        for input_mask, other_mask in zip(input_masks, other_masks):
+            yield SampleInput(
+                make_arg(shape),
+                make_arg(shape),
+                input_mask=input_mask,
+                other_mask=other_mask,
+            )
+
+
+def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
+    """Sample inputs for masked normalize."""
+    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
+        for sample_input in sample_inputs_softmax_variant(
+            op_info, device, dtype, requires_grad, use_zero_dimensions=False, **kwargs
+        ):
+            yield SampleInput(
+                sample_input.input.clone().requires_grad_(requires_grad),
+                ord,
+                *sample_input.args,
+                **sample_input.kwargs,
+            )
+
+
+op_db: list[OpInfo] = [
+    ReductionOpInfo(
+        "masked.sum",
+        ref=reference_reduction_numpy(np.sum),
+        method_variant=None,
+        identity=0,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        promotes_int_to_int64=True,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Failing on some jobs"),
+                "TestReductions",
+                "test_reference_masked",
+                dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-03, rtol=5e-2),
+                        torch.float16: tol(atol=1e-03, rtol=5e-3),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=0.1, rtol=0.1),
+                        torch.float16: tol(atol=5e-3, rtol=5e-3),
+                    }
+                ),
+                "TestMasked",
+                "test_mask_layout",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+    ),
+    ReductionOpInfo(
+        "masked.prod",
+        ref=prod_numpy,
+        method_variant=None,
+        identity=1,
+        nan_policy="propagate",
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        promotes_int_to_int64=True,
+        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.skip("Failing on some jobs"),
+                "TestReductions",
+                "test_reference_masked",
+                dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
+            ),
+            DecorateInfo(
+                "TestReductions",
+                "test_ref_small_input",
+                dtypes=(torch.int8, torch.int16, torch.int32),
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                device_type="cuda",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-02)}),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_duplicate_values",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}),
+                "TestMasked",
+                "test_mask_layout",
+                device_type="cpu",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}),
+                "TestOperators",
+                "test_jvp",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+    ),
+    OpInfo(
+        "masked.cumsum",
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        method_variant=None,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        # Can reuse the same inputs; dim is required in both
+        sample_inputs_func=sample_inputs_masked_cumops,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    OpInfo(
+        "masked.cumprod",
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        method_variant=None,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
+                "TestCompositeCompliance",
+                "test_backward",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-2, rtol=2.6e-3)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ),
+        # Can reuse the same inputs; dim is required in both
+        sample_inputs_func=sample_inputs_masked_cumops,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.amax",
+        nan_policy="propagate",
+        supports_out=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        supports_sparse=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_sparse_csr=True,
+        ref=reference_reduction_numpy(np.amax),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: amax reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: Unknown builtin op: aten::iinfo
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.amin",
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        supports_sparse=True,
+        supports_sparse_csr=True,
+        ref=reference_reduction_numpy(np.amin),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: amax reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: Unknown builtin op: aten::iinfo
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.argmax",
+        supports_out=False,
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # initial is not a keyword for argmax
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_reference_masked"
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.argmin",
+        supports_out=False,
+        supports_multiple_dims=False,
+        supports_autograd=False,
+        dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # initial is not a keyword for argmin
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_reference_masked"
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.mean",
+        ref=reference_reduction_numpy(np.mean)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_sparse_csr=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        promotes_int_to_float=True,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestMasked",
+                "test_mask_layout",
+                dtypes=(torch.bool, *integral_types(), *complex_types()),
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-03, rtol=0.05),
+                        torch.float16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=2e-03)}),
+                "TestSparseCompressed",
+                "test_consistency",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_reduction,
+        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    OpInfo(
+        "masked.median",
+        dtypes=floating_types_and(torch.bfloat16, torch.float16),
+        method_variant=None,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        sample_inputs_func=partial(
+            sample_inputs_masked_softmax, use_zero_dimensions=False
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.norm",
+        identity=0,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        promotes_int_to_float=True,
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # torch.jit.frontend.NotSupportedError: Compiled functions
+            # can't take variable number of arguments or use
+            # keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_masked_norm,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+    ReductionOpInfo(
+        "masked.var",
+        ref=reference_masked_std_var(np.var)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        promotes_int_to_float=True,
+        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
+                    }
+                ),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=4e-5, rtol=2e-2),
+                    }
+                ),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_std_var,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        check_batched_grad=True,
+    ),
+    ReductionOpInfo(
+        "masked.std",
+        ref=reference_masked_std_var(np.std)
+        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
+        else None,
+        method_variant=None,
+        nan_policy="propagate",
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        promotes_int_to_float=True,
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: sum reduces all dimensions when dim=[]
+            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # RuntimeError: undefined value tensor
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.bfloat16: tol(atol=1e-02, rtol=1e-02),
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                    }
+                ),
+                "TestReductions",
+                "test_reference_masked",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestReductions",
+                "test_ref_small_input",
+            ),
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float16: tol(atol=1e-02, rtol=1e-02),
+                        torch.bfloat16: tol(atol=5e-03, rtol=5e-04),
+                    }
+                ),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+        ],
+        sample_inputs_func=sample_inputs_masked_std_var,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        check_batched_grad=True,
+    ),
+    OpInfo(
+        "masked.softmax",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.log_softmax",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}),
+                "TestMasked",
+                "test_reference_masked",
+            ),
+        ],
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.softmin",
+        method_variant=None,
+        dtypes=floating_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_softmax,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # FIXME:
+            # Mismatched elements: 2 / 2 (100.0%)
+            # Greatest absolute difference: nan at index (0,) (up to 0.0001 allowed)
+            # Greatest relative difference: nan at index (0,) (up to 0.0001 allowed
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestOperators",
+                "test_vmapvjpvjp",
+                device_type="cpu",
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.normalize",
+        method_variant=None,
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_masked_normalize,
+        decorators=[
+            DecorateInfo(
+                toleranceOverride({torch.float16: tol(atol=2e-5, rtol=6e-3)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+        ),
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+    ),
+    OpInfo(
+        "masked.logaddexp",
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestFwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestBwdGradients", "test_fn_gradgrad"
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_logaddexp,
+        gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation,
+    ),
+    ReductionOpInfo(
+        "masked.logsumexp",
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        method_variant=None,
+        nan_policy="propagate",
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+            # FIXME: reduces all dimensions when dim=[]
+            DecorateInfo(unittest.skip("Skipped!"), "TestReductions", "test_dim_empty"),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestReductions", "test_dim_empty_keepdim"
+            ),
+            # Identity can't be -torch.inf without overflow
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestReductions",
+                "test_empty_tensor_empty_slice",
+            ),
+            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
+            ),
+            # all the values are the same except for -inf vs nan
+            DecorateInfo(unittest.skip("Skipped!"), "TestDecomp", "test_comprehensive"),
+            # FIXME:
+            # Mismatched elements: 2 / 12 (16.7%)
+            # Greatest absolute difference: 9223372034707292160 at index (0, 0, 0, 0)
+            # Greatest relative difference: 0.0 at index (0, 0, 0, 1)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cpu",
+            ),
+        ),
+        sample_inputs_func=sample_inputs_masked_reduction,
+        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
+    ),
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py
new file mode 100644
index 0000000000000000000000000000000000000000..8293fca978f262d7bf6eea6b546b2c3cd500f227
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/fft.py
@@ -0,0 +1,809 @@
+# mypy: ignore-errors
+
+import unittest
+from functools import partial
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import SM53OrLater
+from torch.testing._internal.common_device_type import precisionOverride
+from torch.testing._internal.common_dtype import (
+    all_types_and,
+    all_types_and_complex_and,
+)
+from torch.testing._internal.common_utils import TEST_SCIPY, TEST_WITH_ROCM
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    ErrorInput,
+    OpInfo,
+    sample_inputs_spectral_ops,
+    SampleInput,
+    SpectralFuncInfo,
+    SpectralFuncType,
+)
+from torch.testing._internal.opinfo.refs import (
+    _find_referenced_opinfo,
+    _inherit_constructor_args,
+    PythonRefInfo,
+)
+
+
+has_scipy_fft = False
+if TEST_SCIPY:
+    try:
+        import scipy.fft
+
+        has_scipy_fft = True
+    except ModuleNotFoundError:
+        pass
+
+
+class SpectralFuncPythonRefInfo(SpectralFuncInfo):
+    """
+    An OpInfo for a Python reference of an elementwise unary operation.
+    """
+
+    def __init__(
+        self,
+        name,  # the stringname of the callable Python reference
+        *,
+        op=None,  # the function variant of the operation, populated as torch. if None
+        torch_opinfo_name,  # the string name of the corresponding torch opinfo
+        torch_opinfo_variant="",
+        **kwargs,
+    ):  # additional kwargs override kwargs inherited from the torch opinfo
+        self.torch_opinfo_name = torch_opinfo_name
+        self.torch_opinfo = _find_referenced_opinfo(
+            torch_opinfo_name, torch_opinfo_variant, op_db=op_db
+        )
+        assert isinstance(self.torch_opinfo, SpectralFuncInfo)
+
+        inherited = self.torch_opinfo._original_spectral_func_args
+        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
+
+        super().__init__(**ukwargs)
+
+
+def error_inputs_fft(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+    # Zero-dimensional tensor has no dimension to take FFT of
+    yield ErrorInput(
+        SampleInput(make_arg()),
+        error_type=IndexError,
+        error_regex="Dimension specified as -1 but tensor has no dimensions",
+    )
+
+
+def error_inputs_fftn(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+    # Specifying a dimension on a zero-dimensional tensor
+    yield ErrorInput(
+        SampleInput(make_arg(), dim=(0,)),
+        error_type=IndexError,
+        error_regex="Dimension specified as 0 but tensor has no dimensions",
+    )
+
+
+def sample_inputs_fft_with_min(
+    op_info, device, dtype, requires_grad=False, *, min_size, **kwargs
+):
+    yield from sample_inputs_spectral_ops(
+        op_info, device, dtype, requires_grad, **kwargs
+    )
+    if TEST_WITH_ROCM:
+        # FIXME: Causes floating point exception on ROCm
+        return
+
+    # Check the "Invalid number of data points" error isn't too strict
+    # https://github.com/pytorch/pytorch/pull/109083
+    a = make_tensor(min_size, dtype=dtype, device=device, requires_grad=requires_grad)
+    yield SampleInput(a)
+
+
+def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
+    def mt(shape, **kwargs):
+        return make_tensor(
+            shape, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
+        )
+
+    yield SampleInput(mt((9, 10)))
+    yield SampleInput(mt((50,)), kwargs=dict(dim=0))
+    yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1,)))
+    yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1)))
+    yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2)))
+
+
+# Operator database
+op_db: list[OpInfo] = [
+    SpectralFuncInfo(
+        "fft.fft",
+        aten_name="fft_fft",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.fft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.fft2",
+        aten_name="fft_fft2",
+        ref=np.fft.fft2,
+        decomp_aten_name="_fft_c2c",
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_complex_half_reference_testing",
+                device_type="cuda",
+                dtypes=[torch.complex32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.fftn",
+        aten_name="fft_fftn",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.fftn,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})],
+    ),
+    SpectralFuncInfo(
+        "fft.hfft",
+        aten_name="fft_hfft",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.hfft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=2),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        check_batched_gradgrad=False,
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.hfft2",
+        aten_name="fft_hfft2",
+        decomp_aten_name="_fft_c2r",
+        ref=scipy.fft.hfft2 if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_gradgrad=False,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+        ],
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+            ),
+            # FIXME: errors are too large; needs investigation
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_complex_half_reference_testing",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.hfftn",
+        aten_name="fft_hfftn",
+        decomp_aten_name="_fft_c2r",
+        ref=scipy.fft.hfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(2, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_gradgrad=False,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+        ],
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+            ),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.rfft",
+        aten_name="fft_rfft",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfft,
+        ndimensional=SpectralFuncType.OneD,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        skips=(),
+        check_batched_gradgrad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.rfft2",
+        aten_name="fft_rfft2",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfft2,
+        ndimensional=SpectralFuncType.TwoD,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            precisionOverride({torch.float: 1e-4}),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.rfftn",
+        aten_name="fft_rfftn",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.rfftn,
+        ndimensional=SpectralFuncType.ND,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            precisionOverride({torch.float: 1e-4}),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ifft",
+        aten_name="fft_ifft",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=1),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.ifft2",
+        aten_name="fft_ifft2",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifft2,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ifftn",
+        aten_name="fft_ifftn",
+        decomp_aten_name="_fft_c2c",
+        ref=np.fft.ifftn,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.ihfft",
+        aten_name="fft_ihfft",
+        decomp_aten_name="_fft_r2c",
+        ref=np.fft.ihfft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fft,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        skips=(),
+        check_batched_grad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.ihfft2",
+        aten_name="fft_ihfft2",
+        decomp_aten_name="_fft_r2c",
+        ref=scipy.fft.ihfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=(
+            # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
+            ),
+            # Mismatched elements!
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warnings"),
+        ),
+    ),
+    SpectralFuncInfo(
+        "fft.ihfftn",
+        aten_name="fft_ihfftn",
+        decomp_aten_name="_fft_r2c",
+        ref=scipy.fft.ihfftn if has_scipy_fft else None,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 1)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archss
+        dtypesIfCUDA=all_types_and(
+            torch.bool, *(() if (not SM53OrLater) else (torch.half,))
+        ),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        decorators=[
+            # The values for attribute 'shape' do not match: torch.Size([5, 6, 5]) != torch.Size([5, 6, 6]).
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_warning"),
+            # Mismatched elements!
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out"),
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}), "TestFFT", "test_reference_nd"
+            ),
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.irfft",
+        aten_name="fft_irfft",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfft,
+        ndimensional=SpectralFuncType.OneD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fft,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+    ),
+    SpectralFuncInfo(
+        "fft.irfft2",
+        aten_name="fft_irfft2",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfft2,
+        ndimensional=SpectralFuncType.TwoD,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncInfo(
+        "fft.irfftn",
+        aten_name="fft_irfftn",
+        decomp_aten_name="_fft_c2r",
+        ref=np.fft.irfftn,
+        ndimensional=SpectralFuncType.ND,
+        sample_inputs_func=partial(sample_inputs_fft_with_min, min_size=(1, 2)),
+        error_inputs_func=error_inputs_fftn,
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        dtypes=all_types_and_complex_and(torch.bool),
+        # CUDA supports Half/ComplexHalf Precision FFT only on SM53 or later archs
+        dtypesIfCUDA=all_types_and_complex_and(
+            torch.bool,
+            *(() if (not SM53OrLater) else (torch.half, torch.complex32)),
+        ),
+        check_batched_gradgrad=False,
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    OpInfo(
+        "fft.fftshift",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.half, torch.chalf
+        ),
+        sample_inputs_func=sample_inputs_fftshift,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    OpInfo(
+        "fft.ifftshift",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.half, torch.chalf
+        ),
+        sample_inputs_func=sample_inputs_fftshift,
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fft",
+        torch_opinfo_name="fft.fft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifft",
+        torch_opinfo_name="fft.ifft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfft",
+        torch_opinfo_name="fft.rfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfft",
+        torch_opinfo_name="fft.irfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfft",
+        torch_opinfo_name="fft.hfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfft",
+        torch_opinfo_name="fft.ihfft",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fftn",
+        torch_opinfo_name="fft.fftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifftn",
+        torch_opinfo_name="fft.ifftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfftn",
+        torch_opinfo_name="fft.rfftn",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfftn",
+        torch_opinfo_name="fft.irfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfftn",
+        torch_opinfo_name="fft.hfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfftn",
+        torch_opinfo_name="fft.ihfftn",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+            # AssertionError: Reference result was farther (0.09746177145360499) from the precise
+            # computation than the torch result was (0.09111555632069855)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_torch_fallback",
+                dtypes=(torch.float16,),
+                device_type="cuda",
+            ),
+            # AssertionError: Reference result was farther (0.0953431016138116) from the precise
+            # computation than the torch result was (0.09305490684430734)
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_executor",
+                dtypes=(torch.float16,),
+                device_type="cuda",
+            ),
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.fft2",
+        torch_opinfo_name="fft.fft2",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ifft2",
+        torch_opinfo_name="fft.ifft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.rfft2",
+        torch_opinfo_name="fft.rfft2",
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.irfft2",
+        torch_opinfo_name="fft.irfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.hfft2",
+        torch_opinfo_name="fft.hfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            )
+        ],
+    ),
+    SpectralFuncPythonRefInfo(
+        "_refs.fft.ihfft2",
+        torch_opinfo_name="fft.ihfft2",
+        decorators=[
+            DecorateInfo(
+                precisionOverride({torch.float: 2e-4}),
+                "TestFFT",
+                "test_reference_nd",
+            ),
+            # FIXME:
+            # Reference result was farther (0.0953431016138116) from the precise computation
+            # than the torch result was (0.09305490684430734)!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_python_ref_executor",
+                device_type="cuda",
+            ),
+        ],
+    ),
+    PythonRefInfo(
+        "_refs.fft.fftshift",
+        op_db=op_db,
+        torch_opinfo_name="fft.fftshift",
+    ),
+    PythonRefInfo(
+        "_refs.fft.ifftshift",
+        op_db=op_db,
+        torch_opinfo_name="fft.ifftshift",
+    ),
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eeacf887084be7ace0555ff580be9f6ce517146
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/linalg.py
@@ -0,0 +1,2369 @@
+# mypy: ignore-errors
+
+import itertools
+import random
+import unittest
+from collections.abc import Iterable
+from functools import partial
+from itertools import chain, product
+
+import numpy as np
+from numpy import inf
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_cuda import (
+    _get_magma_version,
+    _get_torch_cuda_version,
+    with_tf32_off,
+)
+from torch.testing._internal.common_device_type import (
+    has_cusolver,
+    skipCPUIfNoLapack,
+    skipCUDAIf,
+    skipCUDAIfNoCusolver,
+    skipCUDAIfNoMagma,
+    skipCUDAIfNoMagmaAndNoCusolver,
+    skipCUDAIfNoMagmaAndNoLinalgsolver,
+    skipCUDAIfRocm,
+    tol,
+    toleranceOverride,
+)
+from torch.testing._internal.common_dtype import (
+    all_types_and_complex,
+    all_types_and_complex_and,
+    floating_and_complex_types,
+    floating_and_complex_types_and,
+)
+from torch.testing._internal.common_utils import (
+    GRADCHECK_NONDET_TOL,
+    make_fullrank_matrices_with_distinct_singular_values,
+    skipIfSlowGradcheckEnv,
+    slowTest,
+    TEST_WITH_ROCM,
+)
+from torch.testing._internal.opinfo.core import (
+    clone_sample,
+    DecorateInfo,
+    ErrorInput,
+    gradcheck_wrapper_hermitian_input,
+    L,
+    M,
+    OpInfo,
+    ReductionOpInfo,
+    S,
+    SampleInput,
+)
+from torch.testing._internal.opinfo.refs import PythonRefInfo, ReductionPythonRefInfo
+
+
+def sample_kwargs_vector_norm(t, **kwargs):
+    # orders with / without identity
+    def ords():
+        has_id = (6, 4, 2, 1, 0, 0.9)
+        no_id = (inf, -2.1, -inf)
+        if t.numel() == 0:
+            dim = kwargs.get("dim")
+            if dim is None:
+                return has_id
+            if not isinstance(dim, Iterable):
+                dim = (dim,)
+            for d in dim:
+                if t.size(d) == 0:
+                    return has_id
+        return has_id + no_id
+
+    return (((), dict(ord=o)) for o in ords())
+
+
+def sample_inputs_svd(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    is_linalg_svd = "linalg.svd" in op_info.name
+    batches = [(), (0,), (3,)]
+    ns = [0, 3, 5]
+
+    def uniformize(usv):
+        S = usv[1]
+        k = S.shape[-1]
+        U = usv[0][..., :k]
+        Vh = usv[2] if is_linalg_svd else usv[2].mH
+        Vh = Vh[..., :k, :]
+        return U, S, Vh
+
+    def fn_U(usv):
+        U, _, _ = uniformize(usv)
+        return U.abs()
+
+    def fn_S(usv):
+        return uniformize(usv)[1]
+
+    def fn_Vh(usv):
+        # We also return S to test
+        _, S, Vh = uniformize(usv)
+        return S, Vh.abs()
+
+    def fn_UVh(usv):
+        U, S, Vh = uniformize(usv)
+        return U @ Vh, S
+
+    fns = (fn_U, fn_S, fn_Vh, fn_UVh)
+
+    fullmat = "full_matrices" if is_linalg_svd else "some"
+
+    for batch, n, k, fullmat_val, fn in product(batches, ns, ns, (True, False), fns):
+        shape = batch + (n, k)
+        yield SampleInput(
+            make_arg(*shape), kwargs={fullmat: fullmat_val}, output_process_fn_grad=fn
+        )
+
+
+def sample_inputs_cross(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg((S, 3)), args=(make_arg((S, 3)),))
+    yield SampleInput(
+        make_arg((S, 3, S)), args=(make_arg((S, 3, S)),), kwargs=dict(dim=1)
+    )
+    yield SampleInput(make_arg((1, 3)), args=(make_arg((S, 3)),), kwargs=dict(dim=-1))
+
+
+def error_inputs_cross(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    sample = SampleInput(input=make_arg((S, 3)), args=(make_arg((S, 1)),))
+    err = "inputs dimension -1 must have length 3"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(input=make_arg((5, S, 3)), args=(make_arg((S, 3)),))
+    err = "inputs must have the same number of dimensions"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(input=make_arg((S, 2)), args=(make_arg((S, 2)),))
+    err = "must have length 3"
+    yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+
+    sample = SampleInput(
+        input=make_arg((S, 2)), args=(make_arg((S, 2)),), kwargs=dict(dim=2)
+    )
+    err = "Dimension out of range"
+    yield ErrorInput(sample, error_regex=err, error_type=IndexError)
+
+
+def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
+    """
+    This function generates input for torch.linalg.householder_product (torch.orgqr).
+    The first argument should be a square matrix or batch of square matrices, the second argument is a vector or batch of vectors.
+    Empty, square, rectangular, batched square and batched rectangular input is generated.
+    """
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        low=-2,
+        high=2,
+    )
+    # Each column of the matrix is getting multiplied many times leading to very large values for
+    # the Jacobian matrix entries and making the finite-difference result of grad check less accurate.
+    # That's why gradcheck with the default range [-9, 9] fails and [-2, 2] is used here.
+    yield SampleInput(make_arg((S, S)), make_arg((S,)))
+    yield SampleInput(make_arg((S + 1, S)), make_arg((S,)))
+    yield SampleInput(make_arg((2, 1, S, S)), make_arg((2, 1, S)))
+    yield SampleInput(make_arg((2, 1, S + 1, S)), make_arg((2, 1, S)))
+    yield SampleInput(
+        make_arg((0, 0), low=None, high=None),
+        make_arg((0,), low=None, high=None),
+    )
+    yield SampleInput(make_arg((S, S)), make_arg((0,), low=None, high=None))
+    # m = n = S, k = S - 2
+    yield SampleInput(make_arg((S, S)), make_arg((S - 2,), low=None, high=None))
+    # m = S, n = S -1, k = S - 2
+    yield SampleInput(make_arg((S, S - 1)), make_arg((S - 2,), low=None, high=None))
+
+
+def sample_inputs_linalg_matrix_power(op_info, device, dtype, requires_grad, **kwargs):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    make_arg_fullrank = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    # (, ())
+    test_sizes = [
+        (1, ()),
+        (2, (0,)),
+        (2, (2,)),
+    ]
+
+    for matrix_size, batch_sizes in test_sizes:
+        size = batch_sizes + (matrix_size, matrix_size)
+        for n in (0, 3, 5):
+            yield SampleInput(make_arg(size), args=(n,))
+        for n in [-4, -2, -1]:
+            yield SampleInput(make_arg_fullrank(*size), args=(n,))
+
+
+def sample_inputs_linalg_det_logdet_slogdet(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    batches = [(), (0,), (3,)]
+    ns = [0, 1, 5]
+
+    is_logdet = op_info.name == "logdet"
+
+    for (
+        batch,
+        n,
+    ) in product(batches, ns):
+        shape = batch + (n, n)
+        A = make_arg(*shape)
+        # Need to make the matrices in A have positive determinant for autograd
+        # To do so, we multiply A by its determinant to flip the sign of its determinant
+        if is_logdet and not A.is_complex() and A.numel() > 0:
+            s = torch.linalg.slogdet(A).sign
+            A = A * s.unsqueeze(-1).unsqueeze(-1)
+            A.requires_grad_(requires_grad)
+        yield SampleInput(A)
+
+
+def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs):
+    """Samples the inputs for both linalg.lu_solve and lu_solve"""
+    make_fn = make_fullrank_matrices_with_distinct_singular_values
+    make_a = partial(make_fn, dtype=dtype, device=device)
+    make_b = partial(make_tensor, dtype=dtype, device=device)
+
+    def clone(X, requires_grad):
+        Y = X.clone()
+        Y.requires_grad_(requires_grad)
+        return Y
+
+    is_linalg_lu_solve = op_info.name == "linalg.lu_solve"
+
+    batches = ((), (0,), (2,))
+    ns = (3, 1, 0)
+    nrhs = (4, 1, 0)
+
+    for n, batch, rhs in product(ns, batches, nrhs):
+        A = make_a(*(batch + (n, n)))
+        LU, pivots = torch.linalg.lu_factor(A)
+
+        B = make_b(batch + (n, rhs))
+
+        grads = (False,) if not requires_grad else (True, False)
+        # we try all possible combinations of requires_grad for each input
+        for LU_grad, B_grad in product(grads, grads):
+            # when requires_grad == True, at least one input has to have requires_grad enabled
+            if requires_grad and not LU_grad and not B_grad:
+                continue
+
+            if is_linalg_lu_solve:
+                for adjoint, left in product((True, False), repeat=2):
+                    yield SampleInput(
+                        clone(LU, LU_grad),
+                        args=(pivots, clone(B if left else B.mT, B_grad)),
+                        kwargs=dict(adjoint=adjoint, left=left),
+                    )
+            else:
+                yield SampleInput(clone(B, B_grad), args=(clone(LU, LU_grad), pivots))
+
+
+def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad, **kwargs):
+    # Each test case consists of the sizes in the chain of multiplications
+    # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5)
+    test_cases = [
+        [1, 2, 1],
+        [2, 0, 2],
+        [0, 2, 2],
+        [2, 2, 2, 2],
+        [2, 3, 4, 5],
+        [5, 4, 0, 2],
+        [2, 4, 3, 5, 3, 2],
+    ]
+
+    for sizes in test_cases:
+        tensors = []
+        for size in zip(sizes[:-1], sizes[1:]):
+            t = make_tensor(
+                size, dtype=dtype, device=device, requires_grad=requires_grad
+            )
+            tensors.append(t)
+        yield SampleInput(tensors)
+
+
+def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs):
+    low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+
+    sizes = ((2, 2), (2, 3, 2))
+    if dtype in low_precision_dtypes:
+        # svdvals not supported for low precision dtypes
+        ords = ("fro", inf, -inf, 1, -1)
+    else:
+        ords = ("fro", "nuc", inf, -inf, 1, -1, 2, -2)
+    dims = ((-2, -1), (-1, 0))
+
+    for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]):
+        yield SampleInput(make_arg(size), args=(ord, dim, keepdim))
+
+
+def sample_inputs_linalg_norm(
+    op_info, device, dtype, requires_grad, *, variant=None, **kwargs
+):
+    if variant is not None and variant not in ("subgradient_at_zero",):
+        raise ValueError(
+            f"Unsupported variant, expected variant to be 'subgradient_at_zero' but got: {variant}"
+        )
+
+    test_sizes = [
+        (S,),
+        (0,),
+        (S, S),
+        (0, 0),
+        (S, 0),
+        (0, S),
+        (S, S, S),
+        (0, S, S),
+        (S, 0, S),
+        (0, 0, 0),
+    ]
+
+    vector_ords = (None, 0, 0.5, 1, 2, 3.5, inf, -0.5, -1, -2, -3.5, -inf)
+    if dtype in {torch.float16, torch.bfloat16, torch.complex32}:
+        # svdvals not supported for low precision dtypes
+        matrix_ords = ("fro", inf, -inf, 1, -1)
+    else:
+        matrix_ords = (None, "fro", "nuc", inf, -inf, 1, -1, 2, -2)
+
+    make_arg = partial(
+        make_tensor,
+        dtype=dtype,
+        device=device,
+        requires_grad=requires_grad,
+        low=None,
+        high=None,
+    )
+
+    for test_size in test_sizes:
+        is_vector_norm = len(test_size) == 1
+        is_matrix_norm = len(test_size) == 2
+
+        # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
+        is_valid_for_p2 = is_vector_norm or (test_size[-1] != 0 and test_size[-2] != 0)
+
+        for keepdim in [False, True]:
+            if variant != "subgradient_at_zero" and is_valid_for_p2:
+                yield SampleInput(make_arg(test_size), keepdim=keepdim)
+
+            if not (is_vector_norm or is_matrix_norm):
+                continue
+
+            ords = vector_ords if is_vector_norm else matrix_ords
+
+            for ord in ords:
+                if is_vector_norm and test_size[-1] == 0:
+                    if ord == np.inf or (ord is not None and ord < 0):
+                        # RuntimeError: linalg.vector_norm cannot compute the
+                        # {ord} norm on an empty tensor because the operation
+                        # does not have an identity
+                        continue
+                elif is_matrix_norm:
+                    dims_to_check = {
+                        None: (0,),
+                        np.inf: (0,),
+                        2: (0, 1),
+                        1: (1,),
+                        -1: (1,),
+                        -2: (0, 1),
+                        -np.inf: (0,),
+                    }.get(ord, ())
+
+                    if any(test_size[d] == 0 for d in dims_to_check):
+                        # IndexError: amax(): Expected reduction dim {dim} to
+                        # have non-zero size.
+                        continue
+
+                if variant == "subgradient_at_zero":
+                    yield SampleInput(
+                        torch.zeros(
+                            test_size,
+                            dtype=dtype,
+                            device=device,
+                            requires_grad=requires_grad,
+                        ),
+                        ord,
+                        keepdim=keepdim,
+                    )
+                else:
+                    yield SampleInput(make_arg(test_size), ord, keepdim=keepdim)
+
+                    if ord in ["nuc", "fro"]:
+                        yield SampleInput(
+                            make_arg(test_size), ord=ord, keepdim=keepdim, dim=(0, 1)
+                        )
+
+
+def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    batches = ((), (0,), (1,), (5,))
+    ns = (0, 1, 3, 5)
+    for b, n in product(batches, ns):
+        shape = b + (n,)
+        yield SampleInput(make_arg(shape), args=(make_arg(shape),))
+        for i in range(len(shape)):
+            yield SampleInput(
+                make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i)
+            )
+
+
+def sample_inputs_linalg_invertible(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates invertible inputs for linear algebra ops
+    The input is generated as the itertools.product of 'batches' and 'ns'.
+    In total this function generates 8 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices,
+        (1, 1) - 1x1 batch of matrices
+    'ns' gives 0x0 and 5x5 matrices.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    """
+    make_fn = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 0]
+
+    for batch, n in product(batches, ns):
+        yield SampleInput(make_arg(*batch, n, n))
+
+
+def sample_inputs_matrix_rank(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function produces inputs for matrix rank that test
+    all possible combinations for atol and rtol
+    """
+
+    def make_tol_arg(kwarg_type, inp):
+        if kwarg_type == "none":
+            return None
+        if kwarg_type == "float":
+            return 1.0
+        assert kwarg_type == "tensor"
+        return torch.ones(inp.shape[:-2], device=device)
+
+    for tol_type in ["float", "tensor"]:
+        for atol_type, rtol_type in product(["none", tol_type], repeat=2):
+            if (
+                not atol_type and not rtol_type
+            ):  # default behavior, so skipped here so it's not tested 2 extra times
+                continue
+            for sample in sample_inputs_linalg_invertible(
+                op_info, device, dtype, requires_grad
+            ):
+                assert sample.kwargs == {}
+                sample.kwargs = {
+                    "atol": make_tol_arg(atol_type, sample.input),
+                    "rtol": make_tol_arg(rtol_type, sample.input),
+                }
+                yield sample
+
+    # default kwargs
+    yield from sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+
+
+def sample_inputs_linalg_pinv_singular(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function produces factors `a` and `b` to generate inputs of the form `a @ b.t()` to
+    test the backward method of `linalg_pinv`. That way we always preserve the rank of the
+    input no matter the perturbations applied to it by the gradcheck.
+    Note that `pinv` is Frechet-differentiable in a rank-preserving neighborhood.
+    """
+    batches = [(), (0,), (2,), (1, 1)]
+    # the size of at least 30 is required to cause failures for the previous implicit implementation
+    # of the pinv's backward method, albeit it is slow.
+    size = [0, 3, 50]
+
+    for batch, m, n in product(batches, size, size):
+        for k in range(min(3, m, n)):
+            # Note that by making the columns of `a` and `b` orthonormal we make sure that
+            # the product matrix `a @ b.t()` has condition number 1 when restricted to its image
+            a = (
+                torch.rand(*batch, m, k, device=device, dtype=dtype)
+                .qr()
+                .Q.requires_grad_(requires_grad)
+            )
+            b = (
+                torch.rand(*batch, n, k, device=device, dtype=dtype)
+                .qr()
+                .Q.requires_grad_(requires_grad)
+            )
+            yield SampleInput(a, args=(b,))
+
+
+def sample_inputs_linalg_cond(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    # autograd is not supported for inputs with zero number of elements
+    shapes = (
+        (S, S),
+        (2, S, S),
+        (2, 1, S, S),
+    )
+
+    for shape in shapes:
+        yield SampleInput(make_arg(shape))
+
+
+def sample_inputs_linalg_vander(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    shapes = (
+        (),
+        (1,),
+        (S,),
+        (2, S),
+    )
+
+    for shape in shapes:
+        if len(shape) > 0 and shape[-1] > 1:
+            yield SampleInput(make_arg(shape))
+        n = shape[-1] if len(shape) > 0 else 1
+        for i in range(3):
+            # n-1, n, n+1
+            N = n + i - 1
+            if N < 2:
+                continue
+            yield SampleInput(make_arg(shape), kwargs=dict(N=N))
+
+
+def np_vander_batched(x, N=None):
+    # Wrapper around np.vander that supports batches of 1 dimension (enough for the tests)
+    if x.ndim == 0:
+        x = x[np.newaxis]
+    if x.ndim == 1:
+        y = np.vander(x, N=N, increasing=True)
+        return y
+    else:
+        if N is None:
+            N = x.shape[-1]
+        y = np.vander(x.ravel(), N=N, increasing=True).reshape((*x.shape, N))
+        return y
+
+
+def sample_inputs_linalg_cholesky_inverse(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    from torch.testing._internal.common_utils import random_well_conditioned_matrix
+
+    # Cholesky factorization is for positive-definite matrices
+    single_well_conditioned_matrix = random_well_conditioned_matrix(
+        S, S, dtype=dtype, device=device
+    )
+    batch_well_conditioned_matrices = random_well_conditioned_matrix(
+        2, S, S, dtype=dtype, device=device
+    )
+    single_pd = single_well_conditioned_matrix @ single_well_conditioned_matrix.mH
+    batch_pd = batch_well_conditioned_matrices @ batch_well_conditioned_matrices.mH
+
+    inputs = (
+        torch.zeros(0, 0, dtype=dtype, device=device),  # 0x0 matrix
+        torch.zeros(0, 2, 2, dtype=dtype, device=device),  # zero batch of matrices
+        single_pd,
+        batch_pd,
+    )
+    test_cases = (torch.linalg.cholesky(a, upper=False) for a in inputs)
+    for l in test_cases:
+        # generated lower-triangular samples
+        l.requires_grad = requires_grad
+        yield SampleInput(l)  # upper=False by default
+        yield SampleInput(
+            l.detach().clone().requires_grad_(requires_grad), kwargs=dict(upper=False)
+        )
+
+        # generate upper-triangular inputs
+        u = l.detach().clone().mT.contiguous().requires_grad_(requires_grad)
+        yield SampleInput(u, kwargs=dict(upper=True))
+
+
+def sample_inputs_linalg_ldl_factor(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    from torch.testing._internal.common_utils import (
+        random_hermitian_pd_matrix,
+        random_symmetric_pd_matrix,
+    )
+
+    device = torch.device(device)
+
+    # Symmetric inputs
+    yield SampleInput(
+        random_symmetric_pd_matrix(S, dtype=dtype, device=device),
+        kwargs=dict(hermitian=False),
+    )  # single matrix
+    yield SampleInput(
+        random_symmetric_pd_matrix(S, 2, dtype=dtype, device=device),
+        kwargs=dict(hermitian=False),
+    )  # batch of matrices
+    yield SampleInput(
+        torch.zeros(0, 0, dtype=dtype, device=device), kwargs=dict(hermitian=False)
+    )  # 0x0 matrix
+    yield SampleInput(
+        torch.zeros(0, 2, 2, dtype=dtype, device=device), kwargs=dict(hermitian=False)
+    )  # zero batch of matrices
+
+    # Hermitian inputs
+    # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+
+    magma_254_available = device.type == "cuda" and _get_magma_version() >= (2, 5, 4)
+    if dtype.is_complex and (device.type == "cpu" or magma_254_available):
+        yield SampleInput(
+            random_hermitian_pd_matrix(S, dtype=dtype, device=device),
+            kwargs=dict(hermitian=True),
+        )  # single matrix
+        yield SampleInput(
+            random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
+            kwargs=dict(hermitian=True),
+        )  # batch of matrices
+
+
+def sample_inputs_linalg_ldl_solve(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    # Generate LDL factors of symmetric (and Hermitian on CPU) matrices
+    from torch.testing._internal.common_utils import (
+        random_hermitian_pd_matrix,
+        random_symmetric_pd_matrix,
+    )
+
+    device = torch.device(device)
+    symmetric_inputs = (
+        random_symmetric_pd_matrix(S, dtype=dtype, device=device),  # single matrix
+        random_symmetric_pd_matrix(
+            S, 2, dtype=dtype, device=device
+        ),  # batch of matrices
+        torch.zeros(0, 0, dtype=dtype, device=device),  # 0x0 matrix
+        torch.zeros(0, 2, 2, dtype=dtype, device=device),  # zero batch of matrices
+    )
+    hermitian_inputs = (
+        (
+            random_hermitian_pd_matrix(S, dtype=dtype, device=device),
+            random_hermitian_pd_matrix(S, 2, dtype=dtype, device=device),
+        )
+        if device.type == "cpu" and dtype.is_complex
+        else ()
+    )
+    test_cases1 = (
+        torch.linalg.ldl_factor_ex(a, hermitian=False) for a in symmetric_inputs
+    )
+    test_cases2 = (
+        torch.linalg.ldl_factor_ex(a, hermitian=True) for a in hermitian_inputs
+    )
+
+    # Symmetric case
+    make_arg = partial(
+        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+    )
+    for test_case in test_cases1:
+        factors, pivots, _ = test_case
+        factors.requires_grad = requires_grad
+        for B_batch_shape in ((), factors.shape[:-2]):
+            B = make_arg((*B_batch_shape, factors.shape[-1], S))
+            yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=False))
+            clone_factors = factors.detach().clone().requires_grad_(requires_grad)
+            yield SampleInput(
+                clone_factors, args=(pivots, B), kwargs=dict(hermitian=False)
+            )
+
+    # Hermitian case
+    for test_case in test_cases2:
+        factors, pivots, _ = test_case
+        factors.requires_grad = requires_grad
+        for B_batch_shape in ((), factors.shape[:-2]):
+            B = make_arg((*B_batch_shape, factors.shape[-1], S))
+            yield SampleInput(factors, args=(pivots, B), kwargs=dict(hermitian=True))
+            clone_factors = factors.detach().clone().requires_grad_(requires_grad)
+            yield SampleInput(
+                clone_factors, args=(pivots, B), kwargs=dict(hermitian=True)
+            )
+
+
+def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kwargs):
+    from torch.testing._internal.common_utils import random_well_conditioned_matrix
+
+    device = torch.device(device)
+
+    drivers: tuple[str, ...]
+    if device.type == "cuda":
+        drivers = ("gels",)
+    else:
+        drivers = ("gels", "gelsy", "gelss", "gelsd")
+
+    # we generate matrices of shape (..., n + delta, n)
+    deltas: tuple[int, ...]
+    if device.type == "cpu" or has_cusolver():
+        deltas = (-1, 0, +1)
+    # only square systems if Cusolver is not available
+    # because we solve a lstsq problem with a transposed matrix in the backward
+    else:
+        deltas = (0,)
+
+    for batch, driver, delta in product(((), (3,), (3, 3)), drivers, deltas):
+        shape = batch + (3 + delta, 3)
+        a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
+        a.requires_grad_(requires_grad)
+        b = make_tensor(
+            shape,
+            dtype=dtype,
+            device=device,
+            low=None,
+            high=None,
+            requires_grad=requires_grad,
+        )
+        yield SampleInput(a, b, driver=driver)
+
+
+def error_inputs_lstsq(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(
+        SampleInput(zero_d, args=(zero_d,)),
+        error_type=RuntimeError,
+        error_regex="at least 2 dimensions",
+    )
+
+
+def error_inputs_lstsq_grad_oriented(op_info, device, **kwargs):
+    zero_d = torch.randn((), device=device)
+    yield ErrorInput(
+        SampleInput(zero_d, args=(zero_d, None)),
+        error_type=RuntimeError,
+        error_regex="at least 2 dimensions",
+    )
+
+
+def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    # Shapes for 2D Tensors
+    shapes_2d = ((S, S), (3, 5), (5, 3))
+
+    # Shapes for 3D Tensors
+    shapes_3d = ((S, S, S),)
+
+    kwargs_2d = ({}, dict(offset=2), dict(offset=2), dict(offset=1))
+    kwargs_3d = (
+        dict(offset=1, dim1=1, dim2=2),
+        dict(offset=2, dim1=0, dim2=1),
+        dict(offset=-2, dim1=0, dim2=1),
+    )
+
+    for shape, kwarg in chain(
+        product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)
+    ):
+        yield SampleInput(make_arg(shape), kwargs=kwarg)
+
+
+def error_inputs_diagonal_diag_embed(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    shapes1d = (0, 1, (0,), (1,))
+    shapes2d = ((M, L),)
+    shapes3d = ((M, S, L),)
+
+    kwargs1d = {}
+
+    kwargs2d = (
+        # dim1 == dim2 is not allowed
+        dict(dim1=1, dim2=1),
+        # out of bounds dims are not allowed
+        dict(dim1=10000),
+        dict(dim2=10000),
+    )
+
+    kwargs3d = kwargs2d
+
+    samples1d = product(shapes1d, kwargs1d)
+    samples2d = product(shapes2d, kwargs2d)
+    samples3d = product(shapes3d, kwargs3d)
+
+    for shape, kwargs in chain(samples1d, samples2d, samples3d):
+        arg = make_arg(shape)
+        sample = SampleInput(input=arg, kwargs=kwargs)
+
+        dim1 = kwargs.get("dim1")
+        dim2 = kwargs.get("dim2")
+
+        if "diagonal" in op_info.name:
+            num_dim = arg.dim()
+        elif op_info.name in ("diag_embed", "_refs.diag_embed"):
+            # these are valid inputs for diag_embed
+            if shape in ((0,), (1,)):
+                continue
+            num_dim = arg.dim() + 1
+        else:
+            raise RuntimeError("should be unreachable")
+
+        bound1 = -num_dim
+        bound2 = num_dim - 1
+        dim_range = range(bound1, bound2 + 1)
+        dim1_cond = dim1 and dim1 not in dim_range
+        dim2_cond = dim2 and dim2 not in dim_range
+
+        if dim1 == dim2:
+            err = f"diagonal dimensions cannot be identical {dim1}, {dim2}"
+            yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
+        elif dim1_cond or dim2_cond:
+            err_dim = dim1 if dim1_cond else dim2
+            err = (
+                r"Dimension out of range \(expected to be in range of "
+                rf"\[{bound1}, {bound2}\], but got {err_dim}\)"
+            )
+            yield ErrorInput(sample, error_regex=err, error_type=IndexError)
+        else:
+            raise RuntimeError("should be unreachable")
+
+
+def sample_inputs_linalg_cholesky(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates always positive-definite input for torch.linalg.cholesky using
+    random_hermitian_pd_matrix.
+    The input is generated as the itertools.product of 'batches' and 'ns'.
+    In total this function generates 8 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices,
+        (1, 1) - 1x1 batch of matrices
+    'ns' gives 0x0 and 5x5 matrices.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    """
+    from torch.testing._internal.common_utils import random_hermitian_pd_matrix
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 0]
+    for batch, n, upper in product(batches, ns, [True, False]):
+        a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
+        a.requires_grad = requires_grad
+        yield SampleInput(a, upper=upper)
+
+
+def sample_inputs_linalg_eig(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.eig
+    """
+
+    def out_fn(output):
+        return output[0], abs(output[1])
+
+    samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+    for sample in samples:
+        sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument.
+    """
+
+    def out_fn(output):
+        if isinstance(output, tuple):
+            # eigh function
+            return output[0], abs(output[1])
+        else:
+            # eigvalsh function
+            return output
+
+    # Samples do not need to be Hermitian, as we're using gradcheck_wrapper_hermitian_input
+    samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
+    for sample in samples:
+        # Note: we cannot use np.random.choice here as TorchDynamo
+        # does not support tensors of strings.
+        sample.kwargs = {"UPLO": random.choice(["L", "U"])}
+        sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_pinv(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates input for torch.linalg.pinv with hermitian=False keyword argument.
+    """
+    for o in sample_inputs_linalg_invertible(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        real_dtype = o.input.real.dtype if dtype.is_complex else dtype
+        # requires_grad path for rtol tensor is not implemented
+        for rtol in (None, 1.0, torch.tensor(1.0, dtype=real_dtype, device=device)):
+            o = clone_sample(o)
+            o.kwargs = {"rtol": rtol}
+            yield o
+
+
+def sample_inputs_linalg_pinv_hermitian(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    """
+    This function generates input for torch.linalg.pinv with hermitian=True keyword argument.
+    """
+    for o in sample_inputs_linalg_invertible(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        o.kwargs = {"hermitian": True}
+        yield o
+
+
+def sample_inputs_linalg_solve(
+    op_info, device, dtype, requires_grad=False, vector_rhs_allowed=True, **kwargs
+):
+    """
+    This function generates always solvable input for torch.linalg.solve
+    We sample a fullrank square matrix (i.e. invertible) A
+    The first input to torch.linalg.solve is generated as the itertools.product of 'batches' and 'ns'.
+    The second input is generated as the product of 'batches', 'ns' and 'nrhs'.
+    In total this function generates 18 SampleInputs
+    'batches' cases include:
+        () - single input,
+        (0,) - zero batched dimension,
+        (2,) - batch of two matrices.
+    'ns' gives 0x0 and 5x5 matrices.
+    and 'nrhs' controls the number of vectors to solve for:
+        () - using 1 as the number of vectors implicitly
+        (1,) - same as () but explicit
+        (3,) - solve for 3 vectors.
+    Zeros in dimensions are edge cases in the implementation and important to test for in order to avoid unexpected crashes.
+    'vector_rhs_allowed' controls whether to include nrhs = () to the list of SampleInputs.
+    torch.solve / triangular_solve / cholesky_solve (opposed to torch.linalg.solve) do not allow
+    1D tensors (vectors) as the right-hand-side.
+    Once torch.solve / triangular_solve / cholesky_solve and its testing are removed,
+    'vector_rhs_allowed' may be removed here as well.
+    """
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_a = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    make_b = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (2, 2)]
+    ns = [5, 0]
+    if vector_rhs_allowed:
+        nrhs = [(), (1,), (3,)]
+    else:
+        nrhs = [(1,), (3,)]
+
+    for n, batch, rhs in product(ns, batches, nrhs):
+        yield SampleInput(make_a(*batch, n, n), args=(make_b(batch + (n,) + rhs),))
+
+
+def sample_inputs_linalg_solve_triangular(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    make_arg = partial(make_tensor, dtype=dtype, device=device)
+    bs = (1, 2, 0)
+    ns = (3, 0)
+    ks = (1, 3, 0)
+
+    for b, n, k, (left, upper, uni) in product(
+        bs, ns, ks, product((True, False), repeat=3)
+    ):
+        if b == 1:
+            A = make_arg((n, n)) if left else make_arg((k, k))
+            B = make_arg((n, k))
+        else:
+            A = make_arg((b, n, n)) if left else make_arg((b, k, k))
+            B = make_arg((b, n, k))
+        if uni:
+            # Not really necessary, but writing it for consistency
+            A.diagonal(0, -2, -1).fill_(1.0)
+        else:
+            d = A.diagonal(0, -2, -1)
+            d[d.abs() < 1e-6] = 1.0
+        if upper:
+            A.triu_()
+        else:
+            A.tril_()
+        kwargs = {"upper": upper, "left": left, "unitriangular": uni}
+        if requires_grad:
+            for grad_A, grad_B in product((True, False), repeat=2):
+                # Either A or B needs to have a gradient
+                if not grad_A and not grad_B:
+                    continue
+                yield SampleInput(
+                    A.clone().requires_grad_(grad_A),
+                    args=(B.clone().requires_grad_(grad_B),),
+                    kwargs=kwargs,
+                )
+        else:
+            yield SampleInput(A, args=(B,), kwargs=kwargs)
+
+
+def sample_inputs_legacy_solve(op_info, device, dtype, requires_grad=False, **kwargs):
+    """
+    This function generates always solvable input for legacy solve functions
+    (the ones that are not in torch.linalg module).
+    The difference from sample_inputs_linalg_solve is that here the right-hand-side of A x = b equation
+    should have b.ndim >= 2, vectors are not allowed.
+    Also the arguments order is swapped.
+    """
+    out = sample_inputs_linalg_solve(
+        op_info, device, dtype, requires_grad=requires_grad, vector_rhs_allowed=False
+    )
+
+    def out_fn(output):
+        return output[0]
+
+    # Reverses tensor order
+    for sample in out:
+        sample.input, sample.args = sample.args[0], (sample.input,)
+        if op_info.name == "solve":
+            sample.output_process_fn_grad = out_fn
+        yield sample
+
+
+def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwargs):
+    full_rank = op_info.name == "linalg.lu_factor"
+    make_fn = (
+        make_tensor
+        if not full_rank
+        else make_fullrank_matrices_with_distinct_singular_values
+    )
+    make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)
+
+    def out_fn(output):
+        if op_info.name == "linalg.lu":
+            return output[1], output[2]
+        else:
+            return output
+
+    batch_shapes = ((), (3,), (3, 3))
+    # pivot=False only supported in CUDA
+    pivots = (True, False) if torch.device(device).type == "cuda" else (True,)
+    deltas = (-2, -1, 0, +1, +2)
+    for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas):
+        shape = batch_shape + (S + delta, S)
+        # Insanely annoying that make_fullrank_blablabla accepts a *shape and not a tuple!
+        A = make_arg(shape) if not full_rank else make_arg(*shape)
+        yield SampleInput(A, kwargs={"pivot": pivot}, output_process_fn_grad=out_fn)
+
+
+def sample_inputs_linalg_svdvals(op_info, device, dtype, requires_grad=False, **kwargs):
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 2, 0]
+
+    for batch, m, n in product(batches, ns, ns):
+        yield SampleInput(make_arg(batch + (m, n)))
+
+
+def sample_inputs_linalg_qr_geqrf(
+    op_info, device, dtype, requires_grad=False, **kwargs
+):
+    # QR is just well defined when the matrix is full rank
+    make_fullrank = make_fullrank_matrices_with_distinct_singular_values
+    make_arg = partial(
+        make_fullrank, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+
+    batches = [(), (0,), (2,), (1, 1)]
+    ns = [5, 2, 0]
+
+    for batch, (m, n) in product(batches, product(ns, ns)):
+        shape = batch + (m, n)
+        yield SampleInput(make_arg(*shape))
+
+
+def sample_inputs_tensorsolve(op_info, device, dtype, requires_grad, **kwargs):
+    a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
+    # Zero-dim tensors are not supported in NumPy, so we skip them for now.
+    # NumPy is used in reference check tests.
+    # See https://github.com/numpy/numpy/pull/20482 for tracking NumPy bugfix.
+    # a_shapes += [(0, 0, 1, 2, 3, 0)]
+    dimss = [None, (0, 2)]
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
+    )
+    for a_shape, dims in itertools.product(a_shapes, dimss):
+        a = make_arg(a_shape)
+        b = make_arg(a_shape[:2])
+        yield SampleInput(a, b, dims=dims)
+
+
+def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = make_fullrank_matrices_with_distinct_singular_values
+
+    def make_input():
+        return make_arg(12, 12, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    # lhs / rhs shape can have any number of dimensions as long as their product equals 12
+    shapes = [
+        ((2, 2, 3), (12, 1)),
+        ((4, 3), (6, 1, 2)),
+    ]
+
+    for shape_lhs, shape_rhs in shapes:
+        inp = make_input().reshape(*shape_lhs, *shape_rhs).detach()
+        inp.requires_grad_(requires_grad)
+        yield SampleInput(inp, ind=len(shape_lhs))
+
+
+op_db: list[OpInfo] = [
+    OpInfo(
+        "linalg.cross",
+        ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim),
+        op=torch.linalg.cross,
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        aten_name="linalg_cross",
+        sample_inputs_func=sample_inputs_cross,
+        error_inputs_func=error_inputs_cross,
+        supports_out=True,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.det",
+        aten_name="linalg_det",
+        op=torch.linalg.det,
+        aliases=("det",),
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+        check_batched_gradgrad=False,
+    ),
+    OpInfo(
+        "linalg.diagonal",
+        aten_name="linalg_diagonal",
+        aten_backward_name="diagonal_backward",
+        dtypes=all_types_and_complex_and(
+            torch.bool, torch.bfloat16, torch.float16, torch.chalf
+        ),
+        supports_out=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_diagonal_diag_embed,
+        error_inputs_func=error_inputs_diagonal_diag_embed,
+    ),
+    OpInfo(
+        "linalg.cholesky",
+        aten_name="linalg_cholesky",
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_cholesky,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.cholesky_ex",
+        aten_name="linalg_cholesky_ex",
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_cholesky,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.vecdot",
+        aten_name="linalg_vecdot",
+        ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
+        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_linalg_vecdot,
+        check_batched_forward_grad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestSchemaCheckModeOpInfo",
+                "test_schema_correctness",
+                dtypes=(torch.complex64, torch.complex128),
+            ),
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.half: tol(atol=1.2e-2, rtol=1.7e-2)}),
+                "TestInductorOpInfo",
+                "test_comprehensive",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.cond",
+        aten_name="linalg_cond",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_cond,
+        check_batched_gradgrad=False,
+        check_batched_forward_grad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eig",
+        aten_name="linalg_eig",
+        op=torch.linalg.eig,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eig,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # AssertionError: Scalars are not equal!
+            DecorateInfo(
+                unittest.expectedFailure, "TestCommon", "test_out", device_type="cpu"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
+    ),
+    OpInfo(
+        "linalg.eigvals",
+        aten_name="linalg_eigvals",
+        op=torch.linalg.eigvals,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eigh",
+        aten_name="linalg_eigh",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eigh,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.eigvalsh",
+        aten_name="linalg_eigvalsh",
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_eigh,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        check_batched_forward_grad=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # Pre-existing condition; Needs to be fixed
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.householder_product",
+        aten_name="linalg_householder_product",
+        op=torch.linalg.householder_product,
+        aliases=("orgqr",),
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        # TODO: backward uses in-place operations that vmap doesn't like
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_householder_product,
+        decorators=[
+            skipCUDAIfNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.complex64: tol(atol=1e-3, rtol=1e-3)})
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped! Flaky"),
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cpu",
+                dtypes=(torch.complex128,),
+            ),
+            skipCUDAIfRocm,  # regression in ROCm 6.4
+        ],
+    ),
+    OpInfo(
+        "linalg.ldl_factor",
+        aten_name="linalg_ldl_factor",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_factor,
+        decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.ldl_factor_ex",
+        aten_name="linalg_ldl_factor_ex",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_factor,
+        decorators=[skipCUDAIfNoMagmaAndNoLinalgsolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.ldl_solve",
+        aten_name="linalg_ldl_solve",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_ldl_solve,
+        decorators=[
+            skipCUDAIf(
+                _get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1"
+            ),
+            skipCUDAIfNoCusolver,
+            skipCUDAIfRocm,
+            skipCPUIfNoLapack,
+        ],
+    ),
+    OpInfo(
+        "linalg.lstsq",
+        aten_name="linalg_lstsq",
+        dtypes=floating_and_complex_types(),
+        supports_out=True,
+        sample_inputs_func=sample_inputs_linalg_lstsq,
+        error_inputs_func=error_inputs_lstsq,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # we skip gradient checks for this suite as they are tested in
+            # variant_test_name='grad_oriented'
+            DecorateInfo(unittest.skip("Skipped!"), "TestFwdGradients"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestBwdGradients"),
+            # The values for attribute 'shape' do not match
+            DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.lstsq",
+        aten_name="linalg_lstsq",
+        variant_test_name="grad_oriented",
+        # gradchecks for forward AD fails with full output tuple
+        # works when taking [:2], which is (solution, residuals)
+        op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[:2],
+        supports_out=False,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_lstsq,
+        error_inputs_func=error_inputs_lstsq_grad_oriented,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_autograd=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            # tests do not work with passing lambda for op
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestOperatorSignatures",
+                "test_get_torch_func_signature_exhaustive",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_power",
+        aliases=("matrix_power",),
+        aten_name="linalg_matrix_power",
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_inplace_autograd=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        check_batched_grad=False,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=8e-5, rtol=2e-6)}),
+                "TestConsistency",
+                "test_output_grad_match",
+                device_type="mps",
+            ),
+        ),
+        sample_inputs_func=sample_inputs_linalg_matrix_power,
+    ),
+    OpInfo(
+        "linalg.multi_dot",
+        # Need this lambda because gradcheck does not work with TensorList inputs
+        aten_name="linalg_multi_dot",
+        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        supports_inplace_autograd=False,
+        # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407)
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # https://github.com/pytorch/pytorch/issues/66357
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_multi_dot,
+        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
+        skips=(
+            # https://github.com/pytorch/pytorch/issues/67470
+            DecorateInfo(
+                unittest.skip("67470!"), "TestCommon", "test_noncontiguous_samples"
+            ),
+            # Fails on XLA.
+            # AssertionError: False is not true : Tensors failed to compare as equal!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestOpInfo",
+                device_type="xla",
+                dtypes=(torch.long,),
+            ),
+            # https://github.com/pytorch/pytorch/issues/71774
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestNNCOpInfo",
+                "test_nnc_correctness",
+                device_type="cpu",
+                dtypes=(torch.long,),
+            ),
+        ),
+    ),
+    # NB: linalg.norm has two variants so that different skips can be used for different sample inputs
+    OpInfo(
+        "linalg.norm",
+        aten_name="linalg_norm",
+        op=torch.linalg.norm,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=sample_inputs_linalg_norm,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.norm",
+        op=torch.linalg.norm,
+        variant_test_name="subgradients_at_zero",
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=partial(
+            sample_inputs_linalg_norm, variant="subgradient_at_zero"
+        ),
+        aten_name="linalg_norm",
+        supports_forward_ad=True,
+        # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
+        # Could not allocate memory to change Tensor SizesAndStrides!
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # [NEW] Skips specifically for sample inputs at zero
+            # norm's vjp/jvp are not well-conditioned near zero
+            DecorateInfo(
+                unittest.expectedFailure, "TestBwdGradients", "test_fn_gradgrad"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestFwdGradients", "test_fn_fwgrad_bwgrad"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure, "TestFwdGradients", "test_forward_mode_AD"
+            ),
+            DecorateInfo(unittest.expectedFailure, "TestBwdGradients", "test_fn_grad"),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_norm",
+        aten_name="linalg_matrix_norm",
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        check_batched_gradgrad=False,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        sample_inputs_func=sample_inputs_linalg_matrix_norm,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.qr",
+        aten_name="linalg_qr",
+        op=torch.linalg.qr,
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # In-place ops
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_linalg_qr_geqrf,
+        decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.slogdet",
+        aten_name="linalg_slogdet",
+        op=torch.linalg.slogdet,
+        dtypes=floating_and_complex_types(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+    ),
+    OpInfo(
+        "linalg.vander",
+        aten_name="linalg_vander",
+        ref=np_vander_batched,
+        op=torch.linalg.vander,
+        dtypes=all_types_and_complex(),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_out=False,
+        sample_inputs_func=sample_inputs_linalg_vander,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    ReductionOpInfo(
+        "linalg.vector_norm",
+        op=torch.linalg.vector_norm,
+        identity=0,
+        nan_policy="propagate",
+        supports_multiple_dims=True,
+        complex_to_real=True,
+        supports_forward_ad=True,
+        # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
+        # got: Could not allocate memory to change Tensor SizesAndStrides!
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        generate_args_kwargs=sample_kwargs_vector_norm,
+        aten_name="linalg_vector_norm",
+    ),
+    OpInfo(
+        "linalg.lu_factor",
+        aten_name="linalg_lu_factor",
+        op=torch.linalg.lu_factor,
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu_factor_ex",
+        aten_name="linalg_lu_factor_ex",
+        op=torch.linalg.lu_factor_ex,
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu",
+        aten_name="linalg_lu",
+        op=torch.linalg.lu,
+        dtypes=floating_and_complex_types(),
+        # https://github.com/pytorch/pytorch/issues/80411
+        # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_lu,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # linalg.lu_factor: LU without pivoting is not implemented on the CPU
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+    OpInfo(
+        "linalg.lu_solve",
+        op=torch.linalg.lu_solve,
+        aten_name="linalg_lu_solve",
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_lu_solve,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Tests different backward paths"),
+                "TestCommon",
+                "test_floating_inputs_are_differentiable",
+            ),
+        ),
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+    ),
+    OpInfo(
+        "linalg.inv",
+        aten_name="linalg_inv",
+        op=torch.linalg.inv,
+        aliases=("inverse",),
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.inv_ex",
+        aten_name="linalg_inv_ex",
+        op=torch.linalg.inv_ex,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_invertible,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve",
+        aten_name="linalg_solve",
+        op=torch.linalg.solve,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve,
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve_ex",
+        aten_name="linalg_solve_ex",
+        op=torch.linalg.solve_ex,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.solve_triangular",
+        aten_name="linalg_solve_triangular",
+        op=torch.linalg.solve_triangular,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_linalg_solve_triangular,
+        supports_fwgrad_bwgrad=True,
+        skips=(skipCPUIfNoLapack,),
+        # linalg.solve_triangular cannot be batched over because of a call to out.copy_(result);
+        supports_forward_ad=True,
+    ),
+    OpInfo(
+        "linalg.matrix_rank",
+        aten_name="linalg_matrix_rank",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_matrix_rank,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            # jit doesn't accept tensor inputs for matrix rank
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                dtypes=[torch.complex64, torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.matrix_rank",
+        aten_name="linalg_matrix_rank",
+        variant_test_name="hermitian",
+        dtypes=floating_and_complex_types(),
+        supports_autograd=False,
+        sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        op=torch.linalg.pinv,
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_pinv,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            # errors with "leaked XXXX bytes CUDA memory on device 0"
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        variant_test_name="singular",
+        # pinv is Frechet-differentiable in a rank-preserving neighborhood,
+        # so we feed inputs that are the products of two full-rank factors,
+        # to avoid any rank changes caused by the perturbations in the gradcheck
+        op=lambda a, b: torch.linalg.pinv(a @ b.mT),
+        dtypes=floating_and_complex_types(),
+        supports_out=False,
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_linalg_pinv_singular,
+        # Only large tensors show issues with implicit backward used prior to
+        # explicit backward implementation.
+        decorators=[slowTest, skipCUDAIfNoCusolver, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # CUDA runs out of memory
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cuda",
+                dtypes=[torch.cdouble],
+            ),
+            # This test takes almost 2 hours to run!
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestBwdGradients",
+                "test_fn_gradgrad",
+                device_type="cuda",
+                dtypes=[torch.cdouble],
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.pinv",
+        aten_name="linalg_pinv",
+        variant_test_name="hermitian",
+        dtypes=floating_and_complex_types(),
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
+        gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cuda",
+            ),
+            # This test is flaky under slow gradcheck, likely due to rounding issues
+            DecorateInfo(
+                skipIfSlowGradcheckEnv,
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="cuda",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.svd",
+        op=torch.linalg.svd,
+        aten_name="linalg_svd",
+        decomp_aten_name="_linalg_svd",
+        dtypes=floating_and_complex_types(),
+        # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
+        gradcheck_fast_mode=True,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        check_batched_forward_grad=False,
+        # We're using at::allclose, which does not have a batching rule
+        check_batched_grad=False,
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_svd,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_out",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="mps",
+                dtypes=[torch.float32],
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.svdvals",
+        op=torch.linalg.svdvals,
+        aten_name="linalg_svdvals",
+        decomp_aten_name="_linalg_svd",
+        dtypes=floating_and_complex_types(),
+        check_batched_forward_grad=False,
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+        # We're using at::allclose, which does not have a batching rule
+        check_batched_gradgrad=False,
+        sample_inputs_func=sample_inputs_linalg_svdvals,
+        decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestFakeTensor",
+                "test_fake_crossref_backward_no_amp",
+                device_type="cuda",
+                dtypes=[torch.float32],
+                active_if=TEST_WITH_ROCM,
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.tensorinv",
+        ref=np.linalg.tensorinv,
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_tensorinv,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        # See https://github.com/pytorch/pytorch/pull/78358
+        check_batched_forward_grad=False,
+        decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    OpInfo(
+        "linalg.tensorsolve",
+        ref=lambda a, b, dims=None: np.linalg.tensorsolve(a, b, axes=dims),
+        dtypes=floating_and_complex_types(),
+        sample_inputs_func=sample_inputs_tensorsolve,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=[
+            skipCUDAIfNoMagmaAndNoCusolver,
+            skipCPUIfNoLapack,
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cuda",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=8e-04, rtol=7e-06)}),
+                "TestCommon",
+                "test_noncontiguous_samples",
+                device_type="cpu",
+            ),
+            DecorateInfo(
+                toleranceOverride({torch.float32: tol(atol=2e-04, rtol=3e-06)}),
+                "TestConsistency",
+                "test_output_match",
+                device_type="mps",
+            ),
+        ],
+        skips=(
+            DecorateInfo(
+                unittest.skip("Unsupported on MPS for now"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    #
+    # torch.linalg
+    #
+    PythonRefInfo(
+        "_refs.linalg.cross",
+        torch_opinfo_name="linalg.cross",
+        supports_out=True,
+        op_db=op_db,
+        skips=(
+            # TODO: is this really needed?
+            DecorateInfo(
+                unittest.expectedFailure, "TestCommon", "test_python_ref_errors"
+            ),
+        ),
+    ),
+    PythonRefInfo(
+        "_refs.linalg.diagonal",
+        torch_opinfo_name="linalg.diagonal",
+        supports_out=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.vecdot",
+        torch_opinfo_name="linalg.vecdot",
+        op_db=op_db,
+    ),
+    ReductionPythonRefInfo(
+        "_refs.linalg.vector_norm",
+        torch_opinfo_name="linalg.vector_norm",
+        supports_out=True,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.matrix_norm",
+        torch_opinfo_name="linalg.matrix_norm",
+        supports_out=True,
+        # Uses vector_norm inside and vector_norm is affected by
+        # https://github.com/pytorch/pytorch/issues/77216
+        validate_view_consistency=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.norm",
+        torch_opinfo_name="linalg.norm",
+        supports_out=True,
+        # Uses vector_norm inside and vector_norm is affected by
+        # https://github.com/pytorch/pytorch/issues/77216
+        validate_view_consistency=False,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.svd",
+        torch_opinfo_name="linalg.svd",
+        supports_out=True,
+        op_db=op_db,
+    ),
+    PythonRefInfo(
+        "_refs.linalg.svdvals",
+        torch_opinfo_name="linalg.svdvals",
+        supports_out=True,
+        op_db=op_db,
+    ),
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py
new file mode 100644
index 0000000000000000000000000000000000000000..793cbbf89c1b4780f91dd4fcbc6e7a2f8005a4a0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/nested.py
@@ -0,0 +1,1593 @@
+# mypy: ignore-errors
+
+import math
+from copy import copy
+from dataclasses import dataclass
+from functools import partial
+from typing import Optional
+
+import torch
+from torch.fx.experimental.symbolic_shapes import is_nested_int
+from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.opinfo.core import (
+    BinaryUfuncInfo,
+    ReductionOpInfo,
+    SampleInput,
+    UnaryUfuncInfo,
+)
+from torch.utils._pytree import tree_flatten, tree_map
+
+
+@dataclass
+class ExtraOpData:
+    """
+    Contains info on top of the typical OpInfo data that is useful for NJT test generation.
+
+    The process that converts the standard op_db -> an NJT-compatible op_db will attach this
+    data onto each associated OpInfo entry.
+    """
+
+    # Indicates whether the associated op is a view op
+    is_view: bool = False
+
+    # Specifies the names of any dim-related args that the op takes in. This is useful
+    # for NJT tests because there is often asymmetry across the supported set of dims for
+    # an op; it may make sense to operate over the batch dim but not the ragged dim, for
+    # example. The length of this list should match the number of relevant overloads.
+    # Each list item of the outer list should specify dim argnames. Ellipses should be used
+    # to indicate multi-dim support for a given overload.
+    #
+    # For example, squeeze() has both a dim and multi-dim overload, where the argname for
+    # each is simply "dim". Its entry should be: [["dim"], ["dim..."]].
+    #
+    # If no overload of the op accepts dim-related args, this should be None.
+    dim_args: list[list[str]] = None
+
+    # Helper function to extract names of dim-related args.
+    # Returns: tuple of (single dim argname if available, dim list argname if available)
+    # If the op doesn't support dim-related args at all OR this op only has overloads
+    # with multiple dim args (e.g. transpose()), then this returns (None, None).
+    def get_dim_argnames(self) -> tuple[Optional[str], Optional[str]]:
+        if self.dim_args is None:
+            return (None, None)
+
+        # name for the dim arg that supports a single dim
+        single_dim_argname = None
+        # name for the dim arg that supports a list of dims
+        dimlist_argname = None
+        for overload in self.dim_args:
+            # only consider overloads with a single dim-related arg
+            if len(overload) != 1:
+                continue
+            if overload[0].endswith("..."):
+                dimlist_argname = overload[0].replace("...", "")
+                if single_dim_argname is None:
+                    single_dim_argname = dimlist_argname
+            else:
+                single_dim_argname = overload[0]
+        return (single_dim_argname, dimlist_argname)
+
+
+# Mapping of OpInfo full names -> extra data to tack onto the OpInfo entry for use
+# in test generation.
+extra_op_data = {
+    "_segment_reduce.lengths": ExtraOpData(dim_args=[["axis0"]]),
+    "_segment_reduce.offsets": ExtraOpData(dim_args=[["axis0"]]),
+    "all": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "argmax": ExtraOpData(dim_args=[["dim"]]),
+    "argmin": ExtraOpData(dim_args=[["dim"]]),
+    "amax": ExtraOpData(dim_args=[["dim..."]]),
+    "amin": ExtraOpData(dim_args=[["dim..."]]),
+    "any": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "argsort": ExtraOpData(dim_args=[["dim"]]),
+    "broadcast_to": ExtraOpData(is_view=True),
+    "cat": ExtraOpData(dim_args=[["dim"]]),
+    "chunk": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "conj": ExtraOpData(is_view=True),
+    "contiguous": ExtraOpData(is_view=True),
+    "count_nonzero": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "cummax": ExtraOpData(dim_args=[["dim"]]),
+    "cummin": ExtraOpData(dim_args=[["dim"]]),
+    "cumprod": ExtraOpData(dim_args=[["dim"]]),
+    "cumsum": ExtraOpData(dim_args=[["dim"]]),
+    "cumulative_trapezoid": ExtraOpData(dim_args=[["dim"]]),
+    "diag_embed": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
+    "diagonal_copy": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diagonal_scatter": ExtraOpData(dim_args=[["dim1", "dim2"]]),
+    "diff": ExtraOpData(dim_args=[["dim"]]),
+    "expand": ExtraOpData(is_view=True),
+    "expand_as": ExtraOpData(is_view=True),
+    "fft.fft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.hfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.ifft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.ihfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.irfft": ExtraOpData(dim_args=[["dim"]]),
+    "fft.rfft": ExtraOpData(dim_args=[["dim"]]),
+    "flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]),
+    "flip": ExtraOpData(dim_args=[["dims..."]]),
+    "gather": ExtraOpData(dim_args=[["dim"]]),
+    "imag": ExtraOpData(is_view=True),
+    "index_add": ExtraOpData(dim_args=[["dim"]]),
+    "index_copy": ExtraOpData(dim_args=[["dim"]]),
+    "index_fill": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
+    "index_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
+    "index_select": ExtraOpData(dim_args=[["dim"]]),
+    "kthvalue": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.cross": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
+    "linalg.tensorsolve": ExtraOpData(dim_args=[["dims..."]]),
+    "linalg.vecdot": ExtraOpData(dim_args=[["dim"]]),
+    "linalg.vector_norm": ExtraOpData(dim_args=[["dim..."]]),
+    "log_softmax": ExtraOpData(dim_args=[["dim"]]),
+    "logcumsumexp": ExtraOpData(dim_args=[["dim"]]),
+    "masked.amax": ExtraOpData(dim_args=[["dim"]]),
+    "masked.amin": ExtraOpData(dim_args=[["dim"]]),
+    "masked.argmax": ExtraOpData(dim_args=[["dim"]]),
+    "masked.argmin": ExtraOpData(dim_args=[["dim"]]),
+    "masked.logsumexp": ExtraOpData(dim_args=[["dim"]]),
+    "masked.mean": ExtraOpData(dim_args=[["dim"]]),
+    "masked.norm": ExtraOpData(dim_args=[["dim"]]),
+    "masked.prod": ExtraOpData(dim_args=[["dim"]]),
+    "masked.std": ExtraOpData(dim_args=[["dim"]]),
+    "masked.sum": ExtraOpData(dim_args=[["dim"]]),
+    "masked.var": ExtraOpData(dim_args=[["dim"]]),
+    "max.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
+    "median": ExtraOpData(dim_args=[["dim"]]),
+    "mean": ExtraOpData(dim_args=[["dim..."]]),
+    "min.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
+    "mode": ExtraOpData(dim_args=[["dim"]]),
+    "movedim": ExtraOpData(
+        dim_args=[["source", "destination"], ["source...", "destination..."]]
+    ),
+    "nanmean": ExtraOpData(dim_args=[["dim..."]]),
+    "nanmedian": ExtraOpData(dim_args=[["dim"]]),
+    "nansum": ExtraOpData(dim_args=[["dim..."]]),
+    "narrow": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "narrow_copy": ExtraOpData(dim_args=[["dim"]]),
+    "nn.functional.cosine_similarity": ExtraOpData(dim_args=[["dim"]]),
+    "nn.functional.glu": ExtraOpData(dim_args=[["dim"]]),
+    "permute": ExtraOpData(is_view=True, dim_args=[["dims..."]]),
+    "positive": ExtraOpData(is_view=True),
+    "prod": ExtraOpData(dim_args=[["dim"]]),
+    "ravel": ExtraOpData(is_view=True),
+    "real": ExtraOpData(is_view=True),
+    "renorm": ExtraOpData(dim_args=[["dim"]]),
+    "reshape": ExtraOpData(is_view=True),
+    "reshape_as": ExtraOpData(is_view=True),
+    "roll": ExtraOpData(dim_args=[["dims..."]]),
+    "rot90": ExtraOpData(dim_args=[["dims..."]]),
+    "scatter": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_add": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
+    "scatter_reduce.sum": ExtraOpData(dim_args=[["dim"]]),
+    "select": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "select_scatter": ExtraOpData(dim_args=[["dim"]]),
+    "slice": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "slice_scatter": ExtraOpData(dim_args=[["dim"]]),
+    "softmax": ExtraOpData(dim_args=[["dim"]]),
+    "sort": ExtraOpData(dim_args=[["dim"]]),
+    "split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "split_with_sizes": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "split_with_sizes_copy": ExtraOpData(dim_args=[["dim"]]),
+    "squeeze": ExtraOpData(is_view=True, dim_args=[["dim"], ["dim..."]]),
+    "squeeze_copy": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
+    "stack": ExtraOpData(dim_args=[["dim"]]),
+    "std": ExtraOpData(dim_args=[["dim..."]]),
+    "std.unbiased": ExtraOpData(dim_args=[["dim..."]]),
+    "sum": ExtraOpData(dim_args=[["dim..."]]),
+    "t": ExtraOpData(is_view=True),
+    "tensor_split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "tensordot": ExtraOpData(dim_args=[["dims..."]]),
+    "tile": ExtraOpData(dim_args=[["dims..."]]),
+    "topk": ExtraOpData(dim_args=[["dim"]]),
+    "transpose": ExtraOpData(is_view=True, dim_args=[["dim0", "dim1"]]),
+    "transpose_copy": ExtraOpData(dim_args=[["dim0", "dim1"]]),
+    "trapezoid": ExtraOpData(dim_args=[["dim"]]),
+    "trapz": ExtraOpData(dim_args=[["dim"]]),
+    "unbind": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unflatten": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unfold": ExtraOpData(is_view=True, dim_args=[["dimension"]]),
+    "unfold_copy": ExtraOpData(dim_args=[["dimension"]]),
+    "unsafe_chunk": ExtraOpData(dim_args=[["dim"]]),
+    "unsafe_split": ExtraOpData(dim_args=[["dim"]]),
+    "unsqueeze": ExtraOpData(is_view=True, dim_args=[["dim"]]),
+    "unsqueeze_copy": ExtraOpData(dim_args=[["dim"]]),
+    "var": ExtraOpData(dim_args=[["dim..."]]),
+    "var.unbiased": ExtraOpData(dim_args=[["dim..."]]),
+    "view": ExtraOpData(is_view=True),
+    "view_as": ExtraOpData(is_view=True),
+    "view_as_complex": ExtraOpData(is_view=True),
+    "view_as_real": ExtraOpData(is_view=True),
+}
+
+
+# random integer used for sizes
+def _rnd():
+    return torch.randint(3, 8, ()).item()
+
+
+def _raggedness_matches(nt1, nt2):
+    return (
+        nt1.is_nested
+        and nt2.is_nested
+        and nt1._ragged_idx == nt2._ragged_idx
+        and nt1.shape[nt1._ragged_idx] == nt2.shape[nt2._ragged_idx]
+    )
+
+
+# Helper function to avoid reusing the exact same tensor / NJT across SampleInputs,
+# as this causes autograd problems.
+def _clone(t):
+    requires_grad = t.requires_grad
+    return t.detach().clone().requires_grad_(requires_grad)
+
+
+# Helper function to update a sample with new kwargs / name
+def _update_sample(sample, new_kwargs):
+    all_kwargs = dict(sample.kwargs)
+    all_kwargs.update(new_kwargs)
+    full_name = ", ".join([sample.name, *(f"{k}={v}" for (k, v) in new_kwargs.items())])
+    return SampleInput(
+        _clone(sample.input),
+        args=sample.args,
+        kwargs=all_kwargs,
+        name=full_name,
+    )
+
+
+# Generates a random NT.
+# dims should be something like [5, None, 10], with None indicating that a
+# random ragged structure should be used
+def random_nt_from_dims(
+    dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
+):
+    sizes = [[d if d is not None else _rnd() for d in dims[1:]] for d in range(dims[0])]
+    return torch.nested.nested_tensor(
+        [torch.randn(*size) for size in sizes],
+        device=device,
+        dtype=dtype,
+        layout=layout,
+        requires_grad=requires_grad,
+    )
+
+
+# Helper function to get a reasonable string representation of an NJT for use in
+# SampleInput names.
+def _describe_njt(njt) -> str:
+    contig_type = "_contig" if njt.is_contiguous() else "_noncontig"
+    if njt._lengths is not None and njt._offsets is not None:
+        contig_type += "_holes"
+    elif njt._ragged_idx != 1:
+        contig_type += "_transposed"
+
+    cached_data = "_without_seqlen_cache"
+    if njt._max_seqlen_tensor is not None:
+        cached_data = "_with_seqlen_cache"
+
+    return f"{njt.dim()}D{contig_type}{cached_data}"
+
+
+# Helper function to get a reasonable string representation of a given dim wrt an NJT.
+def _describe_dim(njt, dim):
+    if dim == 0:
+        return "batch_dim"
+    elif dim == njt._ragged_idx:
+        return "ragged_dim"
+    return "normal_dim"
+
+
+# Helper function for generating a comprehensive set of NJT sample inputs.
+def _sample_njts(device, dtype, requires_grad=False, dims=None):
+    if dims is None:
+        dims = [2, 3, 4]
+    if not isinstance(dims, (list, tuple)):
+        dims = [dims]
+
+    # contiguous NJTs
+    for dim in dims:
+        # with min / max seqlen cached
+        shape = (_rnd(), None, *[_rnd() for _ in range(dim - 2)])
+        nt = random_nt_from_dims(
+            shape,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            layout=torch.jagged,
+        )
+        yield nt
+
+        # without min / max seqlen cached
+        values = _clone(nt.values())
+        offsets = _clone(nt.offsets())
+        yield torch.nested.nested_tensor_from_jagged(values, offsets).requires_grad_(
+            requires_grad
+        )
+
+        # non-contiguous transposed NJT (not possible for 2D)
+        if dim > 2:
+            yield nt.transpose(-1, nt._ragged_idx)
+
+        # non-contiguous with holes NJT
+        values = _clone(nt.values())
+        offsets = _clone(nt.offsets())
+        # subtract 1 to cause holes
+        lengths = _clone(offsets.diff() - 1)
+        yield torch.nested.nested_tensor_from_jagged(
+            values=values,
+            offsets=offsets,
+            lengths=lengths,
+        ).requires_grad_(requires_grad)
+
+
+# Computes an unbind-based reference for a given OpInfo on a given SampleInput.
+# This reference unbinds the input NJT and invokes the op on each of the components,
+# optionally wrapping the result in an NJT.
+def unbind_reference(op, sample, wrap_output_as_njt=True):
+    # first NJT in the arglist determines expected ragged structure
+    nt_inp = (
+        sample.input
+        if sample.input.is_nested
+        # TODO: look in kwargs too?
+        else next(a for a in sample.args if a.is_nested)
+    )
+
+    out_ref_components = []
+    for i in range(nt_inp.shape[0]):
+
+        def _slice_input(t, i=i, inp=nt_inp):
+            # any NJT with the same ragged structure as the input should
+            # be sliced to pass to the reference
+            if isinstance(t, torch.Tensor) and _raggedness_matches(t, inp):
+                return t[i]
+            # allow the SampleInput to tell us how to slice it for ref calculation
+            elif isinstance(t, torch.Tensor) and hasattr(t, "_batch_dim"):
+                bdim = t._batch_dim  # type: ignore[attr]
+                if t.shape[bdim] == 1:
+                    return t[0]
+                else:
+                    return t.select(bdim, i)
+            else:
+                return t
+
+        inp = _slice_input(sample.input)
+        args = tree_map(_slice_input, sample.args)
+        kwargs = tree_map(_slice_input, sample.kwargs)
+
+        # Handle indices in index_put
+        if "index_put" in op.full_name and "indices" in kwargs:
+            if len(kwargs["indices"]) > 1:
+                # If after unrolling we still have indices left, use them
+                kwargs["indices"] = [t[i] for t in kwargs["indices"][1:]]
+            else:
+                # If no indices are left, create them so they match the NJT implementation
+                sequence_put = kwargs["indices"][0].tolist()
+                if i in sequence_put:
+                    kwargs["indices"] = [
+                        torch.tensor(
+                            list(range(inp.shape[0])),
+                            dtype=torch.int32,
+                            device=kwargs["indices"][0].device,
+                        )
+                    ]
+                else:
+                    kwargs["indices"] = [
+                        torch.tensor(
+                            [], dtype=torch.int32, device=kwargs["indices"][0].device
+                        )
+                    ]
+
+        from torch.nested._internal.ops import _outer_to_inner_dim
+
+        # Need to adjust dims to apply on NJT component
+        if op._extra_op_data.dim_args is not None:
+            # get all possible dim-related argnames that could be encountered for this op
+            argnames = tree_map(
+                lambda a: a.replace("...", ""),
+                tree_flatten(op._extra_op_data.dim_args)[0],
+            )
+            # for all dim-related args present, convert from outer -> inner dim space
+            for argname in {a for a in argnames if a in kwargs}:
+                # allow the SampleInput to tell us how to canonicalize the dim kwargs
+                ndim = nt_inp._ndim if hasattr(nt_inp, "_ndim") else nt_inp.dim()
+                kwargs[argname] = _outer_to_inner_dim(
+                    ndim, kwargs[argname], nt_inp._ragged_idx, canonicalize=True
+                )
+
+        out_ref_component = op.op(inp, *args, **kwargs)
+        out_ref_components.append(out_ref_component)
+
+    if wrap_output_as_njt:
+        # handle list / tuple of outputs
+        if len(out_ref_components) > 0 and isinstance(
+            out_ref_components[0], (list, tuple)
+        ):
+            num_returns = len(out_ref_components[0])
+            # ensure we get the same number of returns for each invocation
+            assert all(len(o) == num_returns for o in out_ref_components)
+            # construct NJTs from same index returns from each invocation
+            njt_returns = [
+                torch.nested.as_nested_tensor(
+                    [o[r] for o in out_ref_components], layout=torch.jagged
+                )
+                for r in range(num_returns)
+            ]
+            return type(out_ref_components[0])(njt_returns)
+        return torch.nested.as_nested_tensor(out_ref_components, layout=torch.jagged)
+
+    return out_ref_components
+
+
+# Computes the reference value for a non-reduction unary op with dim-wise application.
+def unary_dimwise_reference(op, sample, batchwise_reference=None):
+    # extract info about the dim args this op supports
+    assert op._extra_op_data.dim_args is not None
+    single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
+    # only support a single non-list dim arg for now
+    assert dimlist_argname is None
+    assert single_dim_argname is not None
+    if sample.kwargs[single_dim_argname] == 0:
+        # unbind reference won't work for batch-wise operation; handle this case here
+        assert batchwise_reference is not None
+        return batchwise_reference(op, sample)
+    return unbind_reference(op, sample)
+
+
+# Computes the reference value for a reduction op.
+def reduction_reference(op, sample):
+    assert sample.input.is_nested
+
+    # extract info about the dim args this op supports
+    assert op._extra_op_data.dim_args is not None
+    single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+
+    dim = sample.kwargs.get(
+        dimlist_argname, sample.kwargs.get(single_dim_argname, None)
+    )
+    keepdim = sample.kwargs.get("keepdim", False)
+    assert dim != 0, "reductions over just the batch dim are not supported"
+    if isinstance(dim, (tuple, list)):
+        reduce_on_ragged = sample.input._ragged_idx in dim
+        reduce_on_batch = 0 in dim
+    else:
+        reduce_on_ragged = sample.input._ragged_idx == dim
+        reduce_on_batch = dim == 0
+
+    if dim is None:
+        # calculate reference value by running reduction on values buffer
+        return op.op(sample.input.values(), *sample.args, **sample.kwargs)
+
+    if reduce_on_ragged and reduce_on_batch:
+        # run reference directly on buffer with dims converted to inner space
+        from torch.nested._internal.ops import _outer_to_inner_dim
+
+        ref_kwargs = dict(sample.kwargs)
+        assert dimlist_argname is not None
+        ref_kwargs[dimlist_argname] = _outer_to_inner_dim(
+            sample.input.dim(), dim, sample.input._ragged_idx, canonicalize=True
+        )
+        out = op.op(sample.input.values(), *sample.args, **ref_kwargs)
+        if keepdim:
+            if isinstance(out, (tuple, list)):
+                # some ops return multiple things; unsqueeze all of them
+                out = type(out)(o.unsqueeze(0) for o in out)
+            else:
+                out = out.unsqueeze(0)
+        return out
+
+    if reduce_on_ragged and not reduce_on_batch:
+        # calculate reference value by running an unbind reference and stacking
+        out_ref_components = unbind_reference(op, sample, wrap_output_as_njt=False)
+        if len(out_ref_components) > 0 and isinstance(
+            out_ref_components[0], (tuple, list)
+        ):
+            # some ops return multiple things; stack all of them
+            num_returns = len(out_ref_components[0])
+            # ensure we get the same number of returns for each invocation
+            assert all(len(o) == num_returns for o in out_ref_components)
+            # stack same index returns from each invocation
+            stacked_returns = [
+                torch.stack([o[r] for o in out_ref_components], dim=0)
+                for r in range(num_returns)
+            ]
+            return type(out_ref_components[0])(stacked_returns)
+        return torch.stack(out_ref_components, dim=0)
+
+    # unbind reference works for other reductions
+    return unbind_reference(op, sample)
+
+
+def sample_inputs_elementwise_njt_unary(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        yield SampleInput(njt, kwargs=dict(op_kwargs), name=_describe_njt(njt))
+
+
+def sample_inputs_elementwise_njt_binary(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    for njt1 in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        njt_desc = _describe_njt(njt1)
+        njt2 = torch.randn_like(njt1)
+        yield SampleInput(
+            _clone(njt1),
+            args=(njt2,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, NT)",
+        )
+
+        # broadcasting case: (B, j0, ...) with (B, 1, ...)
+        dense_shape = list(njt1.shape)
+        dense_shape[njt1._ragged_idx] = 1
+        t = torch.randn(
+            dense_shape,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        t2 = _clone(t)
+        # used for slicing in unbind_reference()
+        t._batch_dim = 0
+        t2._batch_dim = 0
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(t,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting 1 over ragged",
+        )
+        # (T, NT)
+        yield SampleInput(
+            t2,
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting 1 over ragged",
+        )
+
+        # broadcasting case: (B, j0, ...) with (1, 1...)
+        t = torch.randn(
+            [1 for _ in range(njt1.dim())],
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        t2 = _clone(t)
+        # used for slicing in unbind_reference()
+        t._batch_dim = 0
+        t2._batch_dim = 0
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(t,),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting all 1s",
+        )
+        # (T, NT)
+        yield SampleInput(
+            t2,
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting all 1s",
+        )
+
+        # broadcasting case: (B, j0, ...) with (...)
+        if njt1.dim() > njt1._ragged_idx + 1:
+            t = torch.randn(
+                njt1.shape[njt1._ragged_idx + 1 :],
+                device=device,
+                dtype=dtype,
+                requires_grad=requires_grad,
+            )
+            # (NT, T)
+            yield SampleInput(
+                _clone(njt1),
+                args=(_clone(t),),
+                kwargs=dict(op_kwargs),
+                name=f"{njt_desc}: (NT, T) broadcasting normal dims",
+            )
+            # (T, NT)
+            yield SampleInput(
+                _clone(t),
+                args=(_clone(njt1),),
+                kwargs=dict(op_kwargs),
+                name=f"{njt_desc}: (T, NT) broadcasting normal dims",
+            )
+
+        # broadcasting case: (B, j0, ...) with scalar
+        t = torch.randn((), device=device, dtype=dtype, requires_grad=requires_grad)
+        # (NT, T)
+        yield SampleInput(
+            _clone(njt1),
+            args=(_clone(t),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (NT, T) broadcasting with scalar",
+        )
+        # (T, NT)
+        yield SampleInput(
+            _clone(t),
+            args=(_clone(njt1),),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: (T, NT) broadcasting with scalar",
+        )
+
+    # mixed broadcasting case: (B, j0, 1) with (B, 1, D)
+    B = 4
+    D = 16
+    njt = random_nt_from_dims(
+        (B, None, 1),
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        layout=torch.jagged,
+    )
+    njt_desc = _describe_njt(njt)
+    t = torch.randn(B, 1, D, device=device, dtype=dtype, requires_grad=requires_grad)
+    t2 = _clone(t)
+    # used for slicing in unbind_reference()
+    t._batch_dim = 0
+    t2._batch_dim = 0
+
+    # (NT, T)
+    yield SampleInput(
+        _clone(njt),
+        args=(t,),
+        kwargs=dict(op_kwargs),
+        name=f"{njt_desc}: (NT, T) mixed broadcasting",
+    )
+    # (T, NT)
+    yield SampleInput(
+        t2,
+        args=(_clone(njt),),
+        kwargs=dict(op_kwargs),
+        name=f"{njt_desc}: (T, NT) mixed broadcasting",
+    )
+
+
+def sample_inputs_njt_reduction(
+    op_info,
+    device,
+    dtype,
+    requires_grad,
+    supports_keepdim=True,
+    op_kwargs=None,
+    **kwargs,
+):
+    if not op_kwargs:
+        op_kwargs = {}
+
+    # extract info about the dim args this op supports
+    assert op_info._extra_op_data.dim_args is not None
+    (
+        single_dim_argname,
+        dimlist_argname,
+    ) = op_info._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+    supports_dimlist = dimlist_argname is not None
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        njt_desc = _describe_njt(njt)
+        keepdim_values = [False, True] if supports_keepdim else [None]
+        for keepdim in keepdim_values:
+            keepdim_suffix = f" with keepdim={keepdim}" if supports_keepdim else ""
+            # single dim-wise reduction; includes reduction over the ragged dim
+            # NB: reduction over the batch dim is not supported!
+            # TODO: Cover this in the set of error inputs
+            for dim in range(1, njt.dim()):
+                dim_desc = "normal" if dim != njt._ragged_idx else "ragged"
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        single_dim_argname: dim,
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: {dim_desc} dim reduction{keepdim_suffix}",
+                )
+
+            if supports_dimlist:
+                # reduce on both batch and ragged dims
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        dimlist_argname: [0, njt._ragged_idx],
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: batch+ragged reduction{keepdim_suffix}",
+                )
+
+                # reduce on batch, ragged, and other dims
+                for other_dim in range(njt._ragged_idx + 1, njt.dim()):
+                    yield SampleInput(
+                        _clone(njt),
+                        kwargs={
+                            **op_kwargs,
+                            dimlist_argname: [0, njt._ragged_idx, other_dim],
+                            **({"keepdim": keepdim} if supports_keepdim else {}),
+                        },
+                        name=(
+                            f"{njt_desc}: batch+ragged+dim={other_dim} "
+                            f"reduction{keepdim_suffix}"
+                        ),
+                    )
+
+                # reduce on two non-ragged, non-batch dims
+                if njt.dim() > 3 and njt._ragged_idx == 1:
+                    yield SampleInput(
+                        _clone(njt),
+                        kwargs={
+                            **op_kwargs,
+                            dimlist_argname: [njt.dim() - 2, njt.dim() - 1],
+                            **({"keepdim": keepdim} if supports_keepdim else {}),
+                        },
+                        name=f"{njt_desc}: two normal dim reduction{keepdim_suffix}",
+                    )
+
+                # full reduction by specifying all dims
+                yield SampleInput(
+                    _clone(njt),
+                    kwargs={
+                        **op_kwargs,
+                        dimlist_argname: list(range(njt.dim())),
+                        **({"keepdim": keepdim} if supports_keepdim else {}),
+                    },
+                    name=f"{njt_desc}: all dim reduction{keepdim_suffix}",
+                )
+
+                # TODO: Reducing on ragged dim and non-batch dim is not supported;
+                # cover this in the set of error inputs.
+
+        # full reduction
+        yield SampleInput(
+            _clone(njt),
+            kwargs=dict(op_kwargs),
+            name=f"{njt_desc}: full reduction with keepdim={keepdim}",
+        )
+
+
+def unsupported_sample_inputs_func(op_name):
+    def _f(op_info, device, dtype, requires_grad, op_name=op_name, **kwargs):
+        raise RuntimeError(
+            f"OpInfo for {op_name} does not support NJT. Support can be added by modifying "
+            "torch/testing/_internal/opinfo/definitions/nested.py."
+        )
+
+    return _f
+
+
+def unsupported_reference(op_name):
+    def _f(op, sample):
+        raise RuntimeError(
+            f"OpInfo for {op_name} does not define a ref() function. Support can be added by "
+            "modifying torch/testing/_internal/opinfo/definitions/nested.py."
+        )
+
+    return _f
+
+
+# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
+def sample_inputs_unary_dimwise(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    if op_kwargs is None:
+        op_kwargs = {}
+
+    # only support a single non-list dim arg for now
+    assert op_info._extra_op_data is not None
+    single_dim_argname, dimlist_argname = op_info._extra_op_data.get_dim_argnames()
+    assert single_dim_argname is not None
+    assert dimlist_argname is None
+
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        for dim in range(njt.dim()):
+            kwargs = {single_dim_argname: dim}
+            kwargs.update(op_kwargs)
+            yield SampleInput(
+                _clone(njt),
+                kwargs=kwargs,
+                name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
+            )
+
+
+def batchwise_reference_chunk(op, sample):
+    # reference for chunk() over dim=0
+    B = sample.input.size(0)
+    num_chunks = sample.kwargs["chunks"]
+    chunk_size = math.ceil(B / num_chunks)
+    num_full_chunks = B // chunk_size
+    chunk_sizes = [chunk_size for _ in range(num_full_chunks)]
+    if B % chunk_size != 0:
+        # final chunk contains the leftovers
+        chunk_sizes.append(B % chunk_size)
+
+    # split unbound components into chunks according to calculated sizes
+    components = list(sample.input.unbind())
+    start = 0
+    chunks = []
+    for chunk_size in chunk_sizes:
+        chunks.append(components[start : start + chunk_size])
+        start += chunk_size
+
+    # rejoin into NJT outputs
+    return [torch.nested.as_nested_tensor(lst, layout=torch.jagged) for lst in chunks]
+
+
+def batchwise_reference_narrow(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_select(op, sample):
+    # reference for select() over dim=0
+    return sample.input.unbind()[sample.kwargs["index"]]
+
+
+def batchwise_reference_split(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_split_with_sizes(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_unflatten(op, sample):
+    # TODO: write this!
+    raise NotImplementedError
+
+
+def batchwise_reference_unsqueeze(op, sample):
+    raise ValueError("unsqueeze() is not intended to operate on the batch dim")
+
+
+def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs):
+    # non-contiguous NJTs
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        yield SampleInput(njt, name=_describe_njt(njt))
+
+    for memory_format in (torch.contiguous_format, torch.preserve_format):
+        # construct a "non-contiguous with holes" NJT
+        values = torch.randn(
+            10, 5, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+        offsets = torch.tensor([0, 2, 4, 10], device=device, dtype=torch.int64)
+        lengths = torch.tensor([2, 1, 3], device=device, dtype=torch.int64)
+        njt = torch.nested.nested_tensor_from_jagged(
+            values, offsets=offsets, lengths=lengths
+        )
+
+        njt_desc = _describe_njt(njt)
+        yield SampleInput(
+            njt,
+            kwargs={"memory_format": memory_format},
+            name=f"{njt_desc}: {memory_format})",
+        )
+
+
+def sample_inputs_fill(op_info, device, dtype, requires_grad, **kwargs):
+    # scalar case
+    unary_func = partial(sample_inputs_elementwise_njt_unary, op_kwargs={"value": 42.0})
+    yield from unary_func(op_info, device, dtype, requires_grad)
+
+    # TODO: add Tensor case
+
+
+def sample_inputs_mvl_gamma(p):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"p": p})
+
+
+def sample_inputs_polygamma_n(n):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})
+
+
+def sample_inputs_special_polygamma_n(n):
+    return partial(sample_inputs_elementwise_njt_unary, op_kwargs={"n": n})
+
+
+def sample_inputs_to(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs):
+    for njt in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[2, 3, 4],
+    ):
+        other_dtypes = (
+            d for d in (torch.float32, torch.half, torch.double) if d is not dtype
+        )
+        for other_dtype in other_dtypes:
+            sample_name = f"{njt.dim()}D: {dtype} -> {other_dtype}"
+            yield SampleInput(_clone(njt), kwargs={"dtype": dtype}, name=sample_name)
+
+        # only include device transfer for CUDA inputs
+        if "cuda" in device:
+            other_device = "cpu"
+            sample_name = f"{_describe_njt(njt)}: {device} -> {other_device}"
+            yield SampleInput(
+                _clone(njt), kwargs={"device": other_device}, name=sample_name
+            )
+
+
+def sample_inputs_bmm(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs):
+    for njt_3d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3]
+    ):
+        # (B, j1, D) x (B, D, E) => (B, j1, E)
+        if njt_3d._ragged_idx == 1:
+            B, D = njt_3d.shape[0], njt_3d.shape[-1]
+            E = D + 2
+            other = torch.randn(B, D, E, device=device, dtype=dtype)
+            # used for slicing in unbind_reference()
+            other._batch_dim = 0
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"mat2": other},
+                name=f"{njt_desc}: (B, j, D) x (B, D, E)",
+            )
+
+        # TODO (need factory functions):
+        # (B, D, j1) x (B, j1, E) => (B, D, E)
+
+
+def reference_bmm(op, sample):
+    # unbind reduces a dim and bmm requires 3D, so use matmul as the reference
+    matmul_op = copy(op)
+    matmul_op.op = torch.matmul
+    # change arg name from mat2 -> other
+    modified_sample = copy(sample)
+    other = modified_sample.kwargs["mat2"]
+    del modified_sample.kwargs["mat2"]
+    modified_sample.kwargs["other"] = other
+    return unbind_reference(matmul_op, modified_sample)
+
+
+def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single chunks value
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"chunks": 3})
+        # other dim chunking: test different chunks values
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for chunks in [1, D // 2, D - 1, D]:
+                yield _update_sample(sample_input, {"chunks": chunks})
+
+
+def sample_inputs_matmul(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    # also run bmm samples through
+    for sample_input in sample_inputs_bmm(op_info, device, dtype, requires_grad):
+        # change arg name from mat2 -> other
+        other = sample_input.kwargs["mat2"]
+        del sample_input.kwargs["mat2"]
+        sample_input.kwargs["other"] = other
+        yield sample_input
+
+    # 3D cases not covered by bmm
+    for njt_3d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3]
+    ):
+        # (B, j1, D) x (D, E) => (B, j1, E)
+        if njt_3d._ragged_idx == 1:
+            D = njt_3d.shape[-1]
+            E = D + 2
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"other": torch.randn(D, E, device=device, dtype=dtype)},
+                name=f"{njt_desc}: (B, j, D) x (D, E)",
+            )
+
+    # 4D cases
+    for njt_4d in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[4]
+    ):
+        # (B, j1, D, E) x (E, F) => (B, j1, D, F)
+        if njt_4d._ragged_idx == 1:
+            E = njt_4d.shape[-1]
+            F = E + 2
+            njt_desc = _describe_njt(njt_4d)
+            yield SampleInput(
+                _clone(njt_4d),
+                kwargs={"other": torch.randn(E, F, device=device, dtype=dtype)},
+                name=f"{njt_desc}: (B, j, D, E) x (E, F)",
+            )
+
+    # Dense x NJT cases
+    for njt_3d in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[3],
+    ):
+        # (B, F, E) x (B, E, j1) => (B, F, j1)
+        if njt_3d._ragged_idx == 2:
+            B = njt_3d.shape[0]
+            E = njt_3d.shape[1]
+            F = E + 2
+            njt_desc = _describe_njt(njt_3d)
+            dense_t = torch.randn(
+                B, F, E, device=device, dtype=dtype, requires_grad=requires_grad
+            )
+            dense_t._batch_dim = 0  # for unbind_reference()
+            yield SampleInput(
+                dense_t,
+                args=(_clone(njt_3d),),
+                name=f"{njt_desc}: (B, F, E) x (B, E, j1)",
+            )
+
+    # NJT x NJT => Dense case
+    for njt_3d in _sample_njts(
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        dims=[3],
+    ):
+        # (B, E, j1) x (B, j1, F) => (B, E, F)
+        if njt_3d._ragged_idx == 2 and njt_3d.is_contiguous():
+            B, E, _ = njt_3d.shape
+            sum_j1 = len(njt_3d.values())
+            other_cont = torch.randn(
+                sum_j1, E + 2, device=device, dtype=dtype, requires_grad=requires_grad
+            )
+            other_njt = torch.nested.nested_tensor_from_jagged(
+                other_cont, njt_3d.offsets(), lengths=njt_3d._lengths
+            )
+            njt_desc = _describe_njt(njt_3d)
+            yield SampleInput(
+                _clone(njt_3d),
+                kwargs={"other": _clone(other_njt)},
+                name=f"{njt_desc}: (B, E, j1) x (B, j1, F)",
+            )
+
+        # TODO (need factory functions):
+        # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F)
+
+
+def sample_inputs_masked_select(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2]
+    ):
+        yield SampleInput(
+            njt,
+            kwargs={"mask": (torch.randn_like(njt, requires_grad=False) < 0.0)},
+            name=_describe_njt(njt),
+        )
+
+
+def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim narrowing: test a single start, length value
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"start": 1, "length": 2})
+        # other dim narrowing: test different start, length values
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for start, length in [(0, D), (0, D - 1), (1, D - 1), (D - 1, 1)]:
+                yield _update_sample(sample_input, {"start": start, "length": length})
+
+
+def sample_inputs_nn_functional_embedding(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    indices = torch.nested.nested_tensor(
+        [
+            torch.tensor([0, 2, 1, 3]),
+            torch.tensor([4, 2, 1]),
+            torch.tensor([6, 7, 5, 2, 4]),
+        ],
+        layout=torch.jagged,
+        dtype=torch.int64,
+        device=device,
+    )
+
+    NUM_EMBEDDINGS = 20
+    EMBEDDING_DIM = 32
+    weight = torch.randn(NUM_EMBEDDINGS, EMBEDDING_DIM, device=device, dtype=dtype)
+
+    # NB: the OpInfo entry for embedding_bag expects weight first so the gradients
+    # can be checked
+    yield SampleInput(
+        _clone(weight).requires_grad_(),
+        args=(indices,),
+    )
+
+    yield SampleInput(
+        _clone(weight).requires_grad_(),
+        args=(indices,),
+        kwargs={"padding_idx": 1},
+    )
+
+
+def sample_inputs_index_put(
+    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
+    ):
+        for dim in range(njt.dim()):
+            indices = [
+                torch.tensor(list(range(njt.size(0))), device=njt.device),
+                *[
+                    torch.tensor([0] * njt.size(0), device=njt.device)
+                    for _ in range(dim - 1)
+                ],
+            ]
+            njt_desc = _describe_njt(njt)
+            yield SampleInput(
+                _clone(njt),
+                kwargs={
+                    "indices": indices,
+                    "values": torch.tensor(1.0, device=njt.device),
+                },
+                name=f"{njt_desc}: up to dim {dim - 1}",
+            )
+
+    # Non-cont NJT for completeness
+    offsets = torch.tensor([0, 2, 5, 7], device=device)
+    lengths = torch.tensor([2, 2, 2], device=device)
+    indices = [
+        torch.tensor([0, 1, 2], device=device),
+        torch.tensor([0, 1, 1], device=device),
+        torch.tensor([0, 0, 0], device=device),
+    ]
+    a = torch.nested.nested_tensor_from_jagged(
+        torch.zeros(7, 3, device=device), offsets, lengths
+    ).requires_grad_(requires_grad)
+
+    njt_desc = _describe_njt(a)
+    yield SampleInput(
+        _clone(a),
+        kwargs={"indices": indices, "values": torch.tensor(1.0, device=a.device)},
+        name=f"{njt_desc}: all dims",
+    )
+
+
+def sample_inputs_nn_functional_embedding_bag(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    for generate_per_sample_weight in (True, False):
+        for mode in ("sum", "mean", "max"):
+            # per_sample_weights is only supported for mode='sum'
+            if mode != "sum" and generate_per_sample_weight:
+                continue
+
+            NUM_EMBEDDINGS = 10
+            EMBEDDING_DIM = 32
+            weight = torch.randn(
+                NUM_EMBEDDINGS, EMBEDDING_DIM, dtype=dtype, device=device
+            )
+
+            njt = torch.nested.nested_tensor(
+                [
+                    torch.randint(0, NUM_EMBEDDINGS, size=(2,)),
+                    torch.randint(0, NUM_EMBEDDINGS, size=(3,)),
+                    torch.randint(0, NUM_EMBEDDINGS, size=(4,)),
+                ],
+                layout=torch.jagged,
+                dtype=torch.int64,
+                device=device,
+            )
+
+            per_sample_weights = None
+            if generate_per_sample_weight:
+                per_sample_weights = torch.randn_like(njt, dtype=dtype)
+
+            # NB: the OpInfo entry for embedding_bag expects weight first so the gradients
+            # can be checked
+            yield SampleInput(
+                weight,
+                args=(njt,),
+                kwargs={
+                    "mode": mode,
+                    "per_sample_weights": per_sample_weights,
+                },
+            )
+
+
+def reference_nn_functional_embedding_bag(op, sample):
+    # run reference on a single bag at a time
+    new_kwargs = dict(sample.kwargs)
+    new_kwargs.update(
+        {"offsets": torch.tensor([0], dtype=torch.int64, device=sample.input.device)}
+    )
+    # flip input / weight back to what unbind_reference() expects
+    sample = SampleInput(sample.args[0], args=(sample.input,), kwargs=new_kwargs)
+    old_op = op.op
+    op.op = torch.nn.functional.embedding_bag
+    output = unbind_reference(op, sample, wrap_output_as_njt=False)
+    op.op = old_op
+    # concat bag outputs to get final output
+    return torch.cat(output, dim=0)
+
+
+def sample_inputs_nn_functional_linear(op_info, device, dtype, requires_grad, **kwargs):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4, 5]
+    ):
+        # projection over a ragged dim is not currently supported
+        if is_nested_int(njt.size(-1)):
+            continue
+
+        # with bias
+        NUM_OUTPUT = 10
+        weight = torch.randn(
+            NUM_OUTPUT,
+            njt.size(-1),
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+        bias = torch.randn(
+            NUM_OUTPUT, device=device, dtype=dtype, requires_grad=requires_grad
+        )
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+                "bias": _clone(bias),
+            },
+            name=f"{_describe_njt(njt)}: with bias",
+        )
+
+        # without bias
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+            },
+            name=f"{_describe_njt(njt)}: without bias",
+        )
+
+
+def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **kwargs):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
+    ):
+        # Second dim is interpreted as number of channels; this should be non-ragged for now
+        num_channels = njt.size(1)
+        if is_nested_int(num_channels):
+            continue
+
+        # 1D weight
+        weight = torch.randn(
+            num_channels,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": _clone(weight),
+            },
+            name=f"{_describe_njt(njt)}: 1D weight",
+        )
+
+        # scalar tensor weight
+        yield SampleInput(
+            _clone(njt),
+            kwargs={
+                "weight": torch.tensor(4.2, device=device, dtype=dtype),
+            },
+            name=f"{_describe_njt(njt)}: scalar tensor weight",
+        )
+
+
+def sample_inputs_nn_functional_rms_norm(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    for njt in _sample_njts(
+        device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
+    ):
+        # normalize over non-ragged dims
+        for start_dim in range(njt.dim()):
+            if start_dim <= njt._ragged_idx:
+                continue
+
+            normalized_shape = njt.shape[start_dim:]
+            weight = torch.randn(
+                normalized_shape,
+                device=device,
+                dtype=dtype,
+                requires_grad=requires_grad,
+            )
+
+            yield SampleInput(
+                _clone(njt),
+                kwargs={
+                    "normalized_shape": normalized_shape,
+                    "weight": weight,
+                },
+                name=f"{_describe_njt(njt)}",
+            )
+
+
+sample_inputs_nn_functional_threshold = partial(
+    sample_inputs_elementwise_njt_unary,
+    op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9},
+)
+
+
+def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single index
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"index": 0})
+        # other dim chunking: test different indices
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for index in [0, D // 2, D - 1]:
+                yield _update_sample(sample_input, {"index": index})
+
+
+def sample_inputs_split(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # ragged dim chunking: test a single split size
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            yield _update_sample(sample_input, {"split_size_or_sections": 3})
+        # other dim chunking: test different split sizes
+        else:
+            D = sample_input.input.size(sample_input.kwargs["dim"])
+            for split_size in [1, D // 2, D - 1, D]:
+                yield _update_sample(
+                    sample_input, {"split_size_or_sections": split_size}
+                )
+
+
+def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # It will never make sense to operate on the ragged dim.
+        # TODO: Handle this with error_inputs
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            continue
+
+        D = sample_input.input.size(sample_input.kwargs["dim"])
+        # splits should add up to D
+        split1 = torch.randint(0, D - 1, size=()).item()
+        split2 = D - split1
+        yield _update_sample(sample_input, {"split_sizes": [split1, split2]})
+
+
+def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
+    # squeeze-specific NJT generator (need to ensure there are some 1s in the shape)
+    def _get_njts():
+        njt = random_nt_from_dims(
+            (4, None, 1, 3, 1),
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            layout=torch.jagged,
+        )
+        yield njt
+        # without min / max seqlen cached
+        values = njt.values().detach().clone()
+        offsets = njt.offsets().detach().clone()
+        yield torch.nested.nested_tensor_from_jagged(values, offsets)
+        # non-contiguous transposed
+        yield njt.transpose(1, 3)
+        # non-contiguous with holes
+        values = njt.values().detach().clone()
+        offsets = njt.offsets().detach().clone()
+        # subtract 1 to cause holes
+        lengths = (offsets.diff() - 1).detach().clone()
+        yield torch.nested.nested_tensor_from_jagged(
+            values=values,
+            offsets=offsets,
+            lengths=lengths,
+        )
+
+    for njt in _get_njts():
+        # single dim operation
+        for dim in range(njt.dim()):
+            # Operation on batch / ragged dim is never expected to work.
+            # TODO: Handle these via error_inputs.
+            if dim == 0 or dim == njt._ragged_idx:
+                continue
+
+            yield SampleInput(
+                _clone(njt),
+                kwargs={"dim": dim},
+                name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
+            )
+
+        # multiple dim operation (pass no args)
+        yield SampleInput(
+            _clone(njt),
+            kwargs={"dim": dim},
+            name=f"{_describe_njt(njt)}: multiple dims",
+        )
+
+
+def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        # It will never make sense to operate on the ragged dim.
+        # TODO: Handle this with error_inputs
+        if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
+            continue
+
+        D = sample_input.input.size(sample_input.kwargs["dim"])
+        # sizes should multiply to be D
+        yield _update_sample(sample_input, {"sizes": [D, 1]})
+        yield _update_sample(sample_input, {"sizes": [1, D]})
+        if D % 2 == 0:
+            yield _update_sample(sample_input, {"sizes": [D // 2, 2]})
+            yield _update_sample(sample_input, {"sizes": [2, D // 2]})
+
+
+def sample_inputs_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
+    for sample_input in sample_inputs_unary_dimwise(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        yield sample_input
+
+        last_dim_sample = _update_sample(sample_input, {"dim": -1})
+        last_dim_sample.name = (
+            f"{_describe_njt(last_dim_sample.input)}: add dim to the end"
+        )
+        # Tell the unbind reference how to canonicalize the dim kwargs
+        # This is necessary because unsqueeze() allows for a dim after
+        # the last dim to indicate an unsqueeze at the end.
+        last_dim_sample.input._ndim = last_dim_sample.input.dim() + 1
+        yield last_dim_sample
+
+
+def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
+    for sample in sample_inputs_elementwise_njt_binary(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        other = sample.args[0]
+        sample.args = ()
+        sample.kwargs["other"] = other
+        sample.kwargs["condition"] = sample.input > 0.0
+        sample.name = sample.name.replace("(", "(NT, ")
+        yield sample
+
+
+# === END OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
+
+
+# Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs
+# (involving NJTs) to pass to the op. Full name consists of the OpInfo's name and variant name
+# separated by a period (e.g. special.polygamma.special_polygamma_n_0). These are necessary
+# to specify if they cannot be auto-generated for some reason. Try to keep these sorted
+# in alphabetical order!
+njt_sample_inputs = {
+    "bmm": sample_inputs_bmm,
+    "chunk": sample_inputs_chunk,
+    "clone": sample_inputs_clone,
+    "count_nonzero": partial(sample_inputs_njt_reduction, supports_keepdim=False),
+    "fill": sample_inputs_fill,
+    **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)},
+    "nn.functional.embedding": sample_inputs_nn_functional_embedding,
+    "nn.functional.embedding_bag": sample_inputs_nn_functional_embedding_bag,
+    "nn.functional.linear": sample_inputs_nn_functional_linear,
+    "nn.functional.prelu": sample_inputs_nn_functional_prelu,
+    "nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm,
+    "nn.functional.threshold": sample_inputs_nn_functional_threshold,
+    **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)},
+    "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0),
+    "to": sample_inputs_to,
+    "matmul": sample_inputs_matmul,
+    "masked_select": sample_inputs_masked_select,
+    "narrow": sample_inputs_narrow,
+    "index_put": sample_inputs_index_put,
+    # these two don't have ReductionOpInfo entries
+    "max.reduction_with_dim": sample_inputs_njt_reduction,
+    "min.reduction_with_dim": sample_inputs_njt_reduction,
+    "select": sample_inputs_select,
+    "split": sample_inputs_split,
+    "split_with_sizes": sample_inputs_split_with_sizes,
+    "squeeze": sample_inputs_squeeze,
+    "unflatten": sample_inputs_unflatten,
+    "unsqueeze": sample_inputs_unsqueeze,
+    "where": sample_inputs_where,
+}
+
+njt_references = {
+    "bmm": reference_bmm,
+    "chunk": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_chunk
+    ),
+    "count_nonzero": reduction_reference,
+    # these two don't have ReductionOpInfo entries
+    "max.reduction_with_dim": reduction_reference,
+    "min.reduction_with_dim": reduction_reference,
+    "narrow": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_narrow
+    ),
+    "select": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_select
+    ),
+    "split": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_split
+    ),
+    "split_with_sizes": partial(
+        unary_dimwise_reference,
+        batchwise_reference=batchwise_reference_split_with_sizes,
+    ),
+    "squeeze": unbind_reference,
+    "nn.functional.embedding_bag": reference_nn_functional_embedding_bag,
+    "unflatten": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_unflatten
+    ),
+    "unsqueeze": partial(
+        unary_dimwise_reference, batchwise_reference=batchwise_reference_unsqueeze
+    ),
+}
+
+
+# Translates an OpInfo entry to one that operates on NJTs.
+def translate_opinfo(op):
+    new_op = copy(op)
+    new_op.supports_njt = True
+    # add some extra info for use in generating tests on the right subset of ops
+    new_op._extra_op_data = extra_op_data.get(op.full_name, ExtraOpData())
+
+    if op.full_name in njt_sample_inputs:
+        new_op.sample_inputs_func = njt_sample_inputs[op.full_name]
+        new_op.ref = njt_references.get(op.full_name, unbind_reference)
+    elif isinstance(op, UnaryUfuncInfo):
+        new_op.sample_inputs_func = partial(
+            sample_inputs_elementwise_njt_unary, op_kwargs=None
+        )
+        new_op.ref = unbind_reference
+    elif isinstance(op, BinaryUfuncInfo):
+        new_op.sample_inputs_func = partial(
+            sample_inputs_elementwise_njt_binary, op_kwargs=None
+        )
+        new_op.ref = unbind_reference
+    elif isinstance(op, ReductionOpInfo):
+        new_op.sample_inputs_func = partial(sample_inputs_njt_reduction, op_kwargs=None)
+        new_op.ref = reduction_reference
+    # TODO: Translate the rest of the OpInfos
+    else:
+        new_op.sample_inputs_func = unsupported_sample_inputs_func(op.full_name)
+        new_op.ref = unsupported_reference(op.full_name)
+        new_op.supports_njt = False
+
+    return new_op
+
+
+njt_op_db = [translate_opinfo(op) for op in op_db]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e517b20838e1e257542acc997d6a06e786c3a5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/signal.py
@@ -0,0 +1,459 @@
+# mypy: ignore-errors
+
+import unittest
+from functools import partial
+from itertools import product
+from typing import Callable
+
+import numpy
+
+import torch
+from torch.testing._internal.common_dtype import floating_types
+from torch.testing._internal.common_utils import TEST_SCIPY
+from torch.testing._internal.opinfo.core import (
+    DecorateInfo,
+    ErrorInput,
+    OpInfo,
+    SampleInput,
+)
+
+
+if TEST_SCIPY:
+    import scipy.signal
+
+
+def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
+    r"""Base function used to create sample inputs for windows.
+
+    For additional required args you should use *args, as well as **kwargs for
+    additional keyword arguments.
+    """
+
+    # Remove include_conjugated_inputs from kwargs
+    kwargs.pop("include_conjugated_inputs", None)
+    # Tests window sizes up to 5 samples.
+    for size, sym in product(range(6), (True, False)):
+        yield SampleInput(
+            size,
+            *args,
+            sym=sym,
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+            **kwargs,
+        )
+
+
+def reference_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
+    r"""Reference inputs function to use for windows which have a common signature, i.e.,
+    window size and sym only.
+
+    Implement other special functions for windows that have a specific signature.
+    See exponential and gaussian windows for instance.
+    """
+    yield from sample_inputs_window(
+        op_info, device, dtype, requires_grad, *args, **kwargs
+    )
+
+    cases = (8, 16, 32, 64, 128, 256)
+
+    for size in cases:
+        yield SampleInput(size, sym=False)
+        yield SampleInput(size, sym=True)
+
+
+def reference_inputs_exponential_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"center": 4, "tau": 0.5}),
+        (16, {"center": 8, "tau": 2.5}),
+        (32, {"center": 16, "tau": 43.5}),
+        (64, {"center": 20, "tau": 3.7}),
+        (128, {"center": 62, "tau": 99}),
+        (256, {"tau": 10}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        kw["center"] = None
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_gaussian_window(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"std": 0.1}),
+        (16, {"std": 1.2}),
+        (32, {"std": 2.1}),
+        (64, {"std": 3.9}),
+        (128, {"std": 4.5}),
+        (256, {"std": 10}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_kaiser_window(op_info, device, dtype, requires_grad, **kwargs):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"beta": 2}),
+        (16, {"beta": 12}),
+        (32, {"beta": 30}),
+        (64, {"beta": 35}),
+        (128, {"beta": 41.2}),
+        (256, {"beta": 100}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_general_cosine_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"a": [0.5, 0.5]}),
+        (16, {"a": [0.46, 0.54]}),
+        (32, {"a": [0.46, 0.23, 0.31]}),
+        (64, {"a": [0.5]}),
+        (128, {"a": [0.1, 0.8, 0.05, 0.05]}),
+        (256, {"a": [0.2, 0.2, 0.2, 0.2, 0.2]}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def reference_inputs_general_hamming_window(
+    op_info, device, dtype, requires_grad, **kwargs
+):
+    yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
+
+    cases = (
+        (8, {"alpha": 0.54}),
+        (16, {"alpha": 0.5}),
+        (32, {"alpha": 0.23}),
+        (64, {"alpha": 0.8}),
+        (128, {"alpha": 0.9}),
+        (256, {"alpha": 0.05}),
+    )
+
+    for size, kw in cases:
+        yield SampleInput(size, sym=False, **kw)
+        yield SampleInput(size, sym=True, **kw)
+
+
+def error_inputs_window(op_info, device, *args, **kwargs):
+    # Tests for windows that have a negative size
+    yield ErrorInput(
+        SampleInput(-1, *args, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="requires non-negative window length, got M=-1",
+    )
+
+    # Tests for window tensors that are not torch.strided, for instance, torch.sparse_coo.
+    yield ErrorInput(
+        SampleInput(
+            3,
+            *args,
+            layout=torch.sparse_coo,
+            device=device,
+            dtype=torch.float32,
+            **kwargs,
+        ),
+        error_type=ValueError,
+        error_regex="is implemented for strided tensors only, got: torch.sparse_coo",
+    )
+
+    # Tests for window tensors that are not floating point dtypes, for instance, torch.long.
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.long, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.int64",
+    )
+
+    # Tests for window tensors that are bfloat16
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.bfloat16, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.bfloat16",
+    )
+
+    # Tests for window tensors that are float16
+    yield ErrorInput(
+        SampleInput(3, *args, dtype=torch.float16, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="expects float32 or float64 dtypes, got: torch.float16",
+    )
+
+
+def error_inputs_exponential_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, **kwargs)
+
+    # Tests for negative decay values.
+    yield ErrorInput(
+        SampleInput(3, tau=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Tau must be positive, got: -1 instead.",
+    )
+
+    # Tests for symmetric windows and a given center value.
+    yield ErrorInput(
+        SampleInput(3, center=1, sym=True, dtype=torch.float32, device=device),
+        error_type=ValueError,
+        error_regex="Center must be None for symmetric windows",
+    )
+
+
+def error_inputs_gaussian_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, std=0.5, **kwargs)
+
+    # Tests for negative standard deviations
+    yield ErrorInput(
+        SampleInput(3, std=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Standard deviation must be positive, got: -1 instead.",
+    )
+
+
+def error_inputs_kaiser_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, beta=12, **kwargs)
+
+    # Tests for negative beta
+    yield ErrorInput(
+        SampleInput(3, beta=-1, dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="beta must be non-negative, got: -1 instead.",
+    )
+
+
+def error_inputs_general_cosine_window(op_info, device, **kwargs):
+    # Yield common error inputs
+    yield from error_inputs_window(op_info, device, a=[0.54, 0.46], **kwargs)
+
+    # Tests for negative beta
+    yield ErrorInput(
+        SampleInput(3, a=None, dtype=torch.float32, device=device, **kwargs),
+        error_type=TypeError,
+        error_regex="Coefficients must be a list/tuple",
+    )
+
+    yield ErrorInput(
+        SampleInput(3, a=[], dtype=torch.float32, device=device, **kwargs),
+        error_type=ValueError,
+        error_regex="Coefficients cannot be empty",
+    )
+
+
+def reference_signal_window(fn: Callable):
+    r"""Wrapper for scipy signal window references.
+
+    Discards keyword arguments for window reference functions that don't have a matching signature with
+    torch, e.g., gaussian window.
+    """
+
+    def _fn(
+        *args,
+        dtype=numpy.float64,
+        device=None,
+        layout=torch.strided,
+        requires_grad=False,
+        **kwargs,
+    ):
+        r"""The unused arguments are defined to disregard those values"""
+        return fn(*args, **kwargs).astype(dtype)
+
+    return _fn
+
+
+def make_signal_windows_opinfo(
+    name: str,
+    ref: Callable,
+    sample_inputs_func: Callable,
+    reference_inputs_func: Callable,
+    error_inputs_func: Callable,
+    *,
+    skips: tuple[DecorateInfo, ...] = (),
+):
+    r"""Helper function to create OpInfo objects related to different windows."""
+    return OpInfo(
+        name=name,
+        ref=ref if TEST_SCIPY else None,
+        dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_func,
+        reference_inputs_func=reference_inputs_func,
+        error_inputs_func=error_inputs_func,
+        supports_out=False,
+        supports_autograd=False,
+        skips=(
+            # TODO: same as this?
+            # https://github.com/pytorch/pytorch/issues/81774
+            # also see: arange, new_full
+            # fails to match any schemas despite working in the interpreter
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestOperatorSignatures",
+                "test_get_torch_func_signature_exhaustive",
+            ),
+            # fails to match any schemas despite working in the interpreter
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            # skip these tests since we have non tensor input
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestCommon",
+                "test_variant_consistency_eager",
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_conj_view"),
+            DecorateInfo(
+                unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view"
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestVmapOperatorsOpInfo",
+                "test_vmap_exhaustive",
+            ),
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestVmapOperatorsOpInfo",
+                "test_op_has_batch_rule",
+            ),
+            DecorateInfo(
+                unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+            *skips,
+        ),
+    )
+
+
+op_db: list[OpInfo] = [
+    make_signal_windows_opinfo(
+        name="signal.windows.hamming",
+        ref=reference_signal_window(scipy.signal.windows.hamming)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.hann",
+        ref=reference_signal_window(scipy.signal.windows.hann) if TEST_SCIPY else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.bartlett",
+        ref=reference_signal_window(scipy.signal.windows.bartlett)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.blackman",
+        ref=reference_signal_window(scipy.signal.windows.blackman)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.cosine",
+        ref=reference_signal_window(scipy.signal.windows.cosine)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.exponential",
+        ref=reference_signal_window(scipy.signal.windows.exponential)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, tau=2.78),
+        reference_inputs_func=partial(reference_inputs_exponential_window, tau=2.78),
+        error_inputs_func=error_inputs_exponential_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.gaussian",
+        ref=reference_signal_window(scipy.signal.windows.gaussian)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, std=1.92),
+        reference_inputs_func=partial(reference_inputs_gaussian_window, std=1.92),
+        error_inputs_func=error_inputs_gaussian_window,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
+                "TestCommon",
+                "test_numpy_ref_mps",
+            ),
+        ),
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.kaiser",
+        ref=reference_signal_window(scipy.signal.windows.kaiser)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, beta=12.0),
+        reference_inputs_func=partial(reference_inputs_kaiser_window, beta=12.0),
+        error_inputs_func=error_inputs_kaiser_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.general_cosine",
+        ref=reference_signal_window(scipy.signal.windows.general_cosine)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, a=[0.54, 0.46]),
+        reference_inputs_func=partial(
+            reference_inputs_general_cosine_window, a=[0.54, 0.46]
+        ),
+        error_inputs_func=error_inputs_general_cosine_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.general_hamming",
+        ref=reference_signal_window(scipy.signal.windows.general_hamming)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=partial(sample_inputs_window, alpha=0.54),
+        reference_inputs_func=partial(
+            reference_inputs_general_hamming_window, alpha=0.54
+        ),
+        error_inputs_func=error_inputs_window,
+    ),
+    make_signal_windows_opinfo(
+        name="signal.windows.nuttall",
+        ref=reference_signal_window(scipy.signal.windows.nuttall)
+        if TEST_SCIPY
+        else None,
+        sample_inputs_func=sample_inputs_window,
+        reference_inputs_func=reference_inputs_window,
+        error_inputs_func=error_inputs_window,
+    ),
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..41c17471d9de28427ccf74ad4e26739635c22963
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/sparse.py
@@ -0,0 +1,924 @@
+# mypy: ignore-errors
+
+import os
+
+import torch
+from torch.testing import make_tensor  # noqa: F401
+from torch.testing._internal.opinfo.core import (  # noqa: F401
+    BinaryUfuncInfo,
+    ErrorInput,
+    generate_elementwise_binary_tensors,
+    ReductionOpInfo,
+    sample_inputs_reduction,
+    SampleInput,
+)
+
+
+def _check_validate(op_info, sample):
+    def _check_fail(sample):
+        try:
+            op_info(
+                sample.sample_input.input,
+                *sample.sample_input.args,
+                **sample.sample_input.kwargs,
+            )
+        except sample.error_type:
+            pass
+        except Exception as msg:
+            raise AssertionError(  # noqa: B904
+                f"{op_info.name} on {sample.sample_input=} expected exception "
+                f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}"
+            )
+        else:
+            raise AssertionError(
+                f"{op_info.name} on {sample.sample_input=} expected exception "
+                f"{sample.error_type}: {sample.error_regex}, got none."
+            )
+
+    def _check_success(sample):
+        try:
+            op_info(sample.input, *sample.args, **sample.kwargs)
+        except Exception as msg:
+            raise AssertionError(  # noqa: B904
+                f"{op_info.name} on {sample=} expected to succeed "
+                f", got {type(msg).__name__}: {msg}"
+            )
+
+    if isinstance(sample, ErrorInput):
+        _check_fail(sample)
+    else:
+        _check_success(sample)
+
+
+def _sample_inputs_sparse(
+    sample_inputs,
+    maybe_failing_sample_inputs,
+    validate_sample_input,
+    op_info,
+    *args,
+    **kwargs,
+):
+    check_validate = (
+        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
+    )
+    for sample in sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, SampleInput):
+            yield sample
+        # Error inputs are handled in error_inputs_sparse
+
+    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, SampleInput):
+            yield sample
+
+
+def _error_inputs_sparse(
+    maybe_failing_sample_inputs, validate_sample_input, op_info, *args, **kwargs
+):
+    check_validate = (
+        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
+    )
+    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
+        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
+        if isinstance(sample, ErrorInput):
+            yield sample
+        # Sample inputs are handled in sample_inputs_sparse
+
+
+def _apply_requires_grad_to_samples(sample_inputs):
+    """Decorator to _maybe_failing_sample_inputs_... generator functions
+    that clones and sets requires_grad argument to tensors in sample
+    input arguments. This is needed when the generated samples share
+    tensor instances.
+    """
+
+    def wrapper(op_info, device, dtype, requires_grad, layout, **kwargs):
+        def apply_requires_grad(x):
+            if (
+                not isinstance(x, torch.Tensor)
+                or x.requires_grad
+                or not requires_grad
+                or not (x.is_floating_point() or x.is_complex())
+            ):
+                return x
+            return x.detach().clone().requires_grad_(requires_grad)
+
+        if requires_grad:
+            for sample_input in sample_inputs(
+                op_info, device, dtype, requires_grad, layout, **kwargs
+            ):
+                yield sample_input.transform(apply_requires_grad)
+        else:
+            yield from sample_inputs(
+                op_info, device, dtype, requires_grad, layout, **kwargs
+            )
+
+    return wrapper
+
+
+def sample_inputs_sparse_reduction(
+    op_info, device, dtype, requires_grad, layout, blocksize=None, **kwargs
+):
+    """Sample inputs for reduction operations on sparse tensors."""
+    layout_name = str(layout).split(".", 1)[-1].rsplit("_coo", 1)[0]
+    op_supports_layout = getattr(op_info, "supports_" + layout_name)
+    if not op_supports_layout:
+        return
+
+    for sample_input in sample_inputs_reduction(
+        op_info, device, dtype, requires_grad, **kwargs
+    ):
+        if sample_input.input.ndim == 0:
+            # scalar sparse tensors are not supported
+            continue
+
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if sample_input.input.ndim < 2:
+                # conversion to sparse compressed tensors requires at
+                # least 2 dimensional tensors
+                continue
+            if sample_input.input.ndim > 2 and (sample_input.input == 0).any():
+                # Skip batched sparse compressed samples that contain
+                # explicit zeros because to_sparse(layout=..) will
+                # fail, see gh-98495.
+                # TODO: remove this if-block after gh-98495 is fixed.
+                continue
+
+        if layout in {torch.sparse_bsr, torch.sparse_bsc} and blocksize is None:
+            blocksize = (1, 1)
+
+        yield SampleInput(
+            sample_input.input.detach()
+            .to_sparse(layout=layout, blocksize=blocksize)
+            .requires_grad_(requires_grad),
+            args=sample_input.args,
+            kwargs=sample_input.kwargs,
+        )
+
+        if layout is torch.sparse_coo and (dtype.is_floating_point or dtype.is_complex):
+            # uncoalesced samples
+            inp = sample_input.input.detach().to_sparse(layout=layout)
+            inp = torch.sparse_coo_tensor(
+                inp.indices().repeat(1, 2),
+                inp.values().repeat(2),
+                inp.shape,
+                dtype=inp.dtype,
+                device=inp.device,
+            )
+            assert not inp.is_coalesced()
+            yield SampleInput(
+                inp.requires_grad_(requires_grad),
+                args=sample_input.args,
+                kwargs=sample_input.kwargs,
+            )
+
+        if sample_input.input.ndim > 2:
+            # hybrid samples
+            yield SampleInput(
+                sample_input.input.detach()
+                .to_sparse(
+                    layout=layout,
+                    blocksize=blocksize,
+                    dense_dim=sample_input.input.ndim - 2,
+                )
+                .requires_grad_(requires_grad),
+                args=sample_input.args,
+                kwargs=sample_input.kwargs,
+            )
+
+
+def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=False):
+    """Return the specified sample when it is valid and supported by the
+    operation. Otherwise, return the sample as ErrorInput instance.
+
+    When check_validate is True, the result is validated against
+    calling the op on the sample.
+    """
+    UNSPECIFIED = object()
+    if op_info.name == "sum":
+        sample = _validate_sample_input_sparse_reduction_sum(sample)
+
+    if op_info.name in {"masked.sum"}:
+        mask = sample.kwargs.get("mask", UNSPECIFIED)
+        if (
+            mask not in {None, UNSPECIFIED}
+            and mask.ndim > 2
+            and mask.layout is torch.strided
+            and (mask == 0).any()
+        ):
+            # TODO: remove this if-block after gh-98495 is fixed.
+            sample = ErrorInput(
+                sample,
+                error_regex="Expect the same number of specified elements per batch.",
+            )
+        elif not sample.kwargs.get("keepdim"):
+            sample = ErrorInput(
+                sample,
+                error_type=(AssertionError, RuntimeError),
+                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
+            )
+        elif mask is UNSPECIFIED:
+            sample = ErrorInput(
+                sample,
+                error_type=ValueError,
+                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
+            )
+        elif sample.input.ndim > 2:
+            sample = ErrorInput(
+                sample,
+                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
+            )
+
+    if op_info.name in {"masked.amax", "masked.amin", "masked.mean", "masked.prod"}:
+        t_inp = sample.input
+        mask = sample.kwargs.get("mask")
+        if (
+            mask is not None
+            and mask.ndim > 2
+            and mask.layout is torch.strided
+            and (mask == 0).any()
+        ):
+            # TODO: remove this if-block after gh-98495 is fixed.
+            sample = ErrorInput(
+                sample,
+                error_regex="Expect the same number of specified elements per batch.",
+            )
+        elif mask is None:
+            sample = ErrorInput(
+                sample,
+                error_type=ValueError,
+                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
+            )
+        elif (
+            mask.layout is sample.input.layout
+            and mask.ndim > 2
+            and op_info.name == "masked.mean"
+        ):
+            sample = ErrorInput(
+                sample,
+                error_type=TypeError,
+                error_regex=(
+                    "where[(][)] received an invalid combination of arguments"
+                    " - got [(]Tensor, Tensor, NoneType[)]"
+                ),
+            )
+        elif not sample.kwargs.get("keepdim"):
+            sample = ErrorInput(
+                sample,
+                error_type=(AssertionError, RuntimeError),
+                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
+            )
+        elif (
+            sample.input.ndim > 2
+            and (sample.kwargs.get("dim") not in {0, 1})
+            and mask.ndim > 2
+            and mask.layout is not torch.strided
+        ):
+            if sample.kwargs.get("dim") == (0, -1):
+                sample = ErrorInput(
+                    sample,
+                    error_regex="tensor dimensionality must be sum of batch, base, and dense dimensionalities",
+                )
+            elif op_info.name == "masked.prod":
+                sample = ErrorInput(
+                    sample,
+                    error_regex="input_dim == 2 INTERNAL ASSERT FAILED at",
+                )
+            else:
+                sample = ErrorInput(
+                    sample,
+                    error_type=AssertionError,
+                    error_regex="Sparse CSR tensors are 2D and only support reduction along dim 0 or 1.",
+                )
+        elif sample.input.ndim > 2:
+            sample = ErrorInput(
+                sample,
+                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
+            )
+        elif (
+            mask.layout is t_inp.layout
+            and mask._nnz() != t_inp._nnz()
+            and t_inp.dense_dim() > 0
+        ):
+            sample = ErrorInput(
+                sample,
+                error_regex="Index tensor must have the same number of dimensions as src tensor",
+            )
+
+    if check_validate:
+        _check_validate(op_info, sample)
+
+    return sample
+
+
+def _validate_sample_input_sparse_reduction_sum(sample, check_validate=False):
+    # NOTE: When fixing a failing sample case, remove the
+    #       corresponding if-block
+    t_inp, t_kwargs = sample.input, sample.kwargs
+    dim = t_kwargs.get("dim")
+    keepdim = t_kwargs.get("keepdim")
+    layout = t_inp.layout
+    if isinstance(dim, (int, list, tuple)):
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
+                return ErrorInput(
+                    sample,
+                    error_regex=(
+                        "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout"
+                    ),
+                )
+            if layout in {torch.sparse_csr, torch.sparse_csc} and not keepdim:
+                return ErrorInput(
+                    sample,
+                    error_regex=(
+                        "reduction operations on CSR tensors with keepdim=False is unsupported"
+                    ),
+                )
+            if t_inp.dim() != 2:
+                return ErrorInput(
+                    sample,
+                    error_regex=("input_dim == 2 INTERNAL ASSERT"),
+                )
+            if layout == torch.sparse_csr:
+                if t_inp.dtype == torch.bool:
+                    return ErrorInput(
+                        sample,
+                        error_regex=("_sparse_csr_sum_cpu not implemented for 'Bool'"),
+                    )
+                if t_inp.dtype == torch.complex32:
+                    return ErrorInput(
+                        sample,
+                        error_regex=(
+                            "_sparse_csr_sum_cuda not implemented for 'ComplexHalf'"
+                        ),
+                    )
+    return sample
+
+
+def _maybe_failing_sample_inputs_sparse_reduction_sum(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Generator of samples that are known to fail or that were failing in past."""
+    # NOTE: When fixing a failing case, remove the Exception comment
+    #       but keep the `yield sample` statement.
+    if layout in [
+        torch.sparse_csr,
+        torch.sparse_csc,
+    ]:
+        # NotImplementedError: Could not run 'aten::sum.IntList_out' with arguments from the 'SparseCsrCPU' backend.
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0, keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,), keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+
+        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+
+    if layout in [
+        torch.sparse_bsr,
+        torch.sparse_bsc,
+    ]:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(2, 2))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0, keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, dense_dim=1, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,), keepdim=True),
+        )
+        yield SampleInput(
+            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1), dense_dim=1)
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=(0,)),
+        )
+
+        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
+            .to_sparse(layout=layout, blocksize=(1, 1))
+            .requires_grad_(requires_grad),
+            kwargs=dict(dim=0),
+        )
+
+
+def sample_inputs_sparse_reduction_sum(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for sum on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        sample_inputs_sparse_reduction,
+        _maybe_failing_sample_inputs_sparse_reduction_sum,
+        _validate_sample_input_sparse_reduction,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_reduction_sum(op_info, device, layout, **kwargs):
+    """Error inputs for sum on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_reduction_sum,
+        _validate_sample_input_sparse_reduction,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def sample_inputs_sparse_elementwise_binary_operation(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for elementwise binary operations on sparse tensors.
+
+    The samples include regular, zero-sized, batched, and hybrid
+    sparse tensors as well as rhs scalars. All tensors are full tensors.
+    """
+
+    def _to_sparse(tensor, **kwargs):
+        return tensor.detach().to_sparse(**kwargs).requires_grad_(requires_grad)
+
+    for sample_input in generate_elementwise_binary_tensors(
+        op_info,
+        device=device,
+        dtype=dtype,
+        requires_grad=requires_grad,
+        exclude_zero=True,
+        **kwargs,
+    ):
+        lhs, rhs = sample_input.input, sample_input.args[0]
+        min_dense_dim = 0
+        max_dense_dim = lhs.ndim - 1
+        if layout in {
+            torch.sparse_csr,
+            torch.sparse_csc,
+            torch.sparse_bsr,
+            torch.sparse_bsc,
+        }:
+            if lhs.ndim < 2:
+                # sparse compressed tensors sparse_dim must be 2
+                continue
+            max_dense_dim = lhs.ndim - 2
+
+        for dense_dim in range(min_dense_dim, max_dense_dim + 1):
+            if layout in {torch.sparse_bsr, torch.sparse_bsc}:
+                blocksizes = [(1, 1)]
+                if lhs.numel() > 0:
+                    blocksizes.append(
+                        (
+                            lhs.shape[lhs.ndim - 2 - dense_dim],
+                            lhs.shape[lhs.ndim - 1 - dense_dim],
+                        )
+                    )
+            else:
+                blocksizes = [None]
+            for blocksize in blocksizes:
+                to_sparse_kwargs = dict(
+                    layout=layout, dense_dim=dense_dim, blocksize=blocksize
+                )
+                lhs_sparse = _to_sparse(lhs, **to_sparse_kwargs)
+                rhs_sparse = _to_sparse(rhs, **to_sparse_kwargs)
+                # op(sparse, sparse)
+                yield SampleInput(
+                    lhs_sparse,
+                    args=(rhs_sparse, *sample_input.args[1:]),
+                    kwargs=sample_input.kwargs,
+                )
+                # op(sparse, scalar)
+                yield SampleInput(
+                    lhs_sparse,
+                    args=(
+                        make_tensor(
+                            (), dtype=dtype, device=device, requires_grad=requires_grad
+                        ),
+                        *sample_input.args[1:],
+                    ),
+                    kwargs=sample_input.kwargs,
+                )
+
+
+def _validate_sample_input_elementwise_binary_sparse_mul(sample):
+    # NOTE: When fixing a failing sample case, remove the
+    #       corresponding if-block
+    t_inp, t_args = sample.input, sample.args
+    batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
+    layout = t_inp.layout
+    dtype = t_inp.dtype
+    if layout is torch.sparse_csr and batch_dim > 0 and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "coo_to_sparse_csr: conversion from Sparse to SparseCsr for input"
+                " tensors with sparse_dim[(][)]!=2 is not supported"
+            ),
+        )
+    elif layout is torch.sparse_csc and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample, error_regex="Expected result Tensor to be of format CSR"
+        )
+    elif layout is torch.sparse_bsr and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsr",
+        )
+    elif layout is torch.sparse_bsc and t_args[0].ndim > 0:
+        return ErrorInput(
+            sample,
+            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsc",
+        )
+    elif (
+        layout is torch.sparse_coo
+        and dtype is torch.bool
+        and t_args[0].ndim > 0
+        and t_inp.is_cpu
+        and t_inp.numel() > 0
+        and t_inp.dense_dim() > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Bool'"
+        )
+    elif (
+        layout in {torch.sparse_coo, torch.sparse_csr}
+        and dtype is torch.bool
+        and t_inp._nnz() > 0
+        and t_args[0].ndim > 0
+        and t_inp.is_cpu
+        and t_inp.numel() > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"mul_out_sparse\" not implemented for 'Bool'"
+        )
+    elif (
+        layout is torch.sparse_csr
+        and t_args[0].layout is torch.strided
+        and 0 < t_args[0].ndim
+        and t_args[0].ndim < t_inp.ndim
+    ):
+        return ErrorInput(
+            sample, error_regex="sparse_mask_sparse_csr expects self to be 2D"
+        )
+    elif layout is torch.sparse_csr and (
+        (t_args[0].layout is torch.strided and 0 < t_args[0].ndim)
+        or (t_args[0].layout is layout and t_inp.shape != t_args[0].shape)
+    ):
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "expects sparse inputs with equal dimensionality, number of sparse dimensions,"
+                " and shape of sparse dimensions"
+            ),
+        )
+    elif (
+        layout is torch.sparse_csr
+        and t_inp.dense_dim() > 0
+        and t_inp._nnz() > 0
+        and t_inp.is_cpu
+        and dtype is torch.float16
+        and t_args[0].ndim > 0
+    ):
+        return ErrorInput(
+            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Half'"
+        )
+    return sample
+
+
+@_apply_requires_grad_to_samples
+def _maybe_failing_sample_inputs_sparse_elementwise_binary_mul(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Generator of samples that are known to fail or that were failing in past."""
+    # NOTE: When fixing a failing case, remove the Exception comment
+    #       but keep the `yield sample` statement.
+
+    blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
+    regular = torch.tensor([[1, 2], [3, 4]], device=device, dtype=dtype).to_sparse(
+        layout=layout, dense_dim=0, blocksize=blocksize
+    )
+    batch = torch.tensor(
+        [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], device=device, dtype=dtype
+    ).to_sparse(layout=layout, dense_dim=0, blocksize=blocksize)
+    hybrid = torch.tensor(
+        [[[1], [2]], [[3], [4]]], device=device, dtype=dtype
+    ).to_sparse(layout=layout, dense_dim=1, blocksize=blocksize)
+
+    if layout is torch.sparse_csr:
+        # RuntimeError: crow_indices is supposed to be a vector, but got 2 dimensional tensor
+        yield SampleInput(batch, args=(batch,))
+        # RuntimeError: Only tensors with two sparse dimensions can be
+        # converted to the SparseCsr layout, got self with 3 sparse
+        # dimensions.
+        yield SampleInput(
+            torch.zeros_like(hybrid).requires_grad_(requires_grad),
+            args=(torch.zeros_like(hybrid).requires_grad_(requires_grad),),
+        )
+        if dtype is torch.complex32:
+            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
+            yield SampleInput(regular, args=(regular,))
+        if dtype is torch.bool and regular.is_cpu:
+            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
+            yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_csc:
+        # RuntimeError: Expected result Tensor to be of format CSR
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_bsr:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_bsc:
+        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsc
+        yield SampleInput(regular, args=(regular,))
+    if layout is torch.sparse_coo:
+        if dtype is torch.complex32:
+            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
+            yield SampleInput(regular, args=(regular,))
+        if dtype is torch.bool and regular.is_cpu:
+            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
+            yield SampleInput(regular, args=(regular,))
+        if dtype in {torch.bool, torch.float16} and regular.is_cpu:
+            # RuntimeError: "addcmul_cpu_out" not implemented for '(Bool|Half)'
+            yield SampleInput(hybrid, args=(hybrid,))
+
+
+def _validate_sample_input_sparse_elementwise_binary_operation(
+    op_info, sample, check_validate=False
+):
+    if op_info.name == "mul":
+        sample = _validate_sample_input_elementwise_binary_sparse_mul(sample)
+
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def sample_inputs_sparse_mul(op_info, device, dtype, requires_grad, layout, **kwargs):
+    """Sample inputs for mul operation on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        sample_inputs_sparse_elementwise_binary_operation,
+        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
+        _validate_sample_input_sparse_elementwise_binary_operation,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_mul(op_info, device, layout, **kwargs):
+    """Error inputs for mul operation on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
+        _validate_sample_input_sparse_elementwise_binary_operation,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def _sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    from torch.testing._internal.common_utils import TestCase
+
+    for tensor in TestCase().generate_simple_inputs(
+        layout,
+        device=device,
+        dtype=dtype,
+        enable_batch=True,
+        enable_hybrid=True,
+        enable_zero_sized=True,
+        enable_non_contiguous_indices=False,
+        enable_non_contiguous_values=False,
+    ):
+        yield SampleInput(tensor, args=(), kwargs={})
+        yield SampleInput(
+            tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout)
+        )
+
+        if dtype is not torch.float64:
+            yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64))
+
+        if torch.cuda.is_available():
+            other_device = "cuda" if tensor.device.type == "cpu" else "cpu"
+            yield SampleInput(tensor, args=(), kwargs=dict(device=other_device))
+
+        if layout is torch.sparse_csr:
+            other_layout = torch.sparse_csc
+        elif layout is torch.sparse_csc:
+            other_layout = torch.sparse_csr
+        elif layout is torch.sparse_bsr:
+            other_layout = torch.sparse_bsc
+        elif layout is torch.sparse_bsc:
+            other_layout = torch.sparse_bsr
+        else:
+            other_layout = torch.strided
+        yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout))
+
+        if layout is not torch.sparse_coo:
+            yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo))
+
+
+def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
+    if sample.input.layout in {
+        torch.sparse_csr,
+        torch.sparse_csc,
+        torch.sparse_bsr,
+        torch.sparse_bsc,
+    } and op_info.name not in {"zeros_like"}:
+        if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
+            return ErrorInput(
+                sample,
+                error_regex=(
+                    "empty_like with different sparse layout is not supported"
+                    " \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)"
+                ),
+            )
+    if sample.input.layout is torch.sparse_coo:
+        return ErrorInput(
+            sample,
+            error_regex=(
+                "Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend."
+            ),
+        )
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def _maybe_failing_sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    if torch.cuda.is_available() and layout is not torch.sparse_coo:
+        other_device = "cuda" if torch.device(device).type == "cpu" else "cpu"
+        if layout is torch.sparse_csr:
+            other_layout = torch.sparse_csc
+        elif layout is torch.sparse_csc:
+            other_layout = torch.sparse_csr
+        elif layout is torch.sparse_bsr:
+            other_layout = torch.sparse_bsc
+        elif layout is torch.sparse_bsc:
+            other_layout = torch.sparse_bsr
+        else:
+            other_layout = torch.strided
+
+        blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
+
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
+                layout=layout, blocksize=blocksize
+            ),
+            kwargs=dict(device=other_device),
+        )
+
+        yield SampleInput(
+            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
+                layout=layout, blocksize=blocksize
+            ),
+            kwargs=dict(layout=other_layout),
+        )
+
+
+def sample_inputs_sparse_like_fns(
+    op_info, device, dtype, requires_grad, layout, **kwargs
+):
+    """Sample inputs for like-functions on sparse tensors."""
+    yield from _sample_inputs_sparse(
+        _sample_inputs_sparse_like_fns,
+        _maybe_failing_sample_inputs_sparse_like_fns,
+        _validate_sample_input_sparse_like_fns,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs):
+    """Error inputs for like-functions on sparse tensors."""
+    dtype = torch.float64
+    requires_grad = False
+    yield from _error_inputs_sparse(
+        _maybe_failing_sample_inputs_sparse_like_fns,
+        _validate_sample_input_sparse_like_fns,
+        op_info,
+        device,
+        dtype,
+        requires_grad,
+        layout,
+        **kwargs,
+    )
+
+
+def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
+    if op_info.name == "to_sparse":
+        if (
+            sample.input.layout
+            in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
+            and len(sample.args) == 1
+            and isinstance(sample.args[0], int)
+            and sample.args[0] != 2
+        ):
+            sample = ErrorInput(
+                sample,
+                error_regex="sparse dim argument must be 2 for sparse_compressed_to_sparse",
+            )
+
+    if check_validate:
+        _check_validate(op_info, sample)
+    return sample
+
+
+def validate_sample_input_sparse(op_info, sample, check_validate=False):
+    """Return the specified sample when it is valid and supported by the
+    operation. Otherwise, return the sample as ErrorInput instance.
+
+    When check_validate is True, the result is validated against
+    calling the op on the sample.
+    """
+    if isinstance(op_info, ReductionOpInfo):
+        return _validate_sample_input_sparse_reduction(
+            op_info, sample, check_validate=check_validate
+        )
+    elif isinstance(op_info, BinaryUfuncInfo):
+        return _validate_sample_input_sparse_elementwise_binary_operation(
+            op_info, sample, check_validate=check_validate
+        )
+    else:
+        return _validate_sample_input_sparse_default(
+            op_info, sample, check_validate=check_validate
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5b09f4c8dce7eb230e8ccdd9c74423bf26ca5ac
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/definitions/special.py
@@ -0,0 +1,835 @@
+# mypy: ignore-errors
+
+import unittest
+from functools import partial
+from itertools import product
+
+import numpy as np
+
+import torch
+from torch.testing import make_tensor
+from torch.testing._internal.common_device_type import (
+    precisionOverride,
+    tol,
+    toleranceOverride,
+)
+from torch.testing._internal.common_dtype import all_types_and, floating_types
+from torch.testing._internal.common_utils import TEST_SCIPY, torch_to_numpy_dtype_dict
+from torch.testing._internal.opinfo.core import (
+    BinaryUfuncInfo,
+    DecorateInfo,
+    L,
+    NumericsFilter,
+    OpInfo,
+    S,
+    SampleInput,
+    UnaryUfuncInfo,
+)
+from torch.testing._internal.opinfo.refs import (
+    ElementwiseBinaryPythonRefInfo,
+    ElementwiseUnaryPythonRefInfo,
+)
+from torch.testing._internal.opinfo.utils import (
+    np_unary_ufunc_integer_promotion_wrapper,
+)
+
+
+if TEST_SCIPY:
+    import scipy.special
+
+
+# TODO: Consolidate `i0e` with sample_inputs_unary when `make_tensor`,
+#       supports `exclude` argument.
+#       For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617
+def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs):
+    exclude_zero = requires_grad and op_info.op == torch.special.i0e
+    make_arg = partial(
+        make_tensor,
+        dtype=dtype,
+        device=device,
+        requires_grad=requires_grad,
+        exclude_zero=exclude_zero,
+    )
+    yield SampleInput(make_arg((S,)))
+    yield SampleInput(make_arg(()))
+
+    if requires_grad and not exclude_zero:
+        # Special Case for gradient
+        # Sample with `0` in the input
+        t = make_arg((S,))
+        t[0] = 0
+
+        yield SampleInput(t)
+
+
+def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(
+        make_tensor,
+        device=device,
+        # TODO: eliminate low after gh-106692 is fixed:
+        low=(1 if dtype in {torch.int32, torch.int64} else None),
+        dtype=dtype,
+        requires_grad=requires_grad,
+    )
+    tensor_shapes = ((S, S), ())
+    ns = (1, 2, 3, 4, 5)
+
+    for shape, n in product(tensor_shapes, ns):
+        yield SampleInput(make_arg(shape), args=(n,))
+
+
+def reference_polygamma(x, n):
+    # WEIRD `scipy.special.polygamma` behavior
+    # >>> scipy.special.polygamma(0, np.array(501, dtype=np.float32)).dtype
+    # dtype('float64')
+    # >>> scipy.special.polygamma(0, np.array([501], dtype=np.float32)).dtype
+    # dtype('float32')
+    #
+    # Thus we cast output to the default torch dtype or preserve double
+    result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
+    if x.dtype == np.double:
+        result_dtype = np.double
+    return scipy.special.polygamma(n, x).astype(result_dtype)
+
+
+def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
+    low, _ = op_info.domain
+
+    if requires_grad:
+        low = 0 + op_info._domain_eps
+
+    make_arg = partial(
+        make_tensor, dtype=dtype, device=device, low=low, requires_grad=requires_grad
+    )
+    yield SampleInput(make_arg((L,)))
+    yield SampleInput(make_arg(()))
+
+
+def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
+    for shape in ((L,), (1, 0, 3), ()):
+        yield SampleInput(
+            make_tensor(
+                shape,
+                device=device,
+                dtype=dtype,
+                low=-5,
+                requires_grad=requires_grad,
+            ),
+        )
+
+
+op_db: list[OpInfo] = [
+    UnaryUfuncInfo(
+        "special.i0e",
+        aten_name="special_i0e",
+        ref=scipy.special.i0e if TEST_SCIPY else None,
+        decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_i0_i1,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.i1",
+        aten_name="special_i1",
+        ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1)
+        if TEST_SCIPY
+        else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        backward_dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_i0_i1,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1e-4, rtol=0),
+                        torch.bool: tol(atol=1e-4, rtol=0),
+                    }
+                )
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Incorrect result!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=(torch.int8,),
+            ),
+        ),
+        supports_fwgrad_bwgrad=True,
+        supports_forward_ad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.i1e",
+        aten_name="special_i1e",
+        ref=scipy.special.i1e if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        backward_dtypes=floating_types(),
+        sample_inputs_func=sample_inputs_i0_i1,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.ndtr",
+        aten_name="special_ndtr",
+        decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),),
+        ref=scipy.special.ndtr if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        skips=(
+            # Dispatch stub: unsupported device typemeta
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestFwdGradients",
+                "test_fn_fwgrad_bwgrad",
+                device_type="meta",
+            ),
+        ),
+    ),
+    # A separate OpInfo entry for special.polygamma is needed to reorder the arguments
+    # for the alias. See the discussion here: https://github.com/pytorch/pytorch/pull/59691#discussion_r650261939
+    UnaryUfuncInfo(
+        "special.polygamma",
+        op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs),
+        variant_test_name="special_polygamma_n_0",
+        ref=reference_polygamma if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_polygamma,
+        skips=(
+            # lambda impl
+            DecorateInfo(
+                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
+            ),
+            DecorateInfo(
+                unittest.expectedFailure,
+                "TestNormalizeOperators",
+                "test_normalize_operator_exhaustive",
+            ),
+        ),
+        sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}),
+        # polygamma functions have multiple singularities at x having non-positive integer value
+        reference_numerics_filter=NumericsFilter(
+            condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1
+        ),
+    ),
+    BinaryUfuncInfo(
+        "special.xlog1py",
+        aten_name="special_xlog1py",
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        promotes_int_to_float=True,
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        supports_one_python_scalar=True,
+        # We don't test -1 as the gradient will be NaN and it'll break
+        rhs_make_tensor_kwargs=dict(low=-0.99),
+    ),
+    BinaryUfuncInfo(
+        "special.zeta",
+        aten_name="special_zeta",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        supports_autograd=False,
+        supports_one_python_scalar=True,
+        skips=(
+            # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+    # TODO: FIXME
+    # OpInfo entry to verify the gradient formula of `other`/`q`
+    # BinaryUfuncInfo('special.zeta',
+    #                 op=lambda q, x, **kwargs: torch.special.zeta(x, q, **kwargs),
+    #                 aten_name='special_zeta',
+    #                 variant_test_name='grad',
+    #                 dtypes=all_types_and(torch.bool),
+    #                 promotes_int_to_float=True,
+    #                 supports_autograd=True,
+    #                 supports_rhs_python_scalar=False,
+    #                 decorators=[
+    #                     # Derivative wrt first tensor not implemented
+    #                     DecorateInfo(unittest.expectedFailure, "TestCommon",
+    #                                  "test_floating_inputs_are_differentiable")
+    #                 ],
+    #                 skips=(
+    #                     # Lambda doesn't work in JIT test
+    #                     # AssertionError: JIT Test does not execute any logic
+    #                     DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"),
+    #                 )),
+    UnaryUfuncInfo(
+        "special.entr",
+        ref=scipy.special.entr if TEST_SCIPY else None,
+        aten_name="special_entr",
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
+        dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=[torch.bfloat16, torch.float16],
+            ),
+        ),
+        supports_inplace_autograd=False,
+        sample_inputs_func=sample_inputs_entr,
+    ),
+    UnaryUfuncInfo(
+        "special.ndtri",
+        ref=scipy.special.ndtri if TEST_SCIPY else None,
+        domain=(0, 1),
+        aten_name="special_ndtri",
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.log_ndtr",
+        aten_name="special_log_ndtr",
+        ref=scipy.special.log_ndtr if TEST_SCIPY else None,
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+    ),
+    UnaryUfuncInfo(
+        "special.erfcx",
+        ref=scipy.special.erfcx if TEST_SCIPY else None,
+        aten_name="special_erfcx",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=0, rtol=4e-6),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        supports_forward_ad=True,
+        supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_erfcx,
+    ),
+    UnaryUfuncInfo(
+        "special.airy_ai",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None,
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+            ),
+        ),
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_j0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.j0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_j1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.j1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_y0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.y0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.bessel_y1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.y1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_t",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_u",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_v",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Skipping - testing takes an unreasonably long time, #79528"
+                )
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.chebyshev_polynomial_w",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Skipping - testing takes an unreasonably long time, #79528"
+                )
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.hermite_polynomial_h",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            # Greatest absolute difference: inf
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.hermite_polynomial_he",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.laguerre_polynomial_l",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.legendre_polynomial_p",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Skipping - testing takes an unreasonably long time, #79528"
+                )
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_i0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.i0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_i1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.i1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_k0",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k0 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.modified_bessel_k1",
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-03,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k1 if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.scaled_modified_bessel_k0",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k0e if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.scaled_modified_bessel_k1",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=scipy.special.k1e if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_t",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Skipping - testing takes an unreasonably long time, #79528"
+                )
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_u",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Skipping - testing takes an unreasonably long time, #79528"
+                )
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_v",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Skipping - testing takes an unreasonably long time, #79528"
+                )
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    BinaryUfuncInfo(
+        "special.shifted_chebyshev_polynomial_w",
+        dtypes=all_types_and(torch.bool),
+        promotes_int_to_float=True,
+        skips=(
+            DecorateInfo(
+                unittest.skip(
+                    "Skipping - testing takes an unreasonably long time, #79528"
+                )
+            ),
+            DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
+            DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
+            DecorateInfo(
+                unittest.skip("testing takes an unreasonably long time, #79528"),
+                "TestCommon",
+                "test_compare_cpu",
+            ),
+        ),
+        supports_one_python_scalar=True,
+        supports_autograd=False,
+    ),
+    UnaryUfuncInfo(
+        "special.spherical_bessel_j0",
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+        dtypes=all_types_and(torch.bool),
+        ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None,
+        supports_autograd=False,
+    ),
+]
+
+python_ref_db: list[OpInfo] = [
+    #
+    # Elementwise Unary Special OpInfos
+    #
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.bessel_j0",
+        torch_opinfo_name="special.bessel_j0",
+        op_db=op_db,
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.bessel_j1",
+        torch_opinfo_name="special.bessel_j1",
+        op_db=op_db,
+        decorators=(
+            precisionOverride(
+                {
+                    torch.float32: 1e-04,
+                    torch.float64: 1e-05,
+                },
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.entr",
+        torch_opinfo_name="special.entr",
+        op_db=op_db,
+        decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=[torch.bfloat16, torch.float16],
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.erfcx",
+        torch_opinfo_name="special.erfcx",
+        op_db=op_db,
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=0, rtol=4e-6),
+                }
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i0e",
+        torch_opinfo_name="special.i0e",
+        op_db=op_db,
+        decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i1",
+        torch_opinfo_name="special.i1",
+        op_db=op_db,
+        decorators=(
+            DecorateInfo(
+                toleranceOverride(
+                    {
+                        torch.float32: tol(atol=1e-4, rtol=0),
+                        torch.bool: tol(atol=1e-4, rtol=0),
+                    }
+                )
+            ),
+        ),
+        skips=(
+            DecorateInfo(
+                unittest.skip("Incorrect result!"),
+                "TestUnaryUfuncs",
+                "test_reference_numerics_large",
+                dtypes=(torch.int8,),
+            ),
+        ),
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.i1e",
+        torch_opinfo_name="special.i1e",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.log_ndtr",
+        torch_opinfo_name="special.log_ndtr",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.ndtr",
+        torch_opinfo_name="special.ndtr",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.ndtri",
+        torch_opinfo_name="special.ndtri",
+        op_db=op_db,
+    ),
+    ElementwiseUnaryPythonRefInfo(
+        "_refs.special.spherical_bessel_j0",
+        torch_opinfo_name="special.spherical_bessel_j0",
+        op_db=op_db,
+        decorators=(
+            toleranceOverride(
+                {
+                    torch.float32: tol(atol=1e-03, rtol=1e-03),
+                    torch.float64: tol(atol=1e-05, rtol=1e-03),
+                }
+            ),
+        ),
+    ),
+    #
+    # Elementwise Binary Special OpInfos
+    #
+    ElementwiseBinaryPythonRefInfo(
+        "_refs.special.zeta",
+        torch_opinfo_name="special.zeta",
+        supports_one_python_scalar=True,
+        op_db=op_db,
+        skips=(
+            # Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
+            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"),
+        ),
+    ),
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py
new file mode 100644
index 0000000000000000000000000000000000000000..435a9d113164b3652af4d246655f579d1b72d4dc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/refs.py
@@ -0,0 +1,207 @@
+# mypy: ignore-errors
+
+from torch.testing._internal.opinfo.core import (
+    BinaryUfuncInfo,
+    OpInfo,
+    ReductionOpInfo,
+    UnaryUfuncInfo,
+)
+
+
+# NOTE [Python References]
+# Python References emulate existing PyTorch operations, but can ultimately
+#   be expressed in terms of "primitive" operations from torch._prims.
+#
+# These references are experimental.
+# See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577
+#   for additional context.
+#
+# Python Reference OpInfos should be added to the python_ref_db list below.
+#   Tests can opt-into running on these references by including
+#   that list in the Sequence they pass to the @ops decorator.
+#
+# When a Python Reference OpInfo is constructed a pointer to an
+#   existing OpInfo must be provided using the torch_opinfo_name kwarg.
+#   The existing OpInfo with that name and no variant will be found
+#   to inherit from.
+#
+# Instead of just inheriting the existing OpInfo's metadata, the
+#   Python Reference OpInfos inherit the existing OpInfo's
+#   construction arguments. These arguments can be overridden
+#   by adding kwargs to the constructor.
+
+
+def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None):
+    """
+    Finds the OpInfo with the given name that has no variant name.
+    """
+    # NOTE: searching the global op_db doesn't work when OpInfos are split into
+    # different modules, as otherwise the op_db will not be fully constructed
+    # yet. So, instead the local op_db must be passed in explicitly.
+    if op_db is None:
+        from torch.testing._internal.common_methods_invocations import op_db
+
+    for opinfo in op_db:
+        if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name:
+            return opinfo
+
+
+def _inherit_constructor_args(name, op, inherited, overrides):
+    # inherits metadata
+    common_kwargs = {
+        "name": name,
+        "op": op,
+        "aliases": None,  # TODO add a check for alias coverage
+        "method_variant": None,
+        "inplace_variant": None,  # TODO: add a check for inplace coverage
+        "supports_scripting": False,
+    }
+
+    # Acquires inherited kwargs
+    kwargs = inherited.copy()
+
+    # Fixes metadata
+    if "kwargs" in kwargs:
+        kwargs.update(kwargs["kwargs"])
+        del kwargs["kwargs"]
+    if "self" in kwargs:
+        del kwargs["self"]
+    if "__class__" in kwargs:
+        del kwargs["__class__"]
+    if "skips" in kwargs:
+        del kwargs["skips"]
+    if "decorators" in kwargs:
+        del kwargs["decorators"]
+
+    # Overrides metadata
+    kwargs.update(common_kwargs)
+    kwargs.update(overrides)
+
+    # At the moment no prims support autograd, so we must not run autograd
+    # tests e.g. when testing dtype support.  Once we start writing autograd
+    # formulas for prims this can be removed.
+    kwargs["supports_autograd"] = False
+    kwargs["supports_gradgrad"] = False
+    kwargs["supports_fwgrad_bwgrad"] = False
+    kwargs["supports_inplace_autograd"] = False
+    kwargs["supports_forward_ad"] = False
+
+    return kwargs
+
+
+class PythonRefInfo(OpInfo):
+    """
+    An OpInfo for a Python reference of an OpInfo base class operation.
+    """
+
+    def __init__(
+        self,
+        name,  # the stringname of the callable Python reference
+        *,
+        op=None,  # the function variant of the operation, populated as torch. if None
+        op_db=None,  # The database of opinfos to search for the parent opinfo
+        torch_opinfo_name,  # the string name of the corresponding torch opinfo
+        torch_opinfo_variant_name="",  # the variant name for corresponding torch opinfo
+        validate_view_consistency=True,
+        **kwargs,
+    ):  # additional kwargs override kwargs inherited from the torch opinfo
+        self.torch_opinfo_name = torch_opinfo_name
+        self.torch_opinfo_variant_name = torch_opinfo_variant_name
+        self.torch_opinfo = _find_referenced_opinfo(
+            torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
+        )
+        self.validate_view_consistency = validate_view_consistency
+        assert isinstance(self.torch_opinfo, OpInfo)
+
+        inherited = self.torch_opinfo._original_opinfo_args
+        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
+        super().__init__(**ukwargs)
+
+
+class ReductionPythonRefInfo(ReductionOpInfo):
+    """
+    An OpInfo for a Python reference of an elementwise unary operation.
+    """
+
+    def __init__(
+        self,
+        name,  # the stringname of the callable Python reference
+        *,
+        op=None,  # the function variant of the operation, populated as torch. if None
+        op_db=None,  # The database of opinfos to search for the parent opinfo
+        torch_opinfo_name,  # the string name of the corresponding torch opinfo
+        torch_opinfo_variant_name="",  # the variant name for corresponding torch opinfo
+        **kwargs,
+    ):  # additional kwargs override kwargs inherited from the torch opinfo
+        self.torch_opinfo_name = torch_opinfo_name
+        self.torch_opinfo_variant_name = torch_opinfo_variant_name
+        self.torch_opinfo = _find_referenced_opinfo(
+            torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
+        )
+        assert isinstance(self.torch_opinfo, ReductionOpInfo)
+
+        inherited = self.torch_opinfo._original_reduction_args
+        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
+
+        # See https://github.com/pytorch/pytorch/issues/77216
+        self.validate_view_consistency = False
+
+        super().__init__(**ukwargs)
+
+
+class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
+    """
+    An OpInfo for a Python reference of an elementwise unary operation.
+    """
+
+    def __init__(
+        self,
+        name,  # the stringname of the callable Python reference
+        *,
+        op=None,  # the function variant of the operation, populated as torch. if None
+        op_db=None,  # The database of opinfos to search for the parent opinfo
+        torch_opinfo_name,  # the string name of the corresponding torch opinfo
+        torch_opinfo_variant_name="",  # the variant name for corresponding torch opinfo
+        validate_view_consistency=True,
+        **kwargs,
+    ):  # additional kwargs override kwargs inherited from the torch opinfo
+        self.torch_opinfo_name = torch_opinfo_name
+        self.torch_opinfo_variant_name = torch_opinfo_variant_name
+        self.torch_opinfo = _find_referenced_opinfo(
+            torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
+        )
+        self.validate_view_consistency = validate_view_consistency
+        assert isinstance(self.torch_opinfo, UnaryUfuncInfo)
+
+        inherited = self.torch_opinfo._original_unary_ufunc_args
+        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
+
+        super().__init__(**ukwargs)
+
+
+class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
+    """
+    An OpInfo for a Python reference of an elementwise binary operation.
+    """
+
+    def __init__(
+        self,
+        name,  # the stringname of the callable Python reference
+        *,
+        op=None,  # the function variant of the operation, populated as torch. if None
+        op_db=None,  # The database of opinfos to search for the parent opinfo
+        torch_opinfo_name,  # the string name of the corresponding torch opinfo
+        torch_opinfo_variant_name="",  # the variant name for corresponding torch opinfo
+        **kwargs,
+    ):  # additional kwargs override kwargs inherited from the torch opinfo
+        self.torch_opinfo_name = torch_opinfo_name
+        self.torch_opinfo_variant_name = torch_opinfo_variant_name
+        self.torch_opinfo = _find_referenced_opinfo(
+            torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
+        )
+        assert isinstance(self.torch_opinfo, BinaryUfuncInfo)
+
+        inherited = self.torch_opinfo._original_binary_ufunc_args
+        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
+
+        super().__init__(**ukwargs)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4000ec6ca135512a2be791d88767450b75951ddb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/opinfo/utils.py
@@ -0,0 +1,274 @@
+# mypy: ignore-errors
+
+import collections
+import warnings
+from collections.abc import Sequence
+from functools import partial, wraps
+
+import numpy as np
+import numpy.typing as npt
+
+import torch
+from torch.testing._internal.common_cuda import TEST_CUDA
+from torch.testing._internal.common_dtype import (
+    _dispatch_dtypes,
+    all_types,
+    all_types_and,
+    all_types_and_complex,
+    all_types_and_complex_and,
+    all_types_and_half,
+    complex_types,
+    floating_and_complex_types,
+    floating_and_complex_types_and,
+    floating_types,
+    floating_types_and,
+    floating_types_and_half,
+    integral_types,
+    integral_types_and,
+)
+from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
+
+
+COMPLETE_DTYPES_DISPATCH = (
+    all_types,
+    all_types_and_complex,
+    all_types_and_half,
+    floating_types,
+    floating_and_complex_types,
+    floating_types_and_half,
+    integral_types,
+    complex_types,
+)
+
+EXTENSIBLE_DTYPE_DISPATCH = (
+    all_types_and_complex_and,
+    floating_types_and,
+    floating_and_complex_types_and,
+    integral_types_and,
+    all_types_and,
+)
+
+# Better way to acquire devices?
+DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else [])
+
+
+class _dynamic_dispatch_dtypes(_dispatch_dtypes):
+    # Class to tag the dynamically generated types.
+    pass
+
+
+def get_supported_dtypes(op, sample_inputs_fn, device_type):
+    # Returns the supported dtypes for the given operator and device_type pair.
+    assert device_type in ["cpu", "cuda"]
+    if not TEST_CUDA and device_type == "cuda":
+        warnings.warn(
+            "WARNING: CUDA is not available, empty_dtypes dispatch will be returned!"
+        )
+        return _dynamic_dispatch_dtypes(())
+
+    supported_dtypes = set()
+    for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half):
+        try:
+            samples = sample_inputs_fn(op, device_type, dtype, False)
+        except RuntimeError:
+            # If `sample_inputs_fn` doesn't support sampling for a given
+            # `dtype`, we assume that the `dtype` is not supported.
+            # We raise a warning, so that user knows that this was the case
+            # and can investigate if there was an issue with the `sample_inputs_fn`.
+            warnings.warn(
+                f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}"
+            )
+            continue
+
+        # We assume the dtype is supported
+        # only if all samples pass for the given dtype.
+        supported = True
+        for sample in samples:
+            try:
+                op(sample.input, *sample.args, **sample.kwargs)
+            except RuntimeError:
+                # dtype is not supported
+                supported = False
+                break
+
+        if supported:
+            supported_dtypes.add(dtype)
+
+    return _dynamic_dispatch_dtypes(supported_dtypes)
+
+
+def dtypes_dispatch_hint(dtypes):
+    # Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH)
+    # and its string representation for the passed `dtypes`.
+    return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str")
+
+    # CUDA is not available, dtypes will be empty.
+    if len(dtypes) == 0:
+        return return_type((), "()")
+
+    set_dtypes = set(dtypes)
+    for dispatch in COMPLETE_DTYPES_DISPATCH:
+        # Short circuit if we get an exact match.
+        if set(dispatch()) == set_dtypes:
+            return return_type(dispatch, dispatch.__name__ + "()")
+
+    chosen_dispatch = None
+    chosen_dispatch_score = 0.0
+    for dispatch in EXTENSIBLE_DTYPE_DISPATCH:
+        dispatch_dtypes = set(dispatch())
+        if not dispatch_dtypes.issubset(set_dtypes):
+            continue
+
+        score = len(dispatch_dtypes)
+        if score > chosen_dispatch_score:
+            chosen_dispatch_score = score
+            chosen_dispatch = dispatch
+
+    # If user passed dtypes which are lower than the lowest
+    # dispatch type available (not likely but possible in code path).
+    if chosen_dispatch is None:
+        return return_type((), str(dtypes))
+
+    return return_type(
+        partial(dispatch, *tuple(set(dtypes) - set(dispatch()))),
+        dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))),
+    )
+
+
+def is_dynamic_dtype_set(op):
+    # Detect if the OpInfo entry acquired dtypes dynamically
+    # using `get_supported_dtypes`.
+    return op.dynamic_dtypes
+
+
+def str_format_dynamic_dtype(op):
+    fmt_str = f"""
+        OpInfo({op.name},
+               dtypes={dtypes_dispatch_hint(op.dtypes).dispatch_fn_str},
+               dtypesIfCUDA={dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str},
+        )
+        """
+
+    return fmt_str
+
+
+def np_unary_ufunc_integer_promotion_wrapper(fn):
+    # Wrapper that passes PyTorch's default scalar
+    #   type as an argument to the wrapped NumPy
+    #   unary ufunc when given an integer input.
+    #   This mimics PyTorch's integer->floating point
+    #   type promotion.
+    #
+    # This is necessary when NumPy promotes
+    #   integer types to double, since PyTorch promotes
+    #   integer types to the default scalar type.
+
+    # Helper to determine if promotion is needed
+    def is_integral(dtype):
+        return dtype in [
+            np.bool_,
+            bool,
+            np.uint8,
+            np.int8,
+            np.int16,
+            np.int32,
+            np.int64,
+        ]
+
+    @wraps(fn)
+    def wrapped_fn(x):
+        # As the default dtype can change, acquire it when function is called.
+        # NOTE: Promotion in PyTorch is from integer types to the default dtype
+        np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
+
+        if is_integral(x.dtype):
+            return fn(x.astype(np_dtype))
+        return fn(x)
+
+    return wrapped_fn
+
+
+def reference_reduction_numpy(f, supports_keepdims=True):
+    """Wraps a NumPy reduction operator.
+
+    The wrapper function will forward dim, keepdim, mask, and identity
+    kwargs to the wrapped function as the NumPy equivalent axis,
+    keepdims, where, and initiak kwargs, respectively.
+
+    Args:
+        f: NumPy reduction operator to wrap
+        supports_keepdims (bool, optional): Whether the NumPy operator accepts
+            keepdims parameter. If it does not, the wrapper will manually unsqueeze
+            the reduced dimensions if it was called with keepdim=True. Defaults to True.
+
+    Returns:
+        Wrapped function
+
+    """
+
+    @wraps(f)
+    def wrapper(x: npt.NDArray, *args, **kwargs):
+        # Copy keys into a set
+        keys = set(kwargs.keys())
+
+        dim = kwargs.pop("dim", None)
+        keepdim = kwargs.pop("keepdim", False)
+
+        if "dim" in keys:
+            dim = tuple(dim) if isinstance(dim, Sequence) else dim
+
+            # NumPy reductions don't accept dim=0 for scalar inputs
+            # so we convert it to None if and only if dim is equivalent
+            if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}:
+                kwargs["axis"] = None
+            else:
+                kwargs["axis"] = dim
+
+        if "keepdim" in keys and supports_keepdims:
+            kwargs["keepdims"] = keepdim
+
+        if "mask" in keys:
+            mask = kwargs.pop("mask")
+            if mask is not None:
+                assert mask.layout == torch.strided
+                kwargs["where"] = mask.cpu().numpy()
+
+        if "identity" in keys:
+            identity = kwargs.pop("identity")
+            if identity is not None:
+                if identity.dtype is torch.bfloat16:
+                    identity = identity.cpu().to(torch.float32)
+                else:
+                    identity = identity.cpu()
+                kwargs["initial"] = identity.numpy()
+
+        result = f(x, *args, **kwargs)
+
+        # Unsqueeze reduced dimensions if NumPy does not support keepdims
+        if keepdim and not supports_keepdims and x.ndim > 0:
+            dim = list(range(x.ndim)) if dim is None else dim
+            result = np.expand_dims(result, dim)
+
+        return result
+
+    return wrapper
+
+
+def prod_numpy(a, *args, **kwargs):
+    """
+    The function will call np.prod with type as np.int64 if the input type
+    is int or uint64 if is uint. This is necessary because windows np.prod uses by default
+    int32 while on linux it uses int64.
+    This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320
+
+    Returns:
+        np.prod of input
+    """
+    if "dtype" not in kwargs:
+        if np.issubdtype(a.dtype, np.signedinteger):
+            a = a.astype(np.int64)
+        elif np.issubdtype(a.dtype, np.unsignedinteger):
+            a = a.astype(np.uint64)
+
+    fn = reference_reduction_numpy(np.prod)
+    return fn(a, *args, **kwargs)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9125ba0ebe7e0623a12ad1a1cd7eeb7d2749a3a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/__init__.py
@@ -0,0 +1,7 @@
+# mypy: ignore-errors
+
+from .make_fx import make_fx_check
+from .aot_autograd import aot_autograd_check, _test_aot_autograd_forwards_backwards_helper
+from .fake_tensor import fake_check
+from .autograd_registration import autograd_registration_check
+from .generate_tests import generate_opcheck_tests, opcheck, OpCheckError, dontGenerateOpCheckTests, is_inside_opcheck_mode
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4508a570a00f7c3cb1faf3bc74a3471719962fd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/aot_autograd.py
@@ -0,0 +1,155 @@
+# mypy: ignore-errors
+
+import torch
+import torch.utils._pytree as pytree
+from torch.testing._utils import wrapper_set_seed
+from functorch.compile import compiled_function, min_cut_rematerialization_partition, nop
+from .make_fx import randomize
+import re
+
+
+class assert_raises_regex:
+    def __init__(self, exception_cls, regex):
+        self.exception_cls = exception_cls
+        self.regex = regex
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, exc_type, exc_val, traceback):
+        if exc_type == self.exception_cls:
+            msg = str(exc_val)
+            if not re.search(self.regex, msg):
+                raise AssertionError(
+                    f"Expected exception to match regex. regex: {self.regex}, exception: {msg}")
+            return True  # Squashes the exception
+        if exc_type is not None:
+            raise AssertionError(
+                f"Expected {self.exception_cls} to be raised, instead got exception {exc_type}")
+        raise AssertionError("Expected exception to be raised but none was")
+
+
+def aot_autograd_check(
+        func,
+        args,
+        kwargs,
+        dynamic,
+        assert_raises_regex_fn=assert_raises_regex,
+        assert_equals_fn=torch.testing.assert_close,
+        check_gradients=True,
+        try_check_data_specialization=False,
+        skip_correctness_check=False):
+    """Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
+
+    Compares outputs and (if check_gradients=True) gradients produced by
+    AOTAutograd against eager-mode PyTorch.
+
+    We assume that func(*args, **kwargs) succeeds in eager-mode PyTorch.
+
+    """
+    flat_args, args_spec = pytree.tree_flatten((args, kwargs))
+    args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)]
+
+    # We construct a new function that only accepts Tensors as inputs
+    def func_no_tensors(args):
+        reconstructed_flat_args = []
+        args = iter(args)
+        for v in flat_args:
+            if isinstance(v, torch.Tensor):
+                reconstructed_flat_args.append(next(args))
+            else:
+                reconstructed_flat_args.append(v)
+
+        c_args, c_kwargs = pytree.tree_unflatten(reconstructed_flat_args, args_spec)
+        return func(*c_args, **c_kwargs)
+
+    compiled_f = compiled_function(
+        func_no_tensors, nop, nop, dynamic=dynamic, partition_fn=min_cut_rematerialization_partition)
+
+    out = wrapper_set_seed(func_no_tensors, args)
+    if check_gradients == "auto":
+        any_tensor_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, args)
+        any_output_requires_grad = pytree.tree_any_only(torch.Tensor, lambda x: x.requires_grad, out)
+        check_gradients = any_tensor_requires_grad and any_output_requires_grad
+    if not check_gradients:
+        compiled_out = wrapper_set_seed(compiled_f, args)
+        if not skip_correctness_check:
+            assert_equals_fn(compiled_out, out, msg=outputs_msg)
+        return
+    _test_aot_autograd_forwards_backwards_helper(
+        func_no_tensors, compiled_f, args, assert_raises_regex_fn, assert_equals_fn,
+        try_check_data_specialization, skip_correctness_check)
+
+outputs_msg = (
+    "Outputs of the operator are different in eager-mode PyTorch vs "
+    "AOTDispatcher tracing. This means the operator will have incorrect output "
+    "underneath torch.compile. This could be because the operator's "
+    "implementation not traceable."
+)
+
+
+def _test_aot_autograd_forwards_backwards_helper(
+        f, compiled_f, args, assert_raises_regex_fn, assert_equals_fn,
+        try_check_data_specialization, skip_correctness_check=False):
+    # Verify grads are equal between compiled and non-compiled versions of f.
+
+    def call_forwards_backwards(f, args):
+        flat_args = pytree.arg_tree_leaves(*args)
+        diff_args = [arg for arg in flat_args if isinstance(arg, torch.Tensor) and
+                     arg.requires_grad]
+        out = wrapper_set_seed(f, args)
+        flat_out = pytree.tree_leaves(out)
+
+        sm = 0
+        for i in flat_out:
+            if isinstance(i, torch.Tensor):
+                # We need to call .abs() because it is possible that the output of the
+                # operator is a complex Tensor and autograd will yell at autograd.grad
+                # on a complex Tensor unless we manually provide the grad_output flag.
+                sm += i.sum().abs()
+        assert isinstance(sm, torch.Tensor)
+        return out, torch.autograd.grad(sm, diff_args, allow_unused=True)
+
+    def check(args, ignore_failure=False):
+        try:
+            orig_out, orig_grad = call_forwards_backwards(f, args)
+        except Exception:
+            if ignore_failure:
+                return
+            raise
+
+        # See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215
+        tensor_args = [x for x in pytree.tree_flatten(args)[0] if isinstance(x, torch.Tensor)]
+        any_non_leaves = any(x.grad_fn is not None for x in tensor_args)
+        if all(x is None for x in orig_grad) and any_non_leaves:
+            with assert_raises_regex_fn(RuntimeError, 'does not require grad and does not have a grad_fn'):
+                call_forwards_backwards(compiled_f, args)
+            return
+
+        msg = (
+            "Gradients of the operator are different in eager-mode PyTorch vs "
+            "AOTDispatcher. This means the operator will have incorrect gradients "
+            "underneath torch.compile. This could be because the operator's "
+            "backward is incorrectly registered or not traceable."
+        )
+
+        compiled_out, compiled_grad = call_forwards_backwards(compiled_f, args)
+        if not skip_correctness_check:
+            try:
+                assert_equals_fn(compiled_out, orig_out)
+            except Exception as e:
+                raise type(e)(outputs_msg) from e
+            try:
+                assert_equals_fn(compiled_grad, orig_grad)
+            except Exception as e:
+                raise type(e)(msg) from e
+
+    check(args, ignore_failure=False)
+
+    # Randomize the data and run the traced graph with it, to catch bugs
+    # where we may have baked in Tensor data into the trace.
+    # This is not guaranteed to succeed, because `f` might have preconditions
+    # on the values of the inputs, so we just ignore if this test fails.
+    if try_check_data_specialization:
+        args = randomize(args)
+        check(args, ignore_failure=True)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py
new file mode 100644
index 0000000000000000000000000000000000000000..25df4f1d03fcf50fc5869d8af93cd128f98e9c72
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/autograd_registration.py
@@ -0,0 +1,132 @@
+# mypy: ignore-errors
+
+import contextlib
+
+import torch
+import torch.utils._pytree as pytree
+
+
+@contextlib.contextmanager
+def set_autograd_fallback_mode(mode):
+    prev = torch._C._get_autograd_fallback_mode()
+    try:
+        torch._C._set_autograd_fallback_mode(mode)
+        yield
+    finally:
+        torch._C._set_autograd_fallback_mode(prev)
+
+
+def autograd_registration_check(op, args, kwargs):
+    """Check if autograd was registered correctly (for the operator).
+
+    Operators should have "autograd support" registered directly to an
+    autograd dispatch key.
+    An incorrect registration may lead to unexpected silent incorrectness.
+    Note that this check won't catch all problems but will catch
+    the most common ones.
+
+    Example usage:
+        >>> x = torch.randn(3, requires_grad=True)
+        >>> autograd_registration_check(torch.ops.aten.sin.default, (x,), {})
+
+    Here are some best practices if you do find your autograd is
+    registered incorrectly:
+    - If the operator is composite (i.e. consists of other PyTorch ops)
+      and you wish the operator to decompose and get autograd support
+      that way, then please register the implementation to
+      DispatchKey::CompositeImplicitAutograd
+    - If you're adding an autograd formula for the operator, the correct
+      thing to do is to register an autograd.Function to
+      DispatchKey::Autograd (preferred) or one of the
+      DispatchKey::Autograd keys. It is NOT OK to register
+      an autograd.Function to a backend (e.g. CPU/CUDA) key.
+    - If your operator is non-differentiable, then you should register
+      an implementation to the Autograd key that uses
+      AutoDispatchBelowAutograd and re-invokes the operator.
+
+    """
+    assert isinstance(op, torch._ops.OpOverload)
+    # Implementation details
+    # -----------------------------------------------
+    # If an operator doesn't have an autograd kernel at an autograd key,
+    # and the operator does not return inputs as-is, then all of
+    # the outputs should have requires_grad=False before we apply
+    # special behaviors of our default autograd fallback.
+    # (The default autograd fallback may set requires_grad=True on output
+    # tensors in certain modes so that when they are backpropped through,
+    # they raise an error).
+    #
+    # Our strategy for detecting if an operator doesn't have an autograd
+    # kernel at the autograd key is:
+    # - set the autograd fallback mode to "nothing" (so it does not change
+    #   the required-gradness of outputs)
+    # - run the operator
+    # - Check if any outputs of the operator (that are not inputs) require
+    #   grad. This would only happen if the user calls regular PyTorch
+    #   operations in their backend key (this op should instead be
+    #   CompositeImplicitAutograd or not an op) or if the user invokes
+    #   an autograd.Function in the backend key.
+    #
+    # Note that it's already likely a bug if the operator directly returns
+    # an input as output (because custom ops don't have a good way of
+    # constructing true in-place or out variants), but we defer that
+    # responsibility to a different test (schema_check).
+
+    flat_args = pytree.arg_tree_leaves(*args, **kwargs)
+    all_tensors = [arg for arg in flat_args if isinstance(arg, torch.Tensor)]
+    if not any(t.requires_grad for t in all_tensors):
+        raise RuntimeError(
+            "autograd_registration_check: no inputs have requires_grad=True so "
+            "we are unable to actually perform this test. Please pass inputs "
+            "that do require grad."
+        )
+
+    # Determine which AutogradBACKEND key to check
+    all_device_types = {arg.device.type for arg in all_tensors}
+    if not all_device_types.issubset(["cpu", "cuda"]):
+        # Don't want to support other keys yet
+        raise NotImplementedError(
+            f"autograd_registration_check: NYI devices other than CPU/CUDA, got {all_device_types}"
+        )
+    if "cuda" in all_device_types:
+        key = "AutogradCUDA"
+    elif "cpu" in all_device_types:
+        key = "AutogradCPU"
+
+    if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), key):
+        return
+    if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Autograd"):
+        return
+    if torch._C._dispatch_has_kernel_for_dispatch_key(
+        op.name(), "CompositeImplicitAutograd"
+    ):
+        return
+
+    # At this point, we know the operator doesn't have a kernel registered to an
+    # autograd key. Let's proceed with our test.
+    with set_autograd_fallback_mode("nothing"):
+        all_outs = op(*args, **kwargs)
+
+    inp_ids = {id(arg) for arg in flat_args}
+
+    def not_an_input_and_requires_grad(tensor):
+        if not tensor.requires_grad:
+            return False
+        if id(tensor) in inp_ids:
+            return False
+        return True
+
+    if not pytree.tree_any_only(torch.Tensor, not_an_input_and_requires_grad, all_outs):
+        return
+
+    raise AssertionError(
+        f"{op.name()}: at least one output of this operator has requires_grad=True "
+        f"but the operator does not have an autograd kernel defined at an autograd "
+        f"key (e.g. DispatchKey::Autograd). This could mean that you have "
+        f"incorrectly registered an autograd kernel to a non-Autograd DispatchKey, "
+        f"which may lead to silently incorrect results. If your operator consists "
+        f"of regular PyTorch operations, consider not using an operator at all "
+        f"or registering your operator as CompositeImplicitAutograd. If you have "
+        f"an autograd.Function registered to a backend (CPU/CUDA) key, the correct "
+        f"location for it is the Autograd key."
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e60f50189b5dc3ab43fdd97120d5fa23559a84e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/fake_tensor.py
@@ -0,0 +1,12 @@
+# mypy: ignore-errors
+
+import torch._subclasses
+
+
+def is_builtin(op):
+    return op.namespace in ('aten', 'prims', 'prim')
+
+
+def fake_check(op, args, kwargs):
+    with torch._subclasses.CrossRefFakeMode(ignore_op_fn=is_builtin):
+        op(*args, **kwargs)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py
new file mode 100644
index 0000000000000000000000000000000000000000..51fcadd8dee970910fe7d4b5aef2b17fb0cab71d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/generate_tests.py
@@ -0,0 +1,852 @@
+# mypy: ignore-errors
+
+import datetime
+import difflib
+import functools
+import inspect
+import json
+import os
+import re
+import tempfile
+import threading
+import unittest
+from collections.abc import Sequence
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch._dynamo
+import torch.utils._pytree as pytree
+from torch._dynamo.utils import clone_input
+from torch._library.custom_ops import CustomOpDef
+from torch._subclasses.schema_check_mode import SchemaCheckMode
+from torch._utils_internal import get_file_path_2
+from torch.overrides import TorchFunctionMode
+from torch.testing._internal.optests import (
+    aot_autograd_check,
+    autograd_registration_check,
+    fake_check,
+)
+
+
+def dontGenerateOpCheckTests(reason: str):
+    def inner(fun):
+        fun._torch_dont_generate_opcheck_tests = True
+        return fun
+
+    return inner
+
+
+def is_abstract(tensor: torch.Tensor) -> bool:
+    if tensor.is_meta:
+        return True
+    if torch._subclasses.fake_tensor.is_fake(tensor):
+        return True
+    return False
+
+
+def safe_schema_check(
+    op: torch._ops.OpOverload,
+    args: tuple[Any, ...],
+    kwargs: dict[str, Any],
+    *,
+    copy_inputs: bool = True,
+    rtol: Optional[float] = None,
+    atol: Optional[float] = None,
+) -> Any:
+    if copy_inputs:
+        args, kwargs = deepcopy_tensors((args, kwargs))
+    if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
+        return None
+    with SchemaCheckMode():
+        result = op(*args, **kwargs)
+        return result
+
+
+def safe_autograd_registration_check(
+    op: torch._ops.OpOverload,
+    args: tuple[Any, ...],
+    kwargs: dict[str, Any],
+    *,
+    copy_inputs: bool = True,
+    rtol: Optional[float] = None,
+    atol: Optional[float] = None,
+) -> None:
+    if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
+        return
+    if copy_inputs:
+        args, kwargs = deepcopy_tensors((args, kwargs))
+    # Don't perform autograd_registration_check if none of the inputs require grad.
+    if not pytree.tree_any_only(
+        torch.Tensor, lambda x: x.requires_grad, (args, kwargs)
+    ):
+        return
+    return autograd_registration_check(op, args, kwargs)
+
+
+def safe_fake_check(
+    op: torch._ops.OpOverload,
+    args: tuple[Any, ...],
+    kwargs: dict[str, Any],
+    *,
+    copy_inputs: bool = True,
+    rtol: Optional[float] = None,
+    atol: Optional[float] = None,
+) -> None:
+    if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
+        return None
+    if copy_inputs:
+        args, kwargs = deepcopy_tensors((args, kwargs))
+    return fake_check(op, args, kwargs)
+
+
+def safe_aot_autograd_check(
+    op: torch._ops.OpOverload,
+    args: tuple[Any, ...],
+    kwargs: dict[str, Any],
+    dynamic: bool,
+    *,
+    copy_inputs: bool = True,
+    rtol: Optional[float] = None,
+    atol: Optional[float] = None,
+) -> Any:
+    # NB: copy_inputs does nothing for aot_autograd_check: it always needs to copy
+    # inputs.
+    if pytree.tree_any_only(torch.Tensor, is_abstract, (args, kwargs)):
+        return None
+
+    def func(*args, **kwargs):
+        args, kwargs = pytree.tree_map_only(torch.Tensor, torch.clone, (args, kwargs))
+        return op(*args, **kwargs)
+
+    # aot_autograd_check runs func(*args, **kwargs) multiple times
+    # and assumes `func` does not modify its inputs.
+    if rtol and atol:
+        assert_equals_fn = functools.partial(
+            torch.testing.assert_close, rtol=rtol, atol=atol
+        )
+    else:
+        assert_equals_fn = torch.testing.assert_close
+    return aot_autograd_check(
+        func,
+        args,
+        kwargs,
+        dynamic,
+        check_gradients="auto",
+        assert_equals_fn=assert_equals_fn,
+    )
+
+
+def deepcopy_tensors(inputs: Any) -> Any:
+    return pytree.tree_map_only(torch.Tensor, clone_input, inputs)
+
+
+# Test util requirements
+# - The test util must have signature (op: OpOverload, args, kwargs)
+# - The test util must NOT mutate args, kwargs.
+# - The test utils in this list must not be prefixes of each other. For example,
+#   having both "test_schema" and "test_schema_is_functional" is NOT OK.
+# - The order of items in this dict matters (for opcheck), we'll run them
+#   in order.
+ALL_TEST_UTILS = {
+    "test_schema": safe_schema_check,
+    "test_autograd_registration": safe_autograd_registration_check,
+    "test_faketensor": safe_fake_check,
+    "test_aot_dispatch_static": functools.partial(
+        safe_aot_autograd_check,
+        dynamic=False,
+    ),
+    "test_aot_dispatch_dynamic": functools.partial(
+        safe_aot_autograd_check,
+        dynamic=True,
+    ),
+}
+
+GDOC = "https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit"
+
+DEFAULT_TEST_UTILS = [
+    "test_schema",
+    "test_autograd_registration",
+    "test_faketensor",
+    "test_aot_dispatch_dynamic",
+]
+
+DEPRECATED_DEFAULT_TEST_UTILS = DEFAULT_TEST_UTILS + [
+    "test_aot_dispatch_static",
+]
+
+
+def generate_opcheck_tests(
+    testcase: Any,
+    namespaces: list[str],
+    failures_dict_path: Optional[str] = None,
+    additional_decorators: Optional[dict[str, Callable]] = None,
+    test_utils: list[str] = DEFAULT_TEST_UTILS,
+) -> None:
+    """Given an existing TestCase, use the existing tests to generate
+    additional validation tests for custom operators.
+
+    For {all existing tests in the TestCase} x {all test utils},
+    we will generate one new test. The new test runs a TorchFunctionMode
+    that intercepts ``op(*args, **kwargs)`` calls and invokes
+    ``test_util(op, *args, **kwargs)``, where ``op`` is an operator.
+
+    The test_util that we support are in ALL_TEST_UTILS. They are:
+    - test_schema: This runs SchemaCheckMode.
+    - test_autograd_registration: This runs autograd_registration_check.
+    - test_faketensor: This runs CrossRefFakeMode.
+    - test_aot_dispatch_static: This runs aot_autograd_check, which:
+        checks that the outputs (and gradients, if they are computable)
+        are the same under eager-mode PyTorch and using AOTAutograd.
+    - test_aot_dispatch_dynamic: Same as aot_dispatch_static, but
+        runs AOTAutograd using dynamic shapes instead of static shapes.
+
+    The generated test will have name ``{test_util}__{original_name}``.
+    For example, if there is a method named ``test_cumsum``, then
+    we will generate a ``test_schema__test_cumsum``,
+    ``test_faketensor__test_cumsum``, etc.
+
+    For more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit
+
+    Args:
+        testcase: The testcase we will modify and generate additional tests for.
+        namespaces: We will only intercept calls to custom operators with these
+                    namespaces.
+        failures_dict_path: See ``validate_failures_dict_structure`` for more details
+        test_utils: a list of test_utils to generate. Example: ["test_schema", "test_faketensor"]
+    """
+    if additional_decorators is None:
+        additional_decorators = {}
+    test_methods = [
+        m
+        for m in dir(testcase)
+        if m.startswith("test_") and callable(getattr(testcase, m))
+    ]
+    if failures_dict_path is None:
+        # The default failures_dict_path is failures_dict.json in
+        # the same directory as the test file.
+        prev_frame = inspect.currentframe().f_back
+        filename = inspect.getframeinfo(prev_frame)[0]
+        failures_dict_path = get_file_path_2(
+            os.path.dirname(filename), "failures_dict.json"
+        )
+    failures_dict = FailuresDict.load(
+        failures_dict_path, create_file=should_update_failures_dict()
+    )
+    validate_failures_dict_structure(failures_dict, test_utils, testcase)
+    validate_failures_dict_formatting(failures_dict_path)
+
+    def construct_method(attr, prefix, tester):
+        method = getattr(testcase, attr)
+        if getattr(method, "_torch_dont_generate_opcheck_tests", False):
+            return
+        new_method_name = prefix + "__" + attr
+
+        @functools.wraps(method)
+        def new_method(*args, **kwargs):
+            with OpCheckMode(
+                namespaces,
+                prefix,
+                tester,
+                failures_dict,
+                f"{testcase.__name__}.{new_method_name}",
+                failures_dict_path,
+            ):
+                result = method(*args, **kwargs)
+            return result
+
+        if pytestmark := new_method.__dict__.get("pytestmark"):
+            import pytest
+
+            # check if we need to simplify the parametrize marks
+            # NB: you need to add this mark to your pytest.ini
+            opcheck_only_one = False
+            for mark in pytestmark:
+                if isinstance(mark, pytest.Mark) and mark.name == "opcheck_only_one":
+                    opcheck_only_one = True
+
+            if opcheck_only_one:
+                new_pytestmark = []
+                for mark in pytestmark:
+                    if isinstance(mark, pytest.Mark) and mark.name == "parametrize":
+                        argnames, argvalues = mark.args
+                        assert not mark.kwargs, "NYI"
+                        # Special case for device, we want to run on all
+                        # devices
+                        if argnames != "device":
+                            new_pytestmark.append(
+                                pytest.mark.parametrize(
+                                    argnames, (next(iter(argvalues)),)
+                                )
+                            )
+                            continue
+                    new_pytestmark.append(mark)
+                new_method.__dict__["pytestmark"] = new_pytestmark
+
+        if new_method_name in additional_decorators:
+            for dec in additional_decorators[new_method_name]:
+                new_method = dec(new_method)
+
+        if hasattr(testcase, new_method_name):
+            raise RuntimeError(
+                f"Tried to autogenerate {new_method_name} but {testcase} already "
+                f"has method named {new_method_name}. Please rename the original "
+                f"method on the TestCase."
+            )
+        setattr(testcase, new_method_name, new_method)
+
+    test_utils = {name: ALL_TEST_UTILS[name] for name in test_utils}
+    for attr in test_methods:
+        for prefix, tester in test_utils.items():
+            construct_method(attr, prefix, tester)
+
+    generate_tag_tests(testcase, failures_dict, additional_decorators)
+
+
+def generate_tag_tests(testcase, failures_dict, additional_decorators):
+    def generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests):
+        def inner(self):
+            try:
+                op = torch._library.utils.lookup_op(qualname)
+            except AttributeError as e:
+                # Operator not importable in this test file
+                raise unittest.SkipTest(f"Can't import operator {qualname}") from e
+            op_marked_as_compliant = torch.Tag.pt2_compliant_tag in op.tags
+            if not op_marked_as_compliant:
+                return
+            if not definitely_not_pt2_compliant:
+                return
+            raise AssertionError(
+                f"op '{qualname}' was tagged with torch.Tag.pt2_compliant_tag "
+                f"but it failed some of the generated opcheck tests "
+                f"({xfailed_tests}). This may lead to silent correctness issues, "
+                f"please fix this."
+            )
+
+        return inner
+
+    for qualname, test_dict in failures_dict.data.items():
+        xfailed_tests = [
+            test
+            for test, status_dict in test_dict.items()
+            # We're about to delete the following test after Ed's PR
+            # to specialize on C++ .size() calls
+            if "test_aot_dispatch_static" not in test
+            and status_dict["status"] == "xfail"
+        ]
+        definitely_not_pt2_compliant = len(xfailed_tests) > 0
+        generated = generate_test(qualname, definitely_not_pt2_compliant, xfailed_tests)
+
+        # Could result in collisions, but unlikely. We'll raise if we see one below.
+        mangled_qualname = qualname.replace("::", "_").replace(".", "_")
+        test_name = "test_pt2_compliant_tag_" + mangled_qualname
+
+        # You can skip this test via the additional_decorators argument
+        # in generate_opcheck_tests
+        if test_name in additional_decorators:
+            for decorator in additional_decorators[test_name]:
+                generated = decorator(generated)
+
+        if hasattr(testcase, test_name):
+            raise RuntimeError(
+                f"Tried to generate a test named {test_name}, but it exists "
+                f"already. This could be because of a name collision (where "
+                f"we generated two tests with the same name), or where we "
+                f"generated a test with the same name as an existing test."
+            )
+        setattr(testcase, test_name, generated)
+
+
+TEST_OPTIONS = ("xfail", "skip", "xsuccess")
+
+
+def validate_failures_dict_formatting(failures_dict_path: str) -> None:
+    with open(failures_dict_path) as fp:
+        actual = fp.read()
+    failures_dict = FailuresDict.load(failures_dict_path)
+    expected = failures_dict._save(to_str=True)
+    if actual == expected:
+        return
+    if should_update_failures_dict():
+        failures_dict = FailuresDict.load(failures_dict_path)
+        failures_dict.save()
+        return
+    expected = expected.splitlines(1)
+    actual = actual.splitlines(1)
+    diff = difflib.unified_diff(actual, expected)
+    diff = "".join(diff)
+    raise RuntimeError(
+        f"\n{diff}\n\nExpected the failures dict to be formatted "
+        f"a certain way. Please see the above diff; you can correct "
+        f"this either manually or by re-running the test with "
+        f"PYTORCH_OPCHECK_ACCEPT=1"
+    )
+
+
+def validate_failures_dict_structure(
+    failure_dict: "FailuresDict", test_utils: list[str], testcase: Any
+) -> None:
+    """Validates the failures dict.
+
+    The failure dict looks something like the following.
+    It maps operator name (qualname) to a list of autogenerated tests.
+    Each autogenerated test may have a check for the operator (if the operator is
+    called by the test); the dictionary specifies if we should skip the check,
+    or if we expect some check to fail.
+
+    {
+        "fbgemm::split_lengths": {
+            "test_schema__test_split_lengths": {
+                "comment": "you can put whatever you want into the comment section",
+                "status": "xfail",
+            }
+            "test_schema__test_split_lengths_empty": {
+                "comment": "",
+                "status": "skip",
+            },
+        },
+        "fbgemm::gather_lengths": {
+            "test_schema__test_gather_lengths": {
+                "comment": "",
+                "status": "skip",
+            },
+        },
+    }
+
+    """
+    failure_dict = failure_dict.data
+    for test_to_option in failure_dict.values():
+        for test_name, test_dict in test_to_option.items():
+            if set(test_dict.keys()) != set({"comment", "status"}):
+                raise RuntimeError(
+                    "in failures_dict, expected sub-dict to have keys 'comment' and 'status'"
+                )
+            test_option = test_dict["status"]
+            if test_option not in TEST_OPTIONS:
+                raise RuntimeError(
+                    f"In failures_dict, got status={test_option} but it needs to be in {TEST_OPTIONS}"
+                )
+            test_class, actual_test_name = test_name.split(".")
+            if not any(actual_test_name.startswith(test) for test in test_utils):
+                raise RuntimeError(
+                    f"In failures_dict, test name '{test_name}' should begin with one of {test_utils}"
+                )
+            for test in test_utils:
+                if not actual_test_name.startswith(test):
+                    continue
+                base_test_name = actual_test_name[len(test) + 2 :]
+                # remove potential pytest parametrization suffix
+                base_test_name = re.sub(r"\[.*\]", "", base_test_name)
+                if testcase.__name__ != test_class:
+                    continue
+                if hasattr(testcase, base_test_name):
+                    continue
+                raise RuntimeError(
+                    f"In failures dict, got test name '{test_name}'. We parsed this as "
+                    f"running test '{test}' on '{base_test_name}', but "
+                    f"{base_test_name} does not exist on the TestCase '{testcase.__name__}]. "
+                    f"Maybe you need to change the test name?"
+                )
+
+
+def should_update_failures_dict() -> bool:
+    key = "PYTORCH_OPCHECK_ACCEPT"
+    return key in os.environ and os.environ[key] == "1"
+
+
+_is_inside_opcheck_mode = threading.local()
+_is_inside_opcheck_mode.value = False
+
+
+def is_inside_opcheck_mode():
+    return _is_inside_opcheck_mode.value
+
+
+class OpCheckMode(TorchFunctionMode):
+    """
+    For a given test, OpCheckMode intercepts calls to operators and runs
+    test_util(op, args, kwargs) for each intercepted (op, args, kwargs).
+    """
+
+    def __init__(
+        self,
+        namespaces: list[str],
+        test_util_name: str,
+        test_util: Callable,
+        failures_dict: "FailuresDict",
+        test_name: str,
+        failures_dict_path: str,
+    ):
+        # We will intercept calls to ops with these namespaces
+        self.namespaces = namespaces
+        # The test utility function. Its signature should be (op, args, kwargs) -> None.
+        # Examples of test utilities are: schema_check, make_fx_check
+        self.test_util = test_util
+        self.test_util_name = test_util_name
+        # The name of the test that is running this OpCheckMode.
+        self.test_name = test_name
+        # Maps qualname -> test_name -> skip/xfail
+        # Tells us if we should skip a test or assert that there is a failure.
+        self.failures_dict = failures_dict
+        # Location of the failures dict. Makes it so that the error message is better.
+        self.failures_dict_path = failures_dict_path
+
+        # OpCheckMode suppresses errors, collects them here, and then raises them on exit.
+        # Maps qualname -> List[(Exception, func, maybe args, maybe kwargs)]
+        self.seen_ops_to_errors = {}
+
+    def maybe_raise_errors_on_exit(self) -> None:
+        # Check expected failures first
+        for qualname in self.seen_ops_to_errors.keys():
+            option = self.failures_dict.get_status(qualname, self.test_name)
+            if len(self.seen_ops_to_errors[qualname]) == 0:
+                if should_update_failures_dict():
+                    self.failures_dict.set_status(
+                        qualname, self.test_name, "xsuccess", comment=""
+                    )
+                else:
+                    if option == "xfail":
+                        raise OpCheckError(
+                            f"generate_opcheck_tests: Unexpected success for operator "
+                            f"{qualname} on test {self.test_name}. This may mean that "
+                            f"you have fixed this test failure. Please rerun the test with "
+                            f"PYTORCH_OPCHECK_ACCEPT=1 to automatically update the test runner "
+                            f"or manually remove the "
+                            f"expected failure in the failure dict at "
+                            f"{self.failures_dict_path}"
+                            f"For more details, see "
+                            f"{GDOC}"
+                        )
+                continue
+        failed_ops = []
+        for qualname in self.seen_ops_to_errors.keys():
+            option = self.failures_dict.get_status(qualname, self.test_name)
+            if option != "xsuccess":
+                continue
+            if len(self.seen_ops_to_errors[qualname]) == 0:
+                continue
+            failed_ops.append(qualname)
+        if not failed_ops:
+            return
+
+        if should_update_failures_dict():
+            for op in failed_ops:
+                self.failures_dict.set_status(op, self.test_name, "xfail")
+            return
+
+        # Raise from the first error but also report about all of them to make
+        # recording xfails easier.
+        ex, op, args, kwargs = self.seen_ops_to_errors[failed_ops[0]][0]
+        repro_command = generate_repro(
+            self.test_util_name, op, args, kwargs, save_data=should_print_better_repro()
+        )
+        raise OpCheckError(
+            f"Test generated by `generate_opcheck_tests`, {self.test_name}, "
+            f"failed on operators {failed_ops}. This usually means that the "
+            f"operators are not implemented correctly and may lead to silently "
+            f"incorrect behavior. Set PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 for a standalone repro, "
+            f"or please see "
+            f"{GDOC} "
+            f"for more recommendations. "
+            f"To reproduce this problem locally, try to run the following:\n{repro_command}"
+        ) from ex
+
+    def __enter__(self, *args, **kwargs):
+        self.prev_is_opcheck_mode = _is_inside_opcheck_mode.value
+        self.prev_dynamo_disable = os.environ.get("TORCHDYNAMO_DISABLE", "")
+        _is_inside_opcheck_mode.value = True
+        os.environ["TORCHDYNAMO_DISABLE"] = "1"
+        return super().__enter__(*args, **kwargs)
+
+    def __exit__(self, *args, **kwargs):
+        _is_inside_opcheck_mode.value = self.prev_is_opcheck_mode
+        os.environ["TORCHDYNAMO_DISABLE"] = self.prev_dynamo_disable
+        try:
+            self.maybe_raise_errors_on_exit()
+            if should_update_failures_dict():
+                self.failures_dict.save()
+        finally:
+            result = super().__exit__(*args, **kwargs)
+        return result
+
+    def run_test_util(self, op, args, kwargs):
+        try:
+            self.test_util(op, args, kwargs, copy_inputs=False)
+        except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
+            # We might get here if the input is already a FakeTensor
+            # or if we're in a torch.compile block. Just ignore these
+            # since we can't handle them and reporting them as failures
+            # is too noisy.
+            pass
+
+    def __torch_function__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs if kwargs else {}
+
+        # Only intercept calls to operators
+        if not isinstance(func, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
+            return func(*args, **kwargs)
+        if (
+            torch.jit.is_tracing()
+            or torch.jit.is_scripting()
+            or torch._dynamo.is_compiling()
+        ):
+            return func(*args, **kwargs)
+        # Pre-existing code may not use the .default overload. If we see an
+        # OpOverloadPacket and we cannot resolve the overload, then we just throw
+        # and ask the user to clarify. Otherwise, we attempt to resolve the overload.
+        if isinstance(func, torch._ops.OpOverloadPacket):
+            func = resolve_unique_overload_or_throw(func)
+        qualname = func.name()
+        ns = qualname.split("::")[0]
+        if ns not in self.namespaces:
+            return func(*args, **kwargs)
+
+        args_c, kwargs_c = deepcopy_tensors((args, kwargs))
+        result = func(*args, **kwargs)
+
+        option = self.failures_dict.get_status(qualname, self.test_name)
+        if option == "xsuccess" or option == "xfail":
+            # Suppress all errors during execution. Raise them during __exit__.
+            try:
+                if qualname not in self.seen_ops_to_errors:
+                    self.seen_ops_to_errors[qualname] = []
+                self.run_test_util(func, args_c, kwargs_c)
+            except Exception as ex:
+                if should_print_better_repro():
+                    self.seen_ops_to_errors[qualname].append((ex, func, args, kwargs))
+                else:
+                    self.seen_ops_to_errors[qualname].append((ex, func, None, None))
+        elif option == "skip":
+            pass
+        return result
+
+
+def should_print_better_repro() -> None:
+    """If set, the tests generated by `generate_opcheck_tests` will print a
+    repro command on failure.
+
+    In order to print the repro command, we need to save some tensors to disk.
+    These will be saved under the following directory:
+    {tempfile.gettempdir()}/pytorch_opcheck_safe_to_delete/.
+
+    Although this is a temp folder, it will usually not automatically get cleaned
+    up, so you'll need to manually delete it.
+    """
+    key = "PYTORCH_OPCHECK_PRINT_BETTER_REPRO"
+    if key not in os.environ:
+        return False
+    value = os.environ[key]
+    return value == "1" or value == 1
+
+
+def opcheck(
+    op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
+    args: tuple[Any, ...],
+    kwargs: Optional[dict[str, Any]] = None,
+    *,
+    test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS,
+    raise_exception: bool = True,
+    rtol: Optional[float] = None,
+    atol: Optional[float] = None,
+) -> dict[str, str]:
+    """See torch.library.opcheck for docstring"""
+
+    if (rtol is None) ^ (atol is None):
+        raise ValueError(
+            "opcheck(op, ...): if you specify one of rtol/atol, you must specify both"
+        )
+
+    if kwargs is None:
+        kwargs = {}
+    if isinstance(op, CustomOpDef):
+        op = op._opoverload
+    if isinstance(op, torch._ops.OpOverloadPacket):
+        op = resolve_unique_overload_or_throw(op)
+    if not isinstance(op, torch._ops.OpOverload):
+        raise ValueError(
+            f"opcheck(op, ...): op must be instance of torch._ops.OpOverload, "
+            f"e.g. torch.ops.aten.sin.default, got {type(op)}"
+        )
+    if test_utils == "ALL":
+        test_utils = tuple(ALL_TEST_UTILS.keys())
+    if isinstance(test_utils, str):
+        test_utils = (test_utils,)
+    if not isinstance(test_utils, (tuple, list)) or not set(test_utils).issubset(
+        ALL_TEST_UTILS.keys()
+    ):
+        raise ValueError(
+            f"opcheck(op, ..., test_utils={test_utils}), expected test_utils "
+            f"to be subset of {tuple(ALL_TEST_UTILS.keys())} but it was not"
+        )
+
+    results_dict = {}
+    for test_util in test_utils:
+        tester = ALL_TEST_UTILS[test_util]
+        try:
+            tester(op, args, kwargs, rtol=rtol, atol=atol)
+            results_dict[test_util] = "SUCCESS"
+        except Exception as ex:
+            if raise_exception:
+                raise OpCheckError(
+                    f"opcheck(op, ...): {test_util} failed with {ex} "
+                    f"(scroll up for stack trace)"
+                ) from ex
+            results_dict[test_util] = ex
+    return results_dict
+
+
+class OpCheckError(Exception):
+    pass
+
+
+def generate_repro(
+    test: str,
+    op: torch._ops.OpOverload,
+    args: tuple[Any, ...],
+    kwargs: dict[str, Any],
+    *,
+    save_data: bool,
+    dry_run: bool = False,
+) -> str:
+    if save_data:
+        now = datetime.datetime.now()
+        path = os.path.join(tempfile.gettempdir(), "pytorch_opcheck_safe_to_delete")
+        unix_timestamp = datetime.datetime.timestamp(now) * 100000
+        filepath = os.path.join(path, f"repro_{unix_timestamp}.pt")
+        if not dry_run:
+            os.makedirs(path, exist_ok=True)
+            torch.save((args, kwargs), filepath)
+        args_kwargs = f'args, kwargs = torch.load("{filepath}")'
+    else:
+        args_kwargs = (
+            "# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1\n"
+            "# we will fill them in same (args, kwargs) as in your test\n"
+            "args = ()  # args to the operator\n"
+            "kwargs = {}  # kwargs to the operator"
+        )
+
+    ns, name = op._schema.name.split("::")
+    overload = op._overloadname
+
+    repro_command = (
+        f"# =========================================================\n"
+        f"# BEGIN REPRO SCRIPT\n"
+        f"# =========================================================\n"
+        f"import torch\n"
+        f"from torch.testing._internal.optests import opcheck\n"
+        f"\n"
+        f"# Make sure you have loaded the library that contains the op\n"
+        f"# via an import or torch.ops.load_library(...)\n"
+        f"op = torch.ops.{ns}.{name}.{overload}\n"
+        f"\n"
+        f"{args_kwargs}\n"
+        f'opcheck(op, args, kwargs, test_utils="{test}")\n'
+        f"# =========================================================\n"
+        f"# END REPRO SCRIPT\n"
+        f"# =========================================================\n"
+    )
+    return repro_command
+
+
+def resolve_unique_overload_or_throw(
+    op: torch._ops.OpOverloadPacket,
+) -> torch._ops.OpOverload:
+    all_schemas = torch._C._jit_get_schemas_for_operator(op._qualified_op_name)
+    if len(all_schemas) != 1:
+        raise RuntimeError(
+            f"opcheck can only test operators without overloads. "
+            f"Got the following overloads for {op._qualified_op_name}: "
+            f"{[schema.overload_name for schema in all_schemas]}"
+        )
+
+    overload_name = all_schemas[0].overload_name
+    if overload_name == "":
+        return op.default
+    return getattr(op, overload_name)
+
+
+DUMP_OPTIONS = {"indent": 2, "sort_keys": True}
+
+
+FailuresDictData = dict[str, dict[str, dict[str, str]]]
+
+
+VERSION = 1
+DESCRIPTION = (
+    f"This is a dict containing failures for tests autogenerated by "
+    f"generate_opcheck_tests. "
+    f"For more details, please see {GDOC}"
+)
+
+
+class FailuresDict:
+    def __init__(self, path: str, data: FailuresDictData):
+        self.path = path
+        self.data = data
+
+    @staticmethod
+    def load(path, *, create_file=False) -> "FailuresDict":
+        if create_file and not os.path.exists(path):
+            result = FailuresDict(path, {})
+            FailuresDict.save()
+            return result
+        with open(path) as fp:
+            contents = fp.read()
+            if contents.strip() == "":
+                dct = {
+                    "_description": DESCRIPTION,
+                    "data": {},
+                    "_version": VERSION,
+                }
+            else:
+                dct = json.loads(contents)
+                assert "data" in dct
+                assert "_version" in dct and dct["_version"] == VERSION
+        return FailuresDict(path, dct["data"])
+
+    def _save(self, to_str=False) -> Optional[str]:
+        to_dump = {
+            "_description": DESCRIPTION,
+            "data": self.data,
+            "_version": VERSION,
+        }
+        # json.dumps doesn't end with a newline. Let's add one because files
+        # should end in newlines.
+        serialized = json.dumps(to_dump, **DUMP_OPTIONS) + "\n"
+        if to_str:
+            return serialized
+        with open(self.path, "w") as fp:
+            fp.write(serialized)
+        return None
+
+    def save(self) -> None:
+        return self._save()
+
+    def get_status(self, qualname: str, test_name: str) -> str:
+        if qualname not in self.data:
+            return "xsuccess"
+        dct = self.data[qualname]
+        if test_name not in dct:
+            return "xsuccess"
+        return dct[test_name]["status"]
+
+    def set_status(
+        self,
+        qualname: str,
+        test_name: str,
+        status: str,
+        *,
+        comment: Optional[str] = None,
+    ):
+        if qualname not in self.data:
+            self.data[qualname] = {}
+        dct = self.data[qualname]
+        if test_name not in dct:
+            dct[test_name] = {"status": None, "comment": ""}
+
+        if status == "xsuccess":
+            # The default status is "xsuccess".
+            del dct[test_name]
+        else:
+            dct[test_name]["status"] = status
+            if comment is not None:
+                dct[test_name]["comment"] = comment
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py
new file mode 100644
index 0000000000000000000000000000000000000000..83cefd18bc059a9667c4d224fab360d6d41cfe34
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/optests/make_fx.py
@@ -0,0 +1,89 @@
+# mypy: ignore-errors
+
+import torch
+from torch.fx.experimental.proxy_tensor import make_fx
+from torch.testing._utils import wrapper_set_seed
+import torch.utils._pytree as pytree
+
+
+def make_fx_check(
+    func,
+    args,
+    kwargs,
+    tracing_mode,
+    assert_close=torch.testing.assert_close,
+    randomize_data=False,
+):
+    f, *new_args = handle_sizes_for_dynamic_shapes(func, args, kwargs)
+
+    def run(f, *args, **kwargs):
+        return wrapper_set_seed(f, *args, **kwargs)
+
+    traced_f = make_fx(f, tracing_mode=tracing_mode)(*new_args)
+
+    msg = (
+        "op(*args, **kwargs) and make_fx(op)(*args, **kwargs) produced different "
+        "values. This could mean that your abstract impls (meta/FakeTensor impls) "
+        "are incorrect, that your operator is not completely traceable (e.g., "
+        "it relies on some global state), or that there is a bug in make_fx. "
+        "Note that if you passed a python function (and not an operator) to "
+        "make_fx_check, it is still possible that the python function will still "
+        "work with torch.compile because it handles capturing pieces of "
+        "your python code to compile."
+    )
+
+    # Randomize the data and run the traced graph with it, to catch bugs
+    # where we may have baked in Tensor data into the trace.
+    # This is not guaranteed to succeed, because `f` might have preconditions
+    # on the values of the inputs, so we just ignore if we used
+    # random data and it fails.
+    if randomize_data:
+        new_args = randomize(new_args)
+    try:
+        expected = run(f, *new_args)
+    except Exception:
+        if randomize_data:
+            return
+        raise
+    result = run(traced_f, *new_args)
+    assert_close(result, expected, msg=msg)
+
+
+# Arguably we should make make_fx promote torch.Size() objects to symbolic shapes.
+# Absent that, here is our strategy:
+#
+# If any argument is a torch.Size(), maybe get dynamic shapes for it by:
+# - Create a temporary Tensor whose size is the torch.Size() we want. Note that
+#   we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx.
+# - Pass it to make_fx such that it is is converted to a proxy Tensor
+# - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in
+#   symbolic mode, a no-op otherwise)
+def handle_sizes_for_dynamic_shapes(func, args, kwargs):
+    def f(args, kwargs, extra_args, extra_kwargs):
+        if extra_args:
+            for i, t in extra_args:
+                args[i] = t.size()
+        if extra_kwargs:
+            for k, t in extra_kwargs.items():
+                kwargs[k] = t.size()
+
+        return func(*args, **kwargs)
+
+    extra_args = []
+    extra_kwargs = {}
+    for i, arg in enumerate(args):
+        if isinstance(arg, torch.Size):
+            extra_args.append((i, torch.empty(arg, device="cpu")))
+    for key, value in kwargs.items():
+        if isinstance(value, torch.Size):
+            extra_kwargs[key] = torch.empty(value, device="cpu")
+
+    return f, args, kwargs, extra_args, extra_kwargs
+
+
+def randomize(args):
+    def transform(x):
+        if not x.dtype.is_floating_point:
+            return x
+        return x.detach().clone().uniform_(0, 1).requires_grad_(x.requires_grad)
+    return pytree.tree_map_only(torch.Tensor, transform, args)
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..abc4ab6f7e4734361ec7ecea3d4755910f9cf2ab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/quantization_torch_package_models.py
@@ -0,0 +1,33 @@
+# mypy: ignore-errors
+
+import math
+
+import torch
+import torch.nn as nn
+
+
+class LinearReluFunctionalChild(nn.Module):
+    def __init__(self, N):
+        super().__init__()
+        self.w1 = nn.Parameter(torch.empty(N, N))
+        self.b1 = nn.Parameter(torch.zeros(N))
+        torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
+
+    def forward(self, x):
+        x = torch.nn.functional.linear(x, self.w1, self.b1)
+        x = torch.nn.functional.relu(x)
+        return x
+
+class LinearReluFunctional(nn.Module):
+    def __init__(self, N):
+        super().__init__()
+        self.child = LinearReluFunctionalChild(N)
+        self.w1 = nn.Parameter(torch.empty(N, N))
+        self.b1 = nn.Parameter(torch.zeros(N))
+        torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
+
+    def forward(self, x):
+        x = self.child(x)
+        x = torch.nn.functional.linear(x, self.w1, self.b1)
+        x = torch.nn.functional.relu(x)
+        return x
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/static_module.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/static_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a031b0d8f6e685517b7ac51c236e23835501cd9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/static_module.py
@@ -0,0 +1,27 @@
+# mypy: allow-untyped-defs
+# Owner(s): ["module: unknown"]
+
+import torch
+
+
+class StaticModule:
+    def __init__(self, scripted):
+        # this is an nn.Module
+        if hasattr(scripted, "_c"):
+            self.static_module = torch._C._jit_to_static_module(scripted._c)
+        else:
+            self.static_module = torch._C._jit_to_static_module(scripted.graph)
+
+    def __call__(self, *args, **kwargs):
+        return self.static_module(*args, **kwargs)
+
+    def benchmark(self, args, kwargs, warmup_runs, main_runs):
+        self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
+
+    def runAsync(self, args, kwargs):
+        return self.static_module.runAsync(args, kwargs)
+
+    def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
+        return self.static_module.benchmark_individual_ops(
+            args, kwargs, warmup_runs, main_runs
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py
new file mode 100644
index 0000000000000000000000000000000000000000..0898c288d926f8bf9fa7e4187f27bb3339183b5b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/subclasses.py
@@ -0,0 +1,78 @@
+# mypy: ignore-errors
+from typing import Any, Optional
+
+import torch
+import torch.utils._pytree as pytree
+from torch._subclasses.fake_tensor import is_fake
+from torch.testing._internal.two_tensor import TwoTensor
+from torch.utils._python_dispatch import return_and_correct_aliasing
+
+
+class WrapperSubclass(torch.Tensor):
+    @staticmethod
+    def __new__(cls, a, outer_size=None, outer_stride=None):
+        if outer_size is None:
+            outer_size = a.size()
+        if outer_stride is None:
+            outer_stride = a.stride()
+
+        kwargs = {}
+        kwargs["strides"] = outer_stride
+        kwargs["storage_offset"] = a.storage_offset()
+        kwargs["device"] = a.device
+        kwargs["layout"] = a.layout
+        kwargs["requires_grad"] = a.requires_grad
+        kwargs["dtype"] = a.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, outer_size, **kwargs)
+
+        return out
+
+    def __init__(self, a, outer_size=None, outer_stride=None):
+        self.a = a
+
+    def __repr__(self):
+        return f"WrapperSubclass({repr(self.a)})"
+
+    def __tensor_flatten__(self):
+        return ["a"], None
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        assert meta is None
+        a = inner_tensors["a"]
+        if is_fake(a):
+            assert outer_size is not None
+            assert outer_stride is not None
+        return WrapperSubclass(a, outer_size, outer_stride)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+        args_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, args)
+
+        kwargs_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, kwargs)
+
+        out_a = func(*args_a, **kwargs_a)
+        out_a_flat, spec = pytree.tree_flatten(out_a)
+        out_flat = [
+            WrapperSubclass(o_a) if isinstance(o_a, torch.Tensor) else o_a
+            for o_a in out_a_flat
+        ]
+        out = pytree.tree_unflatten(out_flat, spec)
+        from torch._higher_order_ops.cond import cond_op
+
+        if func is cond_op:
+            return out
+        else:
+            return return_and_correct_aliasing(func, args, kwargs, out)
+
+    def __coerce_same_metadata_as_tangent__(
+        self, expected_metadata: Any, expected_type: Optional[type] = None
+    ):
+        if expected_type == type(self.a):
+            return self.a
+        elif expected_type is TwoTensor:
+            return TwoTensor(self.a, self.a.clone())
+
+        return None
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__init__.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a3494f945fad36d84cb8056dcf722d6911f0af2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/future_div.py
@@ -0,0 +1,10 @@
+# mypy: ignore-errors
+
+
+
+def div_int_future():
+    return 1 / 2
+
+
+def div_float_future():
+    return 3.14 / 0.125
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py
new file mode 100644
index 0000000000000000000000000000000000000000..164e6d168414a11039f3b63885760ad08b81ae99
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/test_module/no_future_div.py
@@ -0,0 +1,11 @@
+# mypy: ignore-errors
+
+import torch  # noqa: F401
+
+
+def div_int_nofuture():
+    return 1 / 2
+
+
+def div_float_nofuture():
+    return 3.14 / 0.125
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fae730df5cb9e21c18fd1a7bb69a3d899fccd80
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/torchbind_impls.py
@@ -0,0 +1,180 @@
+# mypy: allow-untyped-defs
+import contextlib
+from pathlib import Path
+from typing import Optional
+
+import torch
+
+
+_TORCHBIND_IMPLS_INITIALIZED = False
+
+_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None
+
+
+def init_torchbind_implementations():
+    global _TORCHBIND_IMPLS_INITIALIZED
+    global _TENSOR_QUEUE_GLOBAL_TEST
+    if _TORCHBIND_IMPLS_INITIALIZED:
+        return
+
+    load_torchbind_test_lib()
+    register_fake_operators()
+    register_fake_classes()
+    _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
+    _TORCHBIND_IMPLS_INITIALIZED = True
+
+
+def _empty_tensor_queue() -> torch.ScriptObject:
+    return torch.classes._TorchScriptTesting._TensorQueue(
+        torch.empty(
+            0,
+        ).fill_(-1)
+    )
+
+
+# put these under a function because the corresponding library might not be loaded yet.
+def register_fake_operators():
+    @torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
+    def fake_takes_foo(foo, z):
+        return foo.add_tensor(z)
+
+    @torch.library.register_fake("_TorchScriptTesting::queue_pop")
+    def fake_queue_pop(tq):
+        return tq.pop()
+
+    @torch.library.register_fake("_TorchScriptTesting::queue_push")
+    def fake_queue_push(tq, x):
+        return tq.push(x)
+
+    @torch.library.register_fake("_TorchScriptTesting::queue_size")
+    def fake_queue_size(tq):
+        return tq.size()
+
+    def meta_takes_foo_list_return(foo, x):
+        a = foo.add_tensor(x)
+        b = foo.add_tensor(a)
+        c = foo.add_tensor(b)
+        return [a, b, c]
+
+    def meta_takes_foo_tuple_return(foo, x):
+        a = foo.add_tensor(x)
+        b = foo.add_tensor(a)
+        return (a, b)
+
+    @torch.library.register_fake("_TorchScriptTesting::takes_foo_tensor_return")
+    def meta_takes_foo_tensor_return(foo, x):
+        # This implementation deliberately creates unbacked symint for testing
+        ctx = torch.library.get_ctx()
+        fake_shape = [ctx.new_dynamic_size() for _ in range(2)]
+        return torch.empty(fake_shape, dtype=torch.int, device="cpu")
+
+    torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl(
+        torch._C.DispatchKey.Meta
+    )(meta_takes_foo_list_return)
+
+    torch.ops._TorchScriptTesting.takes_foo_tuple_return.default.py_impl(
+        torch._C.DispatchKey.Meta
+    )(meta_takes_foo_tuple_return)
+
+    torch.ops._TorchScriptTesting.takes_foo.default.py_impl(torch._C.DispatchKey.Meta)(
+        # make signature match original cpp implementation to support kwargs
+        lambda foo, x: foo.add_tensor(x)
+    )
+
+
+def register_fake_classes():
+    # noqa: F841
+    @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
+    class FakeFoo:
+        def __init__(self, x: int, y: int):
+            self.x = x
+            self.y = y
+
+        @classmethod
+        def __obj_unflatten__(cls, flattend_foo):
+            return cls(**dict(flattend_foo))
+
+        def add_tensor(self, z):
+            return (self.x + self.y) * z
+
+    @torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
+    class FakeContainsTensor:
+        def __init__(self, t: torch.Tensor):
+            self.t = t
+
+        @classmethod
+        def __obj_unflatten__(cls, flattend_foo):
+            return cls(**dict(flattend_foo))
+
+        def get(self):
+            return self.t
+
+    @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
+    class FakeTensorQueue:
+        def __init__(self, queue):
+            self.queue = queue
+
+        @classmethod
+        def __obj_unflatten__(cls, flattened_ctx):
+            return cls(**dict(flattened_ctx))
+
+        def push(self, x):
+            self.queue.append(x)
+
+        def pop(self):
+            if self.is_empty():
+                return torch.empty([])
+            return self.queue.pop(0)
+
+        def size(self):
+            return len(self.queue)
+
+        def is_empty(self):
+            return len(self.queue) == 0
+
+        def float_size(self):
+            return float(len(self.queue))
+
+    @torch._library.register_fake_class("_TorchScriptTesting::_FlattenWithTensorOp")
+    class FakeFlatten:
+        def __init__(self, t):
+            self.t = t
+
+        def get(self):
+            return self.t
+
+        @classmethod
+        def __obj_unflatten__(cls, flattened_ctx):
+            return cls(**dict(flattened_ctx))
+
+
+def load_torchbind_test_lib():
+    import unittest
+
+    from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
+        find_library_location,
+        IS_FBCODE,
+        IS_MACOS,
+        IS_SANDCASTLE,
+        IS_WINDOWS,
+    )
+
+    if IS_MACOS:
+        raise unittest.SkipTest("non-portable load_library call used in test")
+    elif IS_SANDCASTLE or IS_FBCODE:
+        lib_file_path = Path("//caffe2/test/cpp/jit:test_custom_class_registrations")
+    elif IS_WINDOWS:
+        lib_file_path = find_library_location("torchbind_test.dll")
+    else:
+        lib_file_path = find_library_location("libtorchbind_test.so")
+    torch.ops.load_library(str(lib_file_path))
+
+
+@contextlib.contextmanager
+def _register_py_impl_temporarily(op_overload, key, fn):
+    try:
+        op_overload.py_impl(key)(fn)
+        yield
+    finally:
+        del op_overload.py_kernels[key]
+        op_overload._dispatch_cache.clear()
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..79aca02b63d40e790674e793f3e0aa55b84de035
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/triton_utils.py
@@ -0,0 +1,956 @@
+# mypy: ignore-errors
+
+import unittest
+
+from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
+from torch.utils._triton import has_triton
+
+
+requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
+requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")
+
+if has_triton():
+    import triton
+    from triton import language as tl
+
+    # Define here so that multiple tests can take advantage of it
+    @triton.jit
+    def add_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def sub_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x - y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_optional_param(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        ARGS_PASSED: "tl.constexpr",
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        if ARGS_PASSED == "two":
+            y = tl.load(in_ptr1 + offsets, mask=mask)
+            output = x + y
+        else:
+            output = x
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_none_param_and_equal_to_1_arg(
+        in_ptr0,
+        in_ptr1,  # in_ptr1 could be None
+        out_ptr,
+        n_elements,
+        stride,
+        ARGS_PASSED: "tl.constexpr",
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets * stride, mask=mask)
+        if ARGS_PASSED == "two":
+            y = tl.load(in_ptr1 + offsets, mask=mask)
+            output = x + y
+        else:
+            output = x
+        tl.store(out_ptr + offsets * stride, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def sub_kernel_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x - y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_autotuned_weird_param_order(
+        in_ptr0,
+        in_ptr1,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+        out_ptr,
+    ):
+        # out_ptr is after an autotuned param that's declared as tl.constexpr.
+        # This param ordering can create bugs if not handled correctly.
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config(
+                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
+            ),
+            triton.Config(
+                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
+            ),
+        ],
+        key=[],
+    )
+    @triton.jit
+    def add_kernel_2d_autotuned(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        x_elements,
+        y_elements,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        xoffset = tl.program_id(0) * BLOCK_SIZE_X
+        xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
+        xmask = xindex < x_elements
+        yoffset = tl.program_id(1) * BLOCK_SIZE_Y
+        yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
+        ymask = yindex < y_elements
+        x1 = xindex
+        y0 = yindex
+        tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
+        tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
+        tmp2 = tmp0 + tmp1
+        tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
+
+    def _dummy_early_config_prune(configs, *_, **__):
+        return configs
+
+    @triton.autotune(
+        configs=[
+            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
+            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
+        ],
+        key=[],
+        warmup=10,
+        rep=20,
+        prune_configs_by={"early_config_prune": _dummy_early_config_prune},
+    )
+    @triton.jit
+    def add_kernel_autotuned_with_unsupported_args(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_scaling(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        scaling_factor,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = (x + y) * scaling_factor
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_with_tma_1d_old_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        offset = pid * BLOCK_SIZE
+
+        a = tl._experimental_descriptor_load(
+            in_desc_ptr0,
+            [offset],
+            [BLOCK_SIZE],
+            tl.float32,
+        )
+        b = tl._experimental_descriptor_load(
+            in_desc_ptr1,
+            [offset],
+            [BLOCK_SIZE],
+            tl.float32,
+        )
+
+        output = a + b
+
+        tl._experimental_descriptor_store(
+            out_desc_ptr,
+            output,
+            [offset],
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_2d_old_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE_X
+        offset_y = pid_y * BLOCK_SIZE_Y
+
+        x = tl._experimental_descriptor_load(
+            in_desc_ptr0,
+            [offset_x, offset_y],
+            [BLOCK_SIZE_X, BLOCK_SIZE_Y],
+            tl.float32,
+        )
+        y = tl._experimental_descriptor_load(
+            in_desc_ptr1,
+            [offset_x, offset_y],
+            [BLOCK_SIZE_X, BLOCK_SIZE_Y],
+            tl.float32,
+        )
+
+        output = x + y
+
+        tl._experimental_descriptor_store(
+            out_desc_ptr,
+            output,
+            [offset_x, offset_y],
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_1d_new_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        offset = pid * BLOCK_SIZE
+
+        a = tl.load_tensor_descriptor(
+            in_desc_ptr0,
+            [offset],
+        )
+        b = tl.load_tensor_descriptor(
+            in_desc_ptr1,
+            [offset],
+        )
+
+        output = a + b
+
+        tl.store_tensor_descriptor(
+            out_desc_ptr,
+            [offset],
+            output,
+        )
+
+    @triton.jit
+    def add_kernel_with_tma_2d_new_api(
+        in_desc_ptr0,
+        in_desc_ptr1,
+        out_desc_ptr,
+        BLOCK_SIZE_X: "tl.constexpr",
+        BLOCK_SIZE_Y: "tl.constexpr",
+    ):
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE_X
+        offset_y = pid_y * BLOCK_SIZE_Y
+
+        x = tl.load_tensor_descriptor(
+            in_desc_ptr0,
+            [offset_x, offset_y],
+        )
+        y = tl.load_tensor_descriptor(
+            in_desc_ptr1,
+            [offset_x, offset_y],
+        )
+
+        output = x + y
+
+        tl.store_tensor_descriptor(
+            out_desc_ptr,
+            [offset_x, offset_y],
+            output,
+        )
+
+    @triton.jit
+    def add_kernel_on_device_tma_old_api(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        m,
+        n,
+        workspace,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        a_desc_ptr = workspace
+        b_desc_ptr = workspace + 128
+        c_desc_ptr = workspace + 256
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=a_desc_ptr,
+            global_address=a_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=a_ptr.dtype.element_ty,
+        )
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=b_desc_ptr,
+            global_address=b_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=b_ptr.dtype.element_ty,
+        )
+        tl.extra.cuda.experimental_device_tensormap_create2d(
+            desc_ptr=c_desc_ptr,
+            global_address=c_ptr,
+            load_size=[BLOCK_SIZE, BLOCK_SIZE],
+            global_size=[m, n],
+            element_ty=c_ptr.dtype.element_ty,
+        )
+
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
+        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
+
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE
+        offset_y = pid_y * BLOCK_SIZE
+
+        # Load data using the tensor descriptors
+        a = tl._experimental_descriptor_load(
+            a_desc_ptr,
+            [offset_x, offset_y],
+            [BLOCK_SIZE, BLOCK_SIZE],
+            tl.float32,
+        )
+        b = tl._experimental_descriptor_load(
+            b_desc_ptr,
+            [offset_x, offset_y],
+            [BLOCK_SIZE, BLOCK_SIZE],
+            tl.float32,
+        )
+
+        # Perform addition
+        output = a + b
+
+        # Store the result
+        tl._experimental_descriptor_store(
+            c_desc_ptr,
+            output,
+            [offset_x, offset_y],
+        )
+
+    @triton.jit
+    def add_kernel_on_device_tma_new_api(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        m,
+        n,
+        workspace,  # unused but left here to match the old API kernel
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        # Create tensor descriptors using the new API
+        a_desc = tl.make_tensor_descriptor(
+            base=a_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+        b_desc = tl.make_tensor_descriptor(
+            base=b_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+        c_desc = tl.make_tensor_descriptor(
+            base=c_ptr,
+            shape=[m, n],
+            strides=[n, 1],
+            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
+        )
+
+        pid_x = tl.program_id(axis=0)
+        pid_y = tl.program_id(axis=1)
+        offset_x = pid_x * BLOCK_SIZE
+        offset_y = pid_y * BLOCK_SIZE
+
+        # Load data using the tensor descriptors with the new API
+        a = tl.load_tensor_descriptor(
+            a_desc,
+            [offset_x, offset_y],
+        )
+        b = tl.load_tensor_descriptor(
+            b_desc,
+            [offset_x, offset_y],
+        )
+
+        # Perform addition
+        output = a + b
+
+        # Store the result with the new API
+        tl.store_tensor_descriptor(
+            c_desc,
+            [offset_x, offset_y],
+            output,
+        )
+
+    @triton.jit
+    def mul2_kernel(
+        in_ptr0,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        output = 2 * x
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def mul2_inplace_kernel(
+        ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(ptr + offsets, mask=mask)
+        output = 2 * x
+        tl.store(ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def zero_negs(x):
+        return tl.where(x >= 0, x, 0)
+
+    @triton.jit
+    def indirection_kernel(
+        in_ptr0,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+        ACTIVATION: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        if ACTIVATION == "mul2_inplace_kernel":
+            mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
+        elif ACTIVATION == "add_kernel":
+            add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        tl.store(out_ptr + offsets, x, mask=mask)
+
+    @triton.jit
+    def double_strided_kernel(
+        in_ptr,
+        out_ptr,
+        in_y_stride,
+        out_y_stride,
+        X_BLOCK_SIZE: "tl.constexpr",
+        Y_BLOCK_SIZE: "tl.constexpr",
+    ):
+        xid = tl.program_id(axis=0)
+        yid = tl.program_id(axis=1)
+        x_start = xid * X_BLOCK_SIZE
+        y_start = yid * Y_BLOCK_SIZE
+        x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
+        y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
+        src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
+        dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
+        src = tl.load(in_ptr + src_offsets)
+        tl.store(out_ptr + dst_offsets, src * 2.0)
+
+    @triton.jit
+    def inline_asm_kernel_is_pure_true(
+        X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
+    ):
+        x = tl.load(X + tl.arange(0, BLOCK))
+        y = tl.load(Y + tl.arange(0, BLOCK))
+        s = tl.full([BLOCK], n, tl.int32)
+        z = tl.inline_asm_elementwise(
+            "shf.l.wrap.b32 $0, $1, $2, $3;",
+            "=r,r, r, r",
+            [x, y, s],
+            dtype=tl.int32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(Z + tl.arange(0, BLOCK), z)
+
+    @triton.jit
+    def inline_asm_kernel_is_pure_false(
+        X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
+    ):
+        x = tl.load(X + tl.arange(0, BLOCK))
+        y = tl.load(Y + tl.arange(0, BLOCK))
+        s = tl.full([BLOCK], n, tl.int32)
+        z = tl.inline_asm_elementwise(
+            "shf.l.wrap.b32 $0, $1, $2, $3;",
+            "=r,r, r, r",
+            [x, y, s],
+            dtype=tl.int32,
+            is_pure=False,
+            pack=1,
+        )
+        tl.store(Z + tl.arange(0, BLOCK), z)
+
+    @triton.jit
+    def add_kernel_with_block_ptr(
+        x_ptr,
+        y_ptr,
+        output_ptr,
+        n_elements,
+        BLOCK_SIZE: tl.constexpr,
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        x = tl.load(
+            tl.make_block_ptr(
+                base=x_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            boundary_check=[0],
+        )
+        y = tl.load(
+            tl.make_block_ptr(
+                base=y_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            boundary_check=[0],
+        )
+        output = x + y
+        tl.store(
+            tl.make_block_ptr(
+                base=output_ptr,
+                shape=[n_elements],
+                strides=[1],
+                offsets=[block_start],
+                block_shape=[BLOCK_SIZE],
+                order=[0],
+            ),
+            output,
+            boundary_check=[0],
+        )
+
+    @triton.jit
+    def kernel_with_block_ptr_2d(
+        x_ptr,
+        output_ptr,
+        n_elements,
+        BLOCK_SIZE: tl.constexpr,
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        x = tl.load(
+            tl.make_block_ptr(
+                base=x_ptr,
+                shape=[n_elements, 1],
+                strides=[1, 1],
+                offsets=[block_start, 0],
+                block_shape=[BLOCK_SIZE, 1],
+                order=[1, 0],
+            ),
+            boundary_check=[0],
+        )
+        output = x
+        tl.store(
+            tl.make_block_ptr(
+                base=output_ptr,
+                shape=[n_elements, 1],
+                strides=[1, 1],
+                offsets=[block_start, 0],
+                block_shape=[BLOCK_SIZE, 1],
+                order=[1, 0],
+            ),
+            output,
+            boundary_check=[0],
+        )
+
+    from triton.language import load, store
+
+    @triton.jit
+    def add_kernel_with_import(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = load(in_ptr0 + offsets, mask=mask)
+        y = load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def cond_op_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        if tl.program_id(0) == 0:
+            output = x + y
+        else:
+            output = x * y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def atomic_add_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.atomic_add(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_4_times_kernel(
+        in_ptr0,
+        in_ptr1,
+        out_ptr,
+        n_elements,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        for i in range(2):
+            output = x + y
+            tl.store(out_ptr + offsets, output, mask=mask)
+        i = 2
+        while i > 0:
+            i -= 1
+            output = x + y
+            tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.jit
+    def add_kernel_out_of_order_fn2(
+        in_ptr0,
+        in_ptr1,
+        n_elements,
+        out_ptr,
+        BLOCK_SIZE: "tl.constexpr",
+    ):
+        pid = tl.program_id(axis=0)
+        block_start = pid * BLOCK_SIZE
+        offsets = block_start + tl.arange(0, BLOCK_SIZE)
+        mask = offsets < n_elements
+        x = tl.load(in_ptr0 + offsets, mask=mask)
+        y = tl.load(in_ptr1 + offsets, mask=mask)
+        output = x + y
+        tl.store(out_ptr + offsets, output, mask=mask)
+
+    @triton.autotune(
+        configs=[
+            triton.Config(
+                {
+                    "BLOCK_SIZE_M": 16,
+                    "BLOCK_SIZE_N": 16,
+                    "BLOCK_SIZE_K": 16,
+                    "GROUP_SIZE_M": 4,
+                },
+                num_stages=4,
+                num_warps=4,
+            ),
+            triton.Config(
+                {
+                    "BLOCK_SIZE_M": 128,
+                    "BLOCK_SIZE_N": 64,
+                    "BLOCK_SIZE_K": 32,
+                    "GROUP_SIZE_M": 8,
+                },
+                num_stages=4,
+                num_warps=4,
+            ),
+        ],
+        key=["M_ptr", "N", "K"],
+    )
+    @triton.jit
+    def strange_config_matmul_kernel(
+        a_ptr,
+        b_ptr,
+        c_ptr,
+        M_ptr,
+        N,
+        K,
+        BLOCK_SIZE_M: tl.constexpr,
+        BLOCK_SIZE_N: tl.constexpr,
+        BLOCK_SIZE_K: tl.constexpr,
+        GROUP_SIZE_M: tl.constexpr,
+    ):
+        # This is a simplified matmul from Triton tutorial.
+        pid = tl.program_id(axis=0)
+        M = tl.load(M_ptr)
+        if M == 0 and BLOCK_SIZE_M > 32:
+            # This will run the full matmul if BLOCK_SIZE_M > 32
+            M = 4096
+        elif M == 0:
+            # This directly returns, which will cut short the bad config of 16-block size.
+            return
+        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+        num_pid_in_group = GROUP_SIZE_M * num_pid_n
+        group_id = pid // num_pid_in_group
+        first_pid_m = group_id * GROUP_SIZE_M
+        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
+        pid_n = (pid % num_pid_in_group) // group_size_m
+
+        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
+        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+        offs_k = tl.arange(0, BLOCK_SIZE_K)
+        a_ptrs = a_ptr + (offs_am[:, None] + offs_k[None, :])
+        b_ptrs = b_ptr + (offs_k[:, None] + offs_bn[None, :])
+
+        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
+            accumulator = tl.dot(a, b, accumulator)
+            a_ptrs += BLOCK_SIZE_K
+            b_ptrs += BLOCK_SIZE_K
+        c = accumulator.to(tl.float16)
+
+        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+        c_ptrs = c_ptr + offs_cm[:, None] + offs_cn[None, :]
+        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+        tl.store(c_ptrs, c, mask=c_mask)
+
+    @triton.jit
+    def kernel_with_docstring_double_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
+        """
+        This kernel contains a triple-quote docstring w/ double quotes.
+        Make sure that codegen sanitizes the docstring.
+        """
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
+        tl.store(out_ptr + offsets, ones, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_with_docstring_single_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
+        '''
+        This kernel contains a triple-quote docstring w/ single quotes
+        Make sure that codegen sanitizes the docstring.
+        To prevent it from being linted to double quotes: """!!!"""
+        '''
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
+        tl.store(out_ptr + offsets, ones, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_inline_asm_double_quotes(
+        in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
+    ):
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        data = tl.load(in_ptr + offsets, mask=offsets < numel)
+        cos_pow = tl.inline_asm_elementwise(
+            asm="""
+            {
+                cos.approx.f32 $0, $1;
+                ex2.approx.f32 $0, $0;
+            }
+                """,
+            constraints=("=r, r"),
+            args=[data],
+            dtype=tl.float32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)
+
+    @triton.jit
+    def kernel_inline_asm_single_quotes(
+        in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
+    ):
+        pid = tl.program_id(axis=0)
+        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
+        data = tl.load(in_ptr + offsets, mask=offsets < numel)
+        cos_pow = tl.inline_asm_elementwise(
+            asm='''
+            {
+                // double quotes to pacify the linter """!!!"""
+                cos.approx.f32 $0, $1;
+                ex2.approx.f32 $0, $0;
+            }
+                ''',
+            constraints=("=r, r"),
+            args=[data],
+            dtype=tl.float32,
+            is_pure=True,
+            pack=1,
+        )
+        tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)
+
+    # support the old (experimental) and new (tensor_descriptor) APIs
+    def create_tensor_descriptor_shim(
+        tensor, block_sizes: list[int], new_api: bool = True
+    ):
+        if new_api:
+            return triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
+                tensor, block_sizes
+            )
+        else:
+            if len(block_sizes) == 1:
+                return triton.tools.experimental_descriptor.create_1d_tma_descriptor(
+                    tensor.data_ptr(),
+                    tensor.size(0),
+                    block_sizes[0],
+                    tensor.element_size(),
+                )
+            else:
+                assert len(block_sizes) == 2
+                return triton.tools.experimental_descriptor.create_2d_tma_descriptor(
+                    tensor.data_ptr(),
+                    tensor.size(0),
+                    tensor.size(1),
+                    block_sizes[0],
+                    block_sizes[1],
+                    tensor.element_size(),
+                )
diff --git a/.venv/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py b/.venv/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0bdbf2d4ef69cde420a50b35407cce51221ce7a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/testing/_internal/two_tensor.py
@@ -0,0 +1,100 @@
+# mypy: ignore-errors
+
+import torch
+import torch.utils._pytree as pytree
+from torch._export.wrappers import mark_subclass_constructor_exportable_experimental
+from torch.utils._python_dispatch import return_and_correct_aliasing
+
+
+# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
+class TwoTensor(torch.Tensor):
+    @staticmethod
+    def __new__(cls, a, b, outer_size=None, outer_stride=None):
+        if outer_size is None:
+            outer_size = a.size()
+        if outer_stride is None:
+            outer_stride = a.stride()
+
+        assert (
+            a.device == b.device
+            and a.layout == b.layout
+            and a.requires_grad == b.requires_grad
+            and a.dtype == b.dtype
+        )
+        # I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
+        shape = outer_size
+        kwargs = {}
+        kwargs["strides"] = outer_stride
+        kwargs["storage_offset"] = a.storage_offset()
+        kwargs["device"] = a.device
+        kwargs["layout"] = a.layout
+        kwargs["requires_grad"] = a.requires_grad
+        kwargs["dtype"] = a.dtype
+        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
+
+        assert a.shape == b.shape
+        assert a.stride() == b.stride()
+        assert a.storage_offset() == b.storage_offset()
+        return out
+
+    @torch._disable_dynamo
+    @mark_subclass_constructor_exportable_experimental
+    def __init__(self, a, b, outer_size=None, outer_stride=None):
+        self.a = a
+        self.b = b
+
+    def __repr__(self):
+        a_repr = repr(self.a)
+        b_repr = repr(self.b)
+        return f"TwoTensor({a_repr}, {b_repr})"
+
+    def __tensor_flatten__(self):
+        return ["a", "b"], None
+
+    @staticmethod
+    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
+        assert meta is None
+        a, b = inner_tensors["a"], inner_tensors["b"]
+        if type(a) is torch.Tensor:
+            assert outer_size is not None
+            assert outer_stride is not None
+        return TwoTensor(a, b, outer_size, outer_stride)
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if kwargs is None:
+            kwargs = {}
+        args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
+        args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)
+
+        kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
+        kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)
+
+        out_a = func(*args_a, **kwargs_a)
+        out_b = func(*args_b, **kwargs_b)
+        out_a_flat, spec = pytree.tree_flatten(out_a)
+        out_b_flat = pytree.tree_leaves(out_b)
+        # for aten ops that return non-tensors, just assume that
+        # our two inner tensors return the same value
+        out_flat = [
+            cls(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
+            for o_a, o_b in zip(out_a_flat, out_b_flat)
+        ]
+        out = pytree.tree_unflatten(out_flat, spec)
+        from torch._higher_order_ops.cond import cond_op
+
+        if func is cond_op:
+            return out
+        else:
+            return return_and_correct_aliasing(func, args, kwargs, out)
+
+    def get_elem_a(self):
+        return self.a
+
+
+class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        out = func(*args, **kwargs)
+        if torch._subclasses.fake_tensor._is_tensor_constructor(func):
+            out = TwoTensor(out, out.clone())
+        return out
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65d8fab4997220d057575087ae0f95fab1d4464f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fd3dfa76740e698314dc986e25b9dee11a8d554
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_backport_slots.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_backport_slots.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..540a96fd1bc0f985594f04e7e87d19d0a033b212
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_backport_slots.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4731d9373214b7072f012701f18de10d23d5f95
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_config_module.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_content_store.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_content_store.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15cbcf3a2d68282f4ec4dee2aacefd241a5b68a8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_content_store.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5116ff1e612473d3fb960ca049668b53d0a4db26
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_contextlib.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3edbc1440437daf053b23e11b0a4810bd16b5854
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_device.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_device.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d02fe68d7b543db1218b84f7313b8c0fb3c65432
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_device.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21a9225119179cdccb9232661a071c60e7eadabc
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_exposed_in.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_exposed_in.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2bf4193282046457729492691e096e4096cf96c0
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_exposed_in.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1bb0ed5f78ff3bd7dd5d30c0afe9a340f062447
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65ea7657ecafe6ddafa98e349050c7180026325b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_functools.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..619fdd6a39456463e9a4542303b9028b60111214
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_mode_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_ordered_set.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_ordered_set.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5e10a2d2187564d7c10dd52f52be8da4730dc05
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_ordered_set.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07f82e2f10c56a08f0b7ef2486137baa5d6e08f4
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_pytree.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_pytree.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8828f92dd469f8d9cf511eb665b9cf6fc78e078
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_pytree.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_stats.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_stats.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..170821697573cb748f4bbf260e850efff21769a9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_stats.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cee519d9886a98557bd75a0f5d2acb204e9bb8f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_thunk.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e207f911f4066a68a9b66d0140bf11b8cc542f35
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_traceback.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..445b4239b2dc2e9e6d69395aa8bfec69994465ff
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_triton.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_typing_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_typing_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b8d058b0a9ffd198dfcb4dd5e42896c3e5a48a9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_typing_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9f787e2cd3a8786f5570bbe436d5a8af748f709
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/backend_registration.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5c366a9aa1b0ab7d198eaefd55e30c8c1967fd8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/checkpoint.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/collect_env.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/collect_env.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..031e879df38bdc91b267466ed00d1686b776c11d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/collect_env.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93c11937f580e74625399c039ebf153a40db324e
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/deterministic.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/deterministic.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e096eb483e58aeefb0c0823a02e5276af1e6c9b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/deterministic.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16d9e4c636c842dad4badb729d4f5893399282bd
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/dlpack.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/hooks.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/hooks.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54e1bf45824f4de098cce981613afdb55258d104
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/hooks.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e6c3c9e8bb35174e5774f12f1851e1512783140
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/weak.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/weak.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b35319a5d3f29eac6731cd2ab9e70513eb57d46
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/weak.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9deab75cef21acb1512c19a2a9bb30ddf17f2ca1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eed46e622ea49e2bfc0b618183e14e38b9020670
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd9e35a9c31c644463c143e2c1e8983f140d406d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae7eb38cf1f4ee7e58727ebc61586302dfe4cce0
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2aa8a32d21b896e730718e164e0636ba1b97ec67
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d68a39be73458a8e114ed8e515d380a6c949881
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b90007bb3e27c8bff5c34569f996feefc907c50b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f404a45b43289c7aa4d733f8f7313a91d4a823a9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c39e7e2d050b2d4230d43642033db303f1bb3db
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/functions.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c35cd7eb5732228140f418bd3aa862458f68cc0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/functions.py
@@ -0,0 +1,1407 @@
+# mypy: allow-untyped-defs
+import functools
+import math
+import operator
+import sys
+from typing import Callable, Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union
+from typing_extensions import TypeVarTuple, Unpack
+
+import sympy
+from sympy import S
+from sympy.core import sympify
+from sympy.core.expr import Expr
+from sympy.core.function import Application
+from sympy.core.logic import _torf, fuzzy_and, fuzzy_or
+from sympy.core.numbers import equal_valued
+from sympy.core.operations import LatticeOp, ShortCircuit
+from sympy.core.sorting import ordered
+from sympy.core.traversal import walk
+from sympy.printing.precedence import PRECEDENCE
+from sympy.utilities.iterables import sift
+
+from .numbers import int_oo
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterable
+
+
+_T = TypeVar("_T", bound=SupportsFloat)
+_Ts = TypeVarTuple("_Ts")
+
+# Portions of this file are adapted from the Sympy codebase, which was
+# licensed as follows:
+#
+#   Copyright (c) 2006-2023 SymPy Development Team
+#
+#   All rights reserved.
+#
+#   Redistribution and use in source and binary forms, with or without
+#   modification, are permitted provided that the following conditions are met:
+#
+#     a. Redistributions of source code must retain the above copyright notice,
+#        this list of conditions and the following disclaimer.
+#     b. Redistributions in binary form must reproduce the above copyright
+#        notice, this list of conditions and the following disclaimer in the
+#        documentation and/or other materials provided with the distribution.
+#     c. Neither the name of SymPy nor the names of its contributors
+#        may be used to endorse or promote products derived from this software
+#        without specific prior written permission.
+#
+#   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+#   AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#   IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+#   ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
+#   ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+#   DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+#   SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+#   CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+#   LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
+#   OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
+#   DAMAGE.
+
+__all__ = [
+    "FloorDiv",
+    "ModularIndexing",
+    "Where",
+    "PythonMod",
+    "Mod",
+    "CleanDiv",
+    "CeilToInt",
+    "FloorToInt",
+    "CeilDiv",
+    "IntTrueDiv",
+    "FloatTrueDiv",
+    "LShift",
+    "RShift",
+    "IsNonOverlappingAndDenseIndicator",
+    "TruncToFloat",
+    "TruncToInt",
+    "RoundToInt",
+    "RoundDecimal",
+    "ToFloat",
+    "FloatPow",
+    "PowByNatural",
+    "Identity",
+]
+
+
+def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
+    # No need to check that two args are not the same, since expr is pr-optimized but we do it anyway.
+    return (
+        expr.is_Add
+        and len(expr._args) == 2
+        and expr._args[0].is_symbol
+        and expr._args[1].is_symbol
+        and expr._args[0] is not expr._args[1]
+    )
+
+
+def _keep_float(
+    f: Callable[[Unpack[_Ts]], _T]
+) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
+    @functools.wraps(f)
+    def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
+        r: Union[_T, sympy.Float] = f(*args)
+        if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
+            r, sympy.Float
+        ):
+            r = sympy.Float(float(r))
+        return r
+
+    return inner
+
+
+def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]:
+    if None in (x, y):
+        return None
+    return x == y
+
+
+def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
+    """
+    Fast path for sympy.gcd, using a simple factoring strategy.
+
+    We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0,
+    where n is the greatest common integer factor and e is the largest
+    syntactic common factor (i.e., common sub-expression) in p and q.
+    Then the gcd returned is n*e, cancelling which we would be left with
+    p1 + p2 and q0.
+
+    Note that further factoring of p1 + p2 and q0 might be possible with
+    sympy.factor (which uses domain-specific theories). E.g., we are unable
+    to find that x*y + x + y + 1 is divisible by x + 1. More generally,
+    when q is of the form q1 + q2 (instead of being already factored) it
+    might be necessary to fall back on sympy.gcd.
+    """
+
+    def integer_coefficient(x: sympy.Basic) -> int:
+        integer_coefficients: list[int] = [
+            abs(int(arg))
+            for arg in sympy.Mul.make_args(x)
+            if isinstance(arg, (int, sympy.Integer))
+        ]
+        return math.prod(integer_coefficients)
+
+    def integer_factor(expr: sympy.Basic) -> int:
+        integer_factors: Iterable[int] = map(
+            integer_coefficient, sympy.Add.make_args(expr)
+        )
+        return functools.reduce(math.gcd, integer_factors)
+
+    gcd: int = math.gcd(integer_factor(p), integer_factor(q))
+    p, q = p / gcd, q / gcd  # type: ignore[operator, assignment]  # remove in py3.12
+
+    base_splits: list[tuple[sympy.Basic, ...]] = list(
+        map(sympy.Mul.make_args, sympy.Add.make_args(p))
+    )
+    divisor_split: tuple[sympy.Basic, ...] = sympy.Mul.make_args(q)
+    for x in divisor_split:
+        if all(x in base_split for base_split in base_splits):
+            gcd = gcd * x  # type: ignore[operator]  # remove in py3.12
+    return gcd  # type: ignore[return-value]  # remove in py3.12
+
+
+# It would be nice to have assertions on whether or not inputs is_integer
+# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
+# sometimes inconsistently reports floats an integers.
+#
+# What we can assume from sympy is that if something is an int, it
+# definitely is is_integer, but if it is a float it may or may not
+# be is_integer.  So we are unable to do strong asserts that things
+# are NOT integers.
+
+
+# TODO: In Triton, // rounds to zero, but in Python, it is floor division.
+# When we can prove both arguments are non-negative, we should just have a
+# GenericFloorDiv (name pending) which can codegen efficiently in Python/C,
+# and then PythonFloorDiv and CIntDiv which have the appropriate rounding
+# semantics.
+#
+# Right now, FloorDiv de facto changes behavior if arguments are negative or
+# not, this can potentially cause correctness issues.
+class FloorDiv(sympy.Function):
+    """
+    We maintain this so that:
+    1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
+    2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
+
+    NB: This is Python-style floor division, round to -Inf
+    """
+
+    nargs: tuple[int, ...] = (2,)
+    precedence: int = 35  # lower precedence than add
+    is_integer: bool = True
+
+    @property
+    def base(self) -> sympy.Basic:
+        return self.args[0]
+
+    @property
+    def divisor(self) -> sympy.Basic:
+        return self.args[1]
+
+    def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
+        base = printer.parenthesize(self.base, PRECEDENCE["Atom"] - 0.5)
+        divisor = printer.parenthesize(self.divisor, PRECEDENCE["Atom"] - 0.5)
+        return f"({base}//{divisor})"
+
+    # Automatic evaluation.
+    # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
+    @classmethod
+    def eval(
+        cls, base: sympy.Integer, divisor: sympy.Integer
+    ) -> Union[sympy.Basic, None]:
+        # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
+        # Assert triggered by inequality solver
+        # assert base.is_integer, base
+        # assert divisor.is_integer, divisor
+
+        # We don't provide the same error message as in Python because SymPy
+        # makes it difficult to check the types.
+        if divisor.is_zero:
+            raise ZeroDivisionError("division by zero")
+        if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in (
+            int_oo,
+            -int_oo,
+            sympy.oo,
+            -sympy.oo,
+        ):
+            return sympy.nan
+        if base is sympy.nan or divisor is sympy.nan:
+            return sympy.nan
+
+        if base.is_zero:
+            return sympy.S.Zero
+        if base.is_integer and equal_valued(divisor, 1):
+            return base
+        if base.is_integer and equal_valued(divisor, -1):
+            return sympy.Mul(base, -1)
+        if (
+            isinstance(base, sympy.Number)
+            and isinstance(divisor, sympy.Number)
+            and (
+                base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
+                or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
+            )
+        ):
+            r = float(base) / float(divisor)
+            if r == math.inf:
+                return int_oo
+            elif r == -math.inf:
+                return -int_oo
+            elif math.isnan(r):
+                return sympy.nan
+            else:
+                return sympy.Integer(math.floor(r))
+        if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
+            return sympy.Integer(int(base) // int(divisor))
+        if isinstance(base, FloorDiv):
+            return FloorDiv(base.args[0], base.args[1] * divisor)
+
+        # Expands (x + y) // b into x // b + y // b.
+        # This only works if floor is an identity, i.e. x / b is an integer.
+        if isinstance(divisor, sympy.Integer):
+            quotients = 0
+            terms = []
+            for term in sympy.Add.make_args(base):
+                quotient = term / divisor
+
+                if quotient.is_integer:
+                    terms.append(term)
+                    quotients += quotient
+
+            if len(terms) != 0:
+                # Passing evaluate = False since expression will be optimized during the subtraction post its construction.
+                return (
+                    FloorDiv(base - sympy.Add(*terms, evaluate=False), divisor)
+                    + quotients
+                )
+
+        try:
+            gcd = simple_floordiv_gcd(base, divisor)
+            if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add):
+                gcd = sympy.gcd(base, divisor)
+            if not equal_valued(gcd, 1):
+                return FloorDiv(
+                    sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
+                )
+        except sympy.PolynomialError:
+            pass  # https://github.com/pytorch/pytorch/issues/108276
+
+        return None
+
+    def _ccode(self, printer):
+        base = printer.parenthesize(self.base, PRECEDENCE["Atom"] - 0.5)
+        divisor = printer.parenthesize(self.divisor, PRECEDENCE["Atom"] - 0.5)
+        return f"floor({base}/{divisor})"
+
+
+class ModularIndexing(sympy.Function):
+    """
+    ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus
+    """
+
+    nargs: tuple[int, ...] = (3,)
+    is_integer: bool = True
+    precedence: int = 35  # lower precedence than add
+
+    @classmethod
+    def eval(
+        cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer
+    ) -> Optional[sympy.Basic]:
+        if base == 0 or modulus == 1:
+            return sympy.S.Zero
+
+        if (
+            isinstance(base, sympy.Integer)
+            and isinstance(divisor, sympy.Integer)
+            and isinstance(modulus, sympy.Integer)
+        ):
+            return (base // divisor) % modulus
+
+        try:
+            if divisor != 1:
+                gcd = sympy.gcd(base, divisor)
+                if gcd != 1:
+                    return ModularIndexing(
+                        sympy.simplify(base / gcd),
+                        sympy.simplify(divisor / gcd),
+                        modulus,
+                    )
+        except sympy.PolynomialError:
+            pass  # https://github.com/pytorch/pytorch/issues/108276
+
+        if isinstance(base, sympy.Add):
+            new_terms: list[sympy.Integer] = []
+            all_positive: bool = True
+            for term in base.args:
+                if sympy.gcd(term, modulus * divisor) != modulus * divisor:
+                    if (isinstance(term, sympy.Integer) and term < 0) or (
+                        isinstance(term, sympy.Mul)
+                        and isinstance(term.args[0], sympy.Integer)
+                        and term.args[0] < 0
+                    ):
+                        # workaround for https://github.com/triton-lang/triton/issues/619,
+                        # if there are negative terms, // produces wrong result
+                        # TODO if https://github.com/triton-lang/triton/issues/619 is fixed
+                        # this optimization would become valid
+                        all_positive = False
+                        break
+                    else:
+                        new_terms.append(term)
+
+            if len(new_terms) != len(base.args) and all_positive:
+                return ModularIndexing(sum(new_terms), divisor, modulus)
+
+        if isinstance(base, FloorDiv):
+            return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
+
+        return None
+
+    def _eval_is_nonnegative(self) -> Optional[bool]:
+        p, q = self.args[:2]
+        return fuzzy_eq(p.is_nonnegative, q.is_nonnegative)  # type: ignore[attr-defined]
+
+
+class Where(sympy.Function):
+    """
+    Good ol' ternary operator
+    """
+
+    nargs: tuple[int, ...] = (3,)
+    precedence: int = 35  # lower precedence than add
+
+    def _eval_is_integer(self) -> Optional[bool]:
+        return True if self.args[1].is_integer and self.args[2].is_integer else None  # type: ignore[attr-defined]
+
+    def _eval_is_nonnegative(self) -> Optional[bool]:
+        return (
+            True
+            if self.args[1].is_nonnegative and self.args[2].is_nonnegative  # type: ignore[attr-defined]
+            else None
+        )
+
+    def _eval_is_positive(self) -> Optional[bool]:
+        return True if self.args[1].is_positive and self.args[2].is_positive else None  # type: ignore[attr-defined]
+
+    @classmethod
+    def eval(
+        cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic
+    ) -> Optional[sympy.Basic]:
+        if c == sympy.true:
+            return p
+        elif c == sympy.false:
+            return q
+        return None
+
+
+# Python-style modulus: take sign from RHS
+class PythonMod(sympy.Function):
+    nargs: tuple[int, ...] = (2,)
+
+    precedence: int = 35  # lower precedence than add
+    is_integer: bool = True
+
+    @classmethod
+    def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]:
+        # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
+        # Triggered by sympy.solvers.inequalities.reduce_inequalities
+        # assert p.is_integer, p
+        # assert q.is_integer, q
+
+        if q.is_zero:
+            raise ZeroDivisionError("Modulo by zero")
+
+        # Three cases:
+        #   1. p == 0
+        #   2. p is either q or -q
+        #   3. p is integer and q == 1
+        if p is S.Zero or p in (q, -q) or q == 1:
+            return S.Zero
+
+        # Evaluate if they are both literals.
+        if q.is_Number and p.is_Number:
+            return p % q
+
+        # If q == 2, it's a matter of whether p is odd or even.
+        if q.is_Number and q == 2:
+            if p.is_even:
+                return S.Zero
+            if p.is_odd:
+                return S.One
+
+        # If p is a multiple of q.
+        r = p / q
+        if r.is_integer:
+            return S.Zero
+
+        # If p < q and its ratio is positive, then:
+        #   - floor(p / q) = 0
+        #   - p % q = p - floor(p / q) * q = p
+        less = p < q
+        if less.is_Boolean and bool(less) and r.is_positive:
+            return p
+
+        if sympy.Mod(p, q) == 0:
+            return S.Zero
+
+        return None
+
+    # NB: args[1] for PythonMod
+    def _eval_is_nonnegative(self) -> Optional[bool]:
+        return True if self.args[1].is_positive else None  # type: ignore[attr-defined]
+
+    def _eval_is_nonpositive(self) -> Optional[bool]:
+        return True if self.args[1].is_negative else None  # type: ignore[attr-defined]
+
+
+# Generic modulus: only defined on non-negative arguments
+class Mod(sympy.Function):
+    nargs = (2,)
+    precedence: int = 35  # lower precedence than add
+
+    is_integer = True
+    is_nonnegative = True
+
+    @classmethod
+    def eval(cls, p, q):
+        # This was adapted from: sympy/core/mod.py
+
+        # Triggered by
+        # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
+        # assert p.is_integer, p
+        # assert q.is_integer, q
+
+        if q.is_zero:
+            raise ZeroDivisionError("Modulo by zero")
+
+        # Three cases:
+        #   1. p == 0
+        #   2. p is either q or -q
+        #   3. p is integer and q == 1
+        if p is S.Zero or p in (q, -q) or q == 1:
+            return S.Zero
+
+        # Evaluate if they are both literals.
+        if q.is_Number and p.is_Number:
+            assert p >= 0, p
+            assert q >= 1, q
+            return p % q
+
+        # If q == 2, it's a matter of whether p is odd or even.
+        if q.is_Number and q == 2:
+            if p.is_even:
+                return S.Zero
+            if p.is_odd:
+                return S.One
+
+        # If p is a multiple of q.
+        r = p / q
+        if r.is_integer:
+            return S.Zero
+
+        # If p < q and its ratio is positive, then:
+        #   - floor(p / q) = 0
+        #   - p % q = p - floor(p / q) * q = p
+        less = p < q
+        if less.is_Boolean and bool(less) and r.is_positive:
+            return p
+
+
+class CleanDiv(FloorDiv):
+    """
+    Div where we can assume no rounding.
+    This is to enable future optimizations.
+    """
+
+
+# Don't use sympy ceiling/floor as they will attempt simplifications involving
+# frac
+class CeilToInt(sympy.Function):
+    is_integer = True
+
+    @classmethod
+    def eval(cls, number):
+        # assert number.is_integer is not True, number
+        if number in (sympy.oo, int_oo):
+            return int_oo
+        if number in (-sympy.oo, -int_oo):
+            return -int_oo
+        if isinstance(number, sympy.Number):
+            return sympy.Integer(math.ceil(float(number)))
+
+    def _ccode(self, printer):
+        number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
+        return f"ceil({number})"
+
+
+class FloorToInt(sympy.Function):
+    is_integer = True
+
+    @classmethod
+    def eval(cls, number):
+        if number in (sympy.oo, int_oo):
+            return int_oo
+        if number in (-sympy.oo, int_oo):
+            return -int_oo
+        if isinstance(number, sympy.Integer):
+            return number
+        if isinstance(number, sympy.Number):
+            return sympy.Integer(math.floor(float(number)))
+
+
+class CeilDiv(sympy.Function):
+    """
+    Div used in indexing that rounds up.
+    """
+
+    is_integer = True
+
+    def __new__(cls, base, divisor):
+        base = sympy.sympify(base)
+        divisor = sympy.sympify(divisor)
+        if sympy.gcd(base, divisor) == divisor:
+            return CleanDiv(base, divisor)
+        else:
+            return FloorDiv(base + (divisor - 1), divisor)
+
+
+class LShift(sympy.Function):
+    is_integer = True
+
+    @classmethod
+    def eval(cls, base, shift):
+        if shift < 0:
+            raise ValueError("negative shift count")
+        return base * 2**shift
+
+
+class RShift(sympy.Function):
+    is_integer = True
+
+    @classmethod
+    def eval(cls, base, shift):
+        if shift < 0:
+            raise ValueError("negative shift count")
+        return FloorDiv(base, 2**shift)
+
+
+class MinMaxBase(Expr, LatticeOp):  # type: ignore[misc]
+    def __new__(cls, *original_args, **assumptions):
+        from sympy.core.parameters import global_parameters
+
+        evaluate = assumptions.pop("evaluate", global_parameters.evaluate)
+        args = (sympify(arg) for arg in original_args)
+
+        # See the comment in _satisfy_unique_summations_symbols.
+        unique_summations_symbols = (
+            None
+            if not evaluate
+            else cls._satisfy_unique_summations_symbols(original_args)
+        )
+
+        if evaluate:
+            try:
+                # first standard filter, for cls.zero and cls.identity
+                # also reshape Max(a, Max(b, c)) to Max(a, b, c)
+                args = frozenset(cls._new_args_filter(args))  # type: ignore[assignment]
+            except ShortCircuit:
+                return cls.zero  # type: ignore[attr-defined]
+
+            # No need to run _collapse_arguments and _find_localzeros, see the comment
+            # in _satisfy_unique_summations_symbols.
+            if unique_summations_symbols is None:
+                # remove redundant args that are easily identified
+                args = cls._collapse_arguments(args, **assumptions)
+
+                # find local zeros
+                args = cls._find_localzeros(args, **assumptions)
+
+        args = frozenset(args)
+
+        if not args:
+            return cls.identity  # type: ignore[attr-defined]
+
+        if len(args) == 1:
+            return list(args).pop()
+
+        # base creation
+        obj = Expr.__new__(cls, *ordered(args), **assumptions)
+        obj._argset = args
+
+        obj.unique_summations_symbols = unique_summations_symbols
+        return obj
+
+    @classmethod
+    def _satisfy_unique_summations_symbols(
+        cls, args
+    ) -> Optional[set[sympy.core.symbol.Symbol]]:
+        """
+        One common case in some models is building expressions of the form
+        max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...).
+        For such expressions, we call the Max constructor X times (once for each nested
+        max) and the expression gets flattened.
+
+        An expensive cost in constructing those expressions is running _collapse_arguments
+        and _find_localzeros. However, those two optimizations are unnecessary when the args
+        to max are all of the form a+b, c+d, ..etc where each term uses a unique set of symbols.
+
+        This function is used to detect such properties of the expressions we are building
+        and if so inform that we do not need to run those optimizations. To detect those,
+        we store a property in the expression that tells that this expression is a min/max
+        operation over terms that use unique symbols "unique_summations_symbols". This property
+        also memoize the set of symbols used in all the terms to make it faster to detect this
+        property inductively.
+
+        When we apply max to add a new term, all we need to do is check if the new term uses
+        unique symbols (with respect to existing terms and itself).
+        Example:
+        t = Max(a+b, c+d) ==> satisfies the property
+        Max(t, h+j)       ==> h,j not in [a,b,c,d] => satisfy the property.
+
+        The function returns None if the new expression does not satisfy the unique_summations_symbols
+        property. Otherwise, it returns a new set of unique symbols.
+        """
+        if len(args) != 2:
+            return None
+
+        (lhs, rhs) = (
+            (args[1], args[0])
+            if isinstance(args[1], MinMaxBase)
+            else (args[0], args[1])
+        )
+
+        if not _is_symbols_binary_summation(rhs):
+            return None
+
+        # base case max(a+b, c+d) ==> satisfies the property if a+b and c+d use unique symbols.
+        if _is_symbols_binary_summation(lhs):
+            return cls._unique_symbols(args)
+
+        # inductive case max(t, h+j) ==> satisfies the property if h, j not in t.unique_summations_symbols
+        if isinstance(lhs, MinMaxBase):
+            lhs_unique_summations_symbols = getattr(
+                lhs, "unique_summations_symbols", None
+            )
+            if lhs_unique_summations_symbols is not None:
+                return cls._unique_symbols([rhs], lhs_unique_summations_symbols)
+
+        return None
+
+    @classmethod
+    def _unique_symbols(
+        cls, args, initial_set: Optional[set[sympy.core.symbol.Symbol]] = None
+    ) -> Optional[set[sympy.core.symbol.Symbol]]:
+        """
+        Return seen_symbols if all atoms in all args are all unique symbols,
+        else returns None. initial_set can be used to represent initial value for seen_symbols
+        """
+        seen_symbols = set() if initial_set is None else initial_set
+        for arg in args:
+            for element in arg.atoms():
+                if not isinstance(element, sympy.core.symbol.Symbol):
+                    return None
+                elif element in seen_symbols:
+                    return None
+                else:
+                    seen_symbols.add(element)
+        return seen_symbols
+
+    @classmethod
+    def _collapse_arguments(cls, args, **assumptions):
+        """Remove redundant args.
+
+        Examples
+        ========
+
+        >>> from sympy import Min, Max
+        >>> from sympy.abc import a, b, c, d, e
+
+        Any arg in parent that appears in any
+        parent-like function in any of the flat args
+        of parent can be removed from that sub-arg:
+
+        >>> Min(a, Max(b, Min(a, c, d)))
+        Min(a, Max(b, Min(c, d)))
+
+        If the arg of parent appears in an opposite-than parent
+        function in any of the flat args of parent that function
+        can be replaced with the arg:
+
+        >>> Min(a, Max(b, Min(c, d, Max(a, e))))
+        Min(a, Max(b, Min(a, c, d)))
+        """
+        if not args:
+            return args
+        args = list(ordered(args))
+        if cls is Min:
+            other = Max
+        else:
+            other = Min  # type: ignore[assignment]
+
+        # find global comparable max of Max and min of Min if a new
+        # value is being introduced in these args at position 0 of
+        # the ordered args
+        if args[0].is_number:
+            sifted = mins, maxs = [], []  # type: ignore[var-annotated]
+            for i in args:
+                for v in walk(i, Min, Max):
+                    if v.args[0].is_comparable:
+                        sifted[isinstance(v, Max)].append(v)
+            small = Min.identity
+            for i in mins:
+                v = i.args[0]
+                if v.is_number and (v < small) == True:  # noqa: E712
+                    small = v
+            big = Max.identity
+            for i in maxs:
+                v = i.args[0]
+                if v.is_number and (v > big) == True:  # noqa: E712
+                    big = v
+            # at the point when this function is called from __new__,
+            # there may be more than one numeric arg present since
+            # local zeros have not been handled yet, so look through
+            # more than the first arg
+            if cls is Min:
+                for arg in args:
+                    if not arg.is_number:
+                        break
+                    if (arg < small) == True:  # noqa: E712
+                        small = arg
+            elif cls == Max:
+                for arg in args:
+                    if not arg.is_number:
+                        break
+                    if (arg > big) == True:  # noqa: E712
+                        big = arg
+            T = None
+            if cls is Min:
+                if small != Min.identity:
+                    other = Max
+                    T = small
+            elif big != Max.identity:
+                other = Min  # type: ignore[assignment]
+                T = big
+            if T is not None:
+                # remove numerical redundancy
+                for i in range(len(args)):
+                    a = args[i]
+                    if isinstance(a, other):
+                        a0 = a.args[0]
+                        if (  # noqa: E712
+                            (a0 > T) if other == Max else (a0 < T)  # noqa: E712
+                        ) == True:  # noqa: E712
+                            args[i] = cls.identity  # type: ignore[attr-defined]
+
+        # remove redundant symbolic args
+        def do(ai, a):
+            if not isinstance(ai, (Min, Max)):
+                return ai
+            cond = a in ai.args
+            if not cond:
+                return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
+            if isinstance(ai, cls):
+                return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
+            return a
+
+        for i, a in enumerate(args):
+            args[i + 1 :] = [do(ai, a) for ai in args[i + 1 :]]
+
+        # factor out common elements as for
+        # Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z))
+        # and vice versa when swapping Min/Max -- do this only for the
+        # easy case where all functions contain something in common;
+        # trying to find some optimal subset of args to modify takes
+        # too long
+
+        def factor_minmax(args):
+            is_other = lambda arg: isinstance(arg, other)  # noqa: E731
+            other_args, remaining_args = sift(args, is_other, binary=True)
+            if not other_args:
+                return args
+
+            # Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v})
+            arg_sets = [set(arg.args) for arg in other_args]
+            common = set.intersection(*arg_sets)
+            if not common:
+                return args
+
+            new_other_args = list(common)
+            arg_sets_diff = [arg_set - common for arg_set in arg_sets]
+
+            # If any set is empty after removing common then all can be
+            # discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b)
+            if all(arg_sets_diff):
+                other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff]
+                new_other_args.append(cls(*other_args_diff, evaluate=False))
+
+            other_args_factored = other(*new_other_args, evaluate=False)
+            return remaining_args + [other_args_factored]
+
+        if len(args) > 1:
+            args = factor_minmax(args)
+
+        return args
+
+    @classmethod
+    def _new_args_filter(cls, arg_sequence):
+        """
+        Generator filtering args.
+
+        first standard filter, for cls.zero and cls.identity.
+        Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``,
+        and check arguments for comparability
+        """
+        for arg in arg_sequence:
+            # pre-filter, checking comparability of arguments
+            if (
+                not isinstance(arg, Expr)
+                or arg.is_extended_real is False
+                or (arg.is_number and not arg.is_comparable)
+            ):
+                raise ValueError(f"The argument '{arg}' is not comparable.")
+
+            if arg == cls.zero:  # type: ignore[attr-defined]
+                raise ShortCircuit(arg)
+            elif arg == cls.identity:  # type: ignore[attr-defined]
+                continue
+            elif arg.func == cls:
+                yield from arg.args
+            else:
+                yield arg
+
+    @classmethod
+    def _find_localzeros(cls, values, **options):
+        """
+        Sequentially allocate values to localzeros.
+
+        When a value is identified as being more extreme than another member it
+        replaces that member; if this is never true, then the value is simply
+        appended to the localzeros.
+
+        Unlike the sympy implementation, we only look for zero and one, we don't
+        do generic is connected test pairwise which is slow
+        """
+
+        # First, collapse all numeric arguments
+        other_values = set()
+        num_value = None
+        for arg in values:
+            if arg.is_Number:
+                if num_value is None:
+                    num_value = arg
+                else:
+                    if cls is Max:
+                        num_value = max(num_value, arg)
+                    elif cls is Min:
+                        num_value = min(num_value, arg)
+                    else:
+                        raise AssertionError(f"impossible {cls}")
+            else:
+                other_values.add(arg)
+
+        # Special cases when there is only one symbolic value
+        if num_value is None:
+            return other_values
+
+        if len(other_values) == 0:
+            return {num_value}
+
+        if len(other_values) == 1:
+            other_value = next(iter(other_values))
+            if num_value in (0.0, 0) and other_value.is_nonnegative:
+                return other_values if cls is Max else {num_value}
+            if num_value == 1 and other_value.is_positive:
+                return other_values if cls is Max else {num_value}
+
+        other_values.add(num_value)
+        return other_values
+
+    _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args)  # noqa: E731
+    _eval_is_antihermitian = lambda s: _torf(  # noqa: E731
+        i.is_antihermitian for i in s.args  # noqa: E731
+    )  # noqa: E731
+    _eval_is_commutative = lambda s: _torf(  # noqa: E731
+        i.is_commutative for i in s.args  # noqa: E731
+    )  # noqa: E731
+    _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args)  # noqa: E731
+    _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args)  # noqa: E731
+    _eval_is_even = lambda s: _torf(i.is_even for i in s.args)  # noqa: E731
+    _eval_is_finite = lambda s: _torf(i.is_finite for i in s.args)  # noqa: E731
+    _eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args)  # noqa: E731
+    _eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args)  # noqa: E731
+    _eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args)  # noqa: E731
+    _eval_is_integer = lambda s: _torf(i.is_integer for i in s.args)  # noqa: E731
+    _eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args)  # noqa: E731
+    _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args)  # noqa: E731
+    _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args)  # noqa: E731
+    _eval_is_nonnegative = lambda s: _torf(  # noqa: E731
+        i.is_nonnegative for i in s.args  # noqa: E731
+    )  # noqa: E731
+    _eval_is_nonpositive = lambda s: _torf(  # noqa: E731
+        i.is_nonpositive for i in s.args  # noqa: E731
+    )  # noqa: E731
+    _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args)  # noqa: E731
+    _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args)  # noqa: E731
+    _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args)  # noqa: E731
+    _eval_is_positive = lambda s: _torf(i.is_positive for i in s.args)  # noqa: E731
+    _eval_is_prime = lambda s: _torf(i.is_prime for i in s.args)  # noqa: E731
+    _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args)  # noqa: E731
+    _eval_is_real = lambda s: _torf(i.is_real for i in s.args)  # noqa: E731
+    _eval_is_extended_real = lambda s: _torf(  # noqa: E731
+        i.is_extended_real for i in s.args  # noqa: E731
+    )  # noqa: E731
+    _eval_is_transcendental = lambda s: _torf(  # noqa: E731
+        i.is_transcendental for i in s.args  # noqa: E731
+    )  # noqa: E731
+    _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args)  # noqa: E731
+
+
+class Max(MinMaxBase, Application):  # type: ignore[misc]
+    r"""
+    Return, if possible, the maximum value of the list.
+    """
+
+    zero = S.Infinity
+    identity = S.NegativeInfinity
+
+    def _eval_is_positive(self):  # type:ignore[override]
+        return fuzzy_or(a.is_positive for a in self.args)  # type: ignore[attr-defined]
+
+    def _eval_is_nonnegative(self):  # type:ignore[override]
+        return fuzzy_or(a.is_nonnegative for a in self.args)  # type: ignore[attr-defined]
+
+    def _eval_is_negative(self):  # type:ignore[override]
+        return fuzzy_and(a.is_negative for a in self.args)
+
+
+class Min(MinMaxBase, Application):  # type: ignore[misc]
+    """
+    Return, if possible, the minimum value of the list.
+    """
+
+    zero = S.NegativeInfinity
+    identity = S.Infinity
+
+    def _eval_is_positive(self):  # type:ignore[override]
+        return fuzzy_and(a.is_positive for a in self.args)  # type: ignore[attr-defined]
+
+    def _eval_is_nonnegative(self):  # type:ignore[override]
+        return fuzzy_and(a.is_nonnegative for a in self.args)  # type: ignore[attr-defined]
+
+    def _eval_is_negative(self):  # type:ignore[override]
+        return fuzzy_or(a.is_negative for a in self.args)
+
+
+def safe_pow(base, exp):
+    sign = 1
+    if base < 0:
+        base = -base
+        sign = 1 if exp % 2 == 0 else -1
+    return sign * _safe_pow(base, exp)
+
+
+# Prevent people from overflowing pow
+def _safe_pow(base, exponent):
+    if exponent < 0:
+        raise ValueError("Exponent must be non-negative.")
+
+    if exponent == 0:
+        return 1
+
+    half_exp = safe_pow(base, exponent // 2)
+    if half_exp is int_oo:
+        return int_oo
+
+    # TODO: microoptimization is to avoid overflowing into arbitrary precision
+    # and detect overflow prior to doing operations
+
+    result = half_exp * half_exp
+    if result > sys.maxsize:
+        return int_oo
+
+    if exponent % 2 == 1:
+        result *= base
+        if result > sys.maxsize:
+            return int_oo
+
+    return result
+
+
+class PowByNatural(sympy.Function):
+    is_integer = True
+
+    precedence: int = 50  # precedence of mul
+
+    @classmethod
+    def eval(cls, base, exp):
+        if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
+            r = safe_pow(base, exp)
+            if r in (-int_oo, int_oo):
+                return r
+            return sympy.Integer(r)
+        if isinstance(exp, sympy.Integer):
+            # Rely on regular sympy Pow for this (note that iterated
+            # multiplication turns into a Pow anyway, you can't escape!!)
+            return sympy.Pow(base, exp)
+        if exp in (int_oo, sympy.oo):
+            if base.is_nonnegative:
+                return int_oo
+            elif base.is_negative:
+                return sympy.zoo  # this is apparently what (-2)**sympy.oo does
+        # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
+        # is a natural number if we do
+
+
+# base is assumed to be nonnegative, thereby prevent complex numbers from
+# occuring
+class FloatPow(sympy.Function):
+    is_real = True
+
+    precedence: int = 60  # precedence of pow
+
+    @classmethod
+    def eval(cls, base, exp):
+        # NB: These test sympy.Number, not sympy.Float, because:
+        #   - Sometimes we may have sympy.oo or int_oo, and that's not a Float
+        #     (but coerces to math.Inf)
+        #   - Sometimes Float(0.0) will unpredictably decay to Integer(0),
+        #     but we should still accept it in floatey contexts
+        if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
+            return sympy.Float(float(base) ** float(exp))
+        # NB: do not do any nontrivial reasoning
+
+
+# Overloaded to be compatible with regular Python.
+# https://github.com/pytorch/pytorch/issues/90900
+#
+# In particular, sympy division is willing to simplify x/x == 1
+# where 1 is an integer, but this must be a float if x was float.
+class FloatTrueDiv(sympy.Function):
+    is_real = True
+
+    precedence: int = 35  # lower precedence than add
+
+    @classmethod
+    def eval(cls, base, divisor):
+        # assert base.is_integer is not True, base
+        # assert divisor.is_integer is not True, divisor
+
+        if divisor.is_zero:
+            raise ZeroDivisionError("division by zero")
+
+        if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
+            return sympy.Float(float(base) / float(divisor))
+
+
+# Overloaded to be compatible with regular Python.  We distinguish this from
+# FloatTrueDiv, because the code generation has to be different for this case:
+# Python has a fancy algorithm for integer true division that isn't just
+# "promote both arguments to float and use float division", so you need to
+# codegen it differently.  While technically you can work it out from the
+# types of the input, this is often inconvenient to do in Inductor codegen,
+# so just have a different operator
+# NB: Right now, Inductor codegen doesn't implement this correctly lol
+class IntTrueDiv(sympy.Function):
+    is_real = True
+
+    precedence: int = 35  # lower precedence than add
+
+    @classmethod
+    def eval(cls, base, divisor):
+        if divisor.is_zero:
+            raise ZeroDivisionError("division by zero")
+
+        if (
+            isinstance(base, sympy.Number)
+            and isinstance(divisor, sympy.Number)
+            and (
+                base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
+                or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
+            )
+        ):
+            # Don't have to worry about precision here, you're getting zero or
+            # inf from the division
+            return sympy.Float(float(base) / float(divisor))
+        if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
+            return sympy.Float(int(base) / int(divisor))
+
+    def _ccode(self, printer):
+        base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
+        divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
+        return f"((int){base}/(int){divisor})"
+
+
+# TODO: As an indicator, this != 0 implies == 1 (and vice versa).
+# Because we do not have the ability to guard on the stride permutation
+# at the moment, it is hard to make further inferences when this is true,
+# as although we know the tensor is contiguous in *some* layout, we don't
+# know which one (however, you could, for example, make the inference that
+# reshaping this to a 1D tensor can be guard-free.)
+class IsNonOverlappingAndDenseIndicator(sympy.Function):
+    is_integer = True
+
+    @classmethod
+    def eval(cls, *args):
+        assert len(args) % 2 == 0
+        dim = len(args) // 2
+        sizes = args[0:dim]
+        strides = args[dim:]
+
+        # sym_node imported in torch.__init__. Local import to avoid an import cycle
+        from torch.fx.experimental.symbolic_shapes import (
+            eval_is_non_overlapping_and_dense,
+        )
+
+        if all(isinstance(a, sympy.Integer) for a in args):
+            return eval_is_non_overlapping_and_dense(
+                [int(a) for a in sizes], [int(a) for a in strides]
+            )
+
+        if dim == 1:
+            # Manually implement the rank one short circuit
+            if strides[0].is_Number and strides[0] == 1:
+                return 1
+
+            if sizes[0].is_Number and sizes[0] < 2:
+                return 1
+
+            # return 0 case covered by case above
+
+            # TODO: Inability to access size-obliviousness sucks: if we have a
+            # size oblivious test on a size-like unbacked SymInt, we could
+            # confidently return zero when we have a size-like u0 stride
+            # and a size-like u1 size.  Maybe a fancy ValueRanges analysis for
+            # this function could help figure this out.
+
+        if all(isinstance(a, sympy.Integer) for a in strides):
+            assert dim != 0
+            # When all strides are integral, we can sort, and the size for the
+            # largest stride doesn't matter and can be arbitrarily symbolic
+            s_sizes, s_strides = zip(
+                *sorted(zip(sizes, strides), key=operator.itemgetter(1))
+            )
+            # Put something arbitrary in the max size spot, it'll be ignored
+            if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]):
+                s_sizes = s_sizes[:-1] + (42,)
+                # We can reuse the regular eval, because it is invariant to
+                # permutation of dimensions
+                return eval_is_non_overlapping_and_dense(
+                    [int(a) for a in s_sizes], [int(a) for a in s_strides]
+                )
+
+        return None
+
+
+# NB: this is inconsistent with math.trunc in Python
+class TruncToFloat(sympy.Function):
+    is_real = True
+
+    @classmethod
+    def eval(cls, number):
+        # assert number.is_integer is not True, number
+        if isinstance(number, sympy.Number):
+            # NB: It is safe to use truncation to integer, which is what
+            # math.trunc does, as Python integers are arbitrary precision and
+            # so we are guaranteed not to lose precision when we do this
+            return sympy.Float(math.trunc(float(number)))
+
+
+class TruncToInt(sympy.Function):
+    is_integer = True
+
+    @classmethod
+    def eval(cls, number):
+        # assert number.is_integer is not True, number
+        if number in (sympy.oo, int_oo):
+            return int_oo
+        if number in (-sympy.oo, -int_oo):
+            return -int_oo
+        if isinstance(number, sympy.Number):
+            return sympy.Integer(math.trunc(float(number)))
+
+
+# This is float -> int
+class RoundToInt(sympy.Function):
+    is_integer = True
+
+    @classmethod
+    def eval(cls, number):
+        # assert number.is_integer is not True, number
+
+        if number is sympy.oo:
+            return int_oo
+        if number is -sympy.oo:
+            return -int_oo
+        if isinstance(number, sympy.Number):
+            return sympy.Integer(round(float(number), 0))
+
+
+# To get float -> int, Python style round semantics.
+#
+#   x = PyFloat_AsDouble(self);
+#   if (o_ndigits == Py_None) {
+#       /* single-argument round or with None ndigits:
+#        * round to nearest integer */
+#       rounded = round(x);
+#       if (fabs(x-rounded) == 0.5)
+#           /* halfway case: round to even */
+#           rounded = 2.0*round(x/2.0);
+#       return PyLong_FromDouble(rounded);
+#   }
+
+
+# NB: Like Round, this only ever returns floats.  ndigits cannot be None
+class RoundDecimal(sympy.Function):
+    is_real = True
+
+    @classmethod
+    def eval(cls, number, ndigits):
+        # assert number.is_integer is not True, number
+
+        if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
+            return sympy.Float(round(float(number), int(ndigits)))
+
+
+class ToFloat(sympy.Function):
+    is_real = True
+
+    @classmethod
+    def eval(cls, number):
+        if number in [sympy.oo, -sympy.oo]:
+            return number
+
+        if isinstance(number, sympy.Integer):
+            return sympy.Float(int(number))
+        if number is int_oo:
+            return sympy.oo
+        if number is -int_oo:
+            return -sympy.oo
+
+
+class Identity(sympy.Function):
+    """
+    Prevents expansion and other optimizations
+    """
+
+    precedence = 10
+
+    def __repr__(self):  # type: ignore[override]
+        return f"Identity({self.args[0]})"
+
+    def _eval_is_real(self):
+        return self.args[0].is_real
+
+    def _eval_is_integer(self):
+        return self.args[0].is_integer  # type: ignore[attr-defined]
+
+    def _eval_expand_identity(self, **hints):
+        # Removes the identity op.
+        return self.args[0]
+
+    def __int__(self) -> int:
+        return int(self.args[0])
+
+    def __float__(self) -> float:
+        return float(self.args[0])
+
+
+def make_opaque_unary_fn(name):
+    class OpaqueUnaryFn(sympy.Function):
+        """
+        Unlike the builtin sympy functions on real numbers like sympy.sqrt,
+        these equivalents do not do any nontrivial reasoning besides
+        constant propagation.  This helps avoid performing transformations
+        that are valid for real numbers but are invalid for floating point;
+        in particular, while we are willing to make optimizations that change
+        numerics for Tensor compute, we are NOT willing to make optimziations
+        that change numerics for size compute.
+        """
+
+        _torch_handler_name = name
+        _torch_unpickler = make_opaque_unary_fn
+
+        @classmethod
+        def eval(cls, a):
+            if isinstance(a, (sympy.Integer, sympy.Float)):
+                # Python converts to float64 before computing, c.f.
+                # >>> math.sin(2**53+1)
+                # -0.848925964814655
+                # >>> math.sin(float(2**53+1))
+                # -0.848925964814655
+                try:
+                    return sympy.Float(getattr(math, name)(float(a)))
+                # Just use sympy semantics for infinity/overflow, you might get some
+                # weird objects but ask silly questions, get silly answers
+                except OverflowError:
+                    return getattr(sympy, name)(a)
+            elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
+                if a is int_oo:
+                    a = sympy.oo
+                if a is -int_oo:
+                    a = -sympy.oo
+                if name == "log2":
+                    return sympy.log(a, 2)
+                return getattr(sympy, name)(a)
+            return None
+
+    nm = "OpaqueUnaryFn_" + name
+    OpaqueUnaryFn.__name__ = nm
+    OpaqueUnaryFn.__qualname__ = nm
+
+    return OpaqueUnaryFn
+
+
+# Keep in sync with math_op_names in torch/fx/experimental/sym_node.py
+OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt")
+OpaqueUnaryFn_cos = make_opaque_unary_fn("cos")
+OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh")
+OpaqueUnaryFn_sin = make_opaque_unary_fn("sin")
+OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh")
+OpaqueUnaryFn_tan = make_opaque_unary_fn("tan")
+OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh")
+OpaqueUnaryFn_asin = make_opaque_unary_fn("asin")
+OpaqueUnaryFn_acos = make_opaque_unary_fn("acos")
+OpaqueUnaryFn_atan = make_opaque_unary_fn("atan")
+OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
+OpaqueUnaryFn_log = make_opaque_unary_fn("log")
+OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
+OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")
+
+
+def make_opaque_bitwise_fn(name, real_op_name):
+    if name == "bitwise_and":
+        prec = PRECEDENCE["BitwiseAnd"]
+    elif name == "bitwise_or":
+        prec = PRECEDENCE["BitwiseOr"]
+    else:
+        raise AssertionError(f"unrecognized {name}")
+
+    class BitwiseFn(sympy.Function):
+        _torch_handler_name = name
+        precedence: int = prec
+        _torch_unpickler = functools.partial(
+            make_opaque_bitwise_fn, real_op_name=real_op_name
+        )
+
+        @classmethod
+        def eval(cls, a, b):
+            if a.is_Boolean and b.is_Boolean:
+                return getattr(operator, real_op_name)(a, b)
+            if a.is_Boolean:
+                a = sympy.Integer(1 if a else 0)
+            if b.is_Boolean:
+                b = sympy.Integer(1 if b else 0)
+            if isinstance(a, (sympy.Integer, int)) and isinstance(
+                b, (sympy.Integer, int)
+            ):
+                return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
+            return None
+
+    BitwiseFn.__name__ = "BitwiseFn_" + name
+    return BitwiseFn
+
+
+BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
+BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/interp.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/interp.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b020b5fabbc72ef64aa5a06afacc361602cf1f7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/interp.py
@@ -0,0 +1,225 @@
+# mypy: allow-untyped-defs
+"""
+This is a simple interpreter for Sympy expressions that dispatches to
+classes following the torch._inductor.virtualized calling convention.
+For directness, the interpreter takes the handler directly rather than
+consulting the TLS.  It does not use most of the methods on the full
+handler; only those with corresponding Sympy expressions.  To see an example
+of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
+"""
+
+import functools
+import logging
+from typing import Any, Union
+
+import sympy
+from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
+
+import torch
+
+from .functions import (
+    BitwiseFn_bitwise_and,
+    BitwiseFn_bitwise_or,
+    CeilToInt,
+    CleanDiv,
+    FloatPow,
+    FloatTrueDiv,
+    FloorDiv,
+    FloorToInt,
+    Identity,
+    IntTrueDiv,
+    IsNonOverlappingAndDenseIndicator,
+    Max,
+    Min,
+    Mod,
+    ModularIndexing,
+    OpaqueUnaryFn_log2,
+    PowByNatural,
+    PythonMod,
+    RoundDecimal,
+    RoundToInt,
+    ToFloat,
+    TruncToFloat,
+    TruncToInt,
+    Where,
+)
+
+
+log = logging.getLogger(__name__)
+
+
+# TODO: Dedupe this with SYMPY_INTERP
+
+
+@functools.cache
+def handlers():
+    # TODO add CeilDiv (it doesn't appear in the index_expr)
+
+    # TODO default to some decompositions if the interpreter doesn't have them
+    # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
+
+    HANDLERS = {
+        sympy.Or: "or_",
+        sympy.And: "and_",
+        sympy.Eq: "eq",
+        sympy.Ne: "ne",
+        sympy.Lt: "lt",
+        sympy.Gt: "gt",
+        sympy.Le: "le",
+        sympy.Ge: "ge",
+        sympy.Not: "not_",
+        IntTrueDiv: "int_truediv",
+        FloatTrueDiv: "truediv",
+        FloorDiv: "floordiv",
+        CleanDiv: "floordiv",  # TODO: hmm?
+        TruncToFloat: "trunc",
+        Where: "where",
+        sympy.Add: "add",
+        sympy.Mul: "mul",
+        FloatPow: "pow",
+        PowByNatural: "pow_by_natural",
+        # sympy simplifies x * x into Pow(x, 2), so we need to handle this.
+        # Do NOT use builtin Pow for floats
+        # TODO: There is a hazard here, if we have float * float it will
+        # also get turned into Pow(float, 2) but we don't want this because
+        # pow_by_natural is assumed to only be integers.  Probably the fix is
+        # to add a FloatMul to impede this optimization
+        sympy.Pow: "pow_by_natural",
+        Mod: "mod",
+        PythonMod: "mod",  # TODO: this is wrong
+        # TODO: Inductor can generate these, but it's ill-specified which
+        # semantics were intended here.  Needs to be cleaned up along with
+        # FloorDiv in a bigger cleanup
+        sympy.Mod: "mod",
+        sympy.Abs: "abs",
+        sympy.log: "log",
+        sympy.exp: "exp",
+        sympy.Min: "minimum",
+        sympy.Max: "maximum",
+        Min: "minimum",
+        Max: "maximum",
+        ModularIndexing: "modular_indexing",
+        sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
+        sympy.Piecewise: "piecewise",
+        Identity: "identity",
+        IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
+        RoundDecimal: "round_decimal",
+        # TODO: do the rest of the opaque unary functions...
+        OpaqueUnaryFn_log2: "log2",
+        BitwiseFn_bitwise_and: "bitwise_and",
+        BitwiseFn_bitwise_or: "bitwise_or",
+    }
+    # TODO: This is kind of pointless, we shouldn't be generating sympy.sin
+    # for these functions, they should be Opaque instead
+    for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
+        HANDLERS[getattr(sympy, name)] = name
+
+    return HANDLERS
+
+
+ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
+
+
+def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
+    # Special cases
+    if isinstance(expr, sympy.Pow) and isinstance(
+        expr.args[1], sympy.core.numbers.Half
+    ):
+        return analysis.sqrt(args[0])
+    if isinstance(expr, ToFloat):
+        return analysis.to_dtype(args[0], torch.float64)
+
+    # These handlers are special because they take an extra dtype argument
+    # specifying what they should convert to, and we need to appropriately set
+    # this up when we convert from Sympy.  A reasonable default when you
+    # are translating is to conservatively do int64, and then narrow these
+    # arguments later when you discover you can narrow the index range.  But
+    # if you already know that 32-bit indexing is OK, you can directly do the
+    # sympy translation with index_dtype=torch.int32
+    INDEX_DTYPE_HANDLERS = {
+        TruncToInt: "trunc_to_int",
+        sympy.floor: "floor_to_int",
+        sympy.ceiling: "ceil_to_int",
+        FloorToInt: "floor_to_int",
+        CeilToInt: "ceil_to_int",
+        RoundToInt: "round_to_int",
+    }
+    if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
+        return getattr(analysis, handler_name)(*args, index_dtype)
+
+    # Fastpath for n-ary integral addition
+    if expr.func is sympy.Add and expr.is_integer and hasattr(analysis, "sym_sum"):
+        r = analysis.sym_sum(args)
+        log.debug("sym_sum(%s) -> %s", args, r)
+        return r
+
+    if hasattr(expr.func, "_torch_handler_name"):
+        handler_name = expr.func._torch_handler_name
+    else:
+        handler_name = handlers()[expr.func]
+    handler = getattr(analysis, handler_name)
+    try:
+        if handler_name in ASSOCIATIVE_OPS:
+            assert len(args) > 1
+            acc = handler(args[0], args[1])
+            for i in range(2, len(args)):
+                acc = handler(acc, args[i])
+            log.debug("%s(%s) -> %s", handler_name, args, acc)
+            return acc
+        else:
+            r = handler(*args)
+            log.debug("%s(%s) -> %s", handler_name, args, r)
+            return r
+    except NotImplementedError:
+        raise
+    except Exception:
+        log.warning("failed while executing %s(%s)", handler_name, args)
+        raise
+
+
+_nil = object()
+
+
+def sympy_interp(
+    analysis,
+    env: dict[sympy.Symbol, Any],
+    expr: Union[sympy.Expr, SympyBoolean],
+    *,
+    index_dtype=torch.int64,
+    missing_handler=None,
+):
+    # Handle base cases
+    dtype = None
+    if isinstance(expr, BooleanAtom):
+        dtype = torch.bool
+    elif isinstance(expr, sympy.Integer):
+        dtype = torch.int64
+    elif isinstance(expr, sympy.Number):
+        dtype = torch.double
+
+    if dtype is not None:
+        return analysis.constant(expr, dtype)
+    elif isinstance(expr, sympy.Symbol):
+        if (r := env.get(expr, _nil)) is not _nil:
+            return r
+        elif missing_handler:
+            return missing_handler(expr)
+        else:
+            raise KeyError(expr)
+
+    # Recursive case
+    return _run_sympy_handler(
+        analysis,
+        [
+            sympy_interp(
+                analysis,
+                env,
+                arg,
+                index_dtype=index_dtype,
+                missing_handler=missing_handler,
+            )
+            for arg in expr.args
+        ],  # type: ignore[arg-type]
+        expr,
+        index_dtype=index_dtype,
+    )  # type: ignore[arg-type]
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/numbers.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/numbers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d02b9879cad2373d46c9aa61e0476907a44e6be3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/numbers.py
@@ -0,0 +1,397 @@
+# mypy: allow-untyped-defs
+import mpmath.libmp as mlib  # type: ignore[import-untyped]
+import sympy
+from sympy import Expr
+from sympy.core.decorators import _sympifyit
+from sympy.core.expr import AtomicExpr
+from sympy.core.numbers import Number
+from sympy.core.parameters import global_parameters
+from sympy.core.singleton import S, Singleton
+
+
+class IntInfinity(Number, metaclass=Singleton):
+    r"""Positive integer infinite quantity.
+
+    Integer infinity is a value in an extended integers which
+    is greater than all other integers.  We distinguish it from
+    sympy's existing notion of infinity in that it reports that
+    it is_integer.
+
+    Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
+    or can be imported as ``int_oo``.
+    """
+
+    # NB: We can't actually mark this as infinite, as integer and infinite are
+    # inconsistent assumptions in sympy.  We also report that we are complex,
+    # different from sympy.oo
+
+    is_integer = True
+    is_commutative = True
+    is_number = True
+    is_extended_real = True
+    is_comparable = True
+    is_extended_positive = True
+    is_prime = False
+
+    # Ensure we get dispatched to before plain numbers
+    _op_priority = 100.0
+
+    __slots__ = ()
+
+    def __new__(cls):
+        return AtomicExpr.__new__(cls)
+
+    def _sympystr(self, printer):
+        return "int_oo"
+
+    def _eval_subs(self, old, new):
+        if self == old:
+            return new
+
+    # We could do these, not sure about it
+    """
+    def _eval_evalf(self, prec=None):
+        return Float('inf')
+
+    def evalf(self, prec=None, **options):
+        return self._eval_evalf(prec)
+    """
+
+    @_sympifyit("other", NotImplemented)
+    def __add__(self, other):
+        if isinstance(other, Number) and global_parameters.evaluate:
+            if other in (S.Infinity, S.NegativeInfinity):
+                return other
+            if other in (S.NegativeIntInfinity, S.NaN):
+                return S.NaN
+            return self
+        return Number.__add__(self, other)
+
+    __radd__ = __add__
+
+    @_sympifyit("other", NotImplemented)
+    def __sub__(self, other):
+        if isinstance(other, Number) and global_parameters.evaluate:
+            if other is S.Infinity:
+                return S.NegativeInfinity
+            if other is S.NegativeInfinity:
+                return S.Infinity
+            if other in (S.IntInfinity, S.NaN):
+                return S.NaN
+            return self
+        return Number.__sub__(self, other)
+
+    @_sympifyit("other", NotImplemented)
+    def __rsub__(self, other):
+        return (-self).__add__(other)
+
+    @_sympifyit("other", NotImplemented)
+    def __mul__(self, other):
+        if isinstance(other, Number) and global_parameters.evaluate:
+            if other.is_zero or other is S.NaN:
+                return S.NaN
+            if other.is_extended_positive:
+                return self
+            return S.NegativeIntInfinity
+        return Number.__mul__(self, other)
+
+    __rmul__ = __mul__
+
+    @_sympifyit("other", NotImplemented)
+    def __truediv__(self, other):
+        if isinstance(other, Number) and global_parameters.evaluate:
+            if other in (
+                S.Infinity,
+                S.IntInfinity,
+                S.NegativeInfinity,
+                S.NegativeIntInfinity,
+                S.NaN,
+            ):
+                return S.NaN
+            if other.is_extended_nonnegative:
+                return S.Infinity  # truediv produces float
+            return S.NegativeInfinity  # truediv produces float
+        return Number.__truediv__(self, other)
+
+    def __abs__(self):
+        return S.IntInfinity
+
+    def __neg__(self):
+        return S.NegativeIntInfinity
+
+    def _eval_power(self, expt):
+        if expt.is_extended_positive:
+            return S.IntInfinity
+        if expt.is_extended_negative:
+            return S.Zero
+        if expt is S.NaN:
+            return S.NaN
+        if expt is S.ComplexInfinity:
+            return S.NaN
+        if expt.is_extended_real is False and expt.is_number:
+            from sympy.functions.elementary.complexes import re
+
+            expt_real = re(expt)
+            if expt_real.is_positive:
+                return S.ComplexInfinity
+            if expt_real.is_negative:
+                return S.Zero
+            if expt_real.is_zero:
+                return S.NaN
+
+            return self ** expt.evalf()
+
+    def _as_mpf_val(self, prec):
+        return mlib.finf
+
+    def __hash__(self):
+        return super().__hash__()
+
+    def __eq__(self, other):
+        return other is S.IntInfinity
+
+    def __ne__(self, other):
+        return other is not S.IntInfinity
+
+    def __gt__(self, other):
+        if other is S.Infinity:
+            return sympy.false  # sympy.oo > int_oo
+        elif other is S.IntInfinity:
+            return sympy.false  # consistency with sympy.oo
+        else:
+            return sympy.true
+
+    def __ge__(self, other):
+        if other is S.Infinity:
+            return sympy.false  # sympy.oo > int_oo
+        elif other is S.IntInfinity:
+            return sympy.true  # consistency with sympy.oo
+        else:
+            return sympy.true
+
+    def __lt__(self, other):
+        if other is S.Infinity:
+            return sympy.true  # sympy.oo > int_oo
+        elif other is S.IntInfinity:
+            return sympy.false  # consistency with sympy.oo
+        else:
+            return sympy.false
+
+    def __le__(self, other):
+        if other is S.Infinity:
+            return sympy.true  # sympy.oo > int_oo
+        elif other is S.IntInfinity:
+            return sympy.true  # consistency with sympy.oo
+        else:
+            return sympy.false
+
+    @_sympifyit("other", NotImplemented)
+    def __mod__(self, other):
+        if not isinstance(other, Expr):
+            return NotImplemented
+        return S.NaN
+
+    __rmod__ = __mod__
+
+    def floor(self):
+        return self
+
+    def ceiling(self):
+        return self
+
+
+int_oo = S.IntInfinity
+
+
+class NegativeIntInfinity(Number, metaclass=Singleton):
+    """Negative integer infinite quantity.
+
+    NegativeInfinity is a singleton, and can be accessed
+    by ``S.NegativeInfinity``.
+
+    See Also
+    ========
+
+    IntInfinity
+    """
+
+    # Ensure we get dispatched to before plain numbers
+    _op_priority = 100.0
+
+    is_integer = True
+    is_extended_real = True
+    is_commutative = True
+    is_comparable = True
+    is_extended_negative = True
+    is_number = True
+    is_prime = False
+
+    __slots__ = ()
+
+    def __new__(cls):
+        return AtomicExpr.__new__(cls)
+
+    def _eval_subs(self, old, new):
+        if self == old:
+            return new
+
+    def _sympystr(self, printer):
+        return "-int_oo"
+
+    """
+    def _eval_evalf(self, prec=None):
+        return Float('-inf')
+
+    def evalf(self, prec=None, **options):
+        return self._eval_evalf(prec)
+    """
+
+    @_sympifyit("other", NotImplemented)
+    def __add__(self, other):
+        if isinstance(other, Number) and global_parameters.evaluate:
+            if other is S.Infinity:
+                return S.Infinity
+            if other in (S.IntInfinity, S.NaN):
+                return S.NaN
+            return self
+        return Number.__add__(self, other)
+
+    __radd__ = __add__
+
+    @_sympifyit("other", NotImplemented)
+    def __sub__(self, other):
+        if isinstance(other, Number) and global_parameters.evaluate:
+            if other is S.NegativeInfinity:
+                return S.Infinity
+            if other in (S.NegativeIntInfinity, S.NaN):
+                return S.NaN
+            return self
+        return Number.__sub__(self, other)
+
+    @_sympifyit("other", NotImplemented)
+    def __rsub__(self, other):
+        return (-self).__add__(other)
+
+    @_sympifyit("other", NotImplemented)
+    def __mul__(self, other):
+        if isinstance(other, Number) and global_parameters.evaluate:
+            if other.is_zero or other is S.NaN:
+                return S.NaN
+            if other.is_extended_positive:
+                return self
+            return S.IntInfinity
+        return Number.__mul__(self, other)
+
+    __rmul__ = __mul__
+
+    @_sympifyit("other", NotImplemented)
+    def __truediv__(self, other):
+        if isinstance(other, Number) and global_parameters.evaluate:
+            if other in (
+                S.Infinity,
+                S.IntInfinity,
+                S.NegativeInfinity,
+                S.NegativeIntInfinity,
+                S.NaN,
+            ):
+                return S.NaN
+            if other.is_extended_nonnegative:
+                return self
+            return S.Infinity  # truediv returns float
+        return Number.__truediv__(self, other)
+
+    def __abs__(self):
+        return S.IntInfinity
+
+    def __neg__(self):
+        return S.IntInfinity
+
+    def _eval_power(self, expt):
+        if expt.is_number:
+            if expt in (
+                S.NaN,
+                S.Infinity,
+                S.NegativeInfinity,
+                S.IntInfinity,
+                S.NegativeIntInfinity,
+            ):
+                return S.NaN
+
+            if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
+                if expt.is_odd:
+                    return S.NegativeIntInfinity
+                else:
+                    return S.IntInfinity
+
+            inf_part = S.IntInfinity**expt
+            s_part = S.NegativeOne**expt
+            if inf_part == 0 and s_part.is_finite:
+                return inf_part
+            if (
+                inf_part is S.ComplexInfinity
+                and s_part.is_finite
+                and not s_part.is_zero
+            ):
+                return S.ComplexInfinity
+            return s_part * inf_part
+
+    def _as_mpf_val(self, prec):
+        return mlib.fninf
+
+    def __hash__(self):
+        return super().__hash__()
+
+    def __eq__(self, other):
+        return other is S.NegativeIntInfinity
+
+    def __ne__(self, other):
+        return other is not S.NegativeIntInfinity
+
+    def __gt__(self, other):
+        if other is S.NegativeInfinity:
+            return sympy.true  # -sympy.oo < -int_oo
+        elif other is S.NegativeIntInfinity:
+            return sympy.false  # consistency with sympy.oo
+        else:
+            return sympy.false
+
+    def __ge__(self, other):
+        if other is S.NegativeInfinity:
+            return sympy.true  # -sympy.oo < -int_oo
+        elif other is S.NegativeIntInfinity:
+            return sympy.true  # consistency with sympy.oo
+        else:
+            return sympy.false
+
+    def __lt__(self, other):
+        if other is S.NegativeInfinity:
+            return sympy.false  # -sympy.oo < -int_oo
+        elif other is S.NegativeIntInfinity:
+            return sympy.false  # consistency with sympy.oo
+        else:
+            return sympy.true
+
+    def __le__(self, other):
+        if other is S.NegativeInfinity:
+            return sympy.false  # -sympy.oo < -int_oo
+        elif other is S.NegativeIntInfinity:
+            return sympy.true  # consistency with sympy.oo
+        else:
+            return sympy.true
+
+    @_sympifyit("other", NotImplemented)
+    def __mod__(self, other):
+        if not isinstance(other, Expr):
+            return NotImplemented
+        return S.NaN
+
+    __rmod__ = __mod__
+
+    def floor(self):
+        return self
+
+    def ceiling(self):
+        return self
+
+    def as_powers_dict(self):
+        return {S.NegativeOne: 1, S.IntInfinity: 1}
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/printers.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/printers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9c6b8b0e93393bb21fa891badbdbabeda5c7867
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/printers.py
@@ -0,0 +1,506 @@
+import sys
+from typing import Optional
+
+import sympy
+from sympy.printing.precedence import PRECEDENCE, precedence
+from sympy.printing.str import StrPrinter
+
+
+INDEX_TYPE = "int64_t"
+INDEX_TYPE_MAX = (1 << 63) - 1
+INDEX_TYPE_MIN = -1 << 63
+
+
+# This printer contains rules that are supposed to be generic for both C/C++ and
+# Python
+class ExprPrinter(StrPrinter):
+    # override this so that _print_FloorDiv is used
+    printmethod = "_torch_sympystr"
+
+    def _print_Mul(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, "*", precedence(expr))
+
+    def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
+        return self.stringify(expr.args, " + ", precedence(expr))
+
+    def _print_Relational(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
+
+    def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"])
+
+    def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"])
+
+    # NB: this is OK to put here, because Mod is only defined for positive
+    # numbers, and so across C/Python its behavior is consistent
+    def _print_Mod(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
+
+    def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
+        s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
+        return f"({s})"
+
+    def _print_CleanDiv(self, expr: sympy.Expr) -> str:
+        return self._print_FloorDiv(expr)
+
+    def _print_Identity(self, expr: sympy.Expr) -> str:
+        return self._print(expr.args[0])
+
+    def _print_Float(self, expr: sympy.Expr) -> str:
+        if expr._prec == 53:
+            # IEEE-754 double precision have 53 bits. SymPy prints them with
+            # 15 digits, but we need 17 for round-trip correctness
+            return str(sympy.Float(expr, dps=17))
+        else:
+            # We don't use other precisions in pytorch
+            return str(expr)
+
+    # This must be implemented because sympy will collect x * x into Pow(x, 2), without
+    # any explicit intervention.  We print it just like x * x, notably, we
+    # never generate sympy.Pow with floats.
+    #
+    # NB: this pow by natural, you should never have used builtin sympy.pow
+    # for FloatPow, and a symbolic exponent should be PowByNatural.  These
+    # means exp is guaranteed to be integer.
+    def _print_Pow(self, expr: sympy.Expr) -> str:
+        base, exp = expr.args
+        assert exp == int(exp), exp
+        exp = int(exp)
+        assert exp >= 0
+        if exp > 0:
+            return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
+        return "1"
+
+    # Explicit NotImplemented functions are to prevent default sympy printing
+    # behavior, which will just barf out ToFloat(...) to your IR.  The error
+    # message is better here because it tells you which printer class it needs
+    # to go in.
+
+    def _print_ToFloat(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
+
+    def _print_Infinity(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
+
+    def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(
+            f"_print_NegativeInfinity not implemented for {type(self)}"
+        )
+
+    def _print_FloorDiv(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
+
+    def _print_PythonMod(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
+
+    def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
+
+    def _print_PowByNatural(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(
+            f"_print_PowByNatural not implemented for {type(self)}"
+        )
+
+    def _print_FloatPow(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
+
+    def _print_TruncToInt(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
+
+    def _print_RoundToInt(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
+
+    def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(
+            f"_print_RoundDecimal not implemented for {type(self)}"
+        )
+
+    # NB: Some float operations are INTENTIONALLY not implemented for
+    # printers.  You can implement them as a quick unblock, but it is better
+    # to ask yourself why we haven't done this computation in the Tensor
+    # universe instead
+
+    def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
+        raise NotImplementedError(
+            f"_print_TruncToFloat not implemented for {type(self)}"
+        )
+
+
+class PythonPrinter(ExprPrinter):
+    def _print_ToFloat(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        # NB: We use sym_float here because the printer is used for cache
+        # serialization, and cache guards get evaluated with SymInt to
+        # propagate guards to the parent ShapeEnv.  However, this comes at a
+        # runtime cost for guards involving float.  If this is unacceptable
+        # overhead, what you want to do is have two separate printers for
+        # SymInt, one for when the inputs are guaranteed to be int, and
+        # another for when they could be SymInt.
+        #
+        # NB: sym_min/sym_max also have this problem, but I chose not to fix
+        # those.
+        #
+        # See https://github.com/pytorch/pytorch/issues/142507 for more
+        # context.
+        return f"torch.sym_float({self._print(expr.args[0])})"
+
+    def _print_And(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " and ", precedence(expr))
+
+    def _print_Or(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " or ", precedence(expr))
+
+    def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
+        x, div, mod = (
+            self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
+        )
+        if div != "1":
+            x = f"({x} // {div})"
+        return f"({x} % {mod})"
+
+    def _print_Infinity(self, expr: sympy.Expr) -> str:
+        return "math.inf"
+
+    def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
+        return "-math.inf"
+
+    # WARNING: this is dangerous for Triton, which has C-style modulus
+    def _print_PythonMod(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
+
+    # WARNING: this is dangerous for Triton, which has C-style modulus
+    def _print_FloorDiv(self, expr: sympy.Expr) -> str:
+        x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
+        return f"{x} // {div}"
+
+    # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
+    # does a special algorithm
+    def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
+
+    def _helper_sqrt(self, expr: sympy.Expr) -> str:
+        return f"math.sqrt({self._print(expr)})"
+
+    def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
+        return self._helper_sqrt(expr.args[0])
+
+    def _print_FloatPow(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
+
+    # TODO: Not sure this works with Triton, even when base/exp are integral
+    def _print_PowByNatural(self, expr: sympy.Expr) -> str:
+        return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
+
+    def _print_floor(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.floor({self._print(expr.args[0])})"
+
+    def _print_FloorToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.floor({self._print(expr.args[0])})"
+
+    def _print_TruncToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        # This also could have been int(), they'll do the same thing for float
+        return f"math.trunc({self._print(expr.args[0])})"
+
+    def _print_ceiling(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.ceil({self._print(expr.args[0])})"
+
+    def _print_CeilToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.ceil({self._print(expr.args[0])})"
+
+    def _print_Abs(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"abs({self._print(expr.args[0])})"
+
+    # NB: It's expected that we've made explicit any promotion in the sympy
+    # expression, so it doesn't matter that Python max/min doesn't perform
+    # promotion
+    def _print_Max(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) >= 2
+        return f"max({', '.join(map(self._print, expr.args))})"
+
+    def _print_Min(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) >= 2
+        return f"min({', '.join(map(self._print, expr.args))})"
+
+    def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.cos({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.cosh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.acos({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.sin({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.sinh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.asin({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.tan({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.tanh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.atan({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"math.log2({self._print(expr.args[0])})"
+
+    def _print_RoundToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"round({self._print(expr.args[0])})"
+
+    def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 2
+        number, ndigits = expr.args
+        assert isinstance(ndigits, sympy.Integer)
+        return f"round({self._print(number)}, {ndigits})"
+
+
+class CppPrinter(ExprPrinter):
+    def _print_Integer(self, expr: sympy.Expr) -> str:
+        suffix = "LL" if sys.platform in ["darwin", "win32"] else "L"
+        i = int(expr)
+        if i > INDEX_TYPE_MAX or i < INDEX_TYPE_MIN:
+            raise OverflowError(f"{i} too big to convert to {INDEX_TYPE}")
+        elif i == INDEX_TYPE_MIN:
+            assert i == (-1) << 63
+            # Writing -9223372036854775808L makes the value overflow
+            # as it is parsed as -(9223372036854775808L) by the C/C++ compiler
+            return f"(-1{suffix} << 63)"
+        return f"{i}{suffix}"
+
+    def _print_Where(self, expr: sympy.Expr) -> str:
+        c, p, q = (
+            self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
+        )
+        return f"{c} ? {p} : {q}"
+
+    def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
+        x, div, mod = expr.args
+        x = self.doprint(x)
+        if div != 1:
+            div = self.doprint(div)
+            if expr.is_integer:
+                x = f"c10::div_floor_integer(static_cast({x}), static_cast({div}))"
+            else:
+                x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))"
+        mod = self.doprint(mod)
+        return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
+
+    def _print_FloorDiv(self, expr: sympy.Expr) -> str:
+        x, div = expr.args
+        x = self.doprint(x)
+        div = self.doprint(div)
+        if expr.is_integer:
+            return f"c10::div_floor_integer(static_cast({x}), static_cast({div}))"
+        return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))"
+
+    def _print_floor(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        r = f"std::floor({self._print(expr.args[0])})"
+        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
+
+    def _print_FloorToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        r = f"std::floor({self._print(expr.args[0])})"
+        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
+
+    def _print_TruncToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        r = f"std::trunc({self._print(expr.args[0])})"
+        return f"static_cast<{INDEX_TYPE}>({r})"
+
+    def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::trunc({self._print(expr.args[0])})"
+
+    def _print_ToFloat(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"static_cast({self._print(expr.args[0])})"
+
+    def _print_PythonMod(self, expr: sympy.Expr) -> str:
+        x, div = expr.args
+        x = self.doprint(x)
+        div = self.doprint(div)
+        return f"c10::div_mod({x}, {div})"
+
+    def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
+        lhs, rhs = expr.args
+        # TODO: This is only accurate up to 2**53
+        return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})"
+
+    # TODO: PowByNatural: we need to implement our own int-int pow.  Do NOT
+    # use std::pow, that operates on floats
+    def _print_PowByNatural(self, expr: sympy.Expr) -> str:
+        # Implement the special-case of 2**x for now
+        base, exp = expr.args
+        if base == 2:
+            return f"(1 << ({self._print(exp)}))"
+        raise NotImplementedError(
+            f"_print_PowByNatural not implemented for {type(self)}"
+        )
+
+    def _print_FloatPow(self, expr: sympy.Expr) -> str:
+        base, exp = expr.args
+        return f"std::pow({self._print(base)}, {self._print(exp)})"
+
+    def _print_Pow(self, expr: sympy.Expr) -> str:
+        # Uses float constants to perform FP div
+        base, exp = expr.args
+
+        if exp == 0.5 or exp == -0.5:
+            base = self._print(base)
+            return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
+        if exp.is_integer:
+            exp = int(exp)
+            if exp > 0:
+                r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
+            elif exp < -1:
+                r = (
+                    "1.0/("
+                    + self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"])
+                    + ")"
+                )
+            elif exp == -1:
+                r = "1.0/" + self._print(base)
+            else:  # exp == 0
+                r = "1.0"
+
+            return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
+        else:
+            # TODO: float vs double
+            return f"std::pow({base}, {float(exp)})"
+
+    def _print_Rational(self, expr: sympy.Expr) -> str:
+        # Uses float constants to perform FP div
+        if expr.q == 1:
+            r = f"{expr.p}"
+        else:
+            r = f"{expr.p}.0/{expr.q}.0"
+        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
+
+    def _print_ceiling(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        r = f"std::ceil({self._print(expr.args[0])})"
+        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
+
+    def _print_CeilToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        r = f"std::ceil({self._print(expr.args[0])})"
+        return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
+
+    def _print_Min(self, expr: sympy.Expr) -> str:
+        args = [self._print(a) for a in expr.args]
+        if len(args) == 2:
+            return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
+        else:
+            # Initializer list overload
+            il = "{" + ", ".join(args) + "}"
+            return f"std::min<{INDEX_TYPE}>({il})"
+
+    def _print_Max(self, expr: sympy.Expr) -> str:
+        args = [self._print(a) for a in expr.args]
+        if len(args) == 2:
+            return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
+        else:
+            # Initializer list overload
+            il = "{" + ", ".join(args) + "}"
+            return f"std::max<{INDEX_TYPE}>({il})"
+
+    def _print_Abs(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::abs({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::cos({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::cosh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::acos({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::sin({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::sinh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::asin({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::tan({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::tanh({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        return f"std::atan({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
+        return f"std::sqrt({self._print(expr.args[0])})"
+
+    def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
+        return f"std::log2({self._print(expr.args[0])})"
+
+    def _print_RoundToInt(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 1
+        # TODO: dispatch to llrint depending on index type
+        return f"std::lrint({self._print(expr.args[0])})"
+
+    def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
+        assert len(expr.args) == 2
+        number, ndigits = expr.args
+        if number.is_integer:
+            # ndigits < 0 should have been filtered by the sympy function
+            assert ndigits < 0
+            raise ValueError(
+                f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
+            )
+        number_str = self.parenthesize(number, PRECEDENCE["Mul"])
+        return f"static_cast(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
+
+    def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
+        return "true"
+
+    def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
+        return "false"
+
+    def _print_Infinity(self, expr: sympy.Expr) -> str:
+        return "std::numeric_limits::infinity()"
+
+    def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
+        return f"-{self._print_Infinity(expr)}"
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/reference.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c960e92f22310ecd8ebce22c6b6d3ec86095f50
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/reference.py
@@ -0,0 +1,581 @@
+# mypy: allow-untyped-defs
+import math
+import operator
+from typing import Union
+
+import sympy
+
+import torch
+from torch.utils._sympy.functions import (
+    _keep_float,
+    BitwiseFn_bitwise_and,
+    BitwiseFn_bitwise_or,
+    FloatPow,
+    FloatTrueDiv,
+    FloorDiv,
+    IntTrueDiv,
+    Max,
+    Min,
+    Mod,
+    OpaqueUnaryFn_exp,
+    OpaqueUnaryFn_log,
+    OpaqueUnaryFn_log2,
+    OpaqueUnaryFn_sqrt,
+    PowByNatural,
+    RoundDecimal,
+    RoundToInt,
+    ToFloat,
+    TruncToInt,
+)
+
+
+# The sympy interpretation of operators.  It will also sometimes work with
+# plain int/float, but if you do certain operations you will get out a
+# sympy.Basic in the end.  If you want the Python/FX traceable interpretation,
+# check PythonReferenceAnalysis.
+# NB: For magic methods this needs to use normal magic methods
+# so that test_magic_methods works
+class ReferenceAnalysis:
+    @staticmethod
+    def constant(c, dtype):
+        return sympy.sympify(c)
+
+    @staticmethod
+    def or_(a, b):
+        return a | b
+
+    @staticmethod
+    def and_(a, b):
+        return a & b
+
+    @staticmethod
+    def eq(a, b):
+        if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
+            return sympy.Eq(a, b)
+        return a == b
+
+    @classmethod
+    def ne(cls, a, b):
+        return cls.not_(cls.eq(a, b))
+
+    @staticmethod
+    def lt(a, b):
+        return a < b
+
+    @staticmethod
+    def gt(a, b):
+        return a > b
+
+    @staticmethod
+    def le(a, b):
+        return a <= b
+
+    @staticmethod
+    def ge(a, b):
+        return a >= b
+
+    @staticmethod
+    def not_(a):
+        assert not isinstance(a, bool)
+        return ~a
+
+    @staticmethod
+    def reciprocal(x):
+        return FloatTrueDiv(1.0, x)
+
+    @staticmethod
+    def square(x):
+        return PowByNatural(x, 2)
+
+    @staticmethod
+    def trunc_to_int(x, dtype):
+        return TruncToInt(x)
+
+    @staticmethod
+    def ceil_to_int(x, dtype):
+        return sympy.ceiling(x)
+
+    @staticmethod
+    def floor_to_int(x, dtype):
+        return sympy.floor(x)
+
+    @staticmethod
+    def floor(x):
+        return _keep_float(sympy.floor)(x)
+
+    @staticmethod
+    def ceil(x):
+        return _keep_float(sympy.ceiling)(x)
+
+    @staticmethod
+    def to_dtype(x, dtype):
+        if dtype == torch.float64:
+            return ToFloat(x)
+        raise NotImplementedError(f"to_dtype {dtype} NYI")
+
+    @staticmethod
+    def mod(x, y):
+        return Mod(x, y)
+
+    @staticmethod
+    def abs(x):
+        return abs(x)
+
+    @staticmethod
+    def neg(x):
+        return -x
+
+    @staticmethod
+    def truediv(a, b):
+        return FloatTrueDiv(a, b)
+
+    @staticmethod
+    def int_truediv(a, b):
+        return IntTrueDiv(a, b)
+
+    @staticmethod
+    def floordiv(a, b):
+        return FloorDiv(a, b)
+
+    @staticmethod
+    def truncdiv(a, b):
+        raise NotImplementedError("TODO: truncdiv")
+
+    @staticmethod
+    def add(a, b):
+        return _keep_float(operator.add)(a, b)
+
+    @classmethod
+    def sym_sum(cls, args):
+        return sympy.Add(*args)
+
+    @staticmethod
+    def mul(a, b):
+        return _keep_float(operator.mul)(a, b)
+
+    @staticmethod
+    def sub(a, b):
+        return _keep_float(operator.sub)(a, b)
+
+    @staticmethod
+    def exp(x):
+        return OpaqueUnaryFn_exp(x)
+
+    @staticmethod
+    def log(x):
+        return OpaqueUnaryFn_log(x)
+
+    @staticmethod
+    def log2(x):
+        return OpaqueUnaryFn_log2(x)
+
+    @staticmethod
+    def sqrt(x):
+        return OpaqueUnaryFn_sqrt(x)
+
+    @staticmethod
+    def pow(a, b):
+        return _keep_float(FloatPow)(a, b)
+
+    @staticmethod
+    def pow_by_natural(a, b):
+        return PowByNatural(a, b)
+
+    @staticmethod
+    def minimum(a, b):
+        return Min(a, b)
+
+    @staticmethod
+    def maximum(a, b):
+        return Max(a, b)
+
+    @staticmethod
+    def round_to_int(a, dtype):
+        return RoundToInt(a)
+
+    @staticmethod
+    def round_decimal(a, b):
+        return RoundDecimal(a, b)
+
+    @staticmethod
+    def bitwise_and(a, b):
+        return BitwiseFn_bitwise_and(a, b)
+
+    @staticmethod
+    def bitwise_or(a, b):
+        return BitwiseFn_bitwise_or(a, b)
+
+
+# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
+# Python types and is FX traceable.  Inheritance here is purely for code
+# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
+class PythonReferenceAnalysis(ReferenceAnalysis):
+    @staticmethod
+    def constant(c, dtype):
+        if dtype is torch.int64:
+            return int(c)
+        elif dtype is torch.double:
+            return float(c)
+        elif dtype is torch.bool:
+            return bool(c)
+        else:
+            raise AssertionError(f"unrecognized dtype {dtype}")
+
+    @staticmethod
+    def not_(a):
+        return torch.sym_not(a)
+
+    @classmethod
+    def sym_sum(cls, args):
+        if len(args) == 0:
+            return 0
+        if len(args) == 1:
+            return args[0]
+        acc = cls.add(args[0], args[1])
+        for i in range(2, len(args)):
+            acc = cls.add(acc, args[i])
+        return acc
+
+    @staticmethod
+    def floordiv(a, b):
+        return a // b
+
+    @staticmethod
+    def mod(x, y):
+        return x % y
+
+    @staticmethod
+    def truncdiv(a, b):
+        return a / b
+
+    @staticmethod
+    def to_dtype(x, dtype):
+        if dtype == torch.float64:
+            return torch.sym_float(x)
+        raise NotImplementedError(f"to_dtype {dtype} NYI")
+
+    @staticmethod
+    def exp(x):
+        raise AssertionError("exp is not valid shape sympy expr")
+
+    @staticmethod
+    def log(x):
+        raise AssertionError("log is not valid shape sympy expr")
+
+    @staticmethod
+    def log2(x):
+        return torch._sym_log2(x)  # type: ignore[attr-defined]
+
+    @staticmethod
+    def sqrt(x):
+        return torch._sym_sqrt(x)  # type: ignore[attr-defined]
+
+    @staticmethod
+    def minimum(a, b):
+        return torch.sym_min(a, b)
+
+    @staticmethod
+    def maximum(a, b):
+        return torch.sym_max(a, b)
+
+    @staticmethod
+    def floor_to_int(x, dtype):
+        return math.floor(x)
+
+    @staticmethod
+    def ceil_to_int(x, dtype):
+        return math.ceil(x)
+
+    @staticmethod
+    def floor(x):
+        return float(math.floor(x))
+
+    @staticmethod
+    def ceil(x):
+        return float(math.ceil(x))
+
+    @staticmethod
+    def truediv(a, b):
+        return a / b
+
+    @staticmethod
+    def pow(a, b):
+        return a**b
+
+    @staticmethod
+    def pow_by_natural(a, b):
+        # Pray that safe_pow is not needed here lol.  In particular, this
+        # never participates in VR low/high ranges, so overflow should be
+        # unlikely
+        return a**b
+
+    @staticmethod
+    def round_to_int(a, dtype):
+        return round(a)
+
+    @staticmethod
+    def round_decimal(a, b):
+        return round(a, ndigits=b)
+
+    @staticmethod
+    def bitwise_and(a, b):
+        return a & b
+
+    @staticmethod
+    def bitwise_or(a, b):
+        return a | b
+
+
+# Like PythonReferenceAnalysis, but some export-unfriendly choices of
+# operators to make things faster
+class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis):
+    @staticmethod
+    def sym_sum(args):
+        return torch.sym_sum(args)
+
+
+def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+    return torch.ops.prims.convert_element_type.default(x, dtype)
+
+
+# Suppose we have some int/float arguments.  This diagram commutes:
+#
+#   int/float  -- PythonReferenceAnalysis.op -->  int/float
+#       |                                           |
+#       |                                           |
+#      torch.tensor(..., dtype=torch.int64/torch.float64)
+#       |                                           |
+#       V                                           V
+#    Tensor    -- TensorReferenceAnalysis.op -->  Tensor
+#
+# NB: int before and after must be representable in int64 (we will
+# insert guards accordingly.)
+#
+# This is guaranteed to be FX traceable with OpOverloads only.
+class TensorReferenceAnalysis:
+    # NB: This is actually dead, because with Proxy tracing the factory
+    # function isn't traced correctly.  Here for completeness.
+    @staticmethod
+    def constant(c, dtype):
+        d: Union[int, float, bool]
+        if dtype is torch.int64:
+            d = int(c)
+        elif dtype is torch.double:
+            d = float(c)
+        elif dtype is torch.bool:
+            d = bool(c)
+        else:
+            raise AssertionError(f"unrecognized dtype {dtype}")
+        return torch.ops.aten.scalar_tensor.default(d, dtype=dtype)
+
+    @staticmethod
+    def or_(a, b):
+        return torch.ops.aten.logical_or.default(a, b)
+
+    @staticmethod
+    def and_(a, b):
+        return torch.ops.aten.logical_and.default(a, b)
+
+    @staticmethod
+    def bitwise_and(a, b):
+        return torch.ops.aten.bitwise_and(a, b)
+
+    @staticmethod
+    def bitwise_or(a, b):
+        return torch.ops.aten.bitwise_or(a, b)
+
+    @staticmethod
+    def eq(a, b):
+        return torch.ops.aten.eq.Tensor(a, b)
+
+    @classmethod
+    def ne(cls, a, b):
+        return torch.ops.aten.ne.Tensor(a, b)
+
+    @staticmethod
+    def lt(a, b):
+        return torch.ops.aten.lt.Tensor(a, b)
+
+    @staticmethod
+    def gt(a, b):
+        return torch.ops.aten.gt.Tensor(a, b)
+
+    @staticmethod
+    def le(a, b):
+        return torch.ops.aten.le.Tensor(a, b)
+
+    @staticmethod
+    def ge(a, b):
+        return torch.ops.aten.ge.Tensor(a, b)
+
+    @staticmethod
+    def not_(a):
+        return torch.ops.aten.logical_not.default(a)
+
+    @staticmethod
+    def reciprocal(x):
+        return torch.ops.aten.reciprocal.default(x)
+
+    @staticmethod
+    def square(x):
+        # TODO: maybe composite implicit autograd doesn't work here?
+        return torch.ops.aten.square.default(x)
+
+    @staticmethod
+    def trunc_to_int(x, dtype):
+        return _to_dtype(torch.ops.aten.trunc.default(x), dtype)
+
+    @staticmethod
+    def ceil_to_int(x, dtype):
+        return _to_dtype(torch.ops.aten.ceil.default(x), dtype)
+
+    @staticmethod
+    def floor_to_int(x, dtype):
+        return _to_dtype(torch.ops.aten.floor.default(x), dtype)
+
+    @staticmethod
+    def floor(x):
+        return torch.ops.aten.floor.default(x)
+
+    @staticmethod
+    def ceil(x):
+        return torch.ops.aten.ceil.default(x)
+
+    @staticmethod
+    def to_dtype(x, dtype):
+        return _to_dtype(x, dtype)
+
+    @staticmethod
+    def mod(x, y):
+        # TODO: https://github.com/pytorch/pytorch/pull/133654
+        raise NotImplementedError(
+            "no C-style modulus operation available from frontend atm"
+        )
+
+    @staticmethod
+    def abs(x):
+        return torch.ops.aten.abs.default(x)
+
+    @staticmethod
+    def neg(x):
+        return torch.ops.aten.neg.default(x)
+
+    @staticmethod
+    def truediv(a, b):
+        return torch.ops.aten.true_divide.Tensor(a, b)
+
+    @staticmethod
+    def int_truediv(a, b):
+        raise NotImplementedError(
+            "Python int truediv difficult to implement in PyTorch atm"
+        )
+
+        # TODO: This is wrong, CPython has a custom implementation of true
+        # division that results in higher precision when the floats are
+        # sufficiently large.  Short term fix: add a guard here
+        return torch.ops.aten.true_divide.default(
+            _to_dtype(a, torch.float64), _to_dtype(b, torch.float64)
+        )
+
+    @staticmethod
+    def floordiv(a, b):
+        return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor")
+
+    @staticmethod
+    def truncdiv(a, b):
+        raise NotImplementedError(
+            "no C-style truncdiv operation available from frontend atm"
+        )
+
+    @staticmethod
+    def add(a, b):
+        return torch.ops.aten.add.Tensor(a, b)
+
+    @staticmethod
+    def mul(a, b):
+        return torch.ops.aten.mul.Tensor(a, b)
+
+    @staticmethod
+    def sub(a, b):
+        return torch.ops.aten.sub.Tensor(a, b)
+
+    @staticmethod
+    def exp(x):
+        return torch.ops.aten.exp.default(x)
+
+    @staticmethod
+    def log(x):
+        return torch.ops.aten.log.default(x)
+
+    @staticmethod
+    def log2(x):
+        return torch.ops.aten.log2.default(x)
+
+    @staticmethod
+    def sqrt(x):
+        return torch.ops.aten.sqrt.default(x)
+
+    @staticmethod
+    def sin(x):
+        return torch.ops.aten.sin.default(x)
+
+    @staticmethod
+    def cos(x):
+        return torch.ops.aten.cos.default(x)
+
+    @staticmethod
+    def tanh(x):
+        return torch.ops.aten.tanh.default(x)
+
+    @staticmethod
+    def sinh(x):
+        return torch.ops.aten.sinh.default(x)
+
+    @staticmethod
+    def cosh(x):
+        return torch.ops.aten.cosh.default(x)
+
+    @staticmethod
+    def tan(x):
+        return torch.ops.aten.tan.default(x)
+
+    @staticmethod
+    def acos(x):
+        return torch.ops.aten.acos.default(x)
+
+    @staticmethod
+    def atan(x):
+        return torch.ops.aten.atan.default(x)
+
+    @staticmethod
+    def asin(x):
+        return torch.ops.aten.asin.default(x)
+
+    @staticmethod
+    def pow(a, b):
+        return torch.ops.aten.pow.Tensor_Tensor(a, b)
+
+    @staticmethod
+    def pow_by_natural(a, b):
+        # NB: pow handles int x int fine
+        return torch.ops.aten.pow.Tensor_Tensor(a, b)
+
+    @staticmethod
+    def minimum(a, b):
+        return torch.ops.aten.minimum.default(a, b)
+
+    @staticmethod
+    def maximum(a, b):
+        return torch.ops.aten.maximum.default(a, b)
+
+    @staticmethod
+    def round_to_int(a, dtype):
+        return torch.ops.aten.round.default(a)
+
+    @staticmethod
+    def round_decimal(a, b):
+        raise NotImplementedError(
+            "round decimal doesn't support Tensor second argument atm"
+        )
+
+        # return torch.ops.aten.round.decimals(a, b)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/singleton_int.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/singleton_int.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bac76121f8b65baf44d0a527f98e7300adf730c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/singleton_int.py
@@ -0,0 +1,96 @@
+# mypy: allow-untyped-defs
+import sympy
+from sympy.multipledispatch import dispatch
+
+
+__all__ = ["SingletonInt"]
+
+
+class SingletonInt(sympy.AtomicExpr):
+    # This is probably not super important unless we are in multiple dispatch
+    # situations with other more exotic Expr types.
+    _op_priority = 99999
+
+    def __new__(cls, *args, coeff=None, **kwargs):
+        instance = super().__new__(cls, *args, **kwargs)
+        return instance
+
+    # The semantics of this class should match that of NestedIntSymNodeImpl in
+    # c10/core/NestedIntSymNodeImpl.h
+    def __init__(self, val, *, coeff=1):
+        self._val = val
+        self._coeff = coeff
+        super().__init__()
+
+    # See NOTE [ Inequalities with nested int ]
+    def _eval_Eq(self, other):
+        if (
+            isinstance(other, SingletonInt)
+            and other._val == self._val
+            and self._coeff == other._coeff
+        ):
+            return sympy.true
+        else:
+            return sympy.false
+
+    # This is necessary so that calling expr.free_symbols on exprs that contain
+    # this Singleton does not error
+    @property
+    def free_symbols(self):
+        return set()
+
+    def __mul__(self, other):
+        if isinstance(other, SingletonInt):
+            raise ValueError(
+                "SingletonInt cannot be multiplied by another SingletonInt"
+            )
+        return SingletonInt(self._val, coeff=self._coeff * other)
+
+    def __rmul__(self, other):
+        if isinstance(other, SingletonInt):
+            raise ValueError(
+                "SingletonInt cannot be multiplied by another SingletonInt"
+            )
+        return SingletonInt(self._val, coeff=self._coeff * other)
+
+    # Make sure we promptly raise an error instead of falling back to building
+    # an expression tree. There are probably more ops, how can we be exhaustive?
+    def __add__(self, other):
+        raise NotImplementedError("NYI")
+
+    def __sub__(self, other):
+        raise NotImplementedError("NYI")
+
+    def __truediv__(self, other):
+        raise NotImplementedError("NYI")
+
+    def __floordiv__(self, other):
+        raise NotImplementedError("NYI")
+
+    def __mod__(self, other):
+        raise NotImplementedError("NYI")
+
+
+# See NOTE [ Inequalities with nested int ]
+@dispatch(sympy.Integer, SingletonInt)
+def _eval_is_ge(a, b):
+    if a < 2:
+        return sympy.false
+    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
+
+
+@dispatch(SingletonInt, sympy.Integer)  # type: ignore[no-redef]
+def _eval_is_ge(a, b):  # noqa: F811
+    if b <= 2:
+        return sympy.true
+    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
+
+
+@dispatch(SingletonInt, SingletonInt)  # type: ignore[no-redef]
+def _eval_is_ge(a, b):  # noqa: F811
+    if a._val == b._val:
+        if a._coeff >= b._coeff:
+            return sympy.true
+        else:
+            return sympy.false
+    raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/solve.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/solve.py
new file mode 100644
index 0000000000000000000000000000000000000000..334a023c0f36b883bcbd4ca71126b65e39916426
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/solve.py
@@ -0,0 +1,178 @@
+import logging
+from typing import Optional
+
+import sympy
+
+from torch.utils._sympy.functions import FloorDiv
+
+
+log = logging.getLogger(__name__)
+
+_MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = {
+    sympy.Eq: sympy.Eq,
+    sympy.Ne: sympy.Ne,
+    sympy.Ge: sympy.Le,
+    sympy.Gt: sympy.Lt,
+    sympy.Le: sympy.Ge,
+    sympy.Lt: sympy.Gt,
+}
+
+INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
+
+
+def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]:
+    return _MIRROR_REL_OP.get(type, None)
+
+
+# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
+#
+# Returns a tuple of:
+#   1. The simplified expression
+#   2. The expression on the right-hand side
+#
+# Returns 'None' if it can't reach a state where the only thing in the left
+# hand side is 'thing'.
+#
+# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
+# left-hand side.
+#
+# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
+# inequalities.
+def try_solve(
+    expr: sympy.Basic,
+    thing: sympy.Basic,
+    trials: int = 5,
+    floordiv_inequality: bool = True,
+) -> Optional[tuple[sympy.Rel, sympy.Expr]]:
+    mirror = mirror_rel_op(type(expr))
+
+    # Ignore unsupported expressions:
+    #   - Those that are not relational operations
+    #   - Those that don't have a mirror (just avoiding unexpected classes)
+    if not isinstance(expr, sympy.Rel) or mirror is None:
+        log.debug("expression with unsupported type: %s", type(expr))
+        return None
+
+    lhs_has_thing = expr.lhs.has(thing)
+    rhs_has_thing = expr.rhs.has(thing)
+
+    # Give up when 'thing' appears on both sides of the relational expression.
+    # That is because, as is, we assume the thing we are trying to isolate is
+    # only on the right-hand side.
+    if lhs_has_thing and rhs_has_thing:
+        log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
+        return None
+
+    # Try considering both LHS and RHS by mirroring the original expression:
+    # a < b ==> b > a
+    expressions = []
+
+    # Add each version of 'expr' if 'thing' is in its left-hand side.
+    if lhs_has_thing:
+        expressions.append(expr)
+    if rhs_has_thing:
+        expressions.append(mirror(expr.rhs, expr.lhs))
+
+    for e in expressions:
+        if e is None:
+            continue
+
+        assert isinstance(e, sympy.Rel)
+
+        for _ in range(trials):
+            trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
+            # Stop if there was no change in this trial.
+            if trial == e:
+                break
+            e = trial  # type: ignore[assignment]
+
+        # Return if we were able to isolate 'thing' on the left-hand side.
+        if isinstance(e, sympy.Rel) and e.lhs == thing:
+            log.debug("solved: %s ---> %s", expr, e)
+            return e, e.rhs
+
+    return None
+
+
+def _try_isolate_lhs(
+    e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
+) -> sympy.Basic:
+    op = type(e)
+
+    if isinstance(e, sympy.Rel):
+        # Move any constants in the left-hand side to the right-hand side.
+        lhs_not_thing = (
+            sum(a for a in e.lhs.args if not a.has(thing))
+            if isinstance(e.lhs, sympy.Add)
+            else 0
+        )
+        e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing)  # type: ignore[attr-defined]
+
+    # Divide both sides by the factors that don't contain thing.
+    if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
+        lhs, rhs = e.args
+        other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
+
+        # If we can't tell whether 'other' is negative or positive, we do nothing.
+        # That is because we don't know whether we have mirror the operation or not.
+        # We also divide only when we know 'rhs' is not zero.
+        if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None) and not (
+            not isinstance(e, INEQUALITY_TYPES) and rhs.is_zero
+        ):
+            # Divide both sides by 'other'.
+            lhs = lhs / other
+            rhs = rhs / other
+
+            # If 'e' is an inequality and 'other' is negative, we have to
+            # mirror the expression.
+            if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
+                op = mirror_rel_op(op)  # type: ignore[assignment]
+
+            assert op is not None
+            e = op(lhs, rhs)
+
+    ################################################################################
+    # left-hand side is FloorDiv
+    ################################################################################
+    #
+    # Given the expression: a // b op c
+    # where 'op' is a relational operation, these rules only work if:
+    #   - b > 0
+    #   - c is an integer
+    if (
+        floordiv_inequality
+        and isinstance(e, sympy.Rel)
+        and isinstance(e.lhs, FloorDiv)
+        and e.lhs.divisor.is_positive
+        and e.rhs.is_integer
+    ):
+        # a // b == expr
+        # => a >= (b * expr) and a < (b * (expr + 1))
+        if isinstance(e, sympy.Eq):
+            numerator, denominator = e.lhs.args
+            return sympy.And(
+                sympy.Ge(numerator, (e.rhs * denominator)),  # type: ignore[arg-type]
+                sympy.Lt(numerator, ((e.rhs + 1) * denominator)),  # type: ignore[arg-type]
+            )
+        # a // b != expr
+        # => a < (b * expr) or a >= (b * (expr + 1))
+        if isinstance(e, sympy.Ne):
+            numerator, denominator = e.lhs.args
+            return sympy.Or(
+                sympy.Lt(numerator, (e.rhs * denominator)),  # type: ignore[arg-type]
+                sympy.Ge(numerator, ((e.rhs + 1) * denominator)),  # type: ignore[arg-type]
+            )
+        # The transformations below only work if b is positive.
+        # Note: we only have this information for constants.
+        # a // b > expr  => a >= b * (expr + 1)
+        # a // b >= expr => a >= b * expr
+        if isinstance(e, (sympy.Gt, sympy.Ge)):
+            quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1)  # type: ignore[arg-type]
+            return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1]))  # type: ignore[arg-type]
+        # a // b < expr  => a < b * expr
+        # a // b <= expr => a < b * (expr + 1)
+        if isinstance(e, (sympy.Lt, sympy.Le)):
+            quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1)  # type: ignore[arg-type]
+            return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1]))  # type: ignore[arg-type]
+
+    return e
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/symbol.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/symbol.py
new file mode 100644
index 0000000000000000000000000000000000000000..de810498bbab5f46fbac6e070de510b5b620dcfd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/symbol.py
@@ -0,0 +1,101 @@
+# mypy: allow-untyped-defs
+"""
+This file contains canonical definitions for our symbol naming conventions,
+across torch.fx.experimental.symbolic_shapes and torch._inductor.  The
+intention is:
+
+1. To make it easily greppable where all the sites we use a prefix are
+2. Make it possible to easily tell if we can introduce a new prefix without
+   introducing a conflict
+
+You can occasionally test if prefixes have been hardcoded by renaming prefixes
+in this file and seeing what breaks.
+"""
+
+from collections.abc import Iterable
+from enum import auto, Enum
+from typing import Union
+
+import sympy
+
+
+class SymT(Enum):
+    SIZE = auto()
+    FLOAT = auto()
+    UNBACKED_INT = auto()
+    UNBACKED_FLOAT = auto()
+    # Inductor: The intermediates in inner_fn tmp0, one generated per ops call.
+    # If one of these shows up in an indexing expression, that means an
+    # indirect load is happening.
+    TMP = auto()
+    # Inductor: Placeholder variable that is later replaced with TMP
+    INDIRECT = auto()
+    # Inductor: Some size expressions are replaced with a precomputed size ps0
+    # which is computed host side, and then directly reused in the kernel, so
+    # we don't repeatedly recompute it on device.
+    PRECOMPUTED_SIZE = auto()
+    # Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
+    # dim in the loop
+    INDEX = auto()
+    # Inductor: A reduction indexing (r0, r1) variables in loops IR which ranges over
+    # reduced dim(s) in the loop
+    R0_INDEX = auto()
+    R1_INDEX = auto()
+    # Inductor: In templated kernels torch._inductor.kernel, we have a hook to
+    # store the final output and append epilogue fusions.  To do this, we must
+    # know what the indexes the outputs range over.  NB: These will also
+    # advertise as INDEX, this is... probably OK?
+    TEMPLATE_INDEX = auto()
+    # Inductor: iteration domain for blockIdx.x/blockIdx.y
+    XBLOCK = auto()
+    YBLOCK = auto()
+    ZBLOCK = auto()
+    # Inductor: this is used solely for dynamic_reshape_indexer
+    VIEW = auto()
+    # Alternate (non-modular) indexing used in halide kernels
+    HALIDE = auto()
+
+
+# Invariant: there must not be a prefix which is a prefix of another string,
+# as this introduces ambiguity
+prefix_str = {
+    SymT.SIZE: "s",  # integer
+    SymT.UNBACKED_INT: "u",  # integer
+    # Prefix z here is chosen to avoid false aliasing in symbol_is_type test
+    # DO NOT add a "z" type.  You also need to avoid conflicts on these
+    # prefixes but this is somewhat easier to manage
+    SymT.FLOAT: "zf",
+    SymT.UNBACKED_FLOAT: "zuf",
+    SymT.TMP: "tmp",
+    SymT.PRECOMPUTED_SIZE: "ps",
+    SymT.INDEX: "i",
+    SymT.R0_INDEX: "r0_",
+    SymT.R1_INDEX: "r1_",
+    SymT.TEMPLATE_INDEX: "idx",
+    SymT.XBLOCK: "x",
+    SymT.YBLOCK: "y",
+    SymT.ZBLOCK: "z",
+    SymT.INDIRECT: "indirect",  # false aliasing?
+    SymT.VIEW: "view",
+    SymT.HALIDE: "h",
+}
+
+
+def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
+    # TODO: maybe put the assumptions here directly
+    return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs)
+
+
+# This type is a little wider than it should be, because free_symbols says
+# that it contains Basic, rather than Symbol
+def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> bool:
+    assert isinstance(sym, sympy.Symbol)
+    name_str = sym.name.lower()  # Match capitalized names like XBLOCK, RBLOCK
+    if isinstance(prefix, SymT):
+        return name_str.startswith(prefix_str[prefix])
+    else:
+        return name_str.startswith(tuple(prefix_str[p] for p in prefix))
+
+
+def free_symbol_is_type(e: sympy.Expr, prefix: Union[SymT, Iterable[SymT]]) -> bool:
+    return any(symbol_is_type(v, prefix) for v in e.free_symbols)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_sympy/value_ranges.py b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/value_ranges.py
new file mode 100644
index 0000000000000000000000000000000000000000..784f9e7ba05149f5aea812f4e3a856ff1cebbe30
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/_sympy/value_ranges.py
@@ -0,0 +1,1052 @@
+# mypy: allow-untyped-defs
+from __future__ import annotations
+
+import dataclasses
+import functools
+import itertools
+import logging
+import math
+import operator
+from typing import (
+    Callable,
+    Generic,
+    Optional,
+    overload,
+    SupportsFloat,
+    TYPE_CHECKING,
+    TypeVar,
+    Union,
+)
+from typing_extensions import TypeGuard
+
+import sympy
+from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
+
+import torch
+from torch._logging import LazyString
+from torch._prims_common import dtype_to_type
+
+from .functions import (
+    _keep_float,
+    FloatTrueDiv,
+    FloorDiv,
+    IntTrueDiv,
+    OpaqueUnaryFn_exp,
+    OpaqueUnaryFn_log,
+    OpaqueUnaryFn_log2,
+    OpaqueUnaryFn_sqrt,
+    PowByNatural,
+    RoundDecimal,
+    RoundToInt,
+    safe_pow,
+    ToFloat,
+    TruncToFloat,
+    TruncToInt,
+)
+from .interp import sympy_interp
+from .numbers import int_oo, IntInfinity, NegativeIntInfinity
+
+
+log = logging.getLogger(__name__)
+
+__all__ = ["ValueRanges", "bound_sympy"]
+
+_T = TypeVar("_T", sympy.Expr, SympyBoolean)
+
+
+class ValueRangeError(RuntimeError):
+    pass
+
+
+# Like sympify, but supports less stuff, and also ensures that direct
+# sympy expressions don't have free variables
+def simple_sympify(e):
+    if isinstance(e, bool):
+        return sympy.true if e else sympy.false
+    elif isinstance(e, int):
+        return sympy.Integer(e)
+    elif isinstance(e, float):
+        # infinity is special; we use it to bracket integers as well
+        if math.isinf(e):
+            return sympy.oo if e > 0 else -sympy.oo
+        return sympy.Float(e)
+    elif isinstance(e, sympy.Expr):
+        assert e.is_number, e
+        # NaNs can occur when doing things like 0 * sympy.oo, but it is better
+        # if the operator notices this and takes care of it, because sometimes
+        # the NaN is inappropriate (for example, for ints, the [-oo, oo] range
+        # should go to zero when multiplied with [0, 0])
+        assert e != sympy.nan
+        return e
+    elif isinstance(e, BooleanAtom):
+        return e
+    else:
+        raise AssertionError(f"not simple sympy type {type(e)}: {e}")
+
+
+# Sympy atomics only. Unlike <=, it also works on Sympy bools.
+def sympy_generic_le(lower, upper):
+    if isinstance(lower, sympy.Expr):
+        assert isinstance(upper, sympy.Expr)
+        # instead of lower <= upper, we do upper >= lower since upper is mostly int_oo
+        # and we have better code paths there.
+        return upper >= lower
+    else:
+        # only negative condition is True > False
+        assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), (
+            lower,
+            upper,
+        )
+        return not (lower and not upper)
+
+
+def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]:
+    return vr.is_bool
+
+
+def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]:
+    return not vr.is_bool
+
+
+ExprIn = Union[int, float, sympy.Expr]
+BoolIn = Union[bool, SympyBoolean]
+AllIn = Union[ExprIn, BoolIn]
+ExprFn = Callable[[sympy.Expr], sympy.Expr]
+ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr]
+BoolFn = Callable[[SympyBoolean], SympyBoolean]
+BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean]
+AllFn = Union[ExprFn, BoolFn]
+AllFn2 = Union[ExprFn2, BoolFn2]
+
+
+@dataclasses.dataclass(frozen=True)
+class ValueRanges(Generic[_T]):
+    if TYPE_CHECKING:
+        # ruff doesn't understand circular references but mypy does
+        ExprVR = ValueRanges[sympy.Expr]  # noqa: F821
+        BoolVR = ValueRanges[SympyBoolean]  # noqa: F821
+        AllVR = Union[ExprVR, BoolVR]
+
+    # Although the type signature here suggests you can pass any
+    # sympy expression, in practice the analysis here only works
+    # with constant sympy expressions
+    lower: _T
+    upper: _T
+    is_bool: bool
+    is_int: bool
+    is_float: bool
+
+    def __repr__(self) -> str:
+        return f"VR[{self.lower}, {self.upper}]"
+
+    @overload
+    def __init__(
+        self: ValueRanges[sympy.Expr],
+        lower: ExprIn,
+        upper: ExprIn,
+    ) -> None:
+        ...
+
+    @overload
+    def __init__(  # type: ignore[misc]
+        self: ValueRanges[SympyBoolean],
+        lower: BoolIn,
+        upper: BoolIn,
+    ) -> None:
+        ...
+
+    def __init__(self, lower: AllIn, upper: AllIn) -> None:
+        lower = simple_sympify(lower)
+        upper = simple_sympify(upper)
+        # TODO: when the bounds have free variables, this may be
+        # nontrivial to actually verify
+        try:
+            if not sympy_generic_le(lower, upper):
+                raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
+        except TypeError as e:
+            raise TypeError(f"Could not compare {lower} <= {upper}") from e
+
+        is_bool_lower = isinstance(lower, SympyBoolean)
+        is_bool_upper = isinstance(upper, SympyBoolean)
+        assert is_bool_lower == is_bool_upper, (lower, upper)
+
+        # Warning: is_int/is_float is best effort.  We do pretty well in
+        # Dynamo, but in Inductor these attributes are often wrong because we
+        # are not very rigorous in dtype analysis.  This is also why we need
+        # the flexible analysis for is_int: sometimes a sympy.oo pops in for
+        # an integer bound. I would /like/ for us not to do this, but it's
+        # too hard to push the invariant through right now.
+        if isinstance(lower, sympy.Integer) and upper == sympy.oo:
+            upper = int_oo
+        if isinstance(upper, sympy.Integer) and lower == -sympy.oo:
+            lower = -int_oo
+        # NB: [-int_oo, -int_oo] and [int_oo, int_oo] are allowed
+        integer_types = (sympy.Integer, NegativeIntInfinity, IntInfinity)
+        is_int_lower = isinstance(lower, integer_types)
+        is_int_upper = isinstance(upper, integer_types)
+
+        # Because this is a frozen class
+        object.__setattr__(self, "lower", lower)
+        object.__setattr__(self, "upper", upper)
+        # Unlike bool/int in Python, we don't report bools are ints
+        #
+        # NB: is_bool_lower == is_bool_upper, so we only need to check one
+        object.__setattr__(self, "is_bool", is_bool_lower)
+        object.__setattr__(
+            self,
+            "is_int",
+            not self.is_bool and is_int_lower and is_int_upper,
+        )
+        """
+        # This assert is just impossible right now, too many sympy bugs
+        if self.is_int:
+            # NB: sympy will sometimes randomly lose the float-ness of zero,
+            # so we also need to account for that in the assertion here.
+            # See also https://github.com/sympy/sympy/issues/26620
+            assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], (
+                lower,
+                upper,
+            )
+            assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper)
+        """
+        # NB: [-oo, oo] always advertises as float!
+        object.__setattr__(self, "is_float", not self.is_bool and not self.is_int)
+        assert self.is_bool or self.is_int or self.is_float, (lower, upper)
+
+    def boolify(self) -> ValueRanges[SympyBoolean]:
+        if vr_is_bool(self):
+            return self
+        elif self == ValueRanges.unknown():
+            return ValueRanges.unknown_bool()
+        else:
+            raise AssertionError(f"not bool like {self}")
+
+    def __contains__(self, x: AllIn) -> bool:
+        return ValueRanges.wrap(x).issubset(self)
+
+    def issubset(self, other):
+        if other is self.unknown_int():
+            return True
+        return sympy_generic_le(other.lower, self.lower) and sympy_generic_le(
+            self.upper, other.upper
+        )
+
+    def tighten(self, other) -> ValueRanges:
+        """Given two ValueRanges, returns their intersection"""
+        return self & other
+
+    # Intersection
+    @overload
+    def __and__(
+        self: ValueRanges[sympy.Expr],
+        other: ValueRanges[sympy.Expr],
+    ) -> ValueRanges[sympy.Expr]:
+        ...
+
+    @overload
+    def __and__(  # type: ignore[misc]
+        self: ValueRanges[SympyBoolean],
+        other: ValueRanges[SympyBoolean],
+    ) -> ValueRanges[SympyBoolean]:
+        ...
+
+    def __and__(self: AllVR, other: AllVR) -> AllVR:
+        if other in (ValueRanges.unknown(), ValueRanges.unknown_int()):
+            return self
+        if self in (ValueRanges.unknown(), ValueRanges.unknown_int()):
+            return other
+        assert self.is_bool == other.is_bool, (self, other)
+        assert self.is_int == other.is_int, (self, other)
+        assert self.is_float == other.is_float, (self, other)
+        if self.is_bool:
+            return ValueRanges(
+                sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)
+            )
+        else:
+            return ValueRanges(
+                sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper)
+            )
+
+    # Union
+    @overload
+    def __or__(
+        self: ValueRanges[sympy.Expr],
+        other: ValueRanges[sympy.Expr],
+    ) -> ValueRanges[sympy.Expr]:
+        ...
+
+    @overload
+    def __or__(  # type: ignore[misc]
+        self: ValueRanges[SympyBoolean],
+        other: ValueRanges[SympyBoolean],
+    ) -> ValueRanges[SympyBoolean]:
+        ...
+
+    def __or__(self: AllVR, other: AllVR) -> AllVR:
+        if ValueRanges.unknown() in (self, other):
+            return ValueRanges.unknown()
+        assert self.is_bool == other.is_bool, (self, other)
+        assert self.is_int == other.is_int, (self, other)
+        assert self.is_float == other.is_float, (self, other)
+        if self.is_bool:
+            return ValueRanges(
+                sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)
+            )
+        else:
+            return ValueRanges(
+                sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper)
+            )
+
+    def is_singleton(self) -> bool:
+        return self.lower == self.upper
+
+    @staticmethod
+    @functools.cache
+    def unknown() -> ValueRanges[sympy.Expr]:
+        return ValueRanges(-sympy.oo, sympy.oo)
+
+    @staticmethod
+    @functools.cache
+    def unknown_int() -> ValueRanges[sympy.Expr]:
+        return ValueRanges(-int_oo, int_oo)
+
+    @staticmethod
+    @functools.cache
+    def unknown_bool() -> ValueRanges[SympyBoolean]:
+        return ValueRanges(sympy.false, sympy.true)
+
+    @overload
+    @staticmethod
+    # work around the fact that bool and int overlap
+    def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR:  # type: ignore[overload-overlap]
+        ...
+
+    @overload
+    @staticmethod
+    def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR:  # type: ignore[misc]
+        ...
+
+    @staticmethod
+    def wrap(arg: Union[AllIn, AllVR]) -> AllVR:
+        if isinstance(arg, ValueRanges):
+            return arg
+        if isinstance(arg, float) and math.isnan(arg):
+            return ValueRanges.unknown()
+        # arg is either ExprIn or BoolIn, but we don't know it here
+        return ValueRanges(arg, arg)  # type: ignore[arg-type]
+
+    @staticmethod
+    def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
+        """Increasing: x <= y => f(x) <= f(y)."""
+        x = ValueRanges.wrap(x)
+        return ValueRanges(fn(x.lower), fn(x.upper))
+
+    @overload
+    @staticmethod
+    def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
+        ...
+
+    @overload
+    @staticmethod
+    def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR:  # type: ignore[misc]
+        ...
+
+    @staticmethod
+    def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR:
+        """Decreasing: x <= y => f(x) >= f(y)."""
+        x = ValueRanges.wrap(x)
+        # consistently either Expr or Bool, but we don't know it here
+        return ValueRanges(fn(x.upper), fn(x.lower))  # type: ignore[arg-type]
+
+    @staticmethod
+    def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
+        """It's increasing or decreasing."""
+        x = ValueRanges.wrap(x)
+        l = fn(x.lower)
+        u = fn(x.upper)
+        return ValueRanges(min(l, u), max(l, u))
+
+    @staticmethod
+    def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
+        """Fn is convex and has a minimum at 0."""
+        x = ValueRanges.wrap(x)
+        if 0 in x:
+            upper = max(fn(x.lower), fn(x.upper))
+            upper = simple_sympify(upper)
+            if isinstance(upper, sympy.Float) or upper == sympy.oo:
+                return ValueRanges(0.0, upper)
+            return ValueRanges(0, upper)
+        return ValueRanges.monotone_map(x, fn)
+
+    @overload
+    @staticmethod
+    def coordinatewise_increasing_map(
+        x: Union[ExprIn, ExprVR],
+        y: Union[ExprIn, ExprVR],
+        fn: ExprFn2,
+    ) -> ExprVR:
+        ...
+
+    @overload
+    @staticmethod
+    def coordinatewise_increasing_map(  # type: ignore[misc]
+        x: Union[BoolIn, BoolVR],
+        y: Union[BoolIn, BoolVR],
+        fn: BoolFn2,
+    ) -> BoolVR:
+        ...
+
+    @staticmethod
+    def coordinatewise_increasing_map(
+        x: Union[AllIn, AllVR],
+        y: Union[AllIn, AllVR],
+        fn: AllFn2,
+    ) -> AllVR:
+        """
+        It's increasing on each coordinate.
+
+        Mathematically:
+        For every 1 <= i <= n and x_i <= y_i we have that
+        f(x1, .., xn) <= f(x1, , yi, ..., xn)
+        """
+        x, y = ValueRanges.wrap(x), ValueRanges.wrap(y)
+        return ValueRanges(
+            fn(x.lower, y.lower),  # type: ignore[arg-type]
+            fn(x.upper, y.upper),  # type: ignore[arg-type]
+        )
+
+    @classmethod
+    def coordinatewise_monotone_map(cls, x, y, fn):
+        """It's increasing or decreasing on each coordinate."""
+        x, y = cls.wrap(x), cls.wrap(y)
+        products = [
+            fn(a, b)
+            for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper])
+        ]
+        return ValueRanges(min(products), max(products))
+
+
+class SymPyValueRangeAnalysis:
+    """
+    It gives bounds on a SymPy operator given bounds on its arguments
+    See the function `bound_sympy` for a function that applies this logic to a full SymPy expression
+    """
+
+    @staticmethod
+    def constant(value, dtype):
+        if isinstance(value, ValueRanges):
+            assert value.is_singleton()
+            value = value.lower
+        # NB: value is NOT a sympy expression, it's a constant!
+        is_python = isinstance(value, (int, float, bool))
+        assert is_python or isinstance(
+            value, (BooleanAtom, sympy.Integer, sympy.Number)
+        )
+
+        # using nan makes subsequent computation throw, and for the purposes of optimization
+        # returning -math.inf - math.inf is equivalent to giving up
+        if isinstance(value, SupportsFloat) and math.isnan(value):
+            if dtype == torch.bool:
+                return ValueRanges.unknown_bool()
+            elif dtype.is_floating_point:
+                return ValueRanges.unknown()
+            else:
+                return ValueRanges.unknown_int()
+
+        if is_python:
+            type_ = dtype_to_type(dtype)
+            value = type_(value)
+        else:
+            # We do a type check on a best-effort basis
+            # We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision
+            if dtype == torch.bool:
+                assert isinstance(value, BooleanAtom)
+            elif dtype.is_floating_point:
+                assert not value.is_finite or value.is_real
+            else:
+                # dtype is intXX
+                assert value.is_integer
+
+        r = ValueRanges.wrap(value)
+        return r
+
+    @staticmethod
+    def to_dtype(a, dtype, src_dtype=None):
+        if dtype == torch.float64:
+            return ValueRanges.increasing_map(a, ToFloat)
+        elif dtype == torch.bool:
+            return ValueRanges.unknown_bool()
+        elif not dtype.is_floating_point:
+            return ValueRanges.unknown_int()
+        return ValueRanges.unknown()
+
+    @staticmethod
+    def trunc_to_int(a, dtype):
+        return ValueRanges.increasing_map(a, TruncToInt)
+
+    @staticmethod
+    def not_(a):
+        a = ValueRanges.wrap(a)
+        a = a.boolify()
+        assert a.is_bool
+        return ValueRanges.decreasing_map(a, sympy.Not)
+
+    @staticmethod
+    def or_(a, b):
+        return ValueRanges.coordinatewise_increasing_map(a, b, sympy.Or)
+
+    @staticmethod
+    def and_(a, b):
+        return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And)
+
+    @staticmethod
+    def _bool_to_int(x):
+        if x.is_singleton():
+            return ValueRanges.wrap(sympy.Integer(1 if x.lower else 0))
+        else:
+            return ValueRanges(sympy.Integer(0), sympy.Integer(1))
+
+    @classmethod
+    def bitwise_and(cls, a, b):
+        a, b = ValueRanges.wrap(a), ValueRanges.wrap(b)
+        if a.is_bool and b.is_bool:
+            return cls.and_(a, b)
+        if a.is_bool:
+            a = cls._bool_to_int(a)
+        if b.is_bool:
+            b = cls._bool_to_int(b)
+        lower = min(a.lower, b.lower)
+        if lower < 0 and lower != -sympy.oo and lower != -int_oo:
+            # If both lower bounds are negative, then bits start like
+            # 1...10..., so the smallest possible value is 1...101...1.
+            # Thus, we need to find the next smallest power of 2 (inclusive).
+            try:
+                lower = -(1 << int(-lower - 1).bit_length())
+            except Exception:
+                lower = -int_oo
+        else:
+            lower = 0
+        return ValueRanges(lower, max(a.upper, b.upper))
+
+    @classmethod
+    def bitwise_or(cls, a, b):
+        a, b = ValueRanges.wrap(a), ValueRanges.wrap(b)
+        if a.is_bool and b.is_bool:
+            return cls.or_(a, b)
+        if a.is_bool:
+            a = cls._bool_to_int(a)
+        if b.is_bool:
+            b = cls._bool_to_int(b)
+        upper = max(a.upper, b.upper)
+        if upper == 0:
+            upper = 0
+        elif upper > 0 and upper != sympy.oo and upper != int_oo:
+            # If both upper bounds are positive, then the largest
+            # possible value is 01...1, so we need to find
+            # next largest power of 2 (exclusive), minus 1
+            try:
+                upper = (1 << int(upper).bit_length()) - 1
+            except Exception:
+                upper = int_oo
+        elif upper < 0:
+            upper = -1
+        return ValueRanges(min(a.lower, b.lower), upper)
+
+    @staticmethod
+    def eq(a, b):
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+        if a.is_singleton() and b.is_singleton() and a.lower == b.lower:
+            return ValueRanges.wrap(sympy.true)
+        elif a.lower > b.upper or b.lower > a.upper:  # ranges disjoint
+            return ValueRanges.wrap(sympy.false)
+        return ValueRanges(sympy.false, sympy.true)
+
+    @classmethod
+    def ne(cls, a, b):
+        return cls.not_(cls.eq(a, b))
+
+    @classmethod
+    def identity(cls, a):
+        return ValueRanges.wrap(a)
+
+    @classmethod
+    def lt(cls, a, b):
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+        assert a.is_bool == b.is_bool
+        if a.is_bool:
+            return cls.and_(cls.not_(a), b)
+        else:
+            if a.upper < b.lower:
+                return ValueRanges.wrap(sympy.true)
+            elif a.lower >= b.upper:
+                return ValueRanges.wrap(sympy.false)
+            return ValueRanges(sympy.false, sympy.true)
+
+    @classmethod
+    def gt(cls, a, b):
+        return cls.lt(b, a)
+
+    @classmethod
+    def le(cls, a, b):
+        return cls.not_(cls.gt(a, b))
+
+    @classmethod
+    def ge(cls, a, b):
+        return cls.not_(cls.lt(a, b))
+
+    @staticmethod
+    def add(a, b):
+        return ValueRanges.coordinatewise_increasing_map(
+            a, b, _keep_float(operator.add)
+        )
+
+    @classmethod
+    def mul(cls, a, b):
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+
+        assert a.is_bool == b.is_bool
+        if a.is_bool:
+            return cls.and_(a, b)
+
+        def safe_mul(a, b):
+            # Make unknown() * wrap(0.0) == wrap(0.0)
+            if a == 0.0 or a == 0:
+                return a
+            elif b == 0.0 or b == 0:
+                return b
+            else:
+                return a * b
+
+        return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul))
+
+    @staticmethod
+    def int_truediv(a, b):
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+        if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)):
+            return ValueRanges.unknown()
+        else:
+            return ValueRanges.coordinatewise_monotone_map(
+                a, b, _keep_float(IntTrueDiv)
+            )
+
+    @staticmethod
+    def truediv(a, b):
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+        if 0 in b or (
+            (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
+        ):
+            return ValueRanges.unknown()
+        else:
+            return ValueRanges.coordinatewise_monotone_map(
+                a, b, _keep_float(FloatTrueDiv)
+            )
+
+    @staticmethod
+    def floordiv(a, b):
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+        if 0 in b:
+            return ValueRanges.unknown_int()
+        products = []
+        for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]):
+            r = FloorDiv(x, y)
+            if r is sympy.nan:
+                products.append((sympy.sign(x) * sympy.sign(y)) * int_oo)
+            else:
+                products.append(r)
+
+        return ValueRanges(min(products), max(products))
+
+    @classmethod
+    def mod(cls, x, y):
+        x = ValueRanges.wrap(x)
+        y = ValueRanges.wrap(y)
+        # nb. We implement C semantics
+
+        def c_mod(a, b):
+            ret = abs(a) % abs(b)
+            if a < 0:
+                ret *= -1
+            return ret
+
+        def c_div(a, b):
+            x = a / b
+            return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x
+
+        if 0 in y:
+            return ValueRanges.unknown_int()
+        elif y.is_singleton():
+            y_val = abs(y.lower)
+            # If it wraps, we need to take the whole interval
+
+            # The function is locally linear if they are in the same class
+            if c_div(x.lower, y_val) == c_div(x.upper, y_val):
+                return ValueRanges.increasing_map(x, lambda u: c_mod(u, y_val))
+            if x.upper < 0:
+                # Negative case
+                return ValueRanges(-y_val + 1, 0)
+            elif x.lower > 0:
+                # Positive case
+                return ValueRanges(0, y_val - 1)
+            else:
+                # Mixed case
+                lower = max(-y_val + 1, x.lower)
+                upper = min(y_val - 1, x.upper)
+                return ValueRanges(lower, upper)
+        else:
+            # Too difficult, we bail out
+            upper = cls.abs(y).upper - 1
+            return ValueRanges(-upper, upper)
+
+    @classmethod
+    def modular_indexing(cls, a, b, c):
+        return cls.mod(cls.floordiv(a, b), c)
+
+    @classmethod
+    def is_non_overlapping_and_dense_indicator(cls, *args):
+        return ValueRanges.unknown_int()
+
+    @classmethod
+    def pow_by_natural(cls, a, b):
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+        if a.is_singleton() and b.is_singleton():
+            return ValueRanges.wrap(safe_pow(a.lower, b.lower))
+        # NB: Exclude zero, because zero is special
+        elif a.lower >= 1:
+            # We should know that b >= 0 but we may have forgotten this fact due
+            # to replacements, so don't assert it, but DO clamp it to prevent
+            # degenerate problems
+            return ValueRanges.coordinatewise_increasing_map(
+                a, b & ValueRanges(0, int_oo), PowByNatural
+            )
+        elif b.is_singleton():
+            if b.lower % 2 == 0:
+                # x^n where n is even
+                return ValueRanges.convex_min_zero_map(
+                    a, lambda x: safe_pow(x, b.lower)
+                )
+            else:
+                # x^n where n is odd
+                return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower))
+        else:
+            # a is potentially negative, and we don't know if the exponent is
+            # even or odd.  So just conservatively set the upper and lower
+            # bound based on what the maximum absolute value could be, in both
+            # directions
+            max_base = max(a.upper, -a.lower)
+            return ValueRanges(
+                -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper)
+            )
+
+    @classmethod
+    def pow(cls, a, b):
+        return ValueRanges.unknown()
+
+        # We could implement all this, but for floating point pow, is there
+        # really a point?
+        """
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+
+        # Not implemented yet. It's a bit tricky
+        # If you want to implement it, compute the partial derivatives of a ** b
+        # and check the ranges where the function is increasing / decreasing
+        # Another non-tight way of doing this is defaulting to doing noting that for a > 0,  a ** b == exp(b * log(a))
+        # If this second option is implemented, by carefult about the types and possible infinities here and there.
+        if not b.is_singleton():
+            return ValueRanges.unknown()
+
+        b = b.lower
+        if a.is_singleton():
+            a = a.lower
+            r = a**b
+            if not r.is_finite:
+                return ValueRanges.unknown()
+            return ValueRanges.wrap(r)
+
+        if b == 0:
+            if not a.lower.is_finite:
+                return ValueRanges.unknown()
+            return ValueRanges.wrap(1.0)
+
+        if b < 0:
+            a = cls.reciprocal(a)
+            b = -b
+
+        if a == ValueRanges.unknown():
+            return ValueRanges.unknown()
+
+        # If the base is positive, then we're good, otherwise nothing's defined
+        if a.lower >= 0:
+            return ValueRanges.increasing_map(a, lambda x: x**b)
+        else:
+            return ValueRanges.unknown()
+        """
+
+    @staticmethod
+    def reciprocal(x):
+        """Needed as it's used in pow, but it won't appear on a SymPy expression"""
+        x = ValueRanges.wrap(x)
+        if 0 in x:
+            return ValueRanges.unknown()
+        else:
+            return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y))  # type: ignore[operator]
+
+    @staticmethod
+    def abs(x):
+        return ValueRanges.convex_min_zero_map(x, abs)
+
+    @staticmethod
+    def exp(x):
+        return ValueRanges.increasing_map(x, OpaqueUnaryFn_exp)
+
+    @staticmethod
+    def log(x):
+        x = ValueRanges.wrap(x)
+        if x.lower <= 0:
+            return ValueRanges.unknown()
+        return ValueRanges.increasing_map(x, OpaqueUnaryFn_log)
+
+    @staticmethod
+    def log2(x):
+        x = ValueRanges.wrap(x)
+        if x.lower <= 0:
+            return ValueRanges.unknown()
+        return ValueRanges.increasing_map(x, OpaqueUnaryFn_log2)
+
+    @classmethod
+    def minimum(cls, a, b):
+        return cls.min_or_max(a, b, sympy.Min)
+
+    @classmethod
+    def maximum(cls, a, b):
+        return cls.min_or_max(a, b, sympy.Max)
+
+    @staticmethod
+    def min_or_max(a, b, fn):
+        a = ValueRanges.wrap(a)
+        b = ValueRanges.wrap(b)
+        return ValueRanges.coordinatewise_increasing_map(a, b, fn)
+
+    @classmethod
+    def floor_to_int(cls, x, dtype):
+        return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor)
+
+    @classmethod
+    def ceil_to_int(cls, x, dtype):
+        return ValueRanges.increasing_map(
+            x, sympy.functions.elementary.integers.ceiling
+        )
+
+    # I think these implementations are sound.  The hazard here is that sympy
+    # will carry out the floor/ceil at too high precision and then something
+    # bad will happen when we convert it to float.
+    #
+    # For truncation, the implementation is clearly sound, because the desired
+    # target float is always exactly representable, since you're just chopping
+    # off bits the mantissa.  But what about ceil/floor?
+    #
+    # The important constraint here is that we're not defining floor on
+    # arbitrary real numbers, only representable float numbers.  So we can
+    # take advantage of the fact that before we reach the first
+    # unrepresentable integer in floating point space, we have the range of
+    # numbers corresponding to exponent zero: all integers, with no fractional
+    # amounts.  floor/ceil is an identity operation in this case.  In the
+    # range below here, representable floating point numbers are spaced
+    # exactly 1/2 apart, and notably, both the floor/ceil are defined floating
+    # point numbers.  There is no "gap" as you step up to the next exponent.
+
+    @classmethod
+    def floor(cls, x):
+        return ValueRanges.increasing_map(
+            x, _keep_float(sympy.functions.elementary.integers.floor)
+        )
+
+    @classmethod
+    def ceil(cls, x):
+        return ValueRanges.increasing_map(
+            x, _keep_float(sympy.functions.elementary.integers.ceiling)
+        )
+
+    @classmethod
+    def round_decimal(cls, number, ndigits):
+        if not ndigits.is_singleton():
+            return ValueRanges.unknown()
+
+        ndigits = ndigits.lower
+        # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
+        # the second parameter.
+        fn = lambda number: RoundDecimal(number, ndigits)  # type: ignore[misc, assignment]  # noqa: E731
+
+        return ValueRanges.increasing_map(number, fn)
+
+    @classmethod
+    def round_to_int(cls, number, dtype):
+        return ValueRanges.increasing_map(number, RoundToInt)
+
+    # It's used in some models on symints
+    @staticmethod
+    def sqrt(x):
+        x = ValueRanges.wrap(x)
+        if x.lower < 0:
+            return ValueRanges.unknown()
+        return ValueRanges.increasing_map(x, OpaqueUnaryFn_sqrt)
+
+    @staticmethod
+    def where(a, b, c):
+        b = ValueRanges.wrap(b)
+        c = ValueRanges.wrap(c)
+        a = a.boolify()
+        # We sometimes write unknown without specifying the type correctly
+        # In particular, we do that when initialising the bounds for loads in bounds.py
+        assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c)
+        if b.is_bool:
+            return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
+        else:
+            return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
+
+    # expr_cond_pair is used to represent a single (expr, condition) pair in piecewise.
+    # We just return the value range of the expression and its corresponding condition as a tuple
+    # and defer the analysis to piecewise
+    @staticmethod
+    def expr_cond_pair(a, b):
+        b = b.boolify()
+        return (a, b)
+
+    # piecewise function can be used to convert a SymBool to SymInt:
+    # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise.
+    #
+    # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair.
+    # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True.
+    @staticmethod
+    def piecewise(*ranges):
+        init_range = None
+        for expr_range, cond_range in ranges:
+            if sympy.true in cond_range:
+                if init_range is None:
+                    init_range = expr_range
+                else:
+                    init_range = init_range | expr_range
+        return init_range
+
+    @staticmethod
+    def cos(x):
+        # TODO: We should tighten value ranges
+        # If input range span is pi + 2*pi*k, then output range is (-1, 1)
+        # otherwise the minimum of the value of the function on the extremes
+        return ValueRanges(-1.0, 1.0)
+
+    @staticmethod
+    def cosh(x):
+        return ValueRanges(0.0, sympy.oo)
+        """
+        x = ValueRanges.wrap(x)
+        if x.lower > 0:
+            return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh)
+        elif x.upper < 0:
+            return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh)
+        return ValueRanges(0.0, sympy.oo)
+        """
+
+    @staticmethod
+    def sin(x):
+        # TODO: We should tighten value ranges
+        # See details on cos
+        return ValueRanges(-1.0, 1.0)
+
+    @staticmethod
+    def sinh(x):
+        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh)
+        return ValueRanges(-sympy.oo, sympy.oo)
+
+    @staticmethod
+    def tan(x):
+        return ValueRanges(-sympy.oo, sympy.oo)
+
+    @staticmethod
+    def tanh(x):
+        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh)
+        return ValueRanges(-sympy.oo, sympy.oo)
+
+    @staticmethod
+    def asin(x):
+        return ValueRanges(-sympy.oo, sympy.oo)
+        """
+        x = ValueRanges.wrap(x)
+        if -1 <= x.lower and x.upper <= 1:
+            return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh)
+        return ValueRanges.unknown()
+        """
+
+    @staticmethod
+    def acos(x):
+        return ValueRanges(-sympy.oo, sympy.oo)
+        """
+        x = ValueRanges.wrap(x)
+        if -1 <= x.lower and x.upper <= 1:
+            return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos)
+        return ValueRanges.unknown()
+        """
+
+    @staticmethod
+    def atan(x):
+        return ValueRanges(-sympy.oo, sympy.oo)
+        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)
+
+    @staticmethod
+    def trunc(x):
+        return ValueRanges.increasing_map(x, TruncToFloat)
+
+
+def bound_sympy(
+    expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
+) -> ValueRanges:
+    log.debug(
+        "bound_sympy(%s)%s",
+        expr,
+        LazyString(
+            lambda: (
+                "\n"
+                + "\n".join(
+                    f"  {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
+                )
+                if ranges
+                else ""
+            )
+        ),
+    )
+    if isinstance(expr, sympy.Number):
+        return ValueRanges.wrap(expr)
+
+    ranges = ranges or {}
+
+    # If there's a tracing context, augment available constrained ranges.
+    context = torch._guards.TracingContext.try_get()
+    if context and context.fake_mode.shape_env:
+        if ranges:
+            ranges = {**context.fake_mode.shape_env.var_to_range, **ranges}
+        else:
+            ranges = context.fake_mode.shape_env.var_to_range
+
+    def missing_handler(s):
+        if s.is_integer:  # type: ignore[attr-defined]
+            if s.is_positive:  # type: ignore[attr-defined]
+                vr = ValueRanges(1, int_oo)
+            elif s.is_nonnegative:  # type: ignore[attr-defined]
+                vr = ValueRanges(0, int_oo)
+            else:
+                vr = ValueRanges.unknown_int()
+        else:
+            # Don't bother trying very hard here
+            vr = ValueRanges.unknown()
+        return vr
+
+    return sympy_interp(
+        SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8413b656e906eff2040ab9805844a73d5c4f0cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__init__.py
@@ -0,0 +1,27 @@
+# mypy: allow-untyped-defs
+from torch._C import (
+    _get_backcompat_broadcast_warn,
+    _get_backcompat_keepdim_warn,
+    _set_backcompat_broadcast_warn,
+    _set_backcompat_keepdim_warn,
+)
+
+
+class Warning:
+    def __init__(self, setter, getter):
+        self.setter = setter
+        self.getter = getter
+
+    def set_enabled(self, value):
+        self.setter(value)
+
+    def get_enabled(self):
+        return self.getter()
+
+    enabled = property(get_enabled, set_enabled)
+
+
+broadcast_warning = Warning(
+    _set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn
+)
+keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f6dc9ebde48f5c39b21f21daf15a4743e43888d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d797a5b0a2bfc7be3cecd13d4d1ad2ac4e52686
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/compare.py
@@ -0,0 +1,99 @@
+# mypy: allow-untyped-defs
+"""Example of Timer and Compare APIs:
+
+$ python -m examples.compare
+"""
+
+import pickle
+import sys
+import time
+
+import torch
+
+import torch.utils.benchmark as benchmark_utils
+
+
+class FauxTorch:
+    """Emulate different versions of pytorch.
+
+    In normal circumstances this would be done with multiple processes
+    writing serialized measurements, but this simplifies that model to
+    make the example clearer.
+    """
+    def __init__(self, real_torch, extra_ns_per_element):
+        self._real_torch = real_torch
+        self._extra_ns_per_element = extra_ns_per_element
+
+    def extra_overhead(self, result):
+        # time.sleep has a ~65 us overhead, so only fake a
+        # per-element overhead if numel is large enough.
+        numel = int(result.numel())
+        if numel > 5000:
+            time.sleep(numel * self._extra_ns_per_element * 1e-9)
+        return result
+
+    def add(self, *args, **kwargs):
+        return self.extra_overhead(self._real_torch.add(*args, **kwargs))
+
+    def mul(self, *args, **kwargs):
+        return self.extra_overhead(self._real_torch.mul(*args, **kwargs))
+
+    def cat(self, *args, **kwargs):
+        return self.extra_overhead(self._real_torch.cat(*args, **kwargs))
+
+    def matmul(self, *args, **kwargs):
+        return self.extra_overhead(self._real_torch.matmul(*args, **kwargs))
+
+
+def main():
+    tasks = [
+        ("add", "add", "torch.add(x, y)"),
+        ("add", "add (extra +0)", "torch.add(x, y + zero)"),
+    ]
+
+    serialized_results = []
+    repeats = 2
+    timers = [
+        benchmark_utils.Timer(
+            stmt=stmt,
+            globals={
+                "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns),
+                "x": torch.ones((size, 4)),
+                "y": torch.ones((1, 4)),
+                "zero": torch.zeros(()),
+            },
+            label=label,
+            sub_label=sub_label,
+            description=f"size: {size}",
+            env=branch,
+            num_threads=num_threads,
+        )
+        for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)]
+        for label, sub_label, stmt in tasks
+        for size in [1, 10, 100, 1000, 10000, 50000]
+        for num_threads in [1, 4]
+    ]
+
+    for i, timer in enumerate(timers * repeats):
+        serialized_results.append(pickle.dumps(
+            timer.blocked_autorange(min_run_time=0.05)
+        ))
+        print(f"\r{i + 1} / {len(timers) * repeats}", end="")
+        sys.stdout.flush()
+    print()
+
+    comparison = benchmark_utils.Compare([
+        pickle.loads(i) for i in serialized_results
+    ])
+
+    print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
+    comparison.print()
+
+    print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
+    comparison.trim_significant_figures()
+    comparison.colorize()
+    comparison.print()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee2c9f9c04ed10b3dac5aa89d76b9d060421b664
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/fuzzer.py
@@ -0,0 +1,86 @@
+# mypy: allow-untyped-defs
+"""Example of the Timer and Fuzzer APIs:
+
+$ python -m examples.fuzzer
+"""
+
+import sys
+
+import torch.utils.benchmark as benchmark_utils
+
+
+def main():
+    add_fuzzer = benchmark_utils.Fuzzer(
+        parameters=[
+            [
+                benchmark_utils.FuzzedParameter(
+                    name=f"k{i}",
+                    minval=16,
+                    maxval=16 * 1024,
+                    distribution="loguniform",
+                ) for i in range(3)
+            ],
+            benchmark_utils.FuzzedParameter(
+                name="d",
+                distribution={2: 0.6, 3: 0.4},
+            ),
+        ],
+        tensors=[
+            [
+                benchmark_utils.FuzzedTensor(
+                    name=name,
+                    size=("k0", "k1", "k2"),
+                    dim_parameter="d",
+                    probability_contiguous=0.75,
+                    min_elements=64 * 1024,
+                    max_elements=128 * 1024,
+                ) for name in ("x", "y")
+            ],
+        ],
+        seed=0,
+    )
+
+    n = 250
+    measurements = []
+    for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)):
+        x, x_order = tensors["x"], str(tensor_properties["x"]["order"])
+        y, y_order = tensors["y"], str(tensor_properties["y"]["order"])
+        shape = ", ".join(tuple(f'{i:>4}' for i in x.shape))
+
+        description = "".join([
+            f"{x.numel():>7} | {shape:<16} | ",
+            f"{'contiguous' if x.is_contiguous() else x_order:<12} | ",
+            f"{'contiguous' if y.is_contiguous() else y_order:<12} | ",
+        ])
+
+        timer = benchmark_utils.Timer(
+            stmt="x + y",
+            globals=tensors,
+            description=description,
+        )
+
+        measurements.append(timer.blocked_autorange(min_run_time=0.1))
+        measurements[-1].metadata = {"numel": x.numel()}
+        print(f"\r{i + 1} / {n}", end="")
+        sys.stdout.flush()
+    print()
+
+    # More string munging to make pretty output.
+    print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}")
+
+    def time_fn(m):
+        return m.median / m.metadata["numel"]
+    measurements.sort(key=time_fn)
+
+    template = f"{{:>6}}{' ' * 19}Size    Shape{' ' * 13}X order        Y order\n{'-' * 80}"
+    print(template.format("Best:"))
+    for m in measurements[:15]:
+        print(f"{time_fn(m) * 1e9:>4.1f} ns / element     {m.description}")
+
+    print("\n" + template.format("Worst:"))
+    for m in measurements[-15:]:
+        print(f"{time_fn(m) * 1e9:>4.1f} ns / element     {m.description}")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf3a7853d73783b65f3bfe9885a919e83f91c8f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/op_benchmark.py
@@ -0,0 +1,105 @@
+# mypy: allow-untyped-defs
+"""Example use of Timer and op fuzzers to measure kernel performance.
+
+$ python -m examples.op_benchmark
+"""
+
+import numpy as np
+import torch
+
+from torch.utils.benchmark import Timer
+from torch.utils.benchmark.op_fuzzers.binary import BinaryOpFuzzer
+from torch.utils.benchmark.op_fuzzers.unary import UnaryOpFuzzer
+import operator
+
+
+_MEASURE_TIME = 1.0
+
+
+def assert_dicts_equal(dict_0, dict_1):
+    """Builtin dict comparison will not compare numpy arrays.
+    e.g.
+        x = {"a": np.ones((2, 1))}
+        x == x  # Raises ValueError
+    """
+    assert set(dict_0.keys()) == set(dict_0.keys())
+    assert all(np.all(v == dict_1[k]) for k, v in dict_0.items() if k != "dtype")
+
+
+def run(n, stmt, fuzzer_cls):
+    float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
+    int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n)
+    raw_results = []
+    for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter)):
+        float_tensors, float_tensor_params, float_params = float_values
+        int_tensors, int_tensor_params, int_params = int_values
+
+        # This benchmark assumes that the two fuzzers generate identically
+        # sized and strided Tensors, since the same seed is used.
+        assert_dicts_equal(float_params, int_params)
+        assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"])
+
+        float_measurement, int_measurement = (
+            Timer(
+                stmt,
+                globals=tensors,
+            ).blocked_autorange(min_run_time=_MEASURE_TIME)
+            for tensors in (float_tensors, int_tensors)
+        )
+
+        descriptions = []
+        for name in float_tensors:
+            shape_str = "(" + ", ".join([
+                f"2 ** {int(np.log2(i))}"
+                if 2 ** int(np.log2(i)) == i and i > 1
+                else str(i)
+                for i in float_tensors[name].shape
+            ]) + ")"
+            order = float_tensor_params[name]["order"]
+            order_str = ("" if all(order == np.arange(len(order))) else str(tuple(order)))
+            steps = float_tensor_params[name]["steps"]
+            steps_str = str(steps) if sum(steps) > len(steps) else ""
+            descriptions.append((name, shape_str, order_str, steps_str))
+        raw_results.append((float_measurement, int_measurement, descriptions))
+
+        print(f"\r{i + 1} / {n}", end="")
+    print()
+
+    parsed_results, name_len, shape_len, order_len, steps_len = [], 0, 0, 0, 0
+    for float_measurement, int_measurement, descriptions in raw_results:
+        t_float = float_measurement.median * 1e6
+        t_int = int_measurement.median * 1e6
+        rel_diff = abs(t_float - t_int) / (t_float + t_int) * 2
+        parsed_results.append((t_float, t_int, rel_diff, descriptions))
+        for name, shape, order, steps in descriptions:
+            name_len = max(name_len, len(name))
+            shape_len = max(shape_len, len(shape))
+            order_len = max(order_len, len(order))
+            steps_len = max(steps_len, len(steps))
+
+    parsed_results.sort(key=operator.itemgetter(2))
+
+    print(f"stmt: {stmt}")
+    print(f" diff    faster{'':>17}{' ' * name_len} ", end="")
+    print(f"{'shape'.ljust(shape_len)}{'':>16}{'order'.ljust(order_len)}", end="")
+    print(f"          steps\n{'-' * 100}")
+    for results, spacer in [(parsed_results[:10], "..."), (parsed_results[-10:], "")]:
+        for t_float, t_int, rel_diff, descriptions in results:
+            time_str = [f"{rel_diff * 100:>4.1f}%    {'int' if t_int < t_float else 'float':<20}"]
+            time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]])
+            for t_str, (name, shape, order, steps) in zip(time_str, descriptions):
+                name = f"{name}:".ljust(name_len + 1)
+                shape = shape.ljust(shape_len + 10)
+                order = order.ljust(order_len)
+                print(f"{t_str} {name}  {shape}|     {order}      |   {steps}")
+        print(spacer)
+
+
+def main():
+    run(n=100, stmt="torch.median(x, dim=0)", fuzzer_cls=UnaryOpFuzzer)
+    run(n=100, stmt="torch.square(x)", fuzzer_cls=UnaryOpFuzzer)
+    run(n=100, stmt="x + y", fuzzer_cls=BinaryOpFuzzer)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8137d4d8791975b46b1314c2f3a05ed048dbdcd3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/simple_timeit.py
@@ -0,0 +1,25 @@
+"""Trivial use of Timer API:
+
+$ python -m examples.simple_timeit
+"""
+
+import torch
+
+import torch.utils.benchmark as benchmark_utils
+
+
+def main() -> None:
+    timer = benchmark_utils.Timer(
+        stmt="x + y",
+        globals={"x": torch.ones((4, 8)), "y": torch.ones((1, 8))},
+        label="Broadcasting add (4x8)",
+    )
+
+    for i in range(3):
+        print(f"Run: {i}\n{'-' * 40}")
+        print(f"timeit:\n{timer.timeit(10000)}\n")
+        print(f"autorange:\n{timer.blocked_autorange()}\n\n")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3c8cbe5b12c2d238e6ca580228e94045b11ae48
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py
@@ -0,0 +1,114 @@
+# mypy: allow-untyped-defs
+"""Microbenchmarks for the torch.fft module"""
+from argparse import ArgumentParser
+from collections import namedtuple
+from collections.abc import Iterable
+
+import torch
+import torch.fft
+from torch.utils import benchmark
+from torch.utils.benchmark.op_fuzzers.spectral import SpectralOpFuzzer
+
+
+def _dim_options(ndim):
+    if ndim == 1:
+        return [None]
+    elif ndim == 2:
+        return [0, 1, None]
+    elif ndim == 3:
+        return [0, 1, 2, (0, 1), (0, 2), None]
+    raise ValueError(f"Expected ndim in range 1-3, got {ndim}")
+
+
+def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, device: str, samples: int,
+                  probability_regular: float):
+    cuda = device == 'cuda'
+    spectral_fuzzer = SpectralOpFuzzer(seed=seed, dtype=dtype, cuda=cuda,
+                                       probability_regular=probability_regular)
+    results = []
+    for tensors, tensor_params, params in spectral_fuzzer.take(samples):
+        shape = [params['k0'], params['k1'], params['k2']][:params['ndim']]
+        str_shape = ' x '.join([f"{s:<4}" for s in shape])
+        sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
+        for dim in _dim_options(params['ndim']):
+            for nthreads in (1, 4, 16) if not cuda else (1,):
+                measurement = benchmark.Timer(
+                    stmt='func(x, dim=dim)',
+                    globals={'func': function, 'x': tensors['x'], 'dim': dim},
+                    label=f"{name}_{device}",
+                    sub_label=sub_label,
+                    description=f"dim={dim}",
+                    num_threads=nthreads,
+                ).blocked_autorange(min_run_time=1)
+                measurement.metadata = {
+                    'name': name,
+                    'device': device,
+                    'dim': dim,
+                    'shape': shape,
+                }
+                measurement.metadata.update(tensor_params['x'])
+                results.append(measurement)
+    return results
+
+
+Benchmark = namedtuple('Benchmark', ['name', 'function', 'dtype'])
+BENCHMARKS = [
+    Benchmark('fft_real', torch.fft.fftn, torch.float32),
+    Benchmark('fft_complex', torch.fft.fftn, torch.complex64),
+    Benchmark('ifft', torch.fft.ifftn, torch.complex64),
+    Benchmark('rfft', torch.fft.rfftn, torch.float32),
+    Benchmark('irfft', torch.fft.irfftn, torch.complex64),
+]
+BENCHMARK_MAP = {b.name: b for b in BENCHMARKS}
+BENCHMARK_NAMES = [b.name for b in BENCHMARKS]
+DEVICE_NAMES = ['cpu', 'cuda']
+
+def _output_csv(file, results):
+    file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n')
+    for measurement in results:
+        metadata = measurement.metadata
+        device, dim, shape, name, numel, contiguous = (
+            metadata['device'], metadata['dim'], metadata['shape'],
+            metadata['name'], metadata['numel'], metadata['is_contiguous'])
+
+        if isinstance(dim, Iterable):
+            dim_str = '-'.join(str(d) for d in dim)
+        else:
+            dim_str = str(dim)
+            shape_str = 'x'.join(str(s) for s in shape)
+
+        print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str,  # type: ignore[possibly-undefined]
+              measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6,
+              sep=',', file=file)
+
+
+if __name__ == '__main__':
+    parser = ArgumentParser(description=__doc__)
+    parser.add_argument('--device', type=str, choices=DEVICE_NAMES, nargs='+', default=DEVICE_NAMES)
+    parser.add_argument('--bench', type=str, choices=BENCHMARK_NAMES, nargs='+', default=BENCHMARK_NAMES)
+    parser.add_argument('--seed', type=int, default=0)
+    parser.add_argument('--samples', type=int, default=10)
+    parser.add_argument('--probability-regular', '--probability_regular', type=float, default=1.0)
+    parser.add_argument('-o', '--output', type=str)
+    args = parser.parse_args()
+
+    num_benchmarks = len(args.device) * len(args.bench)
+    i = 0
+    results = []
+    for device in args.device:
+        for bench in (BENCHMARK_MAP[b] for b in args.bench):
+            results += run_benchmark(
+                name=bench.name, function=bench.function, dtype=bench.dtype,
+                seed=args.seed, device=device, samples=args.samples,
+                probability_regular=args.probability_regular)
+            i += 1
+            print(f'Completed {bench.name} benchmark on {device} ({i} of {num_benchmarks})')
+
+    if args.output is not None:
+        with open(args.output, 'w') as f:
+            _output_csv(f, results)
+
+    compare = benchmark.Compare(results)
+    compare.trim_significant_figures()
+    compare.colorize()
+    compare.print()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f394179b3e09a90882057346ed1737e3b84367
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/binary.py
@@ -0,0 +1,107 @@
+# mypy: allow-untyped-defs
+import numpy as np
+import torch
+
+from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor
+
+
+_MIN_DIM_SIZE = 16
+_MAX_DIM_SIZE = 16 * 1024 ** 2
+_POW_TWO_SIZES = tuple(2 ** i for i in range(
+    int(np.log2(_MIN_DIM_SIZE)),
+    int(np.log2(_MAX_DIM_SIZE)) + 1,
+))
+
+
+class BinaryOpFuzzer(Fuzzer):
+    def __init__(self, seed, dtype=torch.float32, cuda=False):
+        super().__init__(
+            parameters=[
+                # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("dim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+
+                # Shapes for `x` and `y`.
+                #       It is important to test all shapes, however
+                #   powers of two are especially important and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the powers of two
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                #       Moreover, `y` will occasionally have singleton
+                #   dimensions in order to test broadcasting.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=_MIN_DIM_SIZE,
+                        maxval=_MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_pow2_{i}",
+                        distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_any_{i}"): 0.8,
+                            ParameterAlias(f"k_pow2_{i}"): 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+
+                [
+                    FuzzedParameter(
+                        name=f"y_k{i}",
+                        distribution={
+                            ParameterAlias(f"k{i}"): 0.8,
+                            1: 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+
+                # Steps for `x` and `y`. (Benchmarks strided memory access.)
+                [
+                    FuzzedParameter(
+                        name=f"{name}_step_{i}",
+                        distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04},
+                    )
+                    for i in range(3)
+                    for name in ("x", "y")
+                ],
+
+                # Repeatable entropy for downstream applications.
+                FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"),
+            ],
+            tensors=[
+                FuzzedTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    steps=("x_step_0", "x_step_1", "x_step_2"),
+                    probability_contiguous=0.75,
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    max_allocation_bytes=2 * 1024**3,  # 2 GB
+                    dim_parameter="dim",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+                FuzzedTensor(
+                    name="y",
+                    size=("y_k0", "y_k1", "y_k2"),
+                    steps=("x_step_0", "x_step_1", "x_step_2"),
+                    probability_contiguous=0.75,
+                    max_allocation_bytes=2 * 1024**3,  # 2 GB
+                    dim_parameter="dim",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py
new file mode 100644
index 0000000000000000000000000000000000000000..014361877dea148064fb71f1b504988d8eebbb17
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py
@@ -0,0 +1,107 @@
+# mypy: allow-untyped-defs
+import numpy as np
+import torch
+
+from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedSparseTensor
+
+
+_MIN_DIM_SIZE = 16
+_MAX_DIM_SIZE = 16 * 1024 ** 2
+_POW_TWO_SIZES = tuple(2 ** i for i in range(
+    int(np.log2(_MIN_DIM_SIZE)),
+    int(np.log2(_MAX_DIM_SIZE)) + 1,
+))
+
+
+class BinaryOpSparseFuzzer(Fuzzer):
+    def __init__(self, seed, dtype=torch.float32, cuda=False):
+        super().__init__(
+            parameters=[
+                # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("dim_parameter", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+                FuzzedParameter(
+                    name="sparse_dim",
+                    distribution={1: 0.4, 2: 0.4, 3: 0.2},
+                    strict=True
+                ),
+                # Shapes for `x` and `y`.
+                #       It is important to test all shapes, however
+                #   powers of two are especially important and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the powers of two
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                #       Moreover, `y` will occasionally have singleton
+                #   dimensions in order to test broadcasting.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=_MIN_DIM_SIZE,
+                        maxval=_MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_pow2_{i}",
+                        distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_any_{i}"): 0.8,
+                            ParameterAlias(f"k_pow2_{i}"): 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"y_k{i}",
+                        distribution={
+                            ParameterAlias(f"k{i}"): 1.0},
+                        strict=True,
+                    ) for i in range(3)
+                ],
+                FuzzedParameter(
+                    name="density",
+                    distribution={0.1: 0.4, 0.05: 0.3, 0.01: 0.3},
+                ),
+                FuzzedParameter(
+                    name="coalesced",
+                    distribution={True: 0.5, False: 0.5},
+                ),
+                # Repeatable entropy for downstream applications.
+                FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"),
+            ],
+            tensors=[
+                FuzzedSparseTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    dim_parameter="dim_parameter",
+                    sparse_dim="sparse_dim",
+                    density="density",
+                    coalesced="coalesced",
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+                FuzzedSparseTensor(
+                    name="y",
+                    size=("y_k0", "y_k1", "y_k2"),
+                    dim_parameter="dim_parameter",
+                    sparse_dim="sparse_dim",
+                    density="density",
+                    coalesced="coalesced",
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6fe622183f68f02770507250d41f280e88a1d92
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py
@@ -0,0 +1,83 @@
+# mypy: allow-untyped-defs
+
+import numpy as np
+import torch
+from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedSparseTensor
+
+
+_MIN_DIM_SIZE = 16
+_MAX_DIM_SIZE = 16 * 1024 ** 2
+_POW_TWO_SIZES = tuple(2 ** i for i in range(
+    int(np.log2(_MIN_DIM_SIZE)),
+    int(np.log2(_MAX_DIM_SIZE)) + 1,
+))
+
+class UnaryOpSparseFuzzer(Fuzzer):
+    def __init__(self, seed, dtype=torch.float32, cuda=False):
+        super().__init__(
+            parameters=[
+                # Sparse dim parameter of x. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("dim_parameter", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+                FuzzedParameter(
+                    name="sparse_dim",
+                    distribution={1: 0.4, 2: 0.4, 3: 0.2},
+                    strict=True
+                ),
+                # Shapes for `x`.
+                #   It is important to test all shapes, however
+                #   powers of two are especially important and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the powers of two
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=_MIN_DIM_SIZE,
+                        maxval=_MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_pow2_{i}",
+                        distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_any_{i}"): 0.8,
+                            ParameterAlias(f"k_pow2_{i}"): 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+                FuzzedParameter(
+                    name="density",
+                    distribution={0.1: 0.4, 0.05: 0.3, 0.01: 0.3},
+                ),
+                FuzzedParameter(
+                    name="coalesced",
+                    distribution={True: 0.5, False: 0.5},
+                ),
+                FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"),
+            ],
+            tensors=[
+                FuzzedSparseTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    dim_parameter="dim_parameter",
+                    sparse_dim="sparse_dim",
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    density="density",
+                    coalesced="coalesced",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b9e92d7a2c7b13a955d2dd9ef2a26ec8e574903
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py
@@ -0,0 +1,94 @@
+# mypy: allow-untyped-defs
+import math
+
+import torch
+from torch.utils import benchmark
+from torch.utils.benchmark import FuzzedParameter, FuzzedTensor, ParameterAlias
+
+
+__all__ = ['SpectralOpFuzzer']
+
+MIN_DIM_SIZE = 16
+MAX_DIM_SIZE = 16 * 1024
+
+def power_range(upper_bound, base):
+    return (base ** i for i in range(int(math.log(upper_bound, base)) + 1))
+
+# List of regular numbers from MIN_DIM_SIZE to MAX_DIM_SIZE
+# These numbers factorize into multiples of prime factors 2, 3, and 5 only
+# and are usually the fastest in FFT implementations.
+REGULAR_SIZES = []
+for i in power_range(MAX_DIM_SIZE, 2):
+    for j in power_range(MAX_DIM_SIZE // i, 3):
+        ij = i * j
+        for k in power_range(MAX_DIM_SIZE // ij, 5):
+            ijk = ij * k
+            if ijk > MIN_DIM_SIZE:
+                REGULAR_SIZES.append(ijk)
+REGULAR_SIZES.sort()
+
+class SpectralOpFuzzer(benchmark.Fuzzer):
+    def __init__(self, *, seed: int, dtype=torch.float64,
+                 cuda: bool = False, probability_regular: float = 1.0):
+        super().__init__(
+            parameters=[
+                # Dimensionality of x. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("ndim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+
+                # Shapes for `x`.
+                #   It is important to test all shapes, however
+                #   regular sizes are especially important to the FFT and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the regular numbers
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=MIN_DIM_SIZE,
+                        maxval=MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_regular_{i}",
+                        distribution={size: 1. / len(REGULAR_SIZES) for size in REGULAR_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_regular_{i}"): probability_regular,
+                            ParameterAlias(f"k_any_{i}"): 1 - probability_regular,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+
+                # Steps for `x`. (Benchmarks strided memory access.)
+                [
+                    FuzzedParameter(
+                        name=f"step_{i}",
+                        distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04},
+                    ) for i in range(3)
+                ],
+            ],
+            tensors=[
+                FuzzedTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    steps=("step_0", "step_1", "step_2"),
+                    probability_contiguous=0.75,
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    max_allocation_bytes=2 * 1024**3,  # 2 GB
+                    dim_parameter="ndim",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py
new file mode 100644
index 0000000000000000000000000000000000000000..e780b421f24c8c9a68d9196036dc925fa634eccb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/op_fuzzers/unary.py
@@ -0,0 +1,82 @@
+# mypy: allow-untyped-defs
+import numpy as np
+import torch
+
+from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor
+
+
+_MIN_DIM_SIZE = 16
+_MAX_DIM_SIZE = 16 * 1024 ** 2
+_POW_TWO_SIZES = tuple(2 ** i for i in range(
+    int(np.log2(_MIN_DIM_SIZE)),
+    int(np.log2(_MAX_DIM_SIZE)) + 1,
+))
+
+
+class UnaryOpFuzzer(Fuzzer):
+    def __init__(self, seed, dtype=torch.float32, cuda=False):
+        super().__init__(
+            parameters=[
+                # Dimensionality of x. (e.g. 1D, 2D, or 3D.)
+                FuzzedParameter("dim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True),
+
+                # Shapes for `x`.
+                #   It is important to test all shapes, however
+                #   powers of two are especially important and therefore
+                #   warrant special attention. This is done by generating
+                #   both a value drawn from all integers between the min and
+                #   max allowed values, and another from only the powers of two
+                #   (both distributions are loguniform) and then randomly
+                #   selecting between the two.
+                [
+                    FuzzedParameter(
+                        name=f"k_any_{i}",
+                        minval=_MIN_DIM_SIZE,
+                        maxval=_MAX_DIM_SIZE,
+                        distribution="loguniform",
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k_pow2_{i}",
+                        distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES}
+                    ) for i in range(3)
+                ],
+                [
+                    FuzzedParameter(
+                        name=f"k{i}",
+                        distribution={
+                            ParameterAlias(f"k_any_{i}"): 0.8,
+                            ParameterAlias(f"k_pow2_{i}"): 0.2,
+                        },
+                        strict=True,
+                    ) for i in range(3)
+                ],
+
+                # Steps for `x`. (Benchmarks strided memory access.)
+                [
+                    FuzzedParameter(
+                        name=f"x_step_{i}",
+                        distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04},
+                    ) for i in range(3)
+                ],
+
+                # Repeatable entropy for downstream applications.
+                FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"),
+            ],
+            tensors=[
+                FuzzedTensor(
+                    name="x",
+                    size=("k0", "k1", "k2"),
+                    steps=("x_step_0", "x_step_1", "x_step_2"),
+                    probability_contiguous=0.75,
+                    min_elements=4 * 1024,
+                    max_elements=32 * 1024 ** 2,
+                    max_allocation_bytes=2 * 1024**3,  # 2 GB
+                    dim_parameter="dim",
+                    dtype=dtype,
+                    cuda=cuda,
+                ),
+            ],
+            seed=seed,
+        )
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py
new file mode 100644
index 0000000000000000000000000000000000000000..068e62ec87a3deb60260a9a9b7b7205a49d7ddfe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/_stubs.py
@@ -0,0 +1,41 @@
+from typing import Any, Callable
+from typing_extensions import Protocol, runtime_checkable
+
+
+class TimerClass(Protocol):
+    """This is the portion of the `timeit.Timer` API used by benchmark utils."""
+    def __init__(
+        self,
+        stmt: str,
+        setup: str,
+        timer: Callable[[], float],
+        globals: dict[str, Any],
+        **kwargs: Any,
+    ) -> None:
+        ...
+
+    def timeit(self, number: int) -> float:
+        ...
+
+
+@runtime_checkable
+class TimeitModuleType(Protocol):
+    """Modules generated from `timeit_template.cpp`."""
+    def timeit(self, number: int) -> float:
+        ...
+
+
+class CallgrindModuleType(Protocol):
+    """Replicates the valgrind endpoints in `torch._C`.
+
+    These bindings are used to collect Callgrind profiles on earlier versions
+    of PyTorch and will eventually be removed.
+    """
+    __file__: str
+    __name__: str
+
+    def _valgrind_supported_platform(self) -> bool:
+        ...
+
+    def _valgrind_toggle(self) -> None:
+        ...
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e25909f6c85ebcec237127aa0f5497cbd0c97908
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/common.py
@@ -0,0 +1,356 @@
+"""Base shared classes and utilities."""
+
+import collections
+import contextlib
+import dataclasses
+import os
+import shutil
+import tempfile
+import textwrap
+import time
+from typing import cast, Any, Optional
+from collections.abc import Iterable, Iterator
+import uuid
+
+import torch
+
+
+__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"]
+
+
+_MAX_SIGNIFICANT_FIGURES = 4
+_MIN_CONFIDENCE_INTERVAL = 25e-9  # 25 ns
+
+# Measurement will include a warning if the distribution is suspect. All
+# runs are expected to have some variation; these parameters set the
+# thresholds.
+_IQR_WARN_THRESHOLD = 0.1
+_IQR_GROSS_WARN_THRESHOLD = 0.25
+
+
+@dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True)
+class TaskSpec:
+    """Container for information used to define a Timer. (except globals)"""
+    stmt: str
+    setup: str
+    global_setup: str = ""
+    label: Optional[str] = None
+    sub_label: Optional[str] = None
+    description: Optional[str] = None
+    env: Optional[str] = None
+    num_threads: int = 1
+
+    @property
+    def title(self) -> str:
+        """Best effort attempt at a string label for the measurement."""
+        if self.label is not None:
+            return self.label + (f": {self.sub_label}" if self.sub_label else "")
+        elif "\n" not in self.stmt:
+            return self.stmt + (f": {self.sub_label}" if self.sub_label else "")
+        return (
+            f"stmt:{f' ({self.sub_label})' if self.sub_label else ''}\n"
+            f"{textwrap.indent(self.stmt, '  ')}"
+        )
+
+    def setup_str(self) -> str:
+        return (
+            "" if (self.setup == "pass" or not self.setup)
+            else f"setup:\n{textwrap.indent(self.setup, '  ')}" if "\n" in self.setup
+            else f"setup: {self.setup}"
+        )
+
+    def summarize(self) -> str:
+        """Build TaskSpec portion of repr string for other containers."""
+        sections = [
+            self.title,
+            self.description or "",
+            self.setup_str(),
+        ]
+        return "\n".join([f"{i}\n" if "\n" in i else i for i in sections if i])
+
+_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(TaskSpec))
+
+
+@dataclasses.dataclass(init=True, repr=False)
+class Measurement:
+    """The result of a Timer measurement.
+
+    This class stores one or more measurements of a given statement. It is
+    serializable and provides several convenience methods
+    (including a detailed __repr__) for downstream consumers.
+    """
+    number_per_run: int
+    raw_times: list[float]
+    task_spec: TaskSpec
+    metadata: Optional[dict[Any, Any]] = None  # Reserved for user payloads.
+
+    def __post_init__(self) -> None:
+        self._sorted_times: tuple[float, ...] = ()
+        self._warnings: tuple[str, ...] = ()
+        self._median: float = -1.0
+        self._mean: float = -1.0
+        self._p25: float = -1.0
+        self._p75: float = -1.0
+
+    def __getattr__(self, name: str) -> Any:
+        # Forward TaskSpec fields for convenience.
+        if name in _TASKSPEC_FIELDS:
+            return getattr(self.task_spec, name)
+        return super().__getattribute__(name)
+
+    # =========================================================================
+    # == Convenience methods for statistics ===================================
+    # =========================================================================
+    #
+    # These methods use raw time divided by number_per_run; this is an
+    # extrapolation and hides the fact that different number_per_run will
+    # result in different amortization of overheads, however if Timer has
+    # selected an appropriate number_per_run then this is a non-issue, and
+    # forcing users to handle that division would result in a poor experience.
+    @property
+    def times(self) -> list[float]:
+        return [t / self.number_per_run for t in self.raw_times]
+
+    @property
+    def median(self) -> float:
+        self._lazy_init()
+        return self._median
+
+    @property
+    def mean(self) -> float:
+        self._lazy_init()
+        return self._mean
+
+    @property
+    def iqr(self) -> float:
+        self._lazy_init()
+        return self._p75 - self._p25
+
+    @property
+    def significant_figures(self) -> int:
+        """Approximate significant figure estimate.
+
+        This property is intended to give a convenient way to estimate the
+        precision of a measurement. It only uses the interquartile region to
+        estimate statistics to try to mitigate skew from the tails, and
+        uses a static z value of 1.645 since it is not expected to be used
+        for small values of `n`, so z can approximate `t`.
+
+        The significant figure estimation used in conjunction with the
+        `trim_sigfig` method to provide a more human interpretable data
+        summary. __repr__ does not use this method; it simply displays raw
+        values. Significant figure estimation is intended for `Compare`.
+        """
+        self._lazy_init()
+        n_total = len(self._sorted_times)
+        lower_bound = int(n_total // 4)
+        upper_bound = int(torch.tensor(3 * n_total / 4).ceil())
+        interquartile_points: tuple[float, ...] = self._sorted_times[lower_bound:upper_bound]
+        std = torch.tensor(interquartile_points).std(unbiased=False).item()
+        sqrt_n = torch.tensor(len(interquartile_points)).sqrt().item()
+
+        # Rough estimates. These are by no means statistically rigorous.
+        confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL)
+        relative_ci = torch.tensor(self._median / confidence_interval).log10().item()
+        num_significant_figures = int(torch.tensor(relative_ci).floor())
+        return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES)
+
+    @property
+    def has_warnings(self) -> bool:
+        self._lazy_init()
+        return bool(self._warnings)
+
+    def _lazy_init(self) -> None:
+        if self.raw_times and not self._sorted_times:
+            self._sorted_times = tuple(sorted(self.times))
+            _sorted_times = torch.tensor(self._sorted_times, dtype=torch.float64)
+            self._median = _sorted_times.quantile(.5).item()
+            self._mean = _sorted_times.mean().item()
+            self._p25 = _sorted_times.quantile(.25).item()
+            self._p75 = _sorted_times.quantile(.75).item()
+
+            def add_warning(msg: str) -> None:
+                rel_iqr = self.iqr / self.median * 100
+                self._warnings += (
+                    f"  WARNING: Interquartile range is {rel_iqr:.1f}% "
+                    f"of the median measurement.\n           {msg}",
+                )
+
+            if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD):
+                add_warning("This suggests significant environmental influence.")
+            elif not self.meets_confidence(_IQR_WARN_THRESHOLD):
+                add_warning("This could indicate system fluctuation.")
+
+
+    def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool:
+        return self.iqr / self.median < threshold
+
+    @property
+    def title(self) -> str:
+        return self.task_spec.title
+
+    @property
+    def env(self) -> str:
+        return (
+            "Unspecified env" if self.taskspec.env is None
+            else cast(str, self.taskspec.env)
+        )
+
+    @property
+    def as_row_name(self) -> str:
+        return self.sub_label or self.stmt or "[Unknown]"
+
+    def __repr__(self) -> str:
+        """
+        Example repr:
+            
+              Broadcasting add (4x8)
+              Median: 5.73 us
+              IQR:    2.25 us (4.01 to 6.26)
+              372 measurements, 100 runs per measurement, 1 thread
+              WARNING: Interquartile range is 39.4% of the median measurement.
+                       This suggests significant environmental influence.
+        """
+        self._lazy_init()
+        skip_line, newline = "MEASUREMENT_REPR_SKIP_LINE", "\n"
+        n = len(self._sorted_times)
+        time_unit, time_scale = select_unit(self._median)
+        iqr_filter = '' if n >= 4 else skip_line
+
+        repr_str = f"""
+{super().__repr__()}
+{self.task_spec.summarize()}
+  {'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit}
+  {iqr_filter}IQR:    {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f})
+  {n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''}
+{newline.join(self._warnings)}""".strip()  # noqa: B950
+
+        return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l)
+
+    @staticmethod
+    def merge(measurements: Iterable["Measurement"]) -> list["Measurement"]:
+        """Convenience method for merging replicates.
+
+        Merge will extrapolate times to `number_per_run=1` and will not
+        transfer any metadata. (Since it might differ between replicates)
+        """
+        grouped_measurements: collections.defaultdict[TaskSpec, list[Measurement]] = collections.defaultdict(list)
+        for m in measurements:
+            grouped_measurements[m.task_spec].append(m)
+
+        def merge_group(task_spec: TaskSpec, group: list["Measurement"]) -> "Measurement":
+            times: list[float] = []
+            for m in group:
+                # Different measurements could have different `number_per_run`,
+                # so we call `.times` which normalizes the results.
+                times.extend(m.times)
+
+            return Measurement(
+                number_per_run=1,
+                raw_times=times,
+                task_spec=task_spec,
+                metadata=None,
+            )
+
+        return [merge_group(t, g) for t, g in grouped_measurements.items()]
+
+
+def select_unit(t: float) -> tuple[str, float]:
+    """Determine how to scale times for O(1) magnitude.
+
+    This utility is used to format numbers for human consumption.
+    """
+    time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(torch.tensor(t).log10().item() // 3), "s")
+    time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit]
+    return time_unit, time_scale
+
+
+def unit_to_english(u: str) -> str:
+    return {
+        "ns": "nanosecond",
+        "us": "microsecond",
+        "ms": "millisecond",
+        "s": "second",
+    }[u]
+
+
+def trim_sigfig(x: float, n: int) -> float:
+    """Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)"""
+    assert n == int(n)
+    magnitude = int(torch.tensor(x).abs().log10().ceil().item())
+    scale = 10 ** (magnitude - n)
+    return float(torch.tensor(x / scale).round() * scale)
+
+
+def ordered_unique(elements: Iterable[Any]) -> list[Any]:
+    return list(collections.OrderedDict(dict.fromkeys(elements)).keys())
+
+
+@contextlib.contextmanager
+def set_torch_threads(n: int) -> Iterator[None]:
+    prior_num_threads = torch.get_num_threads()
+    try:
+        torch.set_num_threads(n)
+        yield
+    finally:
+        torch.set_num_threads(prior_num_threads)
+
+
+def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> str:
+    """Create a temporary directory. The caller is responsible for cleanup.
+
+    This function is conceptually similar to `tempfile.mkdtemp`, but with
+    the key additional feature that it will use shared memory if the
+    `BENCHMARK_USE_DEV_SHM` environment variable is set. This is an
+    implementation detail, but an important one for cases where many Callgrind
+    measurements are collected at once. (Such as when collecting
+    microbenchmarks.)
+
+    This is an internal utility, and is exported solely so that microbenchmarks
+    can reuse the util.
+    """
+    use_dev_shm: bool = (os.getenv("BENCHMARK_USE_DEV_SHM") or "").lower() in ("1", "true")
+    if use_dev_shm:
+        root = "/dev/shm/pytorch_benchmark_utils"
+        assert os.name == "posix", f"tmpfs (/dev/shm) is POSIX only, current platform is {os.name}"
+        assert os.path.exists("/dev/shm"), "This system does not appear to support tmpfs (/dev/shm)."
+        os.makedirs(root, exist_ok=True)
+
+        # Because we're working in shared memory, it is more important than
+        # usual to clean up ALL intermediate files. However we don't want every
+        # worker to walk over all outstanding directories, so instead we only
+        # check when we are sure that it won't lead to contention.
+        if gc_dev_shm:
+            for i in os.listdir(root):
+                owner_file = os.path.join(root, i, "owner.pid")
+                if not os.path.exists(owner_file):
+                    continue
+
+                with open(owner_file) as f:
+                    owner_pid = int(f.read())
+
+                if owner_pid == os.getpid():
+                    continue
+
+                try:
+                    # https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python
+                    os.kill(owner_pid, 0)
+
+                except OSError:
+                    print(f"Detected that {os.path.join(root, i)} was orphaned in shared memory. Cleaning up.")
+                    shutil.rmtree(os.path.join(root, i))
+
+    else:
+        root = tempfile.gettempdir()
+
+    # We include the time so names sort by creation time, and add a UUID
+    # to ensure we don't collide.
+    name = f"{prefix or tempfile.gettempprefix()}__{int(time.time())}__{uuid.uuid4()}"
+    path = os.path.join(root, name)
+    os.makedirs(path, exist_ok=False)
+
+    if use_dev_shm:
+        with open(os.path.join(path, "owner.pid"), "w") as f:
+            f.write(str(os.getpid()))
+
+    return path
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py
new file mode 100644
index 0000000000000000000000000000000000000000..592ec66b4c04a2e7133ffe64ba7eedb23a0d5aed
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compare.py
@@ -0,0 +1,345 @@
+# mypy: allow-untyped-defs
+"""Display class to aggregate and print the results of many measurements."""
+import collections
+import enum
+import itertools as it
+from typing import Optional
+
+from torch.utils.benchmark.utils import common
+from torch import tensor as _tensor
+import operator
+
+__all__ = ["Colorize", "Compare"]
+
+BEST = "\033[92m"
+GOOD = "\033[34m"
+BAD = "\033[2m\033[91m"
+VERY_BAD = "\033[31m"
+BOLD = "\033[1m"
+TERMINATE = "\033[0m"
+
+
+class Colorize(enum.Enum):
+    NONE = "none"
+    COLUMNWISE = "columnwise"
+    ROWWISE = "rowwise"
+
+
+# Classes to separate internal bookkeeping from what is rendered.
+class _Column:
+    def __init__(
+        self,
+        grouped_results: list[tuple[Optional[common.Measurement], ...]],
+        time_scale: float,
+        time_unit: str,
+        trim_significant_figures: bool,
+        highlight_warnings: bool,
+    ):
+        self._grouped_results = grouped_results
+        self._flat_results = [*it.chain.from_iterable(grouped_results)]
+        self._time_scale = time_scale
+        self._time_unit = time_unit
+        self._trim_significant_figures = trim_significant_figures
+        self._highlight_warnings = (
+            highlight_warnings
+            and any(r.has_warnings for r in self._flat_results if r)
+        )
+        leading_digits = [
+            int(_tensor(r.median / self._time_scale).log10().ceil()) if r else None
+            for r in self._flat_results
+        ]
+        unit_digits = max(d for d in leading_digits if d is not None)
+        decimal_digits = min(
+            max(m.significant_figures - digits, 0)
+            for digits, m in zip(leading_digits, self._flat_results)
+            if (m is not None) and (digits is not None)
+        ) if self._trim_significant_figures else 1
+        length = unit_digits + decimal_digits + (1 if decimal_digits else 0)
+        self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}"
+
+    def get_results_for(self, group):
+        return self._grouped_results[group]
+
+    def num_to_str(self, value: Optional[float], estimated_sigfigs: int, spread: Optional[float]):
+        if value is None:
+            return " " * len(self.num_to_str(1, estimated_sigfigs, None))
+
+        if self._trim_significant_figures:
+            value = common.trim_sigfig(value, estimated_sigfigs)
+
+        return self._template.format(
+            value,
+            f" (! {spread * 100:.0f}%)" if self._highlight_warnings and spread is not None else "")
+
+
+def optional_min(seq):
+    l = list(seq)
+    return None if len(l) == 0 else min(l)
+
+
+class _Row:
+    def __init__(self, results, row_group, render_env, env_str_len,
+                 row_name_str_len, time_scale, colorize, num_threads=None):
+        super().__init__()
+        self._results = results
+        self._row_group = row_group
+        self._render_env = render_env
+        self._env_str_len = env_str_len
+        self._row_name_str_len = row_name_str_len
+        self._time_scale = time_scale
+        self._colorize = colorize
+        self._columns: tuple[_Column, ...] = ()
+        self._num_threads = num_threads
+
+    def register_columns(self, columns: tuple[_Column, ...]):
+        self._columns = columns
+
+    def as_column_strings(self):
+        concrete_results = [r for r in self._results if r is not None]
+        env = f"({concrete_results[0].env})" if self._render_env else ""
+        env = env.ljust(self._env_str_len + 4)
+        output = ["  " + env + concrete_results[0].as_row_name]
+        for m, col in zip(self._results, self._columns or ()):
+            if m is None:
+                output.append(col.num_to_str(None, 1, None))
+            else:
+                output.append(col.num_to_str(
+                    m.median / self._time_scale,
+                    m.significant_figures,
+                    m.iqr / m.median if m.has_warnings else None
+                ))
+        return output
+
+    @staticmethod
+    def color_segment(segment, value, best_value):
+        if value <= best_value * 1.01 or value <= best_value + 100e-9:
+            return BEST + BOLD + segment + TERMINATE * 2
+        if value <= best_value * 1.1:
+            return GOOD + BOLD + segment + TERMINATE * 2
+        if value >= best_value * 5:
+            return VERY_BAD + BOLD + segment + TERMINATE * 2
+        if value >= best_value * 2:
+            return BAD + segment + TERMINATE * 2
+
+        return segment
+
+    def row_separator(self, overall_width):
+        return (
+            [f"{self._num_threads} threads: ".ljust(overall_width, "-")]
+            if self._num_threads is not None else []
+        )
+
+    def finalize_column_strings(self, column_strings, col_widths):
+        best_values = [-1 for _ in column_strings]
+        if self._colorize == Colorize.ROWWISE:
+            row_min = min(r.median for r in self._results if r is not None)
+            best_values = [row_min for _ in column_strings]
+        elif self._colorize == Colorize.COLUMNWISE:
+            best_values = [
+                optional_min(r.median for r in column.get_results_for(self._row_group) if r is not None)
+                for column in (self._columns or ())
+            ]
+
+        row_contents = [column_strings[0].ljust(col_widths[0])]
+        for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values):
+            col_str = col_str.center(width)
+            if self._colorize != Colorize.NONE and result is not None and best_value is not None:
+                col_str = self.color_segment(col_str, result.median, best_value)
+            row_contents.append(col_str)
+        return row_contents
+
+
+class Table:
+    def __init__(
+            self,
+            results: list[common.Measurement],
+            colorize: Colorize,
+            trim_significant_figures: bool,
+            highlight_warnings: bool
+    ):
+        assert len({r.label for r in results}) == 1
+
+        self.results = results
+        self._colorize = colorize
+        self._trim_significant_figures = trim_significant_figures
+        self._highlight_warnings = highlight_warnings
+        self.label = results[0].label
+        self.time_unit, self.time_scale = common.select_unit(
+            min(r.median for r in results)
+        )
+
+        self.row_keys = common.ordered_unique([self.row_fn(i) for i in results])
+        self.row_keys.sort(key=operator.itemgetter(slice(2)))  # preserve stmt order
+        self.column_keys = common.ordered_unique([self.col_fn(i) for i in results])
+        self.rows, self.columns = self.populate_rows_and_columns()
+
+    @staticmethod
+    def row_fn(m: common.Measurement) -> tuple[int, Optional[str], str]:
+        return m.num_threads, m.env, m.as_row_name
+
+    @staticmethod
+    def col_fn(m: common.Measurement) -> Optional[str]:
+        return m.description
+
+    def populate_rows_and_columns(self) -> tuple[tuple[_Row, ...], tuple[_Column, ...]]:
+        rows: list[_Row] = []
+        columns: list[_Column] = []
+        ordered_results: list[list[Optional[common.Measurement]]] = [
+            [None for _ in self.column_keys]
+            for _ in self.row_keys
+        ]
+        row_position = {key: i for i, key in enumerate(self.row_keys)}
+        col_position = {key: i for i, key in enumerate(self.column_keys)}
+        for r in self.results:
+            i = row_position[self.row_fn(r)]
+            j = col_position[self.col_fn(r)]
+            ordered_results[i][j] = r
+
+        unique_envs = {r.env for r in self.results}
+        render_env = len(unique_envs) > 1
+        env_str_len = max(len(i) for i in unique_envs) if render_env else 0
+
+        row_name_str_len = max(len(r.as_row_name) for r in self.results)
+
+        prior_num_threads = -1
+        prior_env = ""
+        row_group = -1
+        rows_by_group: list[list[list[Optional[common.Measurement]]]] = []
+        for (num_threads, env, _), row in zip(self.row_keys, ordered_results):
+            thread_transition = (num_threads != prior_num_threads)
+            if thread_transition:
+                prior_num_threads = num_threads
+                prior_env = ""
+                row_group += 1
+                rows_by_group.append([])
+            rows.append(
+                _Row(
+                    results=row,
+                    row_group=row_group,
+                    render_env=(render_env and env != prior_env),
+                    env_str_len=env_str_len,
+                    row_name_str_len=row_name_str_len,
+                    time_scale=self.time_scale,
+                    colorize=self._colorize,
+                    num_threads=num_threads if thread_transition else None,
+                )
+            )
+            rows_by_group[-1].append(row)
+            prior_env = env
+
+        for i in range(len(self.column_keys)):
+            grouped_results = [tuple(row[i] for row in g) for g in rows_by_group]
+            column = _Column(
+                grouped_results=grouped_results,
+                time_scale=self.time_scale,
+                time_unit=self.time_unit,
+                trim_significant_figures=self._trim_significant_figures,
+                highlight_warnings=self._highlight_warnings,)
+            columns.append(column)
+
+        rows_tuple, columns_tuple = tuple(rows), tuple(columns)
+        for ri in rows_tuple:
+            ri.register_columns(columns_tuple)
+        return rows_tuple, columns_tuple
+
+    def render(self) -> str:
+        string_rows = [[""] + self.column_keys]
+        string_rows.extend(r.as_column_strings() for r in self.rows)
+        num_cols = max(len(i) for i in string_rows)
+        for sr in string_rows:
+            sr.extend(["" for _ in range(num_cols - len(sr))])
+
+        col_widths = [max(len(j) for j in i) for i in zip(*string_rows)]
+        finalized_columns = ["  |  ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))]
+        overall_width = len(finalized_columns[0])
+        for string_row, row in zip(string_rows[1:], self.rows):
+            finalized_columns.extend(row.row_separator(overall_width))
+            finalized_columns.append("  |  ".join(row.finalize_column_strings(string_row, col_widths)))
+
+        newline = "\n"
+        has_warnings = self._highlight_warnings and any(ri.has_warnings for ri in self.results)
+        return f"""
+[{(' ' + (self.label or '') + ' ').center(overall_width - 2, '-')}]
+{newline.join(finalized_columns)}
+
+Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).
+{'(! XX%) Measurement has high variance, where XX is the IQR / median * 100.' + newline if has_warnings else ""}"""[1:]
+
+
+class Compare:
+    """Helper class for displaying the results of many measurements in a
+    formatted table.
+
+    The table format is based on the information fields provided in
+    :class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`,
+    `num_threads`, etc).
+
+    The table can be directly printed using :meth:`print` or casted as a `str`.
+
+    For a full tutorial on how to use this class, see:
+    https://pytorch.org/tutorials/recipes/recipes/benchmark.html
+
+    Args:
+        results: List of Measurment to display.
+    """
+    def __init__(self, results: list[common.Measurement]):
+        self._results: list[common.Measurement] = []
+        self.extend_results(results)
+        self._trim_significant_figures = False
+        self._colorize = Colorize.NONE
+        self._highlight_warnings = False
+
+    def __str__(self):
+        return "\n".join(self._render())
+
+    def extend_results(self, results):
+        """Append results to already stored ones.
+
+        All added results must be instances of ``Measurement``.
+        """
+        for r in results:
+            if not isinstance(r, common.Measurement):
+                raise ValueError(
+                    "Expected an instance of `Measurement`, " f"got {type(r)} instead."
+                )
+        self._results.extend(results)
+
+    def trim_significant_figures(self):
+        """Enables trimming of significant figures when building the formatted table."""
+        self._trim_significant_figures = True
+
+    def colorize(self, rowwise=False):
+        """Colorize formatted table.
+
+        Colorize columnwise by default.
+        """
+        self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE
+
+    def highlight_warnings(self):
+        """Enables warning highlighting when building formatted table."""
+        self._highlight_warnings = True
+
+    def print(self):
+        """Print formatted table"""
+        print(str(self))
+
+    def _render(self):
+        results = common.Measurement.merge(self._results)
+        grouped_results = self._group_by_label(results)
+        output = [self._layout(group) for group in grouped_results.values()]
+        return output
+
+    def _group_by_label(self, results: list[common.Measurement]):
+        grouped_results: collections.defaultdict[str, list[common.Measurement]] = collections.defaultdict(list)
+        for r in results:
+            grouped_results[r.label].append(r)
+        return grouped_results
+
+    def _layout(self, results: list[common.Measurement]):
+        table = Table(
+            results,
+            self._colorize,
+            self._trim_significant_figures,
+            self._highlight_warnings
+        )
+        return table.render()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py
new file mode 100644
index 0000000000000000000000000000000000000000..23b74f946c000763e1e76b23631dc349448ba71b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/compile.py
@@ -0,0 +1,191 @@
+# mypy: allow-untyped-defs
+from typing import Any, Callable, cast, Optional, Union
+
+import torch
+import torch._dynamo
+from torch._dynamo.testing import CompileCounterWithBackend
+from torch.utils.benchmark import Timer
+
+
+__all__ = ["bench_all", "benchmark_compile"]
+
+
+_warned_tensor_cores = False
+_default_float_32_precision = torch.get_float32_matmul_precision()
+
+try:
+    from tabulate import tabulate
+
+    HAS_TABULATE = True
+except ModuleNotFoundError:
+    HAS_TABULATE = False
+    tabulate = None  # type: ignore[assignment]
+    print("tabulate is not installed, please pip install tabulate to use this utility")
+
+if HAS_TABULATE:
+    def _enable_tensor_cores():
+        global _warned_tensor_cores
+
+        if torch.cuda.is_available():
+            if torch.backends.cuda.matmul.allow_tf32 is False and torch.cuda.get_device_capability() >= (8, 0):
+                torch.set_float32_matmul_precision("high")
+                if not _warned_tensor_cores:
+                    print("Your GPU supports tensor cores")
+                    print("we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`")
+                    _warned_tensor_cores = True
+
+    def _disable_tensor_cores():
+        torch.set_float32_matmul_precision(_default_float_32_precision)
+
+    def bench_loop(
+        model: Union[torch.nn.Module, Callable],
+        sample_input: Union[torch.Tensor, Any],
+        num_iters: int = 5,
+        optimizer: Optional[torch.optim.Optimizer] = None,
+        loss_fn: Optional[Callable] = None,
+    ):
+        # Define the statement and setup for the benchmark
+        if optimizer and loss_fn:
+            # Training mode
+            stmt = """
+    output = model(sample_input)
+    loss = loss_fn(output) if loss_fn else output.sum()
+    loss.backward()
+    optimizer.step()
+    optimizer.zero_grad()
+            """
+        else:
+            # Inference mode
+            stmt = "model(sample_input)"
+
+        # Create the Timer object
+        timer = Timer(
+            stmt=stmt,
+            globals={"model": model, "sample_input": sample_input, "optimizer": optimizer, "loss_fn": loss_fn},
+        )
+
+
+        result = timer.timeit(number=num_iters)
+
+        # Get the average time per iteration in milliseconds
+        avg_time = result.mean * 1000
+        return round(avg_time, 2)
+
+    def benchmark_compile(
+        model: Union[torch.nn.Module, Callable],
+        sample_input: Union[torch.Tensor, Any],
+        num_iters: int = 5,
+        backend: Optional[str] = None,
+        mode: Optional[str] = "default",
+        optimizer: Optional[torch.optim.Optimizer] = None,
+        loss_fn : Union[torch.nn.Module, Callable, None] = None,
+    ):
+        """
+        Use this utility to benchmark torch.compile
+        """
+        if backend:
+            try:
+                torch._dynamo.reset()
+                compile_counter_with_backend = CompileCounterWithBackend(backend)
+                opt_model = torch.compile(model, backend=compile_counter_with_backend, mode=mode)
+
+                # Compilation only happens after the first inference
+                compilation_time = bench_loop(opt_model, sample_input, 1, optimizer, loss_fn)
+
+                running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn)
+
+                if compile_counter_with_backend.frame_count == 0:
+                    raise RuntimeError("No compilation occurred during benchmarking.")
+
+                if compile_counter_with_backend.frame_count > 1:
+                    raise RuntimeError("Recompilation occurred during benchmarking.")
+
+            except Exception as e:
+                print(e)
+                print(f"Failed to compile {backend} with mode {mode}")
+                return None, None
+        else:
+            opt_model = model
+            compilation_time = None
+            running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn)
+
+        compilation_time = round(compilation_time, 2) if compilation_time else None
+        running_time = round(running_time, 2) if running_time else None
+
+
+        return compilation_time, running_time
+
+
+    def bench_all(
+        model : Union[torch.nn.Module, Callable],
+        sample_input: Union[torch.Tensor, Any],
+        num_iters : int = 5,
+        optimizer: Optional[torch.optim.Optimizer] = None,
+        loss_fn : Union[torch.nn.Module, Callable, None] = None,
+    ):
+        """
+        This is a simple utility that can be used to benchmark torch.compile
+        In particular it ensures that your GPU is setup to use tensor cores if it supports its
+        It also tries out all the main backends and prints a table of results so you can easily compare them all
+        Many of the backendds have their own optional dependencies so please pip install them seperately
+
+        You will get one table for inference and another for training
+        If you'd like to leverage this utility for training make sure to pass in a torch.optim.Optimizer
+
+        The important warnings are
+        Your GPU supports tensor cores
+        we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`
+
+        If a compilation fails for any reason including the dependency not being included
+        then we will print Failed to compile {backend} with mode {mode}
+        """
+        field_names = ["Train/Inference", "Backend", "Mode", "Compilation Time", "Average Running Time"]
+        table = []
+
+
+        eager_time = None
+        torch._dynamo.reset()
+        _, eager_time = benchmark_compile(model, sample_input, num_iters, None, None, optimizer)
+        table.append(
+            [("Training" if optimizer else "Inference"), "Eager", "-", "-", f"{eager_time} ms"]
+        )
+
+        for backend in torch._dynamo.list_backends():
+
+            if backend == "inductor":
+                mode_options = cast(list[Optional[str]], list(torch._inductor.list_mode_options().keys())) + [None]
+                for mode in mode_options:
+                    if mode == "default":
+                        continue
+                    torch._dynamo.reset()
+                    try:
+                        if torch.cuda.is_available():
+                            _enable_tensor_cores()
+                        compilation_time, running_time = benchmark_compile(
+                            model, sample_input, num_iters, backend, mode, optimizer, loss_fn)
+                    finally:
+                        if torch.cuda.is_available():
+                            _disable_tensor_cores()
+                            table.append([
+                                ("Training" if optimizer else "Inference"),
+                                backend if backend else "-",
+                                mode if mode is not None else "-",
+                                f"{compilation_time} ms " if compilation_time else "-",
+                                f"{running_time} ms " if running_time else "-",
+                            ])
+
+            else:
+                torch._dynamo.reset()
+                compilation_time, running_time = benchmark_compile(
+                    model, sample_input, num_iters, backend, None, optimizer, loss_fn)
+
+                if running_time is not None:
+                    table.append([
+                        ("Training" if optimizer else "Inference"),
+                        backend, "-",
+                        f"{compilation_time} ms " or "-",
+                        f"{running_time} ms ",
+                    ])
+
+
+        return tabulate(table, headers=field_names, tablefmt="github")
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7aec25f6a76ef1dfa55a991a3edda42a11a34cf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/cpp_jit.py
@@ -0,0 +1,172 @@
+"""JIT C++ strings into executables."""
+import atexit
+import os
+import re
+import shutil
+import textwrap
+import threading
+from typing import Any, Optional
+
+import torch
+from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType
+from torch.utils.benchmark.utils.common import _make_temp_dir
+from torch.utils import cpp_extension
+
+
+LOCK = threading.Lock()
+SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0]
+
+# We calculate uuid once at import time so that separate processes will have
+# separate build roots, but threads will share the same build root.
+# `cpp_extension` uses build root as part of the cache key, so per-invocation
+# uuid's (e.g. different build root per _compile_template call) would lead to
+# a 0% cache hit rate and spurious recompilation. Consider the following:
+#   ```
+#   setup = "auto x = torch::ones({1024, 1024});"
+#   stmt = "torch::mm(x, x);"
+#   for num_threads in [1, 2, 4, 8]:
+#     print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange())
+#   ````
+# `setup` and `stmt` do not change, so we can reuse the executable from the
+# first pass through the loop.
+_BUILD_ROOT: Optional[str] = None
+
+def _get_build_root() -> str:
+    global _BUILD_ROOT
+    if _BUILD_ROOT is None:
+        _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
+        atexit.register(shutil.rmtree, _BUILD_ROOT)
+    return _BUILD_ROOT
+
+
+# BACK_TESTING_NOTE:
+#   There are two workflows where this code could be used. One is the obvious
+#   case where someone simply builds or installs PyTorch and uses Timer.
+#   The other is that the entire `torch/utils/benchmark` folder from a CURRENT
+#   PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch
+#   source code. This is what we refer to here as "back testing". The rationale
+#   is that we might want to use current tooling to study some aspect of an
+#   earlier version of PyTorch. (e.g. a regression.)
+#
+#   The problem is that Timer relies on several aspects of core PyTorch, namely
+#   some binding functions for Valgrind symbols in `torch._C` and the
+#   `torch.__config__._cxx_flags()` method. If we were to naively copy code
+#   around this wouldn't work as the symbols of interest aren't present in
+#   earlier versions of PyTorch. In order to work around this, we must add back
+#   testing shims. These shims will never activate during normal use, but will
+#   allow Timer to function outside of the "correct" version of PyTorch by
+#   emulating functionality that was added later.
+#
+#   These shims are temporary, and as Timer becomes more integrated with
+#   PyTorch the cost and complexity of such shims will increase. Once back
+#   testing is no longer required (which is to say we have done enough historic
+#   analysis and the shims no longer justify their maintenance and code
+#   complexity costs) back testing paths will be removed.
+
+CXX_FLAGS: Optional[list[str]]
+if hasattr(torch.__config__, "_cxx_flags"):
+    try:
+        CXX_FLAGS = torch.__config__._cxx_flags().strip().split()
+        if CXX_FLAGS is not None and "-g" not in CXX_FLAGS:
+            CXX_FLAGS.append("-g")
+        # remove "-W" flags to allow build benchmarks
+        # with a relaxed constraint of compiler versions
+        if CXX_FLAGS is not None:
+            CXX_FLAGS = list(filter(lambda x: not x.startswith("-W"), CXX_FLAGS))
+
+    except RuntimeError:
+        # We are in FBCode.
+        CXX_FLAGS = None
+else:
+    # FIXME: Remove when back testing is no longer required.
+    CXX_FLAGS = ["-O2", "-fPIC", "-g"]
+
+EXTRA_INCLUDE_PATHS: list[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")]
+CONDA_PREFIX = os.getenv("CONDA_PREFIX")
+if CONDA_PREFIX is not None:
+    # Load will automatically search /usr/include, but not conda include.
+    EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include"))
+
+
+COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None
+def get_compat_bindings() -> CallgrindModuleType:
+    with LOCK:
+        global COMPAT_CALLGRIND_BINDINGS
+        if COMPAT_CALLGRIND_BINDINGS is None:
+            COMPAT_CALLGRIND_BINDINGS = cpp_extension.load(
+                name="callgrind_bindings",
+                sources=[os.path.join(
+                    SOURCE_ROOT,
+                    "valgrind_wrapper",
+                    "compat_bindings.cpp"
+                )],
+                extra_cflags=CXX_FLAGS,
+                extra_include_paths=EXTRA_INCLUDE_PATHS,
+            )
+    return COMPAT_CALLGRIND_BINDINGS
+
+
+def _compile_template(
+    *,
+    stmt: str,
+    setup: str,
+    global_setup: str,
+    src: str,
+    is_standalone: bool
+) -> Any:
+    for before, after, indentation in (
+        ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0),
+        ("// SETUP_TEMPLATE_LOCATION", setup, 4),
+        ("// STMT_TEMPLATE_LOCATION", stmt, 8)
+    ):
+        # C++ doesn't care about indentation so this code isn't load
+        # bearing the way it is with Python, but this makes the source
+        # look nicer if a human has to look at it.
+        src = re.sub(
+            before,
+            textwrap.indent(after, " " * indentation)[indentation:],
+            src
+        )
+
+    # We want to isolate different Timers. However `cpp_extension` will
+    # cache builds which will significantly reduce the cost of repeated
+    # invocations.
+    with LOCK:
+        name = f"timer_cpp_{abs(hash(src))}"
+        build_dir = os.path.join(_get_build_root(), name)
+        os.makedirs(build_dir, exist_ok=True)
+
+        src_path = os.path.join(build_dir, "timer_src.cpp")
+        with open(src_path, "w") as f:
+            f.write(src)
+
+    # `cpp_extension` has its own locking scheme, so we don't need our lock.
+    return cpp_extension.load(
+        name=name,
+        sources=[src_path],
+        build_directory=build_dir,
+        extra_cflags=CXX_FLAGS,
+        extra_include_paths=EXTRA_INCLUDE_PATHS,
+        is_python_module=not is_standalone,
+        is_standalone=is_standalone,
+    )
+
+
+def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType:
+    template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
+    with open(template_path) as f:
+        src: str = f.read()
+
+    module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False)
+    assert isinstance(module, TimeitModuleType)
+    return module
+
+
+def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str:
+    template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp")
+    with open(template_path) as f:
+        src: str = f.read()
+
+    target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True)
+    assert isinstance(target, str)
+    return target
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fd52a7aecd39fe6e1f98ca3534fe352ec387dc3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/fuzzer.py
@@ -0,0 +1,462 @@
+# mypy: allow-untyped-defs
+import functools
+import itertools as it
+from typing import Any, Callable, Optional, Union
+
+import torch
+
+
+__all__ = [
+    "Fuzzer",
+    "FuzzedParameter", "ParameterAlias",
+    "FuzzedTensor",
+]
+
+
+_DISTRIBUTIONS = (
+    "loguniform",
+    "uniform",
+)
+
+
+class FuzzedParameter:
+    """Specification for a parameter to be generated during fuzzing."""
+    def __init__(
+        self,
+        name: str,
+        minval: Optional[Union[int, float]] = None,
+        maxval: Optional[Union[int, float]] = None,
+        distribution: Optional[Union[str, dict[Any, float]]] = None,
+        strict: bool = False,
+    ):
+        """
+        Args:
+            name:
+                A string name with which to identify the parameter.
+                FuzzedTensors can reference this string in their
+                specifications.
+            minval:
+                The lower bound for the generated value. See the description
+                of `distribution` for type behavior.
+            maxval:
+                The upper bound for the generated value. Type behavior is
+                identical to `minval`.
+            distribution:
+                Specifies the distribution from which this parameter should
+                be drawn. There are three possibilities:
+                    - "loguniform"
+                        Samples between `minval` and `maxval` (inclusive) such
+                        that the probabilities are uniform in log space. As a
+                        concrete example, if minval=1 and maxval=100, a sample
+                        is as likely to fall in [1, 10) as it is [10, 100].
+                    - "uniform"
+                        Samples are chosen with uniform probability between
+                        `minval` and `maxval` (inclusive). If either `minval`
+                        or `maxval` is a float then the distribution is the
+                        continuous uniform distribution; otherwise samples
+                        are constrained to the integers.
+                    - dict:
+                        If a dict is passed, the keys are taken to be choices
+                        for the variables and the values are interpreted as
+                        probabilities. (And must sum to one.)
+                If a dict is passed, `minval` and `maxval` must not be set.
+                Otherwise, they must be set.
+            strict:
+                If a parameter is strict, it will not be included in the
+                iterative resampling process which Fuzzer uses to find a
+                valid parameter configuration. This allows an author to
+                prevent skew from resampling for a given parameter (for
+                instance, a low size limit could inadvertently bias towards
+                Tensors with fewer dimensions) at the cost of more iterations
+                when generating parameters.
+        """
+        self._name = name
+        self._minval = minval
+        self._maxval = maxval
+        self._distribution = self._check_distribution(distribution)
+        self.strict = strict
+
+    @property
+    def name(self):
+        return self._name
+
+    def sample(self, state):
+        if self._distribution == "loguniform":
+            return self._loguniform(state)
+
+        if self._distribution == "uniform":
+            return self._uniform(state)
+
+        if isinstance(self._distribution, dict):
+            return self._custom_distribution(state)
+
+    def _check_distribution(self, distribution):
+        if not isinstance(distribution, dict):
+            assert distribution in _DISTRIBUTIONS
+        else:
+            assert not any(i < 0 for i in distribution.values()), "Probabilities cannot be negative"
+            assert abs(sum(distribution.values()) - 1) <= 1e-5, "Distribution is not normalized"
+            assert self._minval is None
+            assert self._maxval is None
+
+        return distribution
+
+    def _loguniform(self, state):
+        import numpy as np
+        output = int(2 ** state.uniform(
+            low=np.log2(self._minval) if self._minval is not None else None,
+            high=np.log2(self._maxval) if self._maxval is not None else None,
+        ))
+        if self._minval is not None and output < self._minval:
+            return self._minval
+        if self._maxval is not None and output > self._maxval:
+            return self._maxval
+        return output
+
+    def _uniform(self, state):
+        if isinstance(self._minval, int) and isinstance(self._maxval, int):
+            return int(state.randint(low=self._minval, high=self._maxval + 1))
+        return state.uniform(low=self._minval, high=self._maxval)
+
+    def _custom_distribution(self, state):
+        import numpy as np
+        # If we directly pass the keys to `choice`, numpy will convert
+        # them to numpy dtypes.
+        index = state.choice(
+            np.arange(len(self._distribution)),
+            p=tuple(self._distribution.values()))
+        return list(self._distribution.keys())[index]
+
+
+class ParameterAlias:
+    """Indicates that a parameter should alias the value of another parameter.
+
+    When used in conjunction with a custom distribution, this allows fuzzed
+    tensors to represent a broader range of behaviors. For example, the
+    following sometimes produces Tensors which broadcast:
+
+    Fuzzer(
+        parameters=[
+            FuzzedParameter("x_len", 4, 1024, distribution="uniform"),
+
+            # `y` will either be size one, or match the size of `x`.
+            FuzzedParameter("y_len", distribution={
+                0.5: 1,
+                0.5: ParameterAlias("x_len")
+            }),
+        ],
+        tensors=[
+            FuzzedTensor("x", size=("x_len",)),
+            FuzzedTensor("y", size=("y_len",)),
+        ],
+    )
+
+    Chains of alias' are allowed, but may not contain cycles.
+    """
+    def __init__(self, alias_to):
+        self.alias_to = alias_to
+
+    def __repr__(self):
+        return f"ParameterAlias[alias_to: {self.alias_to}]"
+
+
+def dtype_size(dtype):
+    if dtype == torch.bool:
+        return 1
+    if dtype.is_floating_point or dtype.is_complex:
+        return int(torch.finfo(dtype).bits / 8)
+    return int(torch.iinfo(dtype).bits / 8)
+
+
+def prod(values, base=1):
+    """np.prod can overflow, so for sizes the product should be done in Python.
+
+    Even though np.prod type promotes to int64, it can still overflow in which
+    case the negative value will pass the size check and OOM when attempting to
+    actually allocate the Tensor.
+    """
+    return functools.reduce(lambda x, y: int(x) * int(y), values, base)
+
+
+class FuzzedTensor:
+    def __init__(
+        self,
+        name: str,
+        size: tuple[Union[str, int], ...],
+        steps: Optional[tuple[Union[str, int], ...]] = None,
+        probability_contiguous: float = 0.5,
+        min_elements: Optional[int] = None,
+        max_elements: Optional[int] = None,
+        max_allocation_bytes: Optional[int] = None,
+        dim_parameter: Optional[str] = None,
+        roll_parameter: Optional[str] = None,
+        dtype=torch.float32,
+        cuda=False,
+        tensor_constructor: Optional[Callable] = None
+    ):
+        """
+        Args:
+            name:
+                A string identifier for the generated Tensor.
+            size:
+                A tuple of integers or strings specifying the size of the generated
+                Tensor. String values will replaced with a concrete int during the
+                generation process, while ints are simply passed as literals.
+            steps:
+                An optional tuple with the same length as `size`. This indicates
+                that a larger Tensor should be allocated, and then sliced to
+                produce the generated Tensor. For instance, if size is (4, 8)
+                and steps is (1, 4), then a tensor `t` of size (4, 32) will be
+                created and then `t[:, ::4]` will be used. (Allowing one to test
+                Tensors with strided memory.)
+            probability_contiguous:
+                A number between zero and one representing the chance that the
+                generated Tensor has a contiguous memory layout. This is achieved by
+                randomly permuting the shape of a Tensor, calling `.contiguous()`,
+                and then permuting back. This is applied before `steps`, which can
+                also cause a Tensor to be non-contiguous.
+            min_elements:
+                The minimum number of parameters that this Tensor must have for a
+                set of parameters to be valid. (Otherwise they are resampled.)
+            max_elements:
+                Like `min_elements`, but setting an upper bound.
+            max_allocation_bytes:
+                Like `max_elements`, but for the size of Tensor that must be
+                allocated prior to slicing for `steps` (if applicable). For
+                example, a FloatTensor with size (1024, 1024) and steps (4, 4)
+                would have 1M elements, but would require a 64 MB allocation.
+            dim_parameter:
+                The length of `size` and `steps` will be truncated to this value.
+                This allows Tensors of varying dimensions to be generated by the
+                Fuzzer.
+            dtype:
+                The PyTorch dtype of the generated Tensor.
+            cuda:
+                Whether to place the Tensor on a GPU.
+            tensor_constructor:
+                Callable which will be used instead of the default Tensor
+                construction method. This allows the author to enforce properties
+                of the Tensor (e.g. it can only have certain values). The dtype and
+                concrete shape of the Tensor to be created will be passed, and
+                concrete values of all parameters will be passed as kwargs. Note
+                that transformations to the result (permuting, slicing) will be
+                performed by the Fuzzer; the tensor_constructor is only responsible
+                for creating an appropriately sized Tensor.
+        """
+        self._name = name
+        self._size = size
+        self._steps = steps
+        self._probability_contiguous = probability_contiguous
+        self._min_elements = min_elements
+        self._max_elements = max_elements
+        self._max_allocation_bytes = max_allocation_bytes
+        self._dim_parameter = dim_parameter
+        self._dtype = dtype
+        self._cuda = cuda
+        self._tensor_constructor = tensor_constructor
+
+    @property
+    def name(self):
+        return self._name
+
+    @staticmethod
+    def default_tensor_constructor(size, dtype, **kwargs):
+        if dtype.is_floating_point or dtype.is_complex:
+            return torch.rand(size=size, dtype=dtype, device="cpu")
+        else:
+            return torch.randint(1, 127, size=size, dtype=dtype, device="cpu")
+
+    def _make_tensor(self, params, state):
+        import numpy as np
+        size, steps, allocation_size = self._get_size_and_steps(params)
+        constructor = (
+            self._tensor_constructor or
+            self.default_tensor_constructor
+        )
+
+        raw_tensor = constructor(size=allocation_size, dtype=self._dtype, **params)
+        if self._cuda:
+            raw_tensor = raw_tensor.cuda()
+
+        # Randomly permute the Tensor and call `.contiguous()` to force re-ordering
+        # of the memory, and then permute it back to the original shape.
+        dim = len(size)
+        order = np.arange(dim)
+        if state.rand() > self._probability_contiguous:
+            while dim > 1 and np.all(order == np.arange(dim)):
+                order = state.permutation(raw_tensor.dim())
+
+            raw_tensor = raw_tensor.permute(tuple(order)).contiguous()
+            raw_tensor = raw_tensor.permute(tuple(np.argsort(order)))
+
+        slices = [slice(0, size * step, step) for size, step in zip(size, steps)]
+        tensor = raw_tensor[tuple(slices)]
+
+        properties = {
+            "numel": int(tensor.numel()),
+            "order": order,
+            "steps": steps,
+            "is_contiguous": tensor.is_contiguous(),
+            "dtype": str(self._dtype),
+        }
+
+        return tensor, properties
+
+    def _get_size_and_steps(self, params):
+        dim = (
+            params[self._dim_parameter]
+            if self._dim_parameter is not None
+            else len(self._size)
+        )
+
+        def resolve(values, dim):
+            """Resolve values into concrete integers."""
+            values = tuple(params.get(i, i) for i in values)
+            if len(values) > dim:
+                values = values[:dim]
+            if len(values) < dim:
+                values = values + tuple(1 for _ in range(dim - len(values)))
+            return values
+
+        size = resolve(self._size, dim)
+        steps = resolve(self._steps or (), dim)
+        allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps))
+        return size, steps, allocation_size
+
+    def satisfies_constraints(self, params):
+        size, _, allocation_size = self._get_size_and_steps(params)
+        # Product is computed in Python to avoid integer overflow.
+        num_elements = prod(size)
+        assert num_elements >= 0
+
+        allocation_bytes = prod(allocation_size, base=dtype_size(self._dtype))
+
+        def nullable_greater(left, right):
+            if left is None or right is None:
+                return False
+            return left > right
+
+        return not any((
+            nullable_greater(num_elements, self._max_elements),
+            nullable_greater(self._min_elements, num_elements),
+            nullable_greater(allocation_bytes, self._max_allocation_bytes),
+        ))
+
+
+class Fuzzer:
+    def __init__(
+        self,
+        parameters: list[Union[FuzzedParameter, list[FuzzedParameter]]],
+        tensors: list[Union[FuzzedTensor, list[FuzzedTensor]]],
+        constraints: Optional[list[Callable]] = None,
+        seed: Optional[int] = None
+    ):
+        """
+        Args:
+            parameters:
+                List of FuzzedParameters which provide specifications
+                for generated parameters. Iterable elements will be
+                unpacked, though arbitrary nested structures will not.
+            tensors:
+                List of FuzzedTensors which define the Tensors which
+                will be created each step based on the parameters for
+                that step. Iterable elements will be unpacked, though
+                arbitrary nested structures will not.
+            constraints:
+                List of callables. They will be called with params
+                as kwargs, and if any of them return False the current
+                set of parameters will be rejected.
+            seed:
+                Seed for the RandomState used by the Fuzzer. This will
+                also be used to set the PyTorch random seed so that random
+                ops will create reproducible Tensors.
+        """
+        import numpy as np
+        if seed is None:
+            seed = int(np.random.RandomState().randint(0, 2 ** 32 - 1, dtype=np.int64))
+        self._seed = seed
+        self._parameters = Fuzzer._unpack(parameters, FuzzedParameter)
+        self._tensors = Fuzzer._unpack(tensors, FuzzedTensor)
+        self._constraints = constraints or ()
+
+        p_names = {p.name for p in self._parameters}
+        t_names = {t.name for t in self._tensors}
+        name_overlap = p_names.intersection(t_names)
+        if name_overlap:
+            raise ValueError(f"Duplicate names in parameters and tensors: {name_overlap}")
+
+        self._rejections = 0
+        self._total_generated = 0
+
+    @staticmethod
+    def _unpack(values, cls):
+        return tuple(it.chain.from_iterable(
+            [[i] if isinstance(i, cls) else i for i in values]
+        ))
+
+    def take(self, n):
+        import numpy as np
+        state = np.random.RandomState(self._seed)
+        torch.manual_seed(state.randint(low=0, high=2 ** 63, dtype=np.int64))
+        for _ in range(n):
+            params = self._generate(state)
+            tensors = {}
+            tensor_properties = {}
+            for t in self._tensors:
+                tensor, properties = t._make_tensor(params, state)
+                tensors[t.name] = tensor
+                tensor_properties[t.name] = properties
+            yield tensors, tensor_properties, params
+
+    @property
+    def rejection_rate(self):
+        if not self._total_generated:
+            return 0.
+        return self._rejections / self._total_generated
+
+    def _generate(self, state):
+        strict_params: dict[str, Union[float, int, ParameterAlias]] = {}
+        for _ in range(1000):
+            candidate_params: dict[str, Union[float, int, ParameterAlias]] = {}
+            for p in self._parameters:
+                if p.strict:
+                    if p.name in strict_params:
+                        candidate_params[p.name] = strict_params[p.name]
+                    else:
+                        candidate_params[p.name] = p.sample(state)
+                        strict_params[p.name] = candidate_params[p.name]
+                else:
+                    candidate_params[p.name] = p.sample(state)
+
+            candidate_params = self._resolve_aliases(candidate_params)
+
+            self._total_generated += 1
+            if not all(f(candidate_params) for f in self._constraints):
+                self._rejections += 1
+                continue
+
+            if not all(t.satisfies_constraints(candidate_params) for t in self._tensors):
+                self._rejections += 1
+                continue
+
+            return candidate_params
+        raise ValueError("Failed to generate a set of valid parameters.")
+
+    @staticmethod
+    def _resolve_aliases(params):
+        params = dict(params)
+        alias_count = sum(isinstance(v, ParameterAlias) for v in params.values())
+
+        keys = list(params.keys())
+        while alias_count:
+            for k in keys:
+                v = params[k]
+                if isinstance(v, ParameterAlias):
+                    params[k] = params[v.alias_to]
+            alias_count_new = sum(isinstance(v, ParameterAlias) for v in params.values())
+            if alias_count == alias_count_new:
+                raise ValueError(f"ParameterAlias cycle detected\n{params}")
+
+            alias_count = alias_count_new
+
+        return params
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..498f94ca26f11118c620214534f3bc6e91ed2f5d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py
@@ -0,0 +1,121 @@
+# mypy: allow-untyped-defs
+from typing import Optional, Union
+from numbers import Number
+import torch
+from torch.utils.benchmark import FuzzedTensor
+import math
+
+class FuzzedSparseTensor(FuzzedTensor):
+    def __init__(
+        self,
+        name: str,
+        size: tuple[Union[str, int], ...],
+        min_elements: Optional[int] = None,
+        max_elements: Optional[int] = None,
+        dim_parameter: Optional[str] = None,
+        sparse_dim: Optional[str] = None,
+        nnz: Optional[str] = None,
+        density: Optional[str] = None,
+        coalesced: Optional[str] = None,
+        dtype=torch.float32,
+        cuda=False
+    ):
+        """
+        Args:
+            name:
+                A string identifier for the generated Tensor.
+            size:
+                A tuple of integers or strings specifying the size of the generated
+                Tensor. String values will replaced with a concrete int during the
+                generation process, while ints are simply passed as literals.
+            min_elements:
+                The minimum number of parameters that this Tensor must have for a
+                set of parameters to be valid. (Otherwise they are resampled.)
+            max_elements:
+                Like `min_elements`, but setting an upper bound.
+            dim_parameter:
+                The length of `size` will be truncated to this value.
+                This allows Tensors of varying dimensions to be generated by the
+                Fuzzer.
+            sparse_dim:
+                The number of sparse dimensions in a sparse tensor.
+            density:
+                This value allows tensors of varying sparsities to be generated by the Fuzzer.
+            coalesced:
+                The sparse tensor format permits uncoalesced sparse tensors,
+                where there may be duplicate coordinates in the indices.
+            dtype:
+                The PyTorch dtype of the generated Tensor.
+            cuda:
+                Whether to place the Tensor on a GPU.
+        """
+        super().__init__(name=name, size=size, min_elements=min_elements,
+                         max_elements=max_elements, dim_parameter=dim_parameter, dtype=dtype, cuda=cuda)
+        self._density = density
+        self._coalesced = coalesced
+        self._sparse_dim = sparse_dim
+
+    @staticmethod
+    def sparse_tensor_constructor(size, dtype, sparse_dim, nnz, is_coalesced):
+        """sparse_tensor_constructor creates a sparse tensor with coo format.
+
+        Note that when `is_coalesced` is False, the number of elements is doubled but the number of indices
+        represents the same amount of number of non zeros `nnz`, i.e, this is virtually the same tensor
+        with the same sparsity pattern. Moreover, most of the sparse operation will use coalesce() method
+        and what we want here is to get a sparse tensor with the same `nnz` even if this is coalesced or not.
+
+        In the other hand when `is_coalesced` is True the number of elements is reduced in the coalescing process
+        by an unclear amount however the probability to generate duplicates indices are low for most of the cases.
+        This decision was taken on purpose to maintain the construction cost as low as possible.
+        """
+        if isinstance(size, Number):
+            size = [size] * sparse_dim
+        assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
+        v_size = [nnz] + list(size[sparse_dim:])
+        if dtype.is_floating_point:
+            v = torch.rand(size=v_size, dtype=dtype, device="cpu")
+        else:
+            v = torch.randint(1, 127, size=v_size, dtype=dtype, device="cpu")
+
+        i = torch.rand(sparse_dim, nnz, device="cpu")
+        i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
+        i = i.to(torch.long)
+
+        if not is_coalesced:
+            v = torch.cat([v, torch.randn_like(v)], 0)
+            i = torch.cat([i, i], 1)
+
+        x = torch.sparse_coo_tensor(i, v, torch.Size(size))
+        if is_coalesced:
+            x = x.coalesce()
+        return x
+
+    def _make_tensor(self, params, state):
+        size, _, _ = self._get_size_and_steps(params)
+        density = params['density']
+        nnz = math.ceil(sum(size) * density)
+        assert nnz <= sum(size)
+
+        is_coalesced = params['coalesced']
+        sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size)
+        sparse_dim = min(sparse_dim, len(size))
+        tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced)
+
+        if self._cuda:
+            tensor = tensor.cuda()
+        sparse_dim = tensor.sparse_dim()
+        dense_dim = tensor.dense_dim()
+        is_hybrid = len(size[sparse_dim:]) > 0
+
+        properties = {
+            "numel": int(tensor.numel()),
+            "shape": tensor.size(),
+            "is_coalesced": tensor.is_coalesced(),
+            "density": density,
+            "sparsity": 1.0 - density,
+            "sparse_dim": sparse_dim,
+            "dense_dim": dense_dim,
+            "is_hybrid": is_hybrid,
+            "dtype": str(self._dtype),
+        }
+        return tensor, properties
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..30b6f79c0b5aebca676035ac0b7c08cfce639b23
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timeit_template.cpp
@@ -0,0 +1,43 @@
+/* C++ template for Timer.timeit
+
+This template will be consumed by `cpp_jit.py`, and will replace:
+    `GLOBAL_SETUP_TEMPLATE_LOCATION`,
+    `SETUP_TEMPLATE_LOCATION`
+      and
+    `STMT_TEMPLATE_LOCATION`
+sections with user provided statements.
+*/
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+// Global setup. (e.g. #includes)
+// GLOBAL_SETUP_TEMPLATE_LOCATION
+
+double timeit(int n) {
+  pybind11::gil_scoped_release no_gil;
+
+  // Setup
+  // SETUP_TEMPLATE_LOCATION
+
+  {
+    // Warmup
+    // STMT_TEMPLATE_LOCATION
+  }
+
+  // Main loop
+  auto start_time = std::chrono::high_resolution_clock::now();
+  for (const auto loop_idx : c10::irange(n)) {
+    (void)loop_idx;
+    // STMT_TEMPLATE_LOCATION
+  }
+  auto end_time = std::chrono::high_resolution_clock::now();
+  return std::chrono::duration(end_time - start_time).count();
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("timeit", &timeit);
+}
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..731ac21359a47a224bb944199bf53f81488dba88
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/timer.py
@@ -0,0 +1,541 @@
+"""Timer class based on the timeit.Timer class, but torch aware."""
+import enum
+import timeit
+import textwrap
+from typing import overload, Any, Callable, NoReturn, Optional, Union
+
+import torch
+from torch.utils.benchmark.utils import common, cpp_jit
+from torch.utils.benchmark.utils._stubs import TimerClass, TimeitModuleType
+from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valgrind_timer_interface
+
+
+__all__ = ["Timer", "timer", "Language"]
+
+
+if torch.backends.cuda.is_built() and torch.cuda.is_available():  # type: ignore[no-untyped-call]
+    def timer() -> float:
+        torch.cuda.synchronize()
+        return timeit.default_timer()
+elif torch.xpu.is_available():
+    def timer() -> float:
+        torch.xpu.synchronize()
+        return timeit.default_timer()
+elif torch._C._get_privateuse1_backend_name() != "privateuseone":
+    privateuse1_device_handler = getattr(torch, torch._C._get_privateuse1_backend_name(), None) \
+        if torch._C._get_privateuse1_backend_name() != "cpu" else None
+
+    def timer() -> float:
+        if privateuse1_device_handler:
+            privateuse1_device_handler.synchronize()
+        return timeit.default_timer()
+else:
+    timer = timeit.default_timer
+
+
+class Language(enum.Enum):
+    PYTHON = 0
+    CPP = 1
+
+
+class CPPTimer:
+    def __init__(
+        self,
+        stmt: str,
+        setup: str,
+        global_setup: str,
+        timer: Callable[[], float],
+        globals: dict[str, Any],
+    ) -> None:
+        if timer is not timeit.default_timer:
+            raise NotImplementedError(
+                "PyTorch was built with CUDA and a GPU is present; however "
+                "Timer does not yet support GPU measurements. If your "
+                "code is CPU only, pass `timer=timeit.default_timer` to the "
+                "Timer's constructor to indicate this. (Note that this will "
+                "produce incorrect results if the GPU is in fact used, as "
+                "Timer will not synchronize CUDA.)"
+            )
+
+        if globals:
+            raise ValueError("C++ timing does not support globals.")
+
+        self._stmt: str = textwrap.dedent(stmt)
+        self._setup: str = textwrap.dedent(setup)
+        self._global_setup: str = textwrap.dedent(global_setup)
+        self._timeit_module: Optional[TimeitModuleType] = None
+
+    def timeit(self, number: int) -> float:
+        if self._timeit_module is None:
+            self._timeit_module = cpp_jit.compile_timeit_template(
+                stmt=self._stmt,
+                setup=self._setup,
+                global_setup=self._global_setup,
+            )
+
+        return self._timeit_module.timeit(number)
+
+
+class Timer:
+    """Helper class for measuring execution time of PyTorch statements.
+
+    For a full tutorial on how to use this class, see:
+    https://pytorch.org/tutorials/recipes/recipes/benchmark.html
+
+    The PyTorch Timer is based on `timeit.Timer` (and in fact uses
+    `timeit.Timer` internally), but with several key differences:
+
+    1) Runtime aware:
+        Timer will perform warmups (important as some elements of PyTorch are
+        lazily initialized), set threadpool size so that comparisons are
+        apples-to-apples, and synchronize asynchronous CUDA functions when
+        necessary.
+
+    2) Focus on replicates:
+        When measuring code, and particularly complex kernels / models,
+        run-to-run variation is a significant confounding factor. It is
+        expected that all measurements should include replicates to quantify
+        noise and allow median computation, which is more robust than mean.
+        To that effect, this class deviates from the `timeit` API by
+        conceptually merging `timeit.Timer.repeat` and `timeit.Timer.autorange`.
+        (Exact algorithms are discussed in method docstrings.) The `timeit`
+        method is replicated for cases where an adaptive strategy is not
+        desired.
+
+    3) Optional metadata:
+        When defining a Timer, one can optionally specify `label`, `sub_label`,
+        `description`, and `env`. (Defined later) These fields are included in
+        the representation of result object and by the `Compare` class to group
+        and display results for comparison.
+
+    4) Instruction counts
+        In addition to wall times, Timer can run a statement under Callgrind
+        and report instructions executed.
+
+    Directly analogous to `timeit.Timer` constructor arguments:
+
+        `stmt`, `setup`, `timer`, `globals`
+
+    PyTorch Timer specific constructor arguments:
+
+        `label`, `sub_label`, `description`, `env`, `num_threads`
+
+    Args:
+        stmt: Code snippet to be run in a loop and timed.
+
+        setup: Optional setup code. Used to define variables used in `stmt`
+
+        global_setup: (C++ only)
+            Code which is placed at the top level of the file for things like
+            `#include` statements.
+
+        timer:
+            Callable which returns the current time. If PyTorch was built
+            without CUDA or there is no GPU present, this defaults to
+            `timeit.default_timer`; otherwise it will synchronize CUDA before
+            measuring the time.
+
+        globals:
+            A dict which defines the global variables when `stmt` is being
+            executed. This is the other method for providing variables which
+            `stmt` needs.
+
+        label:
+            String which summarizes `stmt`. For instance, if `stmt` is
+            "torch.nn.functional.relu(torch.add(x, 1, out=out))"
+            one might set label to "ReLU(x + 1)" to improve readability.
+
+        sub_label:
+            Provide supplemental information to disambiguate measurements
+            with identical stmt or label. For instance, in our example
+            above sub_label might be "float" or "int", so that it is easy
+            to differentiate:
+            "ReLU(x + 1): (float)"
+
+            "ReLU(x + 1): (int)"
+            when printing Measurements or summarizing using `Compare`.
+
+        description:
+            String to distinguish measurements with identical label and
+            sub_label. The principal use of `description` is to signal to
+            `Compare` the columns of data. For instance one might set it
+            based on the input size  to create a table of the form: ::
+
+                                        | n=1 | n=4 | ...
+                                        ------------- ...
+                ReLU(x + 1): (float)    | ... | ... | ...
+                ReLU(x + 1): (int)      | ... | ... | ...
+
+
+            using `Compare`. It is also included when printing a Measurement.
+
+        env:
+            This tag indicates that otherwise identical tasks were run in
+            different environments, and are therefore not equivalent, for
+            instance when A/B testing a change to a kernel. `Compare` will
+            treat Measurements with different `env` specification as distinct
+            when merging replicate runs.
+
+        num_threads:
+            The size of the PyTorch threadpool when executing `stmt`. Single
+            threaded performance is important as both a key inference workload
+            and a good indicator of intrinsic algorithmic efficiency, so the
+            default is set to one. This is in contrast to the default PyTorch
+            threadpool size which tries to utilize all cores.
+    """
+
+    _timer_cls: type[TimerClass] = timeit.Timer
+
+    def __init__(
+        self,
+        stmt: str = "pass",
+        setup: str = "pass",
+        global_setup: str = "",
+        timer: Callable[[], float] = timer,
+        globals: Optional[dict[str, Any]] = None,
+        label: Optional[str] = None,
+        sub_label: Optional[str] = None,
+        description: Optional[str] = None,
+        env: Optional[str] = None,
+        num_threads: int = 1,
+        language: Union[Language, str] = Language.PYTHON,
+    ):
+        if not isinstance(stmt, str):
+            raise ValueError("Currently only a `str` stmt is supported.")
+
+        # We copy `globals` to prevent mutations from leaking.
+        # (For instance, `eval` adds the `__builtins__` key)
+        self._globals = dict(globals or {})
+
+        timer_kwargs = {}
+        if language in (Language.PYTHON, "py", "python"):
+            # Include `torch` if not specified as a convenience feature.
+            self._globals.setdefault("torch", torch)
+            self._language: Language = Language.PYTHON
+            if global_setup:
+                raise ValueError(
+                    f"global_setup is C++ only, got `{global_setup}`. Most "
+                    "likely this code can simply be moved to `setup`."
+                )
+
+        elif language in (Language.CPP, "cpp", "c++"):
+            assert self._timer_cls is timeit.Timer, "_timer_cls has already been swapped."
+            self._timer_cls = CPPTimer
+            setup = ("" if setup == "pass" else setup)
+            self._language = Language.CPP
+            timer_kwargs["global_setup"] = global_setup
+
+        else:
+            raise ValueError(f"Invalid language `{language}`.")
+
+        # Convenience adjustment so that multi-line code snippets defined in
+        # functions do not IndentationError (Python) or look odd (C++). The
+        # leading newline removal is for the initial newline that appears when
+        # defining block strings. For instance:
+        #   textwrap.dedent("""
+        #     print("This is a stmt")
+        #   """)
+        # produces '\nprint("This is a stmt")\n'.
+        #
+        # Stripping this down to 'print("This is a stmt")' doesn't change
+        # what gets executed, but it makes __repr__'s nicer.
+        stmt = textwrap.dedent(stmt)
+        stmt = (stmt[1:] if stmt and stmt[0] == "\n" else stmt).rstrip()
+        setup = textwrap.dedent(setup)
+        setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
+
+        self._timer = self._timer_cls(
+            stmt=stmt,
+            setup=setup,
+            timer=timer,
+            globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals),
+            **timer_kwargs,
+        )
+        self._task_spec = common.TaskSpec(
+            stmt=stmt,
+            setup=setup,
+            global_setup=global_setup,
+            label=label,
+            sub_label=sub_label,
+            description=description,
+            env=env,
+            num_threads=num_threads,
+        )
+
+    def _timeit(self, number: int) -> float:
+        # Even calling a timer in C++ takes ~50 ns, so no real operation should
+        # take less than 1 ns. (And this prevents divide by zero errors.)
+        return max(self._timer.timeit(number), 1e-9)
+
+    def timeit(self, number: int = 1000000) -> common.Measurement:
+        """Mirrors the semantics of timeit.Timer.timeit().
+
+        Execute the main statement (`stmt`) `number` times.
+        https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit
+        """
+        with common.set_torch_threads(self._task_spec.num_threads):
+            # Warmup
+            self._timeit(number=max(int(number // 100), 2))
+
+            return common.Measurement(
+                number_per_run=number,
+                raw_times=[self._timeit(number=number)],
+                task_spec=self._task_spec
+            )
+
+    def repeat(self, repeat: int = -1, number: int = -1) -> None:
+        raise NotImplementedError("See `Timer.blocked_autorange.`")
+
+    def autorange(self, callback: Optional[Callable[[int, float], NoReturn]] = None) -> None:
+        raise NotImplementedError("See `Timer.blocked_autorange.`")
+
+    def _threaded_measurement_loop(
+        self,
+        number: int,
+        time_hook: Callable[[], float],
+        stop_hook: Callable[[list[float]], bool],
+        min_run_time: float,
+        max_run_time: Optional[float] = None,
+        callback: Optional[Callable[[int, float], NoReturn]] = None
+    ) -> list[float]:
+        total_time = 0.0
+        can_stop = False
+        times: list[float] = []
+        with common.set_torch_threads(self._task_spec.num_threads):
+            while (total_time < min_run_time) or (not can_stop):
+                time_spent = time_hook()
+                times.append(time_spent)
+                total_time += time_spent
+                if callback:
+                    callback(number, time_spent)
+                can_stop = stop_hook(times)
+                if max_run_time and total_time > max_run_time:
+                    break
+        return times
+
+    def _estimate_block_size(self, min_run_time: float) -> int:
+        with common.set_torch_threads(self._task_spec.num_threads):
+            # Estimate the block size needed for measurement to be negligible
+            # compared to the inner loop. This also serves as a warmup.
+            overhead = torch.tensor([self._timeit(0) for _ in range(5)]).median().item()
+            number = 1
+            while True:
+                time_taken = self._timeit(number)
+                relative_overhead = overhead / time_taken
+                if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000:
+                    break
+                if time_taken > min_run_time:
+                    break
+                # Avoid overflow in C++ pybind11 interface
+                if number * 10 > 2147483647:
+                    break
+                number *= 10
+        return number
+
+    def blocked_autorange(
+        self,
+        callback: Optional[Callable[[int, float], NoReturn]] = None,
+        min_run_time: float = 0.2,
+    ) -> common.Measurement:
+        """Measure many replicates while keeping timer overhead to a minimum.
+
+        At a high level, blocked_autorange executes the following pseudo-code::
+
+            `setup`
+
+            total_time = 0
+            while total_time < min_run_time
+                start = timer()
+                for _ in range(block_size):
+                    `stmt`
+                total_time += (timer() - start)
+
+        Note the variable `block_size` in the inner loop. The choice of block
+        size is important to measurement quality, and must balance two
+        competing objectives:
+
+            1) A small block size results in more replicates and generally
+               better statistics.
+
+            2) A large block size better amortizes the cost of `timer`
+               invocation, and results in a less biased measurement. This is
+               important because CUDA synchronization time is non-trivial
+               (order single to low double digit microseconds) and would
+               otherwise bias the measurement.
+
+        blocked_autorange sets block_size by running a warmup period,
+        increasing block size until timer overhead is less than 0.1% of
+        the overall computation. This value is then used for the main
+        measurement loop.
+
+        Returns:
+            A `Measurement` object that contains measured runtimes and
+            repetition counts, and can be used to compute statistics.
+            (mean, median, etc.)
+        """
+        number = self._estimate_block_size(min_run_time)
+
+        def time_hook() -> float:
+            return self._timeit(number)
+
+        def stop_hook(times: list[float]) -> bool:
+            return True
+
+        times = self._threaded_measurement_loop(
+            number, time_hook, stop_hook,
+            min_run_time=min_run_time,
+            callback=callback)
+
+        return common.Measurement(
+            number_per_run=number,
+            raw_times=times,
+            task_spec=self._task_spec
+        )
+
+    def adaptive_autorange(
+            self,
+            threshold: float = 0.1,
+            *,
+            min_run_time: float = 0.01,
+            max_run_time: float = 10.0,
+            callback: Optional[Callable[[int, float], NoReturn]] = None,
+    ) -> common.Measurement:
+        """Similar to `blocked_autorange` but also checks for variablility in measurements
+        and repeats until iqr/median is smaller than `threshold` or `max_run_time` is reached.
+
+
+        At a high level, adaptive_autorange executes the following pseudo-code::
+
+            `setup`
+
+            times = []
+            while times.sum < max_run_time
+                start = timer()
+                for _ in range(block_size):
+                    `stmt`
+                times.append(timer() - start)
+
+                enough_data = len(times)>3 and times.sum > min_run_time
+                small_iqr=times.iqr/times.mean float:
+            return self._timeit(number)
+
+        def stop_hook(times: list[float]) -> bool:
+            if len(times) > 3:
+                return common.Measurement(
+                    number_per_run=number,
+                    raw_times=times,
+                    task_spec=self._task_spec
+                ).meets_confidence(threshold=threshold)
+            return False
+        times = self._threaded_measurement_loop(
+            number, time_hook, stop_hook, min_run_time, max_run_time, callback=callback)
+
+        return common.Measurement(
+            number_per_run=number,
+            raw_times=times,
+            task_spec=self._task_spec
+        )
+
+    @overload
+    def collect_callgrind(
+        self,
+        number: int,
+        *,
+        repeats: None,
+        collect_baseline: bool,
+        retain_out_file: bool,
+    ) -> valgrind_timer_interface.CallgrindStats:
+        ...
+
+    @overload
+    def collect_callgrind(
+        self,
+        number: int,
+        *,
+        repeats: int,
+        collect_baseline: bool,
+        retain_out_file: bool,
+    ) -> tuple[valgrind_timer_interface.CallgrindStats, ...]:
+        ...
+
+    def collect_callgrind(
+        self,
+        number: int = 100,
+        *,
+        repeats: Optional[int] = None,
+        collect_baseline: bool = True,
+        retain_out_file: bool = False,
+    ) -> Any:
+        """Collect instruction counts using Callgrind.
+
+        Unlike wall times, instruction counts are deterministic
+        (modulo non-determinism in the program itself and small amounts of
+        jitter from the Python interpreter.) This makes them ideal for detailed
+        performance analysis. This method runs `stmt` in a separate process
+        so that Valgrind can instrument the program. Performance is severely
+        degraded due to the instrumentation, however this is ameliorated by
+        the fact that a small number of iterations is generally sufficient to
+        obtain good measurements.
+
+        In order to to use this method `valgrind`, `callgrind_control`, and
+        `callgrind_annotate` must be installed.
+
+        Because there is a process boundary between the caller (this process)
+        and the `stmt` execution, `globals` cannot contain arbitrary in-memory
+        data structures. (Unlike timing methods) Instead, globals are
+        restricted to builtins, `nn.Modules`'s, and TorchScripted functions/modules
+        to reduce the surprise factor from serialization and subsequent
+        deserialization. The `GlobalsBridge` class provides more detail on this
+        subject. Take particular care with nn.Modules: they rely on pickle and
+        you may need to add an import to `setup` for them to transfer properly.
+
+        By default, a profile for an empty statement will be collected and
+        cached to indicate how many instructions are from the Python loop which
+        drives `stmt`.
+
+        Returns:
+            A `CallgrindStats` object which provides instruction counts and
+            some basic facilities for analyzing and manipulating results.
+        """
+        if not isinstance(self._task_spec.stmt, str):
+            raise ValueError("`collect_callgrind` currently only supports string `stmt`")
+
+        if repeats is not None and repeats < 1:
+            raise ValueError("If specified, `repeats` must be >= 1")
+
+        # Check that the statement is valid. It doesn't guarantee success, but it's much
+        # simpler and quicker to raise an exception for a faulty `stmt` or `setup` in
+        # the parent process rather than the valgrind subprocess.
+        self._timeit(1)
+        is_python = (self._language == Language.PYTHON)
+        assert is_python or not self._globals
+        result = valgrind_timer_interface.wrapper_singleton().collect_callgrind(
+            task_spec=self._task_spec,
+            globals=self._globals,
+            number=number,
+            repeats=repeats or 1,
+            collect_baseline=collect_baseline and is_python,
+            is_python=is_python,
+            retain_out_file=retain_out_file,
+        )
+
+        return (result[0] if repeats is None else result)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
new file mode 100644
index 0000000000000000000000000000000000000000..f078cc82b95daf94d2bea51f1e1b1a8c12daea23
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
@@ -0,0 +1,129 @@
+
+/*
+   ----------------------------------------------------------------
+
+   Notice that the following BSD-style license applies to this one
+   file (callgrind.h) only.  The rest of Valgrind is licensed under the
+   terms of the GNU General Public License, version 2, unless
+   otherwise indicated.  See the COPYING file in the source
+   distribution for details.
+
+   ----------------------------------------------------------------
+
+   This file is part of callgrind, a valgrind tool for cache simulation
+   and call tree tracing.
+
+   Copyright (C) 2003-2017 Josef Weidendorfer.  All rights reserved.
+
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   1. Redistributions of source code must retain the above copyright
+      notice, this list of conditions and the following disclaimer.
+
+   2. The origin of this software must not be misrepresented; you must
+      not claim that you wrote the original software.  If you use this
+      software in a product, an acknowledgment in the product
+      documentation would be appreciated but is not required.
+
+   3. Altered source versions must be plainly marked as such, and must
+      not be misrepresented as being the original software.
+
+   4. The name of the author may not be used to endorse or promote
+      products derived from this software without specific prior written
+      permission.
+
+   THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
+   OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+   WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+   ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+   DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+   DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
+   GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+   INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+   WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+   ----------------------------------------------------------------
+
+   Notice that the above BSD-style license applies to this one file
+   (callgrind.h) only.  The entire rest of Valgrind is licensed under
+   the terms of the GNU General Public License, version 2.  See the
+   COPYING file in the source distribution for details.
+
+   ----------------------------------------------------------------
+*/
+
+#ifndef __CALLGRIND_H
+#define __CALLGRIND_H
+
+#include "valgrind.h"
+
+/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !!
+   This enum comprises an ABI exported by Valgrind to programs
+   which use client requests.  DO NOT CHANGE THE ORDER OF THESE
+   ENTRIES, NOR DELETE ANY -- add new ones at the end.
+
+   The identification ('C','T') for Callgrind has historical
+   reasons: it was called "Calltree" before. Besides, ('C','G') would
+   clash with cachegrind.
+ */
+
+typedef
+   enum {
+      VG_USERREQ__DUMP_STATS = VG_USERREQ_TOOL_BASE('C','T'),
+      VG_USERREQ__ZERO_STATS,
+      VG_USERREQ__TOGGLE_COLLECT,
+      VG_USERREQ__DUMP_STATS_AT,
+      VG_USERREQ__START_INSTRUMENTATION,
+      VG_USERREQ__STOP_INSTRUMENTATION
+   } Vg_CallgrindClientRequest;
+
+/* Dump current state of cost centers, and zero them afterwards */
+#define CALLGRIND_DUMP_STATS                                    \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS,       \
+                                  0, 0, 0, 0, 0)
+
+/* Dump current state of cost centers, and zero them afterwards.
+   The argument is appended to a string stating the reason which triggered
+   the dump. This string is written as a description field into the
+   profile data dump. */
+#define CALLGRIND_DUMP_STATS_AT(pos_str)                        \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS_AT,    \
+                                  pos_str, 0, 0, 0, 0)
+
+/* Zero cost centers */
+#define CALLGRIND_ZERO_STATS                                    \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__ZERO_STATS,       \
+                                  0, 0, 0, 0, 0)
+
+/* Toggles collection state.
+   The collection state specifies whether the happening of events
+   should be noted or if they are to be ignored. Events are noted
+   by increment of counters in a cost center */
+#define CALLGRIND_TOGGLE_COLLECT                                \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__TOGGLE_COLLECT,   \
+                                  0, 0, 0, 0, 0)
+
+/* Start full callgrind instrumentation if not already switched on.
+   When cache simulation is done, it will flush the simulated cache;
+   this will lead to an artificial cache warmup phase afterwards with
+   cache misses which would not have happened in reality. */
+#define CALLGRIND_START_INSTRUMENTATION                              \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__START_INSTRUMENTATION, \
+                                  0, 0, 0, 0, 0)
+
+/* Stop full callgrind instrumentation if not already switched off.
+   This flushes Valgrinds translation cache, and does no additional
+   instrumentation afterwards, which effectivly will run at the same
+   speed as the "none" tool (ie. at minimal slowdown).
+   Use this to bypass Callgrind aggregation for uninteresting code parts.
+   To start Callgrind in this mode to ignore the setup phase, use
+   the option "--instr-atstart=no". */
+#define CALLGRIND_STOP_INSTRUMENTATION                               \
+  VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STOP_INSTRUMENTATION,  \
+                                  0, 0, 0, 0, 0)
+
+#endif /* __CALLGRIND_H */
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..cd41f0de092f0b1488c8945edf2af80c6f9b596c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
@@ -0,0 +1,35 @@
+/* Used to collect profiles of old versions of PyTorch. */
+#include 
+#include 
+
+bool _valgrind_supported_platform() {
+#if defined(NVALGRIND)
+  return false;
+#else
+  return true;
+#endif
+}
+
+void _valgrind_toggle() {
+#if defined(NVALGRIND)
+  TORCH_CHECK(false, "Valgrind is not supported.");
+#else
+  CALLGRIND_TOGGLE_COLLECT;
+#endif
+}
+
+void _valgrind_toggle_and_dump_stats() {
+#if defined(NVALGRIND)
+  TORCH_CHECK(false, "Valgrind is not supported.");
+#else
+  // NB: See note in Module.cpp
+  CALLGRIND_TOGGLE_COLLECT;
+  CALLGRIND_DUMP_STATS;
+#endif
+}
+
+PYBIND11_MODULE(callgrind_bindings, m) {
+  m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
+  m.def("_valgrind_toggle", &_valgrind_toggle);
+  m.def("_valgrind_toggle_and_dump_stats", &_valgrind_dump_stats);
+}
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..587685c7df7445b299c35462307f47cf6012a00d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
@@ -0,0 +1,68 @@
+/* C++ template for Timer.collect_callgrind
+
+This template will be consumed by `cpp_jit.py`, and will replace:
+    `GLOBAL_SETUP_TEMPLATE_LOCATION`,
+    `SETUP_TEMPLATE_LOCATION`
+      and
+    `STMT_TEMPLATE_LOCATION`
+sections with user provided statements.
+*/
+
+#include 
+#include 
+#include 
+
+#include 
+
+// Global setup. (e.g. #includes)
+// GLOBAL_SETUP_TEMPLATE_LOCATION
+
+#if defined(NVALGRIND)
+static_assert(false);
+#endif
+
+int main(int argc, char* argv[]) {
+  // This file should only be called inside of `Timer`, so we can adopt a
+  // very simple and rigid argument parsing scheme.
+  TORCH_CHECK(argc == 9);
+  TORCH_CHECK(std::string(argv[1]) == "--number");
+  auto number = std::stoi(argv[2]);
+
+  TORCH_CHECK(
+      std::string(argv[3]) == "--number-warmup" ||
+      std::string(argv[3]) == "--number_warmup");
+  auto number_warmup = std::stoi(argv[4]);
+
+  TORCH_CHECK(std::string(argv[5]) == "--repeats");
+  auto repeats = std::stoi(argv[6]);
+
+  TORCH_CHECK(
+      std::string(argv[7]) == "--number-threads" ||
+      std::string(argv[7]) == "--number_threads");
+  auto number_threads = std::stoi(argv[8]);
+  torch::set_num_threads(number_threads);
+
+  // Setup
+  // SETUP_TEMPLATE_LOCATION
+
+  // Warmup
+  for (const auto i : c10::irange(number_warmup)) {
+    (void)i;
+    // STMT_TEMPLATE_LOCATION
+  }
+
+  // Main loop
+  for (const auto repeat : c10::irange(repeats)) {
+    (void)repeat;
+    CALLGRIND_TOGGLE_COLLECT;
+
+    for (const auto i : c10::irange(number)) {
+      (void)i;
+      // STMT_TEMPLATE_LOCATION
+    }
+
+    // NB: See note in Module.cpp
+    CALLGRIND_TOGGLE_COLLECT;
+    CALLGRIND_DUMP_STATS;
+  }
+}
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..900d8c3722a8a5b9c4890d49a35e59dbb19d7bb8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
@@ -0,0 +1,910 @@
+"""Intermediate layer between `Timer` and `valgrind`."""
+import collections
+import enum
+import dataclasses
+import itertools as it
+import os
+import pickle
+import re
+import shutil
+import subprocess
+import sys
+import textwrap
+from typing import (
+    cast, Any, Callable, NamedTuple,
+    Optional, Union, TYPE_CHECKING)
+from collections.abc import Iterator
+
+import torch
+from torch.utils.benchmark.utils import common, cpp_jit
+from torch.utils.benchmark.utils._stubs import CallgrindModuleType
+import operator
+
+
+__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
+
+
+if TYPE_CHECKING:
+    CompletedProcessType = subprocess.CompletedProcess[str]
+else:
+    CompletedProcessType = subprocess.CompletedProcess
+
+
+class FunctionCount(NamedTuple):
+    # TODO(#105471): Rename the count field
+    count: int  # type: ignore[assignment]
+    function: str
+
+
+@dataclasses.dataclass(repr=False, eq=False, frozen=True)
+class FunctionCounts:
+    """Container for manipulating Callgrind results.
+
+    It supports:
+        1) Addition and subtraction to combine or diff results.
+        2) Tuple-like indexing.
+        3) A `denoise` function which strips CPython calls which are known to
+           be non-deterministic and quite noisy.
+        4) Two higher order methods (`filter` and `transform`) for custom
+           manipulation.
+    """
+    _data: tuple[FunctionCount, ...]
+    inclusive: bool
+    truncate_rows: bool = True
+
+    # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines
+    # the print settings. This is simply to allow hermetic unit tests.
+    _linewidth: Optional[int] = None
+
+    def __iter__(self) -> Iterator[FunctionCount]:
+        yield from self._data
+
+    def __len__(self) -> int:
+        return len(self._data)
+
+    def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]:
+        data: Union[FunctionCount, tuple[FunctionCount, ...]] = self._data[item]
+        return (
+            FunctionCounts(cast(tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False)
+            if isinstance(data, tuple) else data
+        )
+
+    def __repr__(self) -> str:
+        count_len = 0
+        for c, _ in self:
+            # Account for sign in string length.
+            count_len = max(count_len, len(str(c)) + int(c < 0))
+
+        lines = []
+        linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth
+        fn_str_len = max(linewidth - count_len - 4, 40)
+        for c, fn in self:
+            if len(fn) > fn_str_len:
+                left_len = int((fn_str_len - 5) // 2)
+                fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):]
+            lines.append(f"  {c:>{count_len}}  {fn}")
+
+        if self.truncate_rows and len(lines) > 18:
+            lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:]
+
+        if not self.inclusive:
+            lines.extend(["", f"Total: {self.sum()}"])
+
+        return "\n".join([super().__repr__()] + lines)
+
+    def __add__(
+        self,
+        other: "FunctionCounts",
+    ) -> "FunctionCounts":
+        return self._merge(other, lambda c: c)
+
+    def __sub__(
+        self,
+        other: "FunctionCounts",
+    ) -> "FunctionCounts":
+        return self._merge(other, operator.neg)
+
+    def __mul__(self, other: Union[int, float]) -> "FunctionCounts":
+        return self._from_dict({
+            fn: int(c * other) for c, fn in self._data
+        }, self.inclusive)
+
+    def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts":
+        """Apply `map_fn` to all of the function names.
+
+        This can be used to regularize function names (e.g. stripping irrelevant
+        parts of the file path), coalesce entries by mapping multiple functions
+        to the same name (in which case the counts are added together), etc.
+        """
+        counts: collections.defaultdict[str, int] = collections.defaultdict(int)
+        for c, fn in self._data:
+            counts[map_fn(fn)] += c
+
+        return self._from_dict(counts, self.inclusive)
+
+    def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts":
+        """Keep only the elements where `filter_fn` applied to function name returns True."""
+        return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive)
+
+    def sum(self) -> int:
+        return sum(c for c, _ in self)
+
+    def denoise(self) -> "FunctionCounts":
+        """Remove known noisy instructions.
+
+        Several instructions in the CPython interpreter are rather noisy. These
+        instructions involve unicode to dictionary lookups which Python uses to
+        map variable names. FunctionCounts is generally a content agnostic
+        container, however this is sufficiently important for obtaining
+        reliable results to warrant an exception."""
+        return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn)
+
+    def _merge(
+        self,
+        second: "FunctionCounts",
+        merge_fn: Callable[[int], int]
+    ) -> "FunctionCounts":
+        assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts."
+        counts: collections.defaultdict[str, int] = collections.defaultdict(int)
+        for c, fn in self:
+            counts[fn] += c
+
+        for c, fn in second:
+            counts[fn] += merge_fn(c)
+
+        return self._from_dict(counts, self.inclusive)
+
+    @staticmethod
+    def _from_dict(counts: dict[str, int], inclusive: bool) -> "FunctionCounts":
+        flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c)
+        return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive)
+
+
+@dataclasses.dataclass(repr=False, eq=False, frozen=True)
+class CallgrindStats:
+    """Top level container for Callgrind results collected by Timer.
+
+    Manipulation is generally done using the FunctionCounts class, which is
+    obtained by calling `CallgrindStats.stats(...)`. Several convenience
+    methods are provided as well; the most significant is
+    `CallgrindStats.as_standardized()`.
+    """
+    task_spec: common.TaskSpec
+    number_per_run: int
+    built_with_debug_symbols: bool
+    baseline_inclusive_stats: FunctionCounts
+    baseline_exclusive_stats: FunctionCounts
+    stmt_inclusive_stats: FunctionCounts
+    stmt_exclusive_stats: FunctionCounts
+    stmt_callgrind_out: Optional[str]
+
+    def __repr__(self) -> str:
+        base_stats = self.baseline_exclusive_stats
+        output = f"""
+{super().__repr__()}
+{self.task_spec.summarize()}
+  {'':>25}All{'':>10}Noisy symbols removed
+    Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12}
+    Baseline:     {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12}
+{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''}
+""".strip()
+        if not self.built_with_debug_symbols:
+            output += textwrap.dedent("""
+            Warning: PyTorch was not built with debug symbols.
+                     Source information may be limited. Rebuild with
+                     REL_WITH_DEB_INFO=1 for more detailed results.""")
+        return output
+
+    def stats(self, inclusive: bool = False) -> FunctionCounts:
+        """Returns detailed function counts.
+
+        Conceptually, the FunctionCounts returned can be thought of as a tuple
+        of (count, path_and_function_name) tuples.
+
+        `inclusive` matches the semantics of callgrind. If True, the counts
+        include instructions executed by children. `inclusive=True` is useful
+        for identifying hot spots in code; `inclusive=False` is useful for
+        reducing noise when diffing counts from two different runs. (See
+        CallgrindStats.delta(...) for more details)
+        """
+        return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats
+
+    def counts(self, *, denoise: bool = False) -> int:
+        """Returns the total number of instructions executed.
+
+        See `FunctionCounts.denoise()` for an explanation of the `denoise` arg.
+        """
+        stats = self.stmt_exclusive_stats
+        return (stats.denoise() if denoise else stats).sum()
+
+    # FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563
+    def delta(
+        self,
+        other: "CallgrindStats",
+        inclusive: bool = False,
+    ) -> FunctionCounts:
+        """Diff two sets of counts.
+
+        One common reason to collect instruction counts is to determine the
+        the effect that a particular change will have on the number of instructions
+        needed to perform some unit of work. If a change increases that number, the
+        next logical question is "why". This generally involves looking at what part
+        if the code increased in instruction count. This function automates that
+        process so that one can easily diff counts on both an inclusive and
+        exclusive basis.
+        """
+        return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive)
+
+    def as_standardized(self) -> "CallgrindStats":
+        """Strip library names and some prefixes from function strings.
+
+        When comparing two different sets of instruction counts, on stumbling
+        block can be path prefixes. Callgrind includes the full filepath
+        when reporting a function (as it should). However, this can cause
+        issues when diffing profiles. If a key component such as Python
+        or PyTorch was built in separate locations in the two profiles, which
+        can result in something resembling::
+
+            23234231 /tmp/first_build_dir/thing.c:foo(...)
+             9823794 /tmp/first_build_dir/thing.c:bar(...)
+              ...
+               53453 .../aten/src/Aten/...:function_that_actually_changed(...)
+              ...
+             -9823794 /tmp/second_build_dir/thing.c:bar(...)
+            -23234231 /tmp/second_build_dir/thing.c:foo(...)
+
+        Stripping prefixes can ameliorate this issue by regularizing the
+        strings and causing better cancellation of equivalent call sites
+        when diffing.
+        """
+        def strip(stats: FunctionCounts) -> FunctionCounts:
+            transforms = (
+                # PyTorch may have been built in different locations.
+                (r"^.+build/\.\./", "build/../"),
+                (r"^.+/" + re.escape("build/aten/"), "build/aten/"),
+
+                # "Python" and "Objects" come from CPython.
+                (r"^.+/" + re.escape("Python/"), "Python/"),
+                (r"^.+/" + re.escape("Objects/"), "Objects/"),
+
+                # Strip library name. e.g. `libtorch.so`
+                (r"\s\[.+\]$", ""),
+            )
+
+            for before, after in transforms:
+                stats = stats.transform(lambda fn: re.sub(before, after, fn))
+
+            return stats
+
+        return CallgrindStats(
+            task_spec=self.task_spec,
+            number_per_run=self.number_per_run,
+            built_with_debug_symbols=self.built_with_debug_symbols,
+            baseline_inclusive_stats=strip(self.baseline_inclusive_stats),
+            baseline_exclusive_stats=strip(self.baseline_exclusive_stats),
+            stmt_inclusive_stats=strip(self.stmt_inclusive_stats),
+            stmt_exclusive_stats=strip(self.stmt_exclusive_stats),
+
+            # `as_standardized` will change symbol names, so the contents will
+            # no longer map directly to `callgrind.out`
+            stmt_callgrind_out=None,
+        )
+
+
+class Serialization(enum.Enum):
+    PICKLE = 0
+    TORCH = 1
+    TORCH_JIT = 2
+
+
+_GLOBALS_ALLOWED_TYPES: dict[Serialization, tuple[Any, ...]] = {
+    Serialization.PICKLE: (str, bytes, bool, int, float, complex),
+    Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule),
+    Serialization.TORCH: (torch.nn.Module,),
+}
+
+
+class CopyIfCallgrind:
+    """Signal that a global may be replaced with a deserialized copy.
+
+    See `GlobalsBridge` for why this matters.
+    """
+    def __init__(self, value: Any, *, setup: Optional[str] = None):
+        for method, supported_types in _GLOBALS_ALLOWED_TYPES.items():
+            if any(isinstance(value, t) for t in supported_types):
+                self._value: Any = value
+                self._setup: Optional[str] = setup
+                self._serialization: Serialization = method
+                break
+        else:
+            supported_str = "\n".join([
+                getattr(t, "__name__", repr(t))
+                for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())])
+
+            raise ValueError(
+                f"Unsupported type: {type(value)}\n"
+                f"`collect_callgrind` restricts globals to the following types:\n"
+                f"{textwrap.indent(supported_str, '  ')}"
+            )
+
+    @property
+    def value(self) -> Any:
+        return self._value
+
+    @property
+    def setup(self) -> Optional[str]:
+        return self._setup
+
+    @property
+    def serialization(self) -> Serialization:
+        return self._serialization
+
+    @staticmethod
+    def unwrap_all(globals: dict[str, Any]) -> dict[str, Any]:
+        return {
+            k: (v.value if isinstance(v, CopyIfCallgrind) else v)
+            for k, v in globals.items()
+        }
+
+
+class GlobalsBridge:
+    """Handle the transfer of (certain) globals when collecting Callgrind statistics.
+
+    Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to
+                  work with `Timer.collect_callgrind`.
+
+    Consider the following code snippet:
+    ```
+        import pickle
+        import timeit
+
+        class Counter:
+            value = 0
+
+            def __call__(self):
+                self.value += 1
+
+        counter = Counter()
+        timeit.Timer("counter()", globals={"counter": counter}).timeit(10)
+        print(counter.value)  # 10
+
+        timeit.Timer(
+            "counter()",
+            globals={"counter": pickle.loads(pickle.dumps(counter))}
+        ).timeit(20)
+        print(counter.value)  # Still 10
+    ```
+
+    In the first case, `stmt` is executed using the objects in `globals`;
+    however, the addition of serialization and deserialization changes the
+    semantics and may meaningfully change behavior.
+
+    This is a practical consideration when collecting Callgrind statistics.
+    Unlike `exec` based execution (which `timeit` uses under the hood) which
+    can share in-memory data structures with the caller, Callgrind collection
+    requires an entirely new process in order to run under Valgrind. This means
+    that any data structures used for statement execution will have to be
+    serialized and deserialized in the subprocess.
+
+    In order to avoid surprising semantics from (user invisible) process
+    boundaries, what can be passed through `globals` is severely restricted
+    for `Timer.collect_callgrind`. It is expected that most setup should be
+    achievable (albeit perhaps less ergonomically) by passing a `setup`
+    string.
+
+    There are, however, exceptions. One such class are TorchScripted functions.
+    Because they require a concrete file with source code it is not possible
+    to define them using a `setup` string. Another group are torch.nn.Modules,
+    whose construction can be complex and prohibitively cumbersome to coerce
+    into a `setup` string. Finally, most builtin types are sufficiently well
+    behaved and sufficiently common to warrant allowing as well. (e.g.
+    `globals={"n": 1}` is very convenient.)
+
+    Fortunately, all have well defined serialization semantics. This class
+    is responsible for enabling the Valgrind subprocess to use elements in
+    `globals` so long as they are an allowed type.
+
+    Caveats:
+        The user is required to acknowledge this serialization by wrapping
+        elements in `globals` with `CopyIfCallgrind`.
+
+        While ScriptFunction and ScriptModule are expected to save and load
+        quite robustly, it is up to the user to ensure that an nn.Module can
+        un-pickle successfully.
+
+        `torch.Tensor` and `np.ndarray` are deliberately excluded. The
+        serialization/deserialization process perturbs the representation of a
+        tensor in ways that could result in incorrect measurements. For example,
+        if a tensor lives in pinned CPU memory, this fact would not be preserved
+        by a dump, and that will in turn change the performance of certain CUDA
+        operations.
+    """
+
+    def __init__(self, globals: dict[str, Any], data_dir: str) -> None:
+        self._globals: dict[str, CopyIfCallgrind] = {}
+        self._data_dir = data_dir
+        if not os.path.exists(data_dir):
+            os.mkdir(data_dir)
+
+        if globals.get("torch", torch) is not torch:
+            raise ValueError("`collect_callgrind` does not support mocking out `torch`.")
+
+        for name, value in globals.items():
+            if name in ("torch", "__builtins__"):
+                # Torch will be imported by the collection script, and
+                # __builtins__ is added by Timer.
+                continue
+
+            if not isinstance(value, CopyIfCallgrind):
+                raise ValueError(
+                    "`collect_callgrind` requires that globals be wrapped in "
+                    "`CopyIfCallgrind` so that serialization is explicit."
+                )
+
+            self._globals[name] = value
+
+    def construct(self) -> str:
+        load_lines = []
+        for name, wrapped_value in self._globals.items():
+            if wrapped_value.setup is not None:
+                load_lines.append(textwrap.dedent(wrapped_value.setup))
+
+            if wrapped_value.serialization == Serialization.PICKLE:
+                path = os.path.join(self._data_dir, f"{name}.pkl")
+                load_lines.append(
+                    f"with open({repr(path)}, 'rb') as f:\n    {name} = pickle.load(f)")
+                with open(path, "wb") as f:
+                    pickle.dump(wrapped_value.value, f)
+
+            elif wrapped_value.serialization == Serialization.TORCH:
+                path = os.path.join(self._data_dir, f"{name}.pt")
+                # TODO: Figure out if we can use torch.serialization.add_safe_globals here
+                # Using weights_only=False after the change in
+                # https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
+                load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
+                torch.save(wrapped_value.value, path)
+
+            elif wrapped_value.serialization == Serialization.TORCH_JIT:
+                path = os.path.join(self._data_dir, f"{name}.pt")
+                load_lines.append(f"{name} = torch.jit.load({repr(path)})")
+                with open(path, "wb") as f:
+                    torch.jit.save(wrapped_value.value, f)  # type: ignore[no-untyped-call]
+
+            else:
+                raise NotImplementedError(
+                    f"Unknown serialization method: {wrapped_value.serialization}")
+
+        return "\n".join(load_lines)
+
+
+class _ValgrindWrapper:
+    def __init__(self) -> None:
+        self._bindings_module: Optional[CallgrindModuleType] = None
+        valgrind_symbols = (
+            "_valgrind_supported_platform",
+            "_valgrind_toggle",
+            "_valgrind_toggle_and_dump_stats",
+        )
+        if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols):
+            self._supported_platform: bool = torch._C._valgrind_supported_platform()
+
+        else:
+            print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
+            self._bindings_module = cpp_jit.get_compat_bindings()
+            assert all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols)
+            self._supported_platform = self._bindings_module._valgrind_supported_platform()
+
+        self._commands_available: dict[str, bool] = {}
+        if self._supported_platform:
+            # Only bother checking on supported platforms.
+            for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"):
+                self._commands_available[cmd] = not subprocess.run(
+                    ["which", cmd],
+                    capture_output=True,
+                    check=False,
+                ).returncode
+
+        self._build_type: Optional[str] = None
+        build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show())  # type: ignore[no-untyped-call]
+        if build_search is not None:
+            self._build_type = build_search.groups()[0].split(",")[0]
+
+    def _validate(self) -> None:
+        if not self._supported_platform:
+            raise OSError("Valgrind is not supported on this platform.")
+
+        missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available]
+        if missing_cmds:
+            raise OSError("Missing: " + ", ".join(missing_cmds))
+
+    def collect_callgrind(
+        self,
+        task_spec: common.TaskSpec,
+        globals: dict[str, Any],
+        *,
+        number: int,
+        repeats: int,
+        collect_baseline: bool,
+        is_python: bool,
+        retain_out_file: bool,
+    ) -> tuple[CallgrindStats, ...]:
+        """Collect stats, and attach a reference run which can be used to filter interpreter overhead."""
+        self._validate()
+        assert is_python or not collect_baseline
+
+        *task_stats, baseline_stats = self._invoke(
+            task_spec=task_spec,
+            globals=globals,
+            number=number,
+            repeats=repeats,
+            collect_baseline=collect_baseline,
+            is_python=is_python,
+            retain_out_file=retain_out_file,
+        )
+        assert len(task_stats) == repeats
+
+        return tuple(
+            CallgrindStats(
+                task_spec=task_spec,
+                number_per_run=number,
+                built_with_debug_symbols=self._build_type == "RelWithDebInfo",
+                baseline_inclusive_stats=baseline_stats[0],
+                baseline_exclusive_stats=baseline_stats[1],
+                stmt_inclusive_stats=stmt_inclusive_stats,
+                stmt_exclusive_stats=stmt_exclusive_stats,
+                stmt_callgrind_out=out_contents,
+            )
+            for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats
+        )
+
+    def _invoke(
+        self,
+        *,
+        task_spec: common.TaskSpec,
+        globals: dict[str, Any],
+        number: int,
+        repeats: int,
+        collect_baseline: bool,
+        is_python: bool,
+        retain_out_file: bool,
+    ) -> tuple[tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]:
+        """Core invocation method for Callgrind collection.
+
+        Valgrind operates by effectively replacing the CPU with an emulated
+        version which allows it to instrument any code at the cost of severe
+        performance degradation. This has the practical effect that in order
+        to collect Callgrind statistics, a new process has to be created
+        running under `valgrind`. The steps for this process are:
+
+        1) Create a scratch directory.
+        2) Codegen a run script. (_ValgrindWrapper._construct_script)
+            Inside the run script:
+                * Validate that Python and torch match the parent process
+                * Validate that it is indeed running under valgrind
+                * Execute `setup` and warm up `stmt`
+                * Begin collecting stats
+                * Run the `stmt` loop
+                * Stop collecting stats
+        3) Parse the run results.
+        4) Cleanup the scratch directory.
+        """
+        working_dir = common._make_temp_dir(prefix="callgrind")
+        data_dir = os.path.join(working_dir, "data")
+        script_file = os.path.join(working_dir, "timer_callgrind.py")
+        callgrind_out = os.path.join(working_dir, "callgrind.out")
+        error_log = os.path.join(working_dir, "error.txt")
+        stat_log = os.path.join(working_dir, "callgrind_stat.txt")
+        stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log")
+
+        def run(args: list[str], **kwargs: Any) -> tuple[CompletedProcessType, str]:
+            # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/
+            f_stdout_stderr = open(stdout_stderr_log, "wb")
+            try:
+                invocation = subprocess.run(
+                    args,
+                    stdout=f_stdout_stderr,
+                    stderr=subprocess.STDOUT,
+                    **kwargs,
+                )
+                with open(stdout_stderr_log) as f:
+                    return invocation, f.read()
+            finally:
+                f_stdout_stderr.close()
+
+        try:
+            if is_python:
+                if self._bindings_module is not None:
+                    shutil.copy(
+                        self._bindings_module.__file__,
+                        os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1])
+                    )
+
+                script_file = os.path.join(working_dir, "timer_callgrind.py")
+                with open(script_file, "w") as f:
+                    f.write(self._construct_script(
+                        task_spec,
+                        globals=GlobalsBridge(globals, data_dir),
+                        number=number,
+                        repeats=repeats,
+                        collect_baseline=collect_baseline,
+                        error_log=error_log,
+                        stat_log=stat_log,
+                        bindings=self._bindings_module))
+
+                run_loop_cmd = ["python", script_file]
+            else:
+                assert not collect_baseline
+                run_loop_exec = cpp_jit.compile_callgrind_template(
+                    stmt=task_spec.stmt,
+                    setup=task_spec.setup,
+                    global_setup=task_spec.global_setup,
+                )
+                run_loop_cmd = [
+                    run_loop_exec,
+                    "--number", str(number),
+                    "--number-warmup", str(min(number, 10)),
+                    "--repeats", str(repeats),
+                    "--number-threads", str(task_spec.num_threads),
+                ]
+
+            valgrind_invocation, valgrind_invocation_output = run([
+                "valgrind",
+                "--tool=callgrind",
+                f"--callgrind-out-file={callgrind_out}",
+                "--dump-line=yes",
+                "--dump-instr=yes",
+                "--instr-atstart=yes",
+                "--collect-atstart=no",
+            ] + run_loop_cmd)
+
+            if valgrind_invocation.returncode:
+                error_report = ""
+                if os.path.exists(error_log):
+                    with open(error_log) as f:
+                        error_report = f.read()
+                if not error_report:
+                    error_report = "Unknown error.\n" + valgrind_invocation_output
+
+                raise OSError(f"Failed to collect callgrind profile:\n{error_report}")
+
+            def parse_output(fpath: str, inclusive: bool) -> FunctionCounts:
+                _annotate_invocation, annotate_invocation_output = run([
+                    "callgrind_annotate",
+                    f"--inclusive={'yes' if inclusive else 'no'}",
+                    "--threshold=100",
+                    "--show-percs=no",
+                    fpath
+                ], check=True)
+
+                total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS")
+                begin_pattern = re.compile(r"Ir\s+file:function")
+                function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$")
+
+                class ScanState(enum.Enum):
+                    SCANNING_FOR_TOTAL = 0
+                    SCANNING_FOR_START = 1
+                    PARSING = 2
+
+                scan_state = ScanState.SCANNING_FOR_TOTAL
+                fn_counts = []
+                for l in annotate_invocation_output.splitlines(keepends=False):
+                    if scan_state == ScanState.SCANNING_FOR_TOTAL:
+                        total_match = total_pattern.match(l)
+                        if total_match:
+                            program_totals = int(total_match.groups()[0].replace(",", ""))
+                            scan_state = ScanState.SCANNING_FOR_START
+
+                    elif scan_state == ScanState.SCANNING_FOR_START:
+                        if begin_pattern.match(l):
+                            scan_state = ScanState.PARSING
+
+                    else:
+                        assert scan_state == ScanState.PARSING
+                        fn_match = function_pattern.match(l)
+                        if fn_match:
+                            ir_str, file_function = fn_match.groups()
+                            ir = int(ir_str.replace(",", ""))
+                            if ir == program_totals:  # type: ignore[possibly-undefined]
+                                # Callgrind includes some top level red herring symbols when
+                                # a program dumps multiple profiles.
+                                continue
+                            fn_counts.append(FunctionCount(ir, file_function))
+
+                        elif re.match(r"-+", l):
+                            # Ignore heading separator lines.
+                            continue
+
+                        else:
+                            break
+
+                assert scan_state == ScanState.PARSING, f"Failed to parse {fpath}"
+                return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive)
+
+            def read_results(i: int) -> tuple[FunctionCounts, FunctionCounts, Optional[str]]:
+                if i == repeats and not collect_baseline:
+                    # Null baseline.
+                    return (
+                        FunctionCounts((), inclusive=True),
+                        FunctionCounts((), inclusive=False),
+                        None,
+                    )
+
+                fpath = f"{callgrind_out}.{i + 1}"  # Callgrind one-indexes files.
+                callgrind_out_contents: Optional[str] = None
+                if retain_out_file:
+                    with open(fpath) as f:
+                        callgrind_out_contents = f.read()
+
+                return (
+                    parse_output(fpath, inclusive=True),
+                    parse_output(fpath, inclusive=False),
+                    callgrind_out_contents
+                )
+
+            return tuple(read_results(i) for i in range(repeats + 1))
+        finally:
+            shutil.rmtree(working_dir)
+
+    @staticmethod
+    def _construct_script(
+        task_spec: common.TaskSpec,
+        globals: GlobalsBridge,
+        *,
+        number: int,
+        repeats: int,
+        collect_baseline: bool,
+        error_log: str,
+        stat_log: str,
+        bindings: Optional[CallgrindModuleType],
+    ) -> str:
+        def block_stmt(stmt: str, indent: int = 0) -> str:
+            """Partially unroll benchmark loop.
+
+            The naive template looks something like:
+                "for _ in range({number}): {stmt}"
+
+            However a loop in Python is surprisingly expensive, and significantly
+            increases the number of background Python instructions. So instead we
+            partially unroll the loops, with a block size of 100 chosen to keep
+            the instruction overhead from `range` low while also not ballooning
+            the size of the generated file.
+            """
+            block_size = 100
+            loop_count = number // block_size
+            if loop_count == 1:
+                # There is no point in having `for _ in range(1): ...` rather
+                # than just `...`, and this lets us save shave a few background
+                # instructions.
+                loop_count = 0
+            remainder = number - block_size * loop_count
+            blocked_stmt = ""
+
+            if loop_count:
+                unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4)
+                blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n"
+
+            if remainder:
+                blocked_stmt += "\n".join([stmt] * remainder)
+
+            return textwrap.indent(blocked_stmt, " " * indent)
+
+        pass_baseline = (
+            "callgrind_bindings._valgrind_toggle()\n"
+            f"{block_stmt('pass')}\n"
+            "callgrind_bindings._valgrind_toggle_and_dump_stats()"
+        )
+
+        return textwrap.dedent(r"""
+            import gc
+            import os
+            import pickle
+            import subprocess
+            import sys
+            import time
+
+            # Mitigate https://github.com/pytorch/pytorch/issues/37377
+            # which can sometimes cause the subprocess call to fail.
+            import numpy as np
+
+            import torch
+            torch.set_num_threads({num_threads})
+
+            {bindings_import}
+
+            PID = os.getpid()
+
+            def log_failure(msg):
+                with open({error_log_repr}, "wt") as f:
+                    f.write(msg)
+                sys.exit(1)
+
+            def check_result(completed_process):
+                if completed_process.returncode:
+                    log_failure(f"Command failed: {{' '.join(completed_process.args)}}")
+                return completed_process
+
+            # =============================================================================
+            # == Check that subprocess matches parent =====================================
+            # =============================================================================
+            if os.path.realpath(sys.executable) != "{parent_interpreter}":
+                log_failure(
+                    "Interpreter mismatch:\n"
+                    f"  {{os.path.realpath(sys.executable)}}\n    vs.\n  {parent_interpreter}"
+                )
+
+            if torch.__file__ != "{torch_file}":
+                log_failure(
+                    "PyTorch does not match expected file:\n"
+                    f"  {{torch.__file__}}\n    vs.\n  {torch_file}"
+                )
+
+            # =============================================================================
+            # == User specified setup =====================================================
+            # =============================================================================
+            # Load serialized globals
+            {load_globals}
+
+            # User setup str
+            {setup}
+
+            for _ in range({warmup_number}):
+            {indented_stmt}
+
+            # =============================================================================
+            # == Callgrind management =====================================================
+            # =============================================================================
+            with open("{stat_log}", "wb") as stat_file:
+                # If many instances of callgrind are running at once, the output of
+                # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE`
+                # to deadlock. So instead we use a file.
+                callgrind_stat = check_result(subprocess.run(
+                    ["callgrind_control", "--stat"],
+                    stdout=stat_file,
+                    stderr=subprocess.STDOUT,
+                ))
+
+            with open("{stat_log}", "rt") as stat_file:
+                stat_lines = stat_file.read().splitlines()
+
+            if f"PID {{PID}}: python {{__file__}}" not in stat_lines:
+                log_failure("Process does not appear to be running callgrind.")
+
+            gc.collect()
+            time.sleep(0.01)
+
+            # =============================================================================
+            # == User code block ==========================================================
+            # =============================================================================
+            for _ in range({repeats}):
+                callgrind_bindings._valgrind_toggle()
+            {blocked_stmt}
+                callgrind_bindings._valgrind_toggle_and_dump_stats()
+                gc.collect()
+
+            {baseline}
+        """).strip().format(
+            indented_stmt=textwrap.indent(task_spec.stmt, " " * 4),
+            blocked_stmt=block_stmt(task_spec.stmt, indent=4),
+            baseline=(pass_baseline if collect_baseline else ""),
+            number=number,
+            repeats=repeats,
+            load_globals=globals.construct(),
+            setup=task_spec.setup,
+            warmup_number=min(number, 10),
+            num_threads=task_spec.num_threads,
+            error_log_repr=repr(error_log),
+            stat_log=stat_log,
+            parent_interpreter=os.path.realpath(sys.executable),
+            torch_file=torch.__file__,
+            bindings_import=(
+                "import torch._C as callgrind_bindings" if bindings is None
+                else f"import {bindings.__name__} as callgrind_bindings"),
+        )
+
+
+CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None
+def wrapper_singleton() -> _ValgrindWrapper:
+    global CALLGRIND_SINGLETON
+    if CALLGRIND_SINGLETON is None:
+        CALLGRIND_SINGLETON = _ValgrindWrapper()
+    return CALLGRIND_SINGLETON
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
new file mode 100644
index 0000000000000000000000000000000000000000..d33dd30932aa86b8284cb93d0e29ec646e820197
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
@@ -0,0 +1,7157 @@
+/* -*- c -*-
+   ----------------------------------------------------------------
+
+   Notice that the following BSD-style license applies to this one
+   file (valgrind.h) only.  The rest of Valgrind is licensed under the
+   terms of the GNU General Public License, version 2, unless
+   otherwise indicated.  See the COPYING file in the source
+   distribution for details.
+
+   ----------------------------------------------------------------
+
+   This file is part of Valgrind, a dynamic binary instrumentation
+   framework.
+
+   Copyright (C) 2000-2017 Julian Seward.  All rights reserved.
+
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   1. Redistributions of source code must retain the above copyright
+      notice, this list of conditions and the following disclaimer.
+
+   2. The origin of this software must not be misrepresented; you must 
+      not claim that you wrote the original software.  If you use this 
+      software in a product, an acknowledgment in the product 
+      documentation would be appreciated but is not required.
+
+   3. Altered source versions must be plainly marked as such, and must
+      not be misrepresented as being the original software.
+
+   4. The name of the author may not be used to endorse or promote 
+      products derived from this software without specific prior written 
+      permission.
+
+   THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
+   OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+   WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+   ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+   DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+   DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
+   GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+   INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+   WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+   ----------------------------------------------------------------
+
+   Notice that the above BSD-style license applies to this one file
+   (valgrind.h) only.  The entire rest of Valgrind is licensed under
+   the terms of the GNU General Public License, version 2.  See the
+   COPYING file in the source distribution for details.
+
+   ---------------------------------------------------------------- 
+*/
+
+
+/* This file is for inclusion into client (your!) code.
+
+   You can use these macros to manipulate and query Valgrind's 
+   execution inside your own programs.
+
+   The resulting executables will still run without Valgrind, just a
+   little bit more slowly than they otherwise would, but otherwise
+   unchanged.  When not running on valgrind, each client request
+   consumes very few (eg. 7) instructions, so the resulting performance
+   loss is negligible unless you plan to execute client requests
+   millions of times per second.  Nevertheless, if that is still a
+   problem, you can compile with the NVALGRIND symbol defined (gcc
+   -DNVALGRIND) so that client requests are not even compiled in.  */
+
+#ifndef __VALGRIND_H
+#define __VALGRIND_H
+
+
+/* ------------------------------------------------------------------ */
+/* VERSION NUMBER OF VALGRIND                                         */
+/* ------------------------------------------------------------------ */
+
+/* Specify Valgrind's version number, so that user code can
+   conditionally compile based on our version number.  Note that these
+   were introduced at version 3.6 and so do not exist in version 3.5
+   or earlier.  The recommended way to use them to check for "version
+   X.Y or later" is (eg)
+
+#if defined(__VALGRIND_MAJOR__) && defined(__VALGRIND_MINOR__)   \
+    && (__VALGRIND_MAJOR__ > 3                                   \
+        || (__VALGRIND_MAJOR__ == 3 && __VALGRIND_MINOR__ >= 6))
+*/
+#define __VALGRIND_MAJOR__    3
+#define __VALGRIND_MINOR__    17
+
+
+#include 
+
+/* Nb: this file might be included in a file compiled with -ansi.  So
+   we can't use C++ style "//" comments nor the "asm" keyword (instead
+   use "__asm__"). */
+
+/* Derive some tags indicating what the target platform is.  Note
+   that in this file we're using the compiler's CPP symbols for
+   identifying architectures, which are different to the ones we use
+   within the rest of Valgrind.  Note, __powerpc__ is active for both
+   32 and 64-bit PPC, whereas __powerpc64__ is only active for the
+   latter (on Linux, that is).
+
+   Misc note: how to find out what's predefined in gcc by default:
+   gcc -Wp,-dM somefile.c
+*/
+#undef PLAT_x86_darwin
+#undef PLAT_amd64_darwin
+#undef PLAT_x86_win32
+#undef PLAT_amd64_win64
+#undef PLAT_x86_linux
+#undef PLAT_amd64_linux
+#undef PLAT_ppc32_linux
+#undef PLAT_ppc64be_linux
+#undef PLAT_ppc64le_linux
+#undef PLAT_arm_linux
+#undef PLAT_arm64_linux
+#undef PLAT_s390x_linux
+#undef PLAT_mips32_linux
+#undef PLAT_mips64_linux
+#undef PLAT_nanomips_linux
+#undef PLAT_x86_solaris
+#undef PLAT_amd64_solaris
+
+
+#if defined(__APPLE__) && defined(__i386__)
+#  define PLAT_x86_darwin 1
+#elif defined(__APPLE__) && defined(__x86_64__)
+#  define PLAT_amd64_darwin 1
+#elif (defined(__MINGW32__) && defined(__i386__)) \
+      || defined(__CYGWIN32__) \
+      || (defined(_WIN32) && defined(_M_IX86))
+#  define PLAT_x86_win32 1
+#elif (defined(__MINGW32__) && defined(__x86_64__)) \
+      || (defined(_WIN32) && defined(_M_X64))
+/* __MINGW32__ and _WIN32 are defined in 64 bit mode as well. */
+#  define PLAT_amd64_win64 1
+#elif defined(__linux__) && defined(__i386__)
+#  define PLAT_x86_linux 1
+#elif defined(__linux__) && defined(__x86_64__) && !defined(__ILP32__)
+#  define PLAT_amd64_linux 1
+#elif defined(__linux__) && defined(__powerpc__) && !defined(__powerpc64__)
+#  define PLAT_ppc32_linux 1
+#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF != 2
+/* Big Endian uses ELF version 1 */
+#  define PLAT_ppc64be_linux 1
+#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF == 2
+/* Little Endian uses ELF version 2 */
+#  define PLAT_ppc64le_linux 1
+#elif defined(__linux__) && defined(__arm__) && !defined(__aarch64__)
+#  define PLAT_arm_linux 1
+#elif defined(__linux__) && defined(__aarch64__) && !defined(__arm__)
+#  define PLAT_arm64_linux 1
+#elif defined(__linux__) && defined(__s390__) && defined(__s390x__)
+#  define PLAT_s390x_linux 1
+#elif defined(__linux__) && defined(__mips__) && (__mips==64)
+#  define PLAT_mips64_linux 1
+#elif defined(__linux__) && defined(__mips__) && (__mips==32)
+#  define PLAT_mips32_linux 1
+#elif defined(__linux__) && defined(__nanomips__)
+#  define PLAT_nanomips_linux 1
+#elif defined(__sun) && defined(__i386__)
+#  define PLAT_x86_solaris 1
+#elif defined(__sun) && defined(__x86_64__)
+#  define PLAT_amd64_solaris 1
+#else
+/* If we're not compiling for our target platform, don't generate
+   any inline asms.  */
+#  if !defined(NVALGRIND)
+#    define NVALGRIND 1
+#  endif
+#endif
+
+
+/* ------------------------------------------------------------------ */
+/* ARCHITECTURE SPECIFICS for SPECIAL INSTRUCTIONS.  There is nothing */
+/* in here of use to end-users -- skip to the next section.           */
+/* ------------------------------------------------------------------ */
+
+/*
+ * VALGRIND_DO_CLIENT_REQUEST(): a statement that invokes a Valgrind client
+ * request. Accepts both pointers and integers as arguments.
+ *
+ * VALGRIND_DO_CLIENT_REQUEST_STMT(): a statement that invokes a Valgrind
+ * client request that does not return a value.
+
+ * VALGRIND_DO_CLIENT_REQUEST_EXPR(): a C expression that invokes a Valgrind
+ * client request and whose value equals the client request result.  Accepts
+ * both pointers and integers as arguments.  Note that such calls are not
+ * necessarily pure functions -- they may have side effects.
+ */
+
+#define VALGRIND_DO_CLIENT_REQUEST(_zzq_rlval, _zzq_default,            \
+                                   _zzq_request, _zzq_arg1, _zzq_arg2,  \
+                                   _zzq_arg3, _zzq_arg4, _zzq_arg5)     \
+  do { (_zzq_rlval) = VALGRIND_DO_CLIENT_REQUEST_EXPR((_zzq_default),   \
+                        (_zzq_request), (_zzq_arg1), (_zzq_arg2),       \
+                        (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0)
+
+#define VALGRIND_DO_CLIENT_REQUEST_STMT(_zzq_request, _zzq_arg1,        \
+                           _zzq_arg2,  _zzq_arg3, _zzq_arg4, _zzq_arg5) \
+  do { (void) VALGRIND_DO_CLIENT_REQUEST_EXPR(0,                        \
+                    (_zzq_request), (_zzq_arg1), (_zzq_arg2),           \
+                    (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0)
+
+#if defined(NVALGRIND)
+
+/* Define NVALGRIND to completely remove the Valgrind magic sequence
+   from the compiled code (analogous to NDEBUG's effects on
+   assert()) */
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+      (_zzq_default)
+
+#else  /* ! NVALGRIND */
+
+/* The following defines the magic code sequences which the JITter
+   spots and handles magically.  Don't look too closely at them as
+   they will rot your brain.
+
+   The assembly code sequences for all architectures is in this one
+   file.  This is because this file must be stand-alone, and we don't
+   want to have multiple files.
+
+   For VALGRIND_DO_CLIENT_REQUEST, we must ensure that the default
+   value gets put in the return slot, so that everything works when
+   this is executed not under Valgrind.  Args are passed in a memory
+   block, and so there's no intrinsic limit to the number that could
+   be passed, but it's currently five.
+   
+   The macro args are: 
+      _zzq_rlval    result lvalue
+      _zzq_default  default value (result returned when running on real CPU)
+      _zzq_request  request code
+      _zzq_arg1..5  request params
+
+   The other two macros are used to support function wrapping, and are
+   a lot simpler.  VALGRIND_GET_NR_CONTEXT returns the value of the
+   guest's NRADDR pseudo-register and whatever other information is
+   needed to safely run the call original from the wrapper: on
+   ppc64-linux, the R2 value at the divert point is also needed.  This
+   information is abstracted into a user-visible type, OrigFn.
+
+   VALGRIND_CALL_NOREDIR_* behaves the same as the following on the
+   guest, but guarantees that the branch instruction will not be
+   redirected: x86: call *%eax, amd64: call *%rax, ppc32/ppc64:
+   branch-and-link-to-r11.  VALGRIND_CALL_NOREDIR is just text, not a
+   complete inline asm, since it needs to be combined with more magic
+   inline asm stuff to be useful.
+*/
+
+/* ----------------- x86-{linux,darwin,solaris} ---------------- */
+
+#if defined(PLAT_x86_linux)  ||  defined(PLAT_x86_darwin)  \
+    ||  (defined(PLAT_x86_win32) && defined(__GNUC__)) \
+    ||  defined(PLAT_x86_solaris)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     "roll $3,  %%edi ; roll $13, %%edi\n\t"      \
+                     "roll $29, %%edi ; roll $19, %%edi\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+  __extension__                                                   \
+  ({volatile unsigned int _zzq_args[6];                           \
+    volatile unsigned int _zzq_result;                            \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %EDX = client_request ( %EAX ) */         \
+                     "xchgl %%ebx,%%ebx"                          \
+                     : "=d" (_zzq_result)                         \
+                     : "a" (&_zzq_args[0]), "0" (_zzq_default)    \
+                     : "cc", "memory"                             \
+                    );                                            \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    volatile unsigned int __addr;                                 \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %EAX = guest_NRADDR */                    \
+                     "xchgl %%ecx,%%ecx"                          \
+                     : "=a" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory"                             \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_CALL_NOREDIR_EAX                                 \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* call-noredir *%EAX */                     \
+                     "xchgl %%edx,%%edx\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "xchgl %%edi,%%edi\n\t"                     \
+                     : : : "cc", "memory"                        \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_x86_linux || PLAT_x86_darwin || (PLAT_x86_win32 && __GNUC__)
+          || PLAT_x86_solaris */
+
+/* ------------------------- x86-Win32 ------------------------- */
+
+#if defined(PLAT_x86_win32) && !defined(__GNUC__)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#if defined(_MSC_VER)
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     __asm rol edi, 3  __asm rol edi, 13          \
+                     __asm rol edi, 29 __asm rol edi, 19
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+    valgrind_do_client_request_expr((uintptr_t)(_zzq_default),    \
+        (uintptr_t)(_zzq_request), (uintptr_t)(_zzq_arg1),        \
+        (uintptr_t)(_zzq_arg2), (uintptr_t)(_zzq_arg3),           \
+        (uintptr_t)(_zzq_arg4), (uintptr_t)(_zzq_arg5))
+
+static __inline uintptr_t
+valgrind_do_client_request_expr(uintptr_t _zzq_default, uintptr_t _zzq_request,
+                                uintptr_t _zzq_arg1, uintptr_t _zzq_arg2,
+                                uintptr_t _zzq_arg3, uintptr_t _zzq_arg4,
+                                uintptr_t _zzq_arg5)
+{
+    volatile uintptr_t _zzq_args[6];
+    volatile unsigned int _zzq_result;
+    _zzq_args[0] = (uintptr_t)(_zzq_request);
+    _zzq_args[1] = (uintptr_t)(_zzq_arg1);
+    _zzq_args[2] = (uintptr_t)(_zzq_arg2);
+    _zzq_args[3] = (uintptr_t)(_zzq_arg3);
+    _zzq_args[4] = (uintptr_t)(_zzq_arg4);
+    _zzq_args[5] = (uintptr_t)(_zzq_arg5);
+    __asm { __asm lea eax, _zzq_args __asm mov edx, _zzq_default
+            __SPECIAL_INSTRUCTION_PREAMBLE
+            /* %EDX = client_request ( %EAX ) */
+            __asm xchg ebx,ebx
+            __asm mov _zzq_result, edx
+    }
+    return _zzq_result;
+}
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    volatile unsigned int __addr;                                 \
+    __asm { __SPECIAL_INSTRUCTION_PREAMBLE                        \
+            /* %EAX = guest_NRADDR */                             \
+            __asm xchg ecx,ecx                                    \
+            __asm mov __addr, eax                                 \
+    }                                                             \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_CALL_NOREDIR_EAX ERROR
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm { __SPECIAL_INSTRUCTION_PREAMBLE                       \
+            __asm xchg edi,edi                                   \
+    }                                                            \
+ } while (0)
+
+#else
+#error Unsupported compiler.
+#endif
+
+#endif /* PLAT_x86_win32 */
+
+/* ----------------- amd64-{linux,darwin,solaris} --------------- */
+
+#if defined(PLAT_amd64_linux)  ||  defined(PLAT_amd64_darwin) \
+    ||  defined(PLAT_amd64_solaris) \
+    ||  (defined(PLAT_amd64_win64) && defined(__GNUC__))
+
+typedef
+   struct { 
+      unsigned long int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     "rolq $3,  %%rdi ; rolq $13, %%rdi\n\t"      \
+                     "rolq $61, %%rdi ; rolq $51, %%rdi\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+    __extension__                                                 \
+    ({ volatile unsigned long int _zzq_args[6];                   \
+    volatile unsigned long int _zzq_result;                       \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %RDX = client_request ( %RAX ) */         \
+                     "xchgq %%rbx,%%rbx"                          \
+                     : "=d" (_zzq_result)                         \
+                     : "a" (&_zzq_args[0]), "0" (_zzq_default)    \
+                     : "cc", "memory"                             \
+                    );                                            \
+    _zzq_result;                                                  \
+    })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    volatile unsigned long int __addr;                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %RAX = guest_NRADDR */                    \
+                     "xchgq %%rcx,%%rcx"                          \
+                     : "=a" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory"                             \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_CALL_NOREDIR_RAX                                 \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* call-noredir *%RAX */                     \
+                     "xchgq %%rdx,%%rdx\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "xchgq %%rdi,%%rdi\n\t"                     \
+                     : : : "cc", "memory"                        \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */
+
+/* ------------------------- amd64-Win64 ------------------------- */
+
+#if defined(PLAT_amd64_win64) && !defined(__GNUC__)
+
+#error Unsupported compiler.
+
+#endif /* PLAT_amd64_win64 */
+
+/* ------------------------ ppc32-linux ------------------------ */
+
+#if defined(PLAT_ppc32_linux)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                    "rlwinm 0,0,3,0,31  ; rlwinm 0,0,13,0,31\n\t" \
+                    "rlwinm 0,0,29,0,31 ; rlwinm 0,0,19,0,31\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+    __extension__                                                 \
+  ({         unsigned int  _zzq_args[6];                          \
+             unsigned int  _zzq_result;                           \
+             unsigned int* _zzq_ptr;                              \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+    _zzq_ptr = _zzq_args;                                         \
+    __asm__ volatile("mr 3,%1\n\t" /*default*/                    \
+                     "mr 4,%2\n\t" /*ptr*/                        \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = client_request ( %R4 ) */           \
+                     "or 1,1,1\n\t"                               \
+                     "mr %0,3"     /*result*/                     \
+                     : "=b" (_zzq_result)                         \
+                     : "b" (_zzq_default), "b" (_zzq_ptr)         \
+                     : "cc", "memory", "r3", "r4");               \
+    _zzq_result;                                                  \
+    })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned int __addr;                                          \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR */                     \
+                     "or 2,2,2\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                   \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir *%R11 */       \
+                     "or 3,3,3\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "or 5,5,5\n\t"                              \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_ppc32_linux */
+
+/* ------------------------ ppc64-linux ------------------------ */
+
+#if defined(PLAT_ppc64be_linux)
+
+typedef
+   struct { 
+      unsigned long int nraddr; /* where's the code? */
+      unsigned long int r2;  /* what tocptr do we need? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     "rotldi 0,0,3  ; rotldi 0,0,13\n\t"          \
+                     "rotldi 0,0,61 ; rotldi 0,0,51\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+  __extension__                                                   \
+  ({         unsigned long int  _zzq_args[6];                     \
+             unsigned long int  _zzq_result;                      \
+             unsigned long int* _zzq_ptr;                         \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+    _zzq_ptr = _zzq_args;                                         \
+    __asm__ volatile("mr 3,%1\n\t" /*default*/                    \
+                     "mr 4,%2\n\t" /*ptr*/                        \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = client_request ( %R4 ) */           \
+                     "or 1,1,1\n\t"                               \
+                     "mr %0,3"     /*result*/                     \
+                     : "=b" (_zzq_result)                         \
+                     : "b" (_zzq_default), "b" (_zzq_ptr)         \
+                     : "cc", "memory", "r3", "r4");               \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned long int __addr;                                     \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR */                     \
+                     "or 2,2,2\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR_GPR2 */                \
+                     "or 4,4,4\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->r2 = __addr;                                       \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                   \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir *%R11 */       \
+                     "or 3,3,3\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "or 5,5,5\n\t"                              \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_ppc64be_linux */
+
+#if defined(PLAT_ppc64le_linux)
+
+typedef
+   struct {
+      unsigned long int nraddr; /* where's the code? */
+      unsigned long int r2;     /* what tocptr do we need? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+                     "rotldi 0,0,3  ; rotldi 0,0,13\n\t"          \
+                     "rotldi 0,0,61 ; rotldi 0,0,51\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+  __extension__                                                   \
+  ({         unsigned long int  _zzq_args[6];                     \
+             unsigned long int  _zzq_result;                      \
+             unsigned long int* _zzq_ptr;                         \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+    _zzq_ptr = _zzq_args;                                         \
+    __asm__ volatile("mr 3,%1\n\t" /*default*/                    \
+                     "mr 4,%2\n\t" /*ptr*/                        \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = client_request ( %R4 ) */           \
+                     "or 1,1,1\n\t"                               \
+                     "mr %0,3"     /*result*/                     \
+                     : "=b" (_zzq_result)                         \
+                     : "b" (_zzq_default), "b" (_zzq_ptr)         \
+                     : "cc", "memory", "r3", "r4");               \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned long int __addr;                                     \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR */                     \
+                     "or 2,2,2\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %R3 = guest_NRADDR_GPR2 */                \
+                     "or 4,4,4\n\t"                               \
+                     "mr %0,3"                                    \
+                     : "=b" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->r2 = __addr;                                       \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                   \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir *%R12 */       \
+                     "or 3,3,3\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "or 5,5,5\n\t"                              \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_ppc64le_linux */
+
+/* ------------------------- arm-linux ------------------------- */
+
+#if defined(PLAT_arm_linux)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+            "mov r12, r12, ror #3  ; mov r12, r12, ror #13 \n\t"  \
+            "mov r12, r12, ror #29 ; mov r12, r12, ror #19 \n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+  __extension__                                                   \
+  ({volatile unsigned int  _zzq_args[6];                          \
+    volatile unsigned int  _zzq_result;                           \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+    __asm__ volatile("mov r3, %1\n\t" /*default*/                 \
+                     "mov r4, %2\n\t" /*ptr*/                     \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* R3 = client_request ( R4 ) */             \
+                     "orr r10, r10, r10\n\t"                      \
+                     "mov %0, r3"     /*result*/                  \
+                     : "=r" (_zzq_result)                         \
+                     : "r" (_zzq_default), "r" (&_zzq_args[0])    \
+                     : "cc","memory", "r3", "r4");                \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned int __addr;                                          \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* R3 = guest_NRADDR */                      \
+                     "orr r11, r11, r11\n\t"                      \
+                     "mov %0, r3"                                 \
+                     : "=r" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "r3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                    \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir *%R4 */        \
+                     "orr r12, r12, r12\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "orr r9, r9, r9\n\t"                        \
+                     : : : "cc", "memory"                        \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_arm_linux */
+
+/* ------------------------ arm64-linux ------------------------- */
+
+#if defined(PLAT_arm64_linux)
+
+typedef
+   struct { 
+      unsigned long int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE                            \
+            "ror x12, x12, #3  ;  ror x12, x12, #13 \n\t"         \
+            "ror x12, x12, #51 ;  ror x12, x12, #61 \n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+        _zzq_default, _zzq_request,                               \
+        _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+                                                                  \
+  __extension__                                                   \
+  ({volatile unsigned long int  _zzq_args[6];                     \
+    volatile unsigned long int  _zzq_result;                      \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+    __asm__ volatile("mov x3, %1\n\t" /*default*/                 \
+                     "mov x4, %2\n\t" /*ptr*/                     \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* X3 = client_request ( X4 ) */             \
+                     "orr x10, x10, x10\n\t"                      \
+                     "mov %0, x3"     /*result*/                  \
+                     : "=r" (_zzq_result)                         \
+                     : "r" ((unsigned long int)(_zzq_default)),   \
+                       "r" (&_zzq_args[0])                        \
+                     : "cc","memory", "x3", "x4");                \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    unsigned long int __addr;                                     \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* X3 = guest_NRADDR */                      \
+                     "orr x11, x11, x11\n\t"                      \
+                     "mov %0, x3"                                 \
+                     : "=r" (__addr)                              \
+                     :                                            \
+                     : "cc", "memory", "x3"                       \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                    \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* branch-and-link-to-noredir X8 */          \
+                     "orr x12, x12, x12\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "orr x9, x9, x9\n\t"                        \
+                     : : : "cc", "memory"                        \
+                    );                                           \
+ } while (0)
+
+#endif /* PLAT_arm64_linux */
+
+/* ------------------------ s390x-linux ------------------------ */
+
+#if defined(PLAT_s390x_linux)
+
+typedef
+  struct {
+     unsigned long int nraddr; /* where's the code? */
+  }
+  OrigFn;
+
+/* __SPECIAL_INSTRUCTION_PREAMBLE will be used to identify Valgrind specific
+ * code. This detection is implemented in platform specific toIR.c
+ * (e.g. VEX/priv/guest_s390_decoder.c).
+ */
+#define __SPECIAL_INSTRUCTION_PREAMBLE                           \
+                     "lr 15,15\n\t"                              \
+                     "lr 1,1\n\t"                                \
+                     "lr 2,2\n\t"                                \
+                     "lr 3,3\n\t"
+
+#define __CLIENT_REQUEST_CODE "lr 2,2\n\t"
+#define __GET_NR_CONTEXT_CODE "lr 3,3\n\t"
+#define __CALL_NO_REDIR_CODE  "lr 4,4\n\t"
+#define __VEX_INJECT_IR_CODE  "lr 5,5\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                         \
+       _zzq_default, _zzq_request,                               \
+       _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)    \
+  __extension__                                                  \
+ ({volatile unsigned long int _zzq_args[6];                      \
+   volatile unsigned long int _zzq_result;                       \
+   _zzq_args[0] = (unsigned long int)(_zzq_request);             \
+   _zzq_args[1] = (unsigned long int)(_zzq_arg1);                \
+   _zzq_args[2] = (unsigned long int)(_zzq_arg2);                \
+   _zzq_args[3] = (unsigned long int)(_zzq_arg3);                \
+   _zzq_args[4] = (unsigned long int)(_zzq_arg4);                \
+   _zzq_args[5] = (unsigned long int)(_zzq_arg5);                \
+   __asm__ volatile(/* r2 = args */                              \
+                    "lgr 2,%1\n\t"                               \
+                    /* r3 = default */                           \
+                    "lgr 3,%2\n\t"                               \
+                    __SPECIAL_INSTRUCTION_PREAMBLE               \
+                    __CLIENT_REQUEST_CODE                        \
+                    /* results = r3 */                           \
+                    "lgr %0, 3\n\t"                              \
+                    : "=d" (_zzq_result)                         \
+                    : "a" (&_zzq_args[0]), "0" (_zzq_default)    \
+                    : "cc", "2", "3", "memory"                   \
+                   );                                            \
+   _zzq_result;                                                  \
+ })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                      \
+ { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+   volatile unsigned long int __addr;                            \
+   __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                    __GET_NR_CONTEXT_CODE                        \
+                    "lgr %0, 3\n\t"                              \
+                    : "=a" (__addr)                              \
+                    :                                            \
+                    : "cc", "3", "memory"                        \
+                   );                                            \
+   _zzq_orig->nraddr = __addr;                                   \
+ }
+
+#define VALGRIND_CALL_NOREDIR_R1                                 \
+                    __SPECIAL_INSTRUCTION_PREAMBLE               \
+                    __CALL_NO_REDIR_CODE
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     __VEX_INJECT_IR_CODE);                      \
+ } while (0)
+
+#endif /* PLAT_s390x_linux */
+
+/* ------------------------- mips32-linux ---------------- */
+
+#if defined(PLAT_mips32_linux)
+
+typedef
+   struct { 
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+/* .word  0x342
+ * .word  0x742
+ * .word  0xC2
+ * .word  0x4C2*/
+#define __SPECIAL_INSTRUCTION_PREAMBLE          \
+                     "srl $0, $0, 13\n\t"       \
+                     "srl $0, $0, 29\n\t"       \
+                     "srl $0, $0, 3\n\t"        \
+                     "srl $0, $0, 19\n\t"
+                    
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+       _zzq_default, _zzq_request,                                \
+       _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)     \
+  __extension__                                                   \
+  ({ volatile unsigned int _zzq_args[6];                          \
+    volatile unsigned int _zzq_result;                            \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+        __asm__ volatile("move $11, %1\n\t" /*default*/           \
+                     "move $12, %2\n\t" /*ptr*/                   \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* T3 = client_request ( T4 ) */             \
+                     "or $13, $13, $13\n\t"                       \
+                     "move %0, $11\n\t"     /*result*/            \
+                     : "=r" (_zzq_result)                         \
+                     : "r" (_zzq_default), "r" (&_zzq_args[0])    \
+                     : "$11", "$12", "memory");                   \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                       \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                   \
+    volatile unsigned int __addr;                                 \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* %t9 = guest_NRADDR */                     \
+                     "or $14, $14, $14\n\t"                       \
+                     "move %0, $11"     /*result*/                \
+                     : "=r" (__addr)                              \
+                     :                                            \
+                     : "$11"                                      \
+                    );                                            \
+    _zzq_orig->nraddr = __addr;                                   \
+  }
+
+#define VALGRIND_CALL_NOREDIR_T9                                 \
+                     __SPECIAL_INSTRUCTION_PREAMBLE              \
+                     /* call-noredir *%t9 */                     \
+                     "or $15, $15, $15\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                 \
+ do {                                                            \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE              \
+                     "or $11, $11, $11\n\t"                      \
+                    );                                           \
+ } while (0)
+
+
+#endif /* PLAT_mips32_linux */
+
+/* ------------------------- mips64-linux ---------------- */
+
+#if defined(PLAT_mips64_linux)
+
+typedef
+   struct {
+      unsigned long nraddr; /* where's the code? */
+   }
+   OrigFn;
+
+/* dsll $0,$0, 3
+ * dsll $0,$0, 13
+ * dsll $0,$0, 29
+ * dsll $0,$0, 19*/
+#define __SPECIAL_INSTRUCTION_PREAMBLE                              \
+                     "dsll $0,$0, 3 ; dsll $0,$0,13\n\t"            \
+                     "dsll $0,$0,29 ; dsll $0,$0,19\n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                            \
+       _zzq_default, _zzq_request,                                  \
+       _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)       \
+  __extension__                                                     \
+  ({ volatile unsigned long int _zzq_args[6];                       \
+    volatile unsigned long int _zzq_result;                         \
+    _zzq_args[0] = (unsigned long int)(_zzq_request);               \
+    _zzq_args[1] = (unsigned long int)(_zzq_arg1);                  \
+    _zzq_args[2] = (unsigned long int)(_zzq_arg2);                  \
+    _zzq_args[3] = (unsigned long int)(_zzq_arg3);                  \
+    _zzq_args[4] = (unsigned long int)(_zzq_arg4);                  \
+    _zzq_args[5] = (unsigned long int)(_zzq_arg5);                  \
+        __asm__ volatile("move $11, %1\n\t" /*default*/             \
+                         "move $12, %2\n\t" /*ptr*/                 \
+                         __SPECIAL_INSTRUCTION_PREAMBLE             \
+                         /* $11 = client_request ( $12 ) */         \
+                         "or $13, $13, $13\n\t"                     \
+                         "move %0, $11\n\t"     /*result*/          \
+                         : "=r" (_zzq_result)                       \
+                         : "r" (_zzq_default), "r" (&_zzq_args[0])  \
+                         : "$11", "$12", "memory");                 \
+    _zzq_result;                                                    \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                         \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                     \
+    volatile unsigned long int __addr;                              \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     /* $11 = guest_NRADDR */                       \
+                     "or $14, $14, $14\n\t"                         \
+                     "move %0, $11"     /*result*/                  \
+                     : "=r" (__addr)                                \
+                     :                                              \
+                     : "$11");                                      \
+    _zzq_orig->nraddr = __addr;                                     \
+  }
+
+#define VALGRIND_CALL_NOREDIR_T9                                    \
+                     __SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     /* call-noredir $25 */                         \
+                     "or $15, $15, $15\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                    \
+ do {                                                               \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     "or $11, $11, $11\n\t"                         \
+                    );                                              \
+ } while (0)
+
+#endif /* PLAT_mips64_linux */
+
+#if defined(PLAT_nanomips_linux)
+
+typedef
+   struct {
+      unsigned int nraddr; /* where's the code? */
+   }
+   OrigFn;
+/*
+   8000 c04d  srl  zero, zero, 13
+   8000 c05d  srl  zero, zero, 29
+   8000 c043  srl  zero, zero,  3
+   8000 c053  srl  zero, zero, 19
+*/
+
+#define __SPECIAL_INSTRUCTION_PREAMBLE "srl[32] $zero, $zero, 13 \n\t" \
+                                       "srl[32] $zero, $zero, 29 \n\t" \
+                                       "srl[32] $zero, $zero, 3  \n\t" \
+                                       "srl[32] $zero, $zero, 19 \n\t"
+
+#define VALGRIND_DO_CLIENT_REQUEST_EXPR(                          \
+       _zzq_default, _zzq_request,                                \
+       _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5)     \
+  __extension__                                                   \
+  ({ volatile unsigned int _zzq_args[6];                          \
+    volatile unsigned int _zzq_result;                            \
+    _zzq_args[0] = (unsigned int)(_zzq_request);                  \
+    _zzq_args[1] = (unsigned int)(_zzq_arg1);                     \
+    _zzq_args[2] = (unsigned int)(_zzq_arg2);                     \
+    _zzq_args[3] = (unsigned int)(_zzq_arg3);                     \
+    _zzq_args[4] = (unsigned int)(_zzq_arg4);                     \
+    _zzq_args[5] = (unsigned int)(_zzq_arg5);                     \
+    __asm__ volatile("move $a7, %1\n\t" /* default */             \
+                     "move $t0, %2\n\t" /* ptr */                 \
+                     __SPECIAL_INSTRUCTION_PREAMBLE               \
+                     /* $a7 = client_request( $t0 ) */            \
+                     "or[32] $t0, $t0, $t0\n\t"                   \
+                     "move %0, $a7\n\t"     /* result */          \
+                     : "=r" (_zzq_result)                         \
+                     : "r" (_zzq_default), "r" (&_zzq_args[0])    \
+                     : "$a7", "$t0", "memory");                   \
+    _zzq_result;                                                  \
+  })
+
+#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval)                         \
+  { volatile OrigFn* _zzq_orig = &(_zzq_rlval);                     \
+    volatile unsigned long int __addr;                              \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     /* $a7 = guest_NRADDR */                       \
+                     "or[32] $t1, $t1, $t1\n\t"                     \
+                     "move %0, $a7"     /*result*/                  \
+                     : "=r" (__addr)                                \
+                     :                                              \
+                     : "$a7");                                      \
+    _zzq_orig->nraddr = __addr;                                     \
+  }
+
+#define VALGRIND_CALL_NOREDIR_T9                                    \
+                     __SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     /* call-noredir $25 */                         \
+                     "or[32] $t2, $t2, $t2\n\t"
+
+#define VALGRIND_VEX_INJECT_IR()                                    \
+ do {                                                               \
+    __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE                 \
+                     "or[32] $t3, $t3, $t3\n\t"                     \
+                    );                                              \
+ } while (0)
+
+#endif
+/* Insert assembly code for other platforms here... */
+
+#endif /* NVALGRIND */
+
+
+/* ------------------------------------------------------------------ */
+/* PLATFORM SPECIFICS for FUNCTION WRAPPING.  This is all very        */
+/* ugly.  It's the least-worst tradeoff I can think of.               */
+/* ------------------------------------------------------------------ */
+
+/* This section defines magic (a.k.a appalling-hack) macros for doing
+   guaranteed-no-redirection macros, so as to get from function
+   wrappers to the functions they are wrapping.  The whole point is to
+   construct standard call sequences, but to do the call itself with a
+   special no-redirect call pseudo-instruction that the JIT
+   understands and handles specially.  This section is long and
+   repetitious, and I can't see a way to make it shorter.
+
+   The naming scheme is as follows:
+
+      CALL_FN_{W,v}_{v,W,WW,WWW,WWWW,5W,6W,7W,etc}
+
+   'W' stands for "word" and 'v' for "void".  Hence there are
+   different macros for calling arity 0, 1, 2, 3, 4, etc, functions,
+   and for each, the possibility of returning a word-typed result, or
+   no result.
+*/
+
+/* Use these to write the name of your wrapper.  NOTE: duplicates
+   VG_WRAP_FUNCTION_Z{U,Z} in pub_tool_redir.h.  NOTE also: inserts
+   the default behaviour equivalance class tag "0000" into the name.
+   See pub_tool_redir.h for details -- normally you don't need to
+   think about this, though. */
+
+/* Use an extra level of macroisation so as to ensure the soname/fnname
+   args are fully macro-expanded before pasting them together. */
+#define VG_CONCAT4(_aa,_bb,_cc,_dd) _aa##_bb##_cc##_dd
+
+#define I_WRAP_SONAME_FNNAME_ZU(soname,fnname)                    \
+   VG_CONCAT4(_vgw00000ZU_,soname,_,fnname)
+
+#define I_WRAP_SONAME_FNNAME_ZZ(soname,fnname)                    \
+   VG_CONCAT4(_vgw00000ZZ_,soname,_,fnname)
+
+/* Use this macro from within a wrapper function to collect the
+   context (address and possibly other info) of the original function.
+   Once you have that you can then use it in one of the CALL_FN_
+   macros.  The type of the argument _lval is OrigFn. */
+#define VALGRIND_GET_ORIG_FN(_lval)  VALGRIND_GET_NR_CONTEXT(_lval)
+
+/* Also provide end-user facilities for function replacement, rather
+   than wrapping.  A replacement function differs from a wrapper in
+   that it has no way to get hold of the original function being
+   called, and hence no way to call onwards to it.  In a replacement
+   function, VALGRIND_GET_ORIG_FN always returns zero. */
+
+#define I_REPLACE_SONAME_FNNAME_ZU(soname,fnname)                 \
+   VG_CONCAT4(_vgr00000ZU_,soname,_,fnname)
+
+#define I_REPLACE_SONAME_FNNAME_ZZ(soname,fnname)                 \
+   VG_CONCAT4(_vgr00000ZZ_,soname,_,fnname)
+
+/* Derivatives of the main macros below, for calling functions
+   returning void. */
+
+#define CALL_FN_v_v(fnptr)                                        \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_v(_junk,fnptr); } while (0)
+
+#define CALL_FN_v_W(fnptr, arg1)                                  \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_W(_junk,fnptr,arg1); } while (0)
+
+#define CALL_FN_v_WW(fnptr, arg1,arg2)                            \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_WW(_junk,fnptr,arg1,arg2); } while (0)
+
+#define CALL_FN_v_WWW(fnptr, arg1,arg2,arg3)                      \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_WWW(_junk,fnptr,arg1,arg2,arg3); } while (0)
+
+#define CALL_FN_v_WWWW(fnptr, arg1,arg2,arg3,arg4)                \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_WWWW(_junk,fnptr,arg1,arg2,arg3,arg4); } while (0)
+
+#define CALL_FN_v_5W(fnptr, arg1,arg2,arg3,arg4,arg5)             \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_5W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5); } while (0)
+
+#define CALL_FN_v_6W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6)        \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_6W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6); } while (0)
+
+#define CALL_FN_v_7W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6,arg7)   \
+   do { volatile unsigned long _junk;                             \
+        CALL_FN_W_7W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6,arg7); } while (0)
+
+/* ----------------- x86-{linux,darwin,solaris} ---------------- */
+
+#if defined(PLAT_x86_linux)  ||  defined(PLAT_x86_darwin) \
+    ||  defined(PLAT_x86_solaris)
+
+/* These regs are trashed by the hidden call.  No need to mention eax
+   as gcc can already see that, plus causes gcc to bomb. */
+#define __CALLER_SAVED_REGS /*"eax"*/ "ecx", "edx"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "movl %%esp,%%edi\n\t"               \
+      "andl $0xfffffff0,%%esp\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "movl %%edi,%%esp\n\t"
+
+/* These CALL_FN_ macros assume that on x86-linux, sizeof(unsigned
+   long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $12, %%esp\n\t"                                    \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $8, %%esp\n\t"                                     \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $4, %%esp\n\t"                                     \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $12, %%esp\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $8, %%esp\n\t"                                     \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $4, %%esp\n\t"                                     \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $12, %%esp\n\t"                                    \
+         "pushl 36(%%eax)\n\t"                                    \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $8, %%esp\n\t"                                     \
+         "pushl 40(%%eax)\n\t"                                    \
+         "pushl 36(%%eax)\n\t"                                    \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "subl $4, %%esp\n\t"                                     \
+         "pushl 44(%%eax)\n\t"                                    \
+         "pushl 40(%%eax)\n\t"                                    \
+         "pushl 36(%%eax)\n\t"                                    \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "pushl 48(%%eax)\n\t"                                    \
+         "pushl 44(%%eax)\n\t"                                    \
+         "pushl 40(%%eax)\n\t"                                    \
+         "pushl 36(%%eax)\n\t"                                    \
+         "pushl 32(%%eax)\n\t"                                    \
+         "pushl 28(%%eax)\n\t"                                    \
+         "pushl 24(%%eax)\n\t"                                    \
+         "pushl 20(%%eax)\n\t"                                    \
+         "pushl 16(%%eax)\n\t"                                    \
+         "pushl 12(%%eax)\n\t"                                    \
+         "pushl 8(%%eax)\n\t"                                     \
+         "pushl 4(%%eax)\n\t"                                     \
+         "movl (%%eax), %%eax\n\t"  /* target->%eax */            \
+         VALGRIND_CALL_NOREDIR_EAX                                \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=a" (_res)                                  \
+         : /*in*/    "a" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_x86_linux || PLAT_x86_darwin || PLAT_x86_solaris */
+
+/* ---------------- amd64-{linux,darwin,solaris} --------------- */
+
+#if defined(PLAT_amd64_linux)  ||  defined(PLAT_amd64_darwin) \
+    ||  defined(PLAT_amd64_solaris)
+
+/* ARGREGS: rdi rsi rdx rcx r8 r9 (the rest on stack in R-to-L order) */
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS /*"rax",*/ "rcx", "rdx", "rsi",       \
+                            "rdi", "r8", "r9", "r10", "r11"
+
+/* This is all pretty complex.  It's so as to make stack unwinding
+   work reliably.  See bug 243270.  The basic problem is the sub and
+   add of 128 of %rsp in all of the following macros.  If gcc believes
+   the CFA is in %rsp, then unwinding may fail, because what's at the
+   CFA is not what gcc "expected" when it constructs the CFIs for the
+   places where the macros are instantiated.
+
+   But we can't just add a CFI annotation to increase the CFA offset
+   by 128, to match the sub of 128 from %rsp, because we don't know
+   whether gcc has chosen %rsp as the CFA at that point, or whether it
+   has chosen some other register (eg, %rbp).  In the latter case,
+   adding a CFI annotation to change the CFA offset is simply wrong.
+
+   So the solution is to get hold of the CFA using
+   __builtin_dwarf_cfa(), put it in a known register, and add a
+   CFI annotation to say what the register is.  We choose %rbp for
+   this (perhaps perversely), because:
+
+   (1) %rbp is already subject to unwinding.  If a new register was
+       chosen then the unwinder would have to unwind it in all stack
+       traces, which is expensive, and
+
+   (2) %rbp is already subject to precise exception updates in the
+       JIT.  If a new register was chosen, we'd have to have precise
+       exceptions for it too, which reduces performance of the
+       generated code.
+
+   However .. one extra complication.  We can't just whack the result
+   of __builtin_dwarf_cfa() into %rbp and then add %rbp to the
+   list of trashed registers at the end of the inline assembly
+   fragments; gcc won't allow %rbp to appear in that list.  Hence
+   instead we need to stash %rbp in %r15 for the duration of the asm,
+   and say that %r15 is trashed instead.  gcc seems happy to go with
+   that.
+
+   Oh .. and this all needs to be conditionalised so that it is
+   unchanged from before this commit, when compiled with older gccs
+   that don't support __builtin_dwarf_cfa.  Furthermore, since
+   this header file is freestanding, it has to be independent of
+   config.h, and so the following conditionalisation cannot depend on
+   configure time checks.
+
+   Although it's not clear from
+   'defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)',
+   this expression excludes Darwin.
+   .cfi directives in Darwin assembly appear to be completely
+   different and I haven't investigated how they work.
+
+   For even more entertainment value, note we have to use the
+   completely undocumented __builtin_dwarf_cfa(), which appears to
+   really compute the CFA, whereas __builtin_frame_address(0) claims
+   to but actually doesn't.  See
+   https://bugs.kde.org/show_bug.cgi?id=243270#c47
+*/
+#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)
+#  define __FRAME_POINTER                                         \
+      ,"r"(__builtin_dwarf_cfa())
+#  define VALGRIND_CFI_PROLOGUE                                   \
+      "movq %%rbp, %%r15\n\t"                                     \
+      "movq %2, %%rbp\n\t"                                        \
+      ".cfi_remember_state\n\t"                                   \
+      ".cfi_def_cfa rbp, 0\n\t"
+#  define VALGRIND_CFI_EPILOGUE                                   \
+      "movq %%r15, %%rbp\n\t"                                     \
+      ".cfi_restore_state\n\t"
+#else
+#  define __FRAME_POINTER
+#  define VALGRIND_CFI_PROLOGUE
+#  define VALGRIND_CFI_EPILOGUE
+#endif
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "movq %%rsp,%%r14\n\t"               \
+      "andq $0xfffffffffffffff0,%%rsp\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "movq %%r14,%%rsp\n\t"
+
+/* These CALL_FN_ macros assume that on amd64-linux, sizeof(unsigned
+   long) == 8. */
+
+/* NB 9 Sept 07.  There is a nasty kludge here in all these CALL_FN_
+   macros.  In order not to trash the stack redzone, we need to drop
+   %rsp by 128 before the hidden call, and restore afterwards.  The
+   nastyness is that it is only by luck that the stack still appears
+   to be unwindable during the hidden call - since then the behaviour
+   of any routine using this macro does not match what the CFI data
+   says.  Sigh.
+
+   Why is this important?  Imagine that a wrapper has a stack
+   allocated local, and passes to the hidden call, a pointer to it.
+   Because gcc does not know about the hidden call, it may allocate
+   that local in the redzone.  Unfortunately the hidden call may then
+   trash it before it comes to use it.  So we must step clear of the
+   redzone, for the duration of the hidden call, to make it safe.
+
+   Probably the same problem afflicts the other redzone-style ABIs too
+   (ppc64-linux); but for those, the stack is
+   self describing (none of this CFI nonsense) so at least messing
+   with the stack pointer doesn't give a danger of non-unwindable
+   stack. */
+
+#define CALL_FN_W_v(lval, orig)                                        \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[1];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                                  \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[2];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                            \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[3];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                      \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[4];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)                \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[5];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)             \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[6];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)        \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[7];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,        \
+                                 arg7)                                 \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[8];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $136,%%rsp\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,        \
+                                 arg7,arg8)                            \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[9];                               \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,        \
+                                 arg7,arg8,arg9)                       \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[10];                              \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      _argvec[9] = (unsigned long)(arg9);                              \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $136,%%rsp\n\t"                                         \
+         "pushq 72(%%rax)\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,       \
+                                  arg7,arg8,arg9,arg10)                \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[11];                              \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      _argvec[9] = (unsigned long)(arg9);                              \
+      _argvec[10] = (unsigned long)(arg10);                            \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "pushq 80(%%rax)\n\t"                                         \
+         "pushq 72(%%rax)\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,       \
+                                  arg7,arg8,arg9,arg10,arg11)          \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[12];                              \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      _argvec[9] = (unsigned long)(arg9);                              \
+      _argvec[10] = (unsigned long)(arg10);                            \
+      _argvec[11] = (unsigned long)(arg11);                            \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $136,%%rsp\n\t"                                         \
+         "pushq 88(%%rax)\n\t"                                         \
+         "pushq 80(%%rax)\n\t"                                         \
+         "pushq 72(%%rax)\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,       \
+                                arg7,arg8,arg9,arg10,arg11,arg12)      \
+   do {                                                                \
+      volatile OrigFn        _orig = (orig);                           \
+      volatile unsigned long _argvec[13];                              \
+      volatile unsigned long _res;                                     \
+      _argvec[0] = (unsigned long)_orig.nraddr;                        \
+      _argvec[1] = (unsigned long)(arg1);                              \
+      _argvec[2] = (unsigned long)(arg2);                              \
+      _argvec[3] = (unsigned long)(arg3);                              \
+      _argvec[4] = (unsigned long)(arg4);                              \
+      _argvec[5] = (unsigned long)(arg5);                              \
+      _argvec[6] = (unsigned long)(arg6);                              \
+      _argvec[7] = (unsigned long)(arg7);                              \
+      _argvec[8] = (unsigned long)(arg8);                              \
+      _argvec[9] = (unsigned long)(arg9);                              \
+      _argvec[10] = (unsigned long)(arg10);                            \
+      _argvec[11] = (unsigned long)(arg11);                            \
+      _argvec[12] = (unsigned long)(arg12);                            \
+      __asm__ volatile(                                                \
+         VALGRIND_CFI_PROLOGUE                                         \
+         VALGRIND_ALIGN_STACK                                          \
+         "subq $128,%%rsp\n\t"                                         \
+         "pushq 96(%%rax)\n\t"                                         \
+         "pushq 88(%%rax)\n\t"                                         \
+         "pushq 80(%%rax)\n\t"                                         \
+         "pushq 72(%%rax)\n\t"                                         \
+         "pushq 64(%%rax)\n\t"                                         \
+         "pushq 56(%%rax)\n\t"                                         \
+         "movq 48(%%rax), %%r9\n\t"                                    \
+         "movq 40(%%rax), %%r8\n\t"                                    \
+         "movq 32(%%rax), %%rcx\n\t"                                   \
+         "movq 24(%%rax), %%rdx\n\t"                                   \
+         "movq 16(%%rax), %%rsi\n\t"                                   \
+         "movq 8(%%rax), %%rdi\n\t"                                    \
+         "movq (%%rax), %%rax\n\t"  /* target->%rax */                 \
+         VALGRIND_CALL_NOREDIR_RAX                                     \
+         VALGRIND_RESTORE_STACK                                        \
+         VALGRIND_CFI_EPILOGUE                                         \
+         : /*out*/   "=a" (_res)                                       \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER                 \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \
+      );                                                               \
+      lval = (__typeof__(lval)) _res;                                  \
+   } while (0)
+
+#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */
+
+/* ------------------------ ppc32-linux ------------------------ */
+
+#if defined(PLAT_ppc32_linux)
+
+/* This is useful for finding out about the on-stack stuff:
+
+   extern int f9  ( int,int,int,int,int,int,int,int,int );
+   extern int f10 ( int,int,int,int,int,int,int,int,int,int );
+   extern int f11 ( int,int,int,int,int,int,int,int,int,int,int );
+   extern int f12 ( int,int,int,int,int,int,int,int,int,int,int,int );
+
+   int g9 ( void ) {
+      return f9(11,22,33,44,55,66,77,88,99);
+   }
+   int g10 ( void ) {
+      return f10(11,22,33,44,55,66,77,88,99,110);
+   }
+   int g11 ( void ) {
+      return f11(11,22,33,44,55,66,77,88,99,110,121);
+   }
+   int g12 ( void ) {
+      return f12(11,22,33,44,55,66,77,88,99,110,121,132);
+   }
+*/
+
+/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS                                       \
+   "lr", "ctr", "xer",                                            \
+   "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7",        \
+   "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10",   \
+   "r11", "r12", "r13"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "mr 28,1\n\t"                        \
+      "rlwinm 1,1,0,0,27\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mr 1,28\n\t"
+
+/* These CALL_FN_ macros assume that on ppc32-linux, 
+   sizeof(unsigned long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      _argvec[9] = (unsigned long)arg9;                           \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "addi 1,1,-16\n\t"                                       \
+         /* arg9 */                                               \
+         "lwz 3,36(11)\n\t"                                       \
+         "stw 3,8(1)\n\t"                                         \
+         /* args1-8 */                                            \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      _argvec[9] = (unsigned long)arg9;                           \
+      _argvec[10] = (unsigned long)arg10;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "addi 1,1,-16\n\t"                                       \
+         /* arg10 */                                              \
+         "lwz 3,40(11)\n\t"                                       \
+         "stw 3,12(1)\n\t"                                        \
+         /* arg9 */                                               \
+         "lwz 3,36(11)\n\t"                                       \
+         "stw 3,8(1)\n\t"                                         \
+         /* args1-8 */                                            \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11)     \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      _argvec[9] = (unsigned long)arg9;                           \
+      _argvec[10] = (unsigned long)arg10;                         \
+      _argvec[11] = (unsigned long)arg11;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "addi 1,1,-32\n\t"                                       \
+         /* arg11 */                                              \
+         "lwz 3,44(11)\n\t"                                       \
+         "stw 3,16(1)\n\t"                                        \
+         /* arg10 */                                              \
+         "lwz 3,40(11)\n\t"                                       \
+         "stw 3,12(1)\n\t"                                        \
+         /* arg9 */                                               \
+         "lwz 3,36(11)\n\t"                                       \
+         "stw 3,8(1)\n\t"                                         \
+         /* args1-8 */                                            \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                arg7,arg8,arg9,arg10,arg11,arg12) \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)arg1;                           \
+      _argvec[2] = (unsigned long)arg2;                           \
+      _argvec[3] = (unsigned long)arg3;                           \
+      _argvec[4] = (unsigned long)arg4;                           \
+      _argvec[5] = (unsigned long)arg5;                           \
+      _argvec[6] = (unsigned long)arg6;                           \
+      _argvec[7] = (unsigned long)arg7;                           \
+      _argvec[8] = (unsigned long)arg8;                           \
+      _argvec[9] = (unsigned long)arg9;                           \
+      _argvec[10] = (unsigned long)arg10;                         \
+      _argvec[11] = (unsigned long)arg11;                         \
+      _argvec[12] = (unsigned long)arg12;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "addi 1,1,-32\n\t"                                       \
+         /* arg12 */                                              \
+         "lwz 3,48(11)\n\t"                                       \
+         "stw 3,20(1)\n\t"                                        \
+         /* arg11 */                                              \
+         "lwz 3,44(11)\n\t"                                       \
+         "stw 3,16(1)\n\t"                                        \
+         /* arg10 */                                              \
+         "lwz 3,40(11)\n\t"                                       \
+         "stw 3,12(1)\n\t"                                        \
+         /* arg9 */                                               \
+         "lwz 3,36(11)\n\t"                                       \
+         "stw 3,8(1)\n\t"                                         \
+         /* args1-8 */                                            \
+         "lwz 3,4(11)\n\t"   /* arg1->r3 */                       \
+         "lwz 4,8(11)\n\t"                                        \
+         "lwz 5,12(11)\n\t"                                       \
+         "lwz 6,16(11)\n\t"  /* arg4->r6 */                       \
+         "lwz 7,20(11)\n\t"                                       \
+         "lwz 8,24(11)\n\t"                                       \
+         "lwz 9,28(11)\n\t"                                       \
+         "lwz 10,32(11)\n\t" /* arg8->r10 */                      \
+         "lwz 11,0(11)\n\t"  /* target->r11 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         VALGRIND_RESTORE_STACK                                   \
+         "mr %0,3"                                                \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_ppc32_linux */
+
+/* ------------------------ ppc64-linux ------------------------ */
+
+#if defined(PLAT_ppc64be_linux)
+
+/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS                                       \
+   "lr", "ctr", "xer",                                            \
+   "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7",        \
+   "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10",         \
+   "r11", "r12", "r13"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "mr 28,1\n\t"                        \
+      "rldicr 1,1,0,59\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mr 1,28\n\t"
+
+/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned
+   long) == 8. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+0];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1] = (unsigned long)_orig.r2;                       \
+      _argvec[2] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+1];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+2];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+3];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+4];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+5];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+6];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+7];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+8];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+9];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-128\n\t"  /* expand stack frame */            \
+         /* arg9 */                                               \
+         "ld  3,72(11)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* args1-8 */                                            \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+10];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-128\n\t"  /* expand stack frame */            \
+         /* arg10 */                                              \
+         "ld  3,80(11)\n\t"                                       \
+         "std 3,120(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(11)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* args1-8 */                                            \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11)     \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+11];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      _argvec[2+11] = (unsigned long)arg11;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-144\n\t"  /* expand stack frame */            \
+         /* arg11 */                                              \
+         "ld  3,88(11)\n\t"                                       \
+         "std 3,128(1)\n\t"                                       \
+         /* arg10 */                                              \
+         "ld  3,80(11)\n\t"                                       \
+         "std 3,120(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(11)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* args1-8 */                                            \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                arg7,arg8,arg9,arg10,arg11,arg12) \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+12];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      _argvec[2+11] = (unsigned long)arg11;                       \
+      _argvec[2+12] = (unsigned long)arg12;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 11,%1\n\t"                                           \
+         "std 2,-16(11)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(11)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-144\n\t"  /* expand stack frame */            \
+         /* arg12 */                                              \
+         "ld  3,96(11)\n\t"                                       \
+         "std 3,136(1)\n\t"                                       \
+         /* arg11 */                                              \
+         "ld  3,88(11)\n\t"                                       \
+         "std 3,128(1)\n\t"                                       \
+         /* arg10 */                                              \
+         "ld  3,80(11)\n\t"                                       \
+         "std 3,120(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(11)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* args1-8 */                                            \
+         "ld   3, 8(11)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(11)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(11)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(11)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(11)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(11)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(11)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(11)\n\t" /* arg8->r10 */                     \
+         "ld  11, 0(11)\n\t"  /* target->r11 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11                  \
+         "mr 11,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(11)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_ppc64be_linux */
+
+/* ------------------------- ppc64le-linux ----------------------- */
+#if defined(PLAT_ppc64le_linux)
+
+/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS                                       \
+   "lr", "ctr", "xer",                                            \
+   "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7",        \
+   "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10",         \
+   "r11", "r12", "r13"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+#define VALGRIND_ALIGN_STACK               \
+      "mr 28,1\n\t"                        \
+      "rldicr 1,1,0,59\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mr 1,28\n\t"
+
+/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned
+   long) == 8. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+0];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1] = (unsigned long)_orig.r2;                       \
+      _argvec[2] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+1];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+2];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+3];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+4];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+5];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+6];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+7];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+8];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+9];                        \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-128\n\t"  /* expand stack frame */            \
+         /* arg9 */                                               \
+         "ld  3,72(12)\n\t"                                       \
+         "std 3,96(1)\n\t"                                        \
+         /* args1-8 */                                            \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+10];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-128\n\t"  /* expand stack frame */            \
+         /* arg10 */                                              \
+         "ld  3,80(12)\n\t"                                       \
+         "std 3,104(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(12)\n\t"                                       \
+         "std 3,96(1)\n\t"                                        \
+         /* args1-8 */                                            \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11)     \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+11];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      _argvec[2+11] = (unsigned long)arg11;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-144\n\t"  /* expand stack frame */            \
+         /* arg11 */                                              \
+         "ld  3,88(12)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* arg10 */                                              \
+         "ld  3,80(12)\n\t"                                       \
+         "std 3,104(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(12)\n\t"                                       \
+         "std 3,96(1)\n\t"                                        \
+         /* args1-8 */                                            \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                arg7,arg8,arg9,arg10,arg11,arg12) \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3+12];                       \
+      volatile unsigned long _res;                                \
+      /* _argvec[0] holds current r2 across the call */           \
+      _argvec[1]   = (unsigned long)_orig.r2;                     \
+      _argvec[2]   = (unsigned long)_orig.nraddr;                 \
+      _argvec[2+1] = (unsigned long)arg1;                         \
+      _argvec[2+2] = (unsigned long)arg2;                         \
+      _argvec[2+3] = (unsigned long)arg3;                         \
+      _argvec[2+4] = (unsigned long)arg4;                         \
+      _argvec[2+5] = (unsigned long)arg5;                         \
+      _argvec[2+6] = (unsigned long)arg6;                         \
+      _argvec[2+7] = (unsigned long)arg7;                         \
+      _argvec[2+8] = (unsigned long)arg8;                         \
+      _argvec[2+9] = (unsigned long)arg9;                         \
+      _argvec[2+10] = (unsigned long)arg10;                       \
+      _argvec[2+11] = (unsigned long)arg11;                       \
+      _argvec[2+12] = (unsigned long)arg12;                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "mr 12,%1\n\t"                                           \
+         "std 2,-16(12)\n\t"  /* save tocptr */                   \
+         "ld   2,-8(12)\n\t"  /* use nraddr's tocptr */           \
+         "addi 1,1,-144\n\t"  /* expand stack frame */            \
+         /* arg12 */                                              \
+         "ld  3,96(12)\n\t"                                       \
+         "std 3,120(1)\n\t"                                       \
+         /* arg11 */                                              \
+         "ld  3,88(12)\n\t"                                       \
+         "std 3,112(1)\n\t"                                       \
+         /* arg10 */                                              \
+         "ld  3,80(12)\n\t"                                       \
+         "std 3,104(1)\n\t"                                       \
+         /* arg9 */                                               \
+         "ld  3,72(12)\n\t"                                       \
+         "std 3,96(1)\n\t"                                        \
+         /* args1-8 */                                            \
+         "ld   3, 8(12)\n\t"  /* arg1->r3 */                      \
+         "ld   4, 16(12)\n\t" /* arg2->r4 */                      \
+         "ld   5, 24(12)\n\t" /* arg3->r5 */                      \
+         "ld   6, 32(12)\n\t" /* arg4->r6 */                      \
+         "ld   7, 40(12)\n\t" /* arg5->r7 */                      \
+         "ld   8, 48(12)\n\t" /* arg6->r8 */                      \
+         "ld   9, 56(12)\n\t" /* arg7->r9 */                      \
+         "ld  10, 64(12)\n\t" /* arg8->r10 */                     \
+         "ld  12, 0(12)\n\t"  /* target->r12 */                   \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12                  \
+         "mr 12,%1\n\t"                                           \
+         "mr %0,3\n\t"                                            \
+         "ld 2,-16(12)\n\t" /* restore tocptr */                  \
+         VALGRIND_RESTORE_STACK                                   \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[2])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_ppc64le_linux */
+
+/* ------------------------- arm-linux ------------------------- */
+
+#if defined(PLAT_arm_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS "r0", "r1", "r2", "r3","r4", "r12", "r14"
+
+/* Macros to save and align the stack before making a function
+   call and restore it afterwards as gcc may not keep the stack
+   pointer aligned if it doesn't realise calls are being made
+   to other functions. */
+
+/* This is a bit tricky.  We store the original stack pointer in r10
+   as it is callee-saves.  gcc doesn't allow the use of r11 for some
+   reason.  Also, we can't directly "bic" the stack pointer in thumb
+   mode since r13 isn't an allowed register number in that context.
+   So use r4 as a temporary, since that is about to get trashed
+   anyway, just after each use of this macro.  Side effect is we need
+   to be very careful about any future changes, since
+   VALGRIND_ALIGN_STACK simply assumes r4 is usable. */
+#define VALGRIND_ALIGN_STACK               \
+      "mov r10, sp\n\t"                    \
+      "mov r4,  sp\n\t"                    \
+      "bic r4,  r4, #7\n\t"                \
+      "mov sp,  r4\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mov sp,  r10\n\t"
+
+/* These CALL_FN_ macros assume that on arm-linux, sizeof(unsigned
+   long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #4 \n\t"                                    \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "push {r0} \n\t"                                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "push {r0, r1} \n\t"                                     \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #4 \n\t"                                    \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "push {r0, r1, r2} \n\t"                                 \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "push {r0, r1, r2, r3} \n\t"                             \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #4 \n\t"                                    \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "ldr r4, [%1, #36] \n\t"                                 \
+         "push {r0, r1, r2, r3, r4} \n\t"                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #40] \n\t"                                 \
+         "push {r0} \n\t"                                         \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "ldr r4, [%1, #36] \n\t"                                 \
+         "push {r0, r1, r2, r3, r4} \n\t"                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #4 \n\t"                                    \
+         "ldr r0, [%1, #40] \n\t"                                 \
+         "ldr r1, [%1, #44] \n\t"                                 \
+         "push {r0, r1} \n\t"                                     \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "ldr r4, [%1, #36] \n\t"                                 \
+         "push {r0, r1, r2, r3, r4} \n\t"                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr r0, [%1, #40] \n\t"                                 \
+         "ldr r1, [%1, #44] \n\t"                                 \
+         "ldr r2, [%1, #48] \n\t"                                 \
+         "push {r0, r1, r2} \n\t"                                 \
+         "ldr r0, [%1, #20] \n\t"                                 \
+         "ldr r1, [%1, #24] \n\t"                                 \
+         "ldr r2, [%1, #28] \n\t"                                 \
+         "ldr r3, [%1, #32] \n\t"                                 \
+         "ldr r4, [%1, #36] \n\t"                                 \
+         "push {r0, r1, r2, r3, r4} \n\t"                         \
+         "ldr r0, [%1, #4] \n\t"                                  \
+         "ldr r1, [%1, #8] \n\t"                                  \
+         "ldr r2, [%1, #12] \n\t"                                 \
+         "ldr r3, [%1, #16] \n\t"                                 \
+         "ldr r4, [%1] \n\t"  /* target->r4 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, r0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_arm_linux */
+
+/* ------------------------ arm64-linux ------------------------ */
+
+#if defined(PLAT_arm64_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS \
+     "x0", "x1", "x2", "x3","x4", "x5", "x6", "x7", "x8", "x9",   \
+     "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17",      \
+     "x18", "x19", "x20", "x30",                                  \
+     "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",  \
+     "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",      \
+     "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",      \
+     "v26", "v27", "v28", "v29", "v30", "v31"
+
+/* x21 is callee-saved, so we can use it to save and restore SP around
+   the hidden call. */
+#define VALGRIND_ALIGN_STACK               \
+      "mov x21, sp\n\t"                    \
+      "bic sp, x21, #15\n\t"
+#define VALGRIND_RESTORE_STACK             \
+      "mov sp,  x21\n\t"
+
+/* These CALL_FN_ macros assume that on arm64-linux,
+   sizeof(unsigned long) == 8. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0\n"                                           \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #0x20 \n\t"                                 \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1, #72] \n\t"                                 \
+         "str x8, [sp, #0]  \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #0x20 \n\t"                                 \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1, #72] \n\t"                                 \
+         "str x8, [sp, #0]  \n\t"                                 \
+         "ldr x8, [%1, #80] \n\t"                                 \
+         "str x8, [sp, #8]  \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11)     \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #0x30 \n\t"                                 \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1, #72] \n\t"                                 \
+         "str x8, [sp, #0]  \n\t"                                 \
+         "ldr x8, [%1, #80] \n\t"                                 \
+         "str x8, [sp, #8]  \n\t"                                 \
+         "ldr x8, [%1, #88] \n\t"                                 \
+         "str x8, [sp, #16] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10,arg11,     \
+                                  arg12)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         VALGRIND_ALIGN_STACK                                     \
+         "sub sp, sp, #0x30 \n\t"                                 \
+         "ldr x0, [%1, #8] \n\t"                                  \
+         "ldr x1, [%1, #16] \n\t"                                 \
+         "ldr x2, [%1, #24] \n\t"                                 \
+         "ldr x3, [%1, #32] \n\t"                                 \
+         "ldr x4, [%1, #40] \n\t"                                 \
+         "ldr x5, [%1, #48] \n\t"                                 \
+         "ldr x6, [%1, #56] \n\t"                                 \
+         "ldr x7, [%1, #64] \n\t"                                 \
+         "ldr x8, [%1, #72] \n\t"                                 \
+         "str x8, [sp, #0]  \n\t"                                 \
+         "ldr x8, [%1, #80] \n\t"                                 \
+         "str x8, [sp, #8]  \n\t"                                 \
+         "ldr x8, [%1, #88] \n\t"                                 \
+         "str x8, [sp, #16] \n\t"                                 \
+         "ldr x8, [%1, #96] \n\t"                                 \
+         "str x8, [sp, #24] \n\t"                                 \
+         "ldr x8, [%1] \n\t"  /* target->x8 */                    \
+         VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8                   \
+         VALGRIND_RESTORE_STACK                                   \
+         "mov %0, x0"                                             \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21"   \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_arm64_linux */
+
+/* ------------------------- s390x-linux ------------------------- */
+
+#if defined(PLAT_s390x_linux)
+
+/* Similar workaround as amd64 (see above), but we use r11 as frame
+   pointer and save the old r11 in r7. r11 might be used for
+   argvec, therefore we copy argvec in r1 since r1 is clobbered
+   after the call anyway.  */
+#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)
+#  define __FRAME_POINTER                                         \
+      ,"d"(__builtin_dwarf_cfa())
+#  define VALGRIND_CFI_PROLOGUE                                   \
+      ".cfi_remember_state\n\t"                                   \
+      "lgr 1,%1\n\t" /* copy the argvec pointer in r1 */          \
+      "lgr 7,11\n\t"                                              \
+      "lgr 11,%2\n\t"                                             \
+      ".cfi_def_cfa r11, 0\n\t"
+#  define VALGRIND_CFI_EPILOGUE                                   \
+      "lgr 11, 7\n\t"                                             \
+      ".cfi_restore_state\n\t"
+#else
+#  define __FRAME_POINTER
+#  define VALGRIND_CFI_PROLOGUE                                   \
+      "lgr 1,%1\n\t"
+#  define VALGRIND_CFI_EPILOGUE
+#endif
+
+/* Nb: On s390 the stack pointer is properly aligned *at all times*
+   according to the s390 GCC maintainer. (The ABI specification is not
+   precise in this regard.) Therefore, VALGRIND_ALIGN_STACK and
+   VALGRIND_RESTORE_STACK are not defined here. */
+
+/* These regs are trashed by the hidden call. Note that we overwrite
+   r14 in s390_irgen_noredir (VEX/priv/guest_s390_irgen.c) to give the
+   function a proper return address. All others are ABI defined call
+   clobbers. */
+#if defined(__VX__) || defined(__S390_VX__)
+#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14",   \
+      "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",             \
+      "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",       \
+      "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",     \
+      "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+#else
+#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14",   \
+      "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7"
+#endif
+
+/* Nb: Although r11 is modified in the asm snippets below (inside 
+   VALGRIND_CFI_PROLOGUE) it is not listed in the clobber section, for
+   two reasons:
+   (1) r11 is restored in VALGRIND_CFI_EPILOGUE, so effectively it is not
+       modified
+   (2) GCC will complain that r11 cannot appear inside a clobber section,
+       when compiled with -O -fno-omit-frame-pointer
+ */
+
+#define CALL_FN_W_v(lval, orig)                                  \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long  _argvec[1];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 1, 0(1)\n\t"  /* target->r1 */                      \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "d" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+/* The call abi has the arguments in r2-r6 and stack */
+#define CALL_FN_W_W(lval, orig, arg1)                            \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[2];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1, arg2)                     \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[3];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1, arg2, arg3)              \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[4];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1, arg2, arg3, arg4)       \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[5];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7"     \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1, arg2, arg3, arg4, arg5)   \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[6];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-160\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,160\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1, arg2, arg3, arg4, arg5,   \
+                     arg6)                                       \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[7];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-168\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,168\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1, arg2, arg3, arg4, arg5,   \
+                     arg6, arg7)                                 \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[8];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-176\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,176\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1, arg2, arg3, arg4, arg5,   \
+                     arg6, arg7 ,arg8)                           \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[9];                         \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-184\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,184\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1, arg2, arg3, arg4, arg5,   \
+                     arg6, arg7 ,arg8, arg9)                     \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[10];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      _argvec[9] = (unsigned long)arg9;                          \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-192\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "mvc 184(8,15), 72(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,192\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1, arg2, arg3, arg4, arg5,  \
+                     arg6, arg7 ,arg8, arg9, arg10)              \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[11];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      _argvec[9] = (unsigned long)arg9;                          \
+      _argvec[10] = (unsigned long)arg10;                        \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-200\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "mvc 184(8,15), 72(1)\n\t"                              \
+         "mvc 192(8,15), 80(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,200\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1, arg2, arg3, arg4, arg5,  \
+                     arg6, arg7 ,arg8, arg9, arg10, arg11)       \
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[12];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      _argvec[9] = (unsigned long)arg9;                          \
+      _argvec[10] = (unsigned long)arg10;                        \
+      _argvec[11] = (unsigned long)arg11;                        \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-208\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "mvc 184(8,15), 72(1)\n\t"                              \
+         "mvc 192(8,15), 80(1)\n\t"                              \
+         "mvc 200(8,15), 88(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,208\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1, arg2, arg3, arg4, arg5,  \
+                     arg6, arg7 ,arg8, arg9, arg10, arg11, arg12)\
+   do {                                                          \
+      volatile OrigFn        _orig = (orig);                     \
+      volatile unsigned long _argvec[13];                        \
+      volatile unsigned long _res;                               \
+      _argvec[0] = (unsigned long)_orig.nraddr;                  \
+      _argvec[1] = (unsigned long)arg1;                          \
+      _argvec[2] = (unsigned long)arg2;                          \
+      _argvec[3] = (unsigned long)arg3;                          \
+      _argvec[4] = (unsigned long)arg4;                          \
+      _argvec[5] = (unsigned long)arg5;                          \
+      _argvec[6] = (unsigned long)arg6;                          \
+      _argvec[7] = (unsigned long)arg7;                          \
+      _argvec[8] = (unsigned long)arg8;                          \
+      _argvec[9] = (unsigned long)arg9;                          \
+      _argvec[10] = (unsigned long)arg10;                        \
+      _argvec[11] = (unsigned long)arg11;                        \
+      _argvec[12] = (unsigned long)arg12;                        \
+      __asm__ volatile(                                          \
+         VALGRIND_CFI_PROLOGUE                                   \
+         "aghi 15,-216\n\t"                                      \
+         "lg 2, 8(1)\n\t"                                        \
+         "lg 3,16(1)\n\t"                                        \
+         "lg 4,24(1)\n\t"                                        \
+         "lg 5,32(1)\n\t"                                        \
+         "lg 6,40(1)\n\t"                                        \
+         "mvc 160(8,15), 48(1)\n\t"                              \
+         "mvc 168(8,15), 56(1)\n\t"                              \
+         "mvc 176(8,15), 64(1)\n\t"                              \
+         "mvc 184(8,15), 72(1)\n\t"                              \
+         "mvc 192(8,15), 80(1)\n\t"                              \
+         "mvc 200(8,15), 88(1)\n\t"                              \
+         "mvc 208(8,15), 96(1)\n\t"                              \
+         "lg 1, 0(1)\n\t"                                        \
+         VALGRIND_CALL_NOREDIR_R1                                \
+         "aghi 15,216\n\t"                                       \
+         VALGRIND_CFI_EPILOGUE                                   \
+         "lgr %0, 2\n\t"                                         \
+         : /*out*/   "=d" (_res)                                 \
+         : /*in*/    "a" (&_argvec[0]) __FRAME_POINTER           \
+         : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \
+      );                                                         \
+      lval = (__typeof__(lval)) _res;                            \
+   } while (0)
+
+
+#endif /* PLAT_s390x_linux */
+
+/* ------------------------- mips32-linux ----------------------- */
+ 
+#if defined(PLAT_mips32_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6",       \
+"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \
+"$25", "$31"
+
+/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned
+   long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16\n\t"                                  \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+     volatile unsigned long _argvec[2];                           \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $4, 4(%1) \n\t"   /* arg1*/                          \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory",  __CALLER_SAVED_REGS               \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "subu $29, $29, 16 \n\t"                                 \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 16 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 24\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 24 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 32\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "nop\n\t"                                                \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 32 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 32\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 32 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 40\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 40 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 40\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 36(%1) \n\t"                                     \
+         "sw $4, 32($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 40 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 48\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 36(%1) \n\t"                                     \
+         "sw $4, 32($29) \n\t"                                    \
+         "lw $4, 40(%1) \n\t"                                     \
+         "sw $4, 36($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 48 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 48\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 36(%1) \n\t"                                     \
+         "sw $4, 32($29) \n\t"                                    \
+         "lw $4, 40(%1) \n\t"                                     \
+         "sw $4, 36($29) \n\t"                                    \
+         "lw $4, 44(%1) \n\t"                                     \
+         "sw $4, 40($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 48 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         "subu $29, $29, 8 \n\t"                                  \
+         "sw $28, 0($29) \n\t"                                    \
+         "sw $31, 4($29) \n\t"                                    \
+         "lw $4, 20(%1) \n\t"                                     \
+         "subu $29, $29, 56\n\t"                                  \
+         "sw $4, 16($29) \n\t"                                    \
+         "lw $4, 24(%1) \n\t"                                     \
+         "sw $4, 20($29) \n\t"                                    \
+         "lw $4, 28(%1) \n\t"                                     \
+         "sw $4, 24($29) \n\t"                                    \
+         "lw $4, 32(%1) \n\t"                                     \
+         "sw $4, 28($29) \n\t"                                    \
+         "lw $4, 36(%1) \n\t"                                     \
+         "sw $4, 32($29) \n\t"                                    \
+         "lw $4, 40(%1) \n\t"                                     \
+         "sw $4, 36($29) \n\t"                                    \
+         "lw $4, 44(%1) \n\t"                                     \
+         "sw $4, 40($29) \n\t"                                    \
+         "lw $4, 48(%1) \n\t"                                     \
+         "sw $4, 44($29) \n\t"                                    \
+         "lw $4, 4(%1) \n\t"                                      \
+         "lw $5, 8(%1) \n\t"                                      \
+         "lw $6, 12(%1) \n\t"                                     \
+         "lw $7, 16(%1) \n\t"                                     \
+         "lw $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "addu $29, $29, 56 \n\t"                                 \
+         "lw $28, 0($29) \n\t"                                    \
+         "lw $31, 4($29) \n\t"                                    \
+         "addu $29, $29, 8 \n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_mips32_linux */
+
+/* ------------------------- nanomips-linux -------------------- */
+
+#if defined(PLAT_nanomips_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS "$t4", "$t5", "$a0", "$a1", "$a2",     \
+"$a3", "$a4", "$a5", "$a6", "$a7", "$t0", "$t1", "$t2", "$t3",     \
+"$t8","$t9", "$at"
+
+/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned
+   long) == 4. */
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[1];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[2];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[3];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[4];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[5];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[6];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         "lw $a4,20(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[7];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         "lw $a4,20(%1)\n\t"                                      \
+         "lw $a5,24(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[8];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         "lw $a4,20(%1)\n\t"                                      \
+         "lw $a5,24(%1)\n\t"                                      \
+         "lw $a6,28(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[9];                          \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      __asm__ volatile(                                           \
+         "lw $t9, 0(%1)\n\t"                                      \
+         "lw $a0, 4(%1)\n\t"                                      \
+         "lw $a1, 8(%1)\n\t"                                      \
+         "lw $a2,12(%1)\n\t"                                      \
+         "lw $a3,16(%1)\n\t"                                      \
+         "lw $a4,20(%1)\n\t"                                      \
+         "lw $a5,24(%1)\n\t"                                      \
+         "lw $a6,28(%1)\n\t"                                      \
+         "lw $a7,32(%1)\n\t"                                      \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0\n"                                         \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[10];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      __asm__ volatile(                                           \
+         "addiu $sp, $sp, -16  \n\t"                              \
+         "lw $t9,36(%1)        \n\t"                              \
+         "sw $t9, 0($sp)       \n\t"                              \
+         "lw $t9, 0(%1)        \n\t"                              \
+         "lw $a0, 4(%1)        \n\t"                              \
+         "lw $a1, 8(%1)        \n\t"                              \
+         "lw $a2,12(%1)        \n\t"                              \
+         "lw $a3,16(%1)        \n\t"                              \
+         "lw $a4,20(%1)        \n\t"                              \
+         "lw $a5,24(%1)        \n\t"                              \
+         "lw $a6,28(%1)        \n\t"                              \
+         "lw $a7,32(%1)        \n\t"                              \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0         \n\t"                              \
+         "addiu $sp, $sp, 16   \n\t"                              \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[11];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      __asm__ volatile(                                           \
+         "addiu $sp, $sp, -16  \n\t"                              \
+         "lw $t9,36(%1)        \n\t"                              \
+         "sw $t9, 0($sp)       \n\t"                              \
+         "lw $t9,40(%1)        \n\t"                              \
+         "sw $t9, 4($sp)       \n\t"                              \
+         "lw $t9, 0(%1)        \n\t"                              \
+         "lw $a0, 4(%1)        \n\t"                              \
+         "lw $a1, 8(%1)        \n\t"                              \
+         "lw $a2,12(%1)        \n\t"                              \
+         "lw $a3,16(%1)        \n\t"                              \
+         "lw $a4,20(%1)        \n\t"                              \
+         "lw $a5,24(%1)        \n\t"                              \
+         "lw $a6,28(%1)        \n\t"                              \
+         "lw $a7,32(%1)        \n\t"                              \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0         \n\t"                              \
+         "addiu $sp, $sp, 16   \n\t"                              \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[12];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      __asm__ volatile(                                           \
+         "addiu $sp, $sp, -16  \n\t"                              \
+         "lw $t9,36(%1)        \n\t"                              \
+         "sw $t9, 0($sp)       \n\t"                              \
+         "lw $t9,40(%1)        \n\t"                              \
+         "sw $t9, 4($sp)       \n\t"                              \
+         "lw $t9,44(%1)        \n\t"                              \
+         "sw $t9, 8($sp)       \n\t"                              \
+         "lw $t9, 0(%1)        \n\t"                              \
+         "lw $a0, 4(%1)        \n\t"                              \
+         "lw $a1, 8(%1)        \n\t"                              \
+         "lw $a2,12(%1)        \n\t"                              \
+         "lw $a3,16(%1)        \n\t"                              \
+         "lw $a4,20(%1)        \n\t"                              \
+         "lw $a5,24(%1)        \n\t"                              \
+         "lw $a6,28(%1)        \n\t"                              \
+         "lw $a7,32(%1)        \n\t"                              \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0         \n\t"                              \
+         "addiu $sp, $sp, 16   \n\t"                              \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long _argvec[13];                         \
+      volatile unsigned long _res;                                \
+      _argvec[0] = (unsigned long)_orig.nraddr;                   \
+      _argvec[1] = (unsigned long)(arg1);                         \
+      _argvec[2] = (unsigned long)(arg2);                         \
+      _argvec[3] = (unsigned long)(arg3);                         \
+      _argvec[4] = (unsigned long)(arg4);                         \
+      _argvec[5] = (unsigned long)(arg5);                         \
+      _argvec[6] = (unsigned long)(arg6);                         \
+      _argvec[7] = (unsigned long)(arg7);                         \
+      _argvec[8] = (unsigned long)(arg8);                         \
+      _argvec[9] = (unsigned long)(arg9);                         \
+      _argvec[10] = (unsigned long)(arg10);                       \
+      _argvec[11] = (unsigned long)(arg11);                       \
+      _argvec[12] = (unsigned long)(arg12);                       \
+      __asm__ volatile(                                           \
+         "addiu $sp, $sp, -16  \n\t"                              \
+         "lw $t9,36(%1)        \n\t"                              \
+         "sw $t9, 0($sp)       \n\t"                              \
+         "lw $t9,40(%1)        \n\t"                              \
+         "sw $t9, 4($sp)       \n\t"                              \
+         "lw $t9,44(%1)        \n\t"                              \
+         "sw $t9, 8($sp)       \n\t"                              \
+         "lw $t9,48(%1)        \n\t"                              \
+         "sw $t9,12($sp)       \n\t"                              \
+         "lw $t9, 0(%1)        \n\t"                              \
+         "lw $a0, 4(%1)        \n\t"                              \
+         "lw $a1, 8(%1)        \n\t"                              \
+         "lw $a2,12(%1)        \n\t"                              \
+         "lw $a3,16(%1)        \n\t"                              \
+         "lw $a4,20(%1)        \n\t"                              \
+         "lw $a5,24(%1)        \n\t"                              \
+         "lw $a6,28(%1)        \n\t"                              \
+         "lw $a7,32(%1)        \n\t"                              \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $a0         \n\t"                              \
+         "addiu $sp, $sp, 16   \n\t"                              \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) _res;                             \
+   } while (0)
+
+#endif /* PLAT_nanomips_linux */
+
+/* ------------------------- mips64-linux ------------------------- */
+
+#if defined(PLAT_mips64_linux)
+
+/* These regs are trashed by the hidden call. */
+#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6",       \
+"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \
+"$25", "$31"
+
+/* These CALL_FN_ macros assume that on mips64-linux,
+   sizeof(long long) == 8. */
+
+#define MIPS64_LONG2REG_CAST(x) ((long long)(long)x)
+
+#define CALL_FN_W_v(lval, orig)                                   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[1];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      __asm__ volatile(                                           \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "0" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_W(lval, orig, arg1)                             \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[2];                     \
+      volatile unsigned long long  _res;                          \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"   /* arg1*/                           \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_WW(lval, orig, arg1,arg2)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[3];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = _orig.nraddr;                                  \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+
+#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3)                 \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[4];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = _orig.nraddr;                                  \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[5];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5)        \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[6];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6)   \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[7];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7)                            \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[8];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8)                       \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[9];                     \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      __asm__ volatile(                                           \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1) \n\t"  /* target->t9 */                   \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,   \
+                                 arg7,arg8,arg9)                  \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[10];                    \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      _argvec[9] = MIPS64_LONG2REG_CAST(arg9);                    \
+      __asm__ volatile(                                           \
+         "dsubu $29, $29, 8\n\t"                                  \
+         "ld $4, 72(%1)\n\t"                                      \
+         "sd $4, 0($29)\n\t"                                      \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "daddu $29, $29, 8\n\t"                                  \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6,  \
+                                  arg7,arg8,arg9,arg10)           \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[11];                    \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      _argvec[9] = MIPS64_LONG2REG_CAST(arg9);                    \
+      _argvec[10] = MIPS64_LONG2REG_CAST(arg10);                  \
+      __asm__ volatile(                                           \
+         "dsubu $29, $29, 16\n\t"                                 \
+         "ld $4, 72(%1)\n\t"                                      \
+         "sd $4, 0($29)\n\t"                                      \
+         "ld $4, 80(%1)\n\t"                                      \
+         "sd $4, 8($29)\n\t"                                      \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "daddu $29, $29, 16\n\t"                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11)                          \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[12];                    \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      _argvec[9] = MIPS64_LONG2REG_CAST(arg9);                    \
+      _argvec[10] = MIPS64_LONG2REG_CAST(arg10);                  \
+      _argvec[11] = MIPS64_LONG2REG_CAST(arg11);                  \
+      __asm__ volatile(                                           \
+         "dsubu $29, $29, 24\n\t"                                 \
+         "ld $4, 72(%1)\n\t"                                      \
+         "sd $4, 0($29)\n\t"                                      \
+         "ld $4, 80(%1)\n\t"                                      \
+         "sd $4, 8($29)\n\t"                                      \
+         "ld $4, 88(%1)\n\t"                                      \
+         "sd $4, 16($29)\n\t"                                     \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "daddu $29, $29, 24\n\t"                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,       \
+                                  arg6,arg7,arg8,arg9,arg10,      \
+                                  arg11,arg12)                    \
+   do {                                                           \
+      volatile OrigFn        _orig = (orig);                      \
+      volatile unsigned long long _argvec[13];                    \
+      volatile unsigned long long _res;                           \
+      _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr);            \
+      _argvec[1] = MIPS64_LONG2REG_CAST(arg1);                    \
+      _argvec[2] = MIPS64_LONG2REG_CAST(arg2);                    \
+      _argvec[3] = MIPS64_LONG2REG_CAST(arg3);                    \
+      _argvec[4] = MIPS64_LONG2REG_CAST(arg4);                    \
+      _argvec[5] = MIPS64_LONG2REG_CAST(arg5);                    \
+      _argvec[6] = MIPS64_LONG2REG_CAST(arg6);                    \
+      _argvec[7] = MIPS64_LONG2REG_CAST(arg7);                    \
+      _argvec[8] = MIPS64_LONG2REG_CAST(arg8);                    \
+      _argvec[9] = MIPS64_LONG2REG_CAST(arg9);                    \
+      _argvec[10] = MIPS64_LONG2REG_CAST(arg10);                  \
+      _argvec[11] = MIPS64_LONG2REG_CAST(arg11);                  \
+      _argvec[12] = MIPS64_LONG2REG_CAST(arg12);                  \
+      __asm__ volatile(                                           \
+         "dsubu $29, $29, 32\n\t"                                 \
+         "ld $4, 72(%1)\n\t"                                      \
+         "sd $4, 0($29)\n\t"                                      \
+         "ld $4, 80(%1)\n\t"                                      \
+         "sd $4, 8($29)\n\t"                                      \
+         "ld $4, 88(%1)\n\t"                                      \
+         "sd $4, 16($29)\n\t"                                     \
+         "ld $4, 96(%1)\n\t"                                      \
+         "sd $4, 24($29)\n\t"                                     \
+         "ld $4, 8(%1)\n\t"                                       \
+         "ld $5, 16(%1)\n\t"                                      \
+         "ld $6, 24(%1)\n\t"                                      \
+         "ld $7, 32(%1)\n\t"                                      \
+         "ld $8, 40(%1)\n\t"                                      \
+         "ld $9, 48(%1)\n\t"                                      \
+         "ld $10, 56(%1)\n\t"                                     \
+         "ld $11, 64(%1)\n\t"                                     \
+         "ld $25, 0(%1)\n\t"  /* target->t9 */                    \
+         VALGRIND_CALL_NOREDIR_T9                                 \
+         "daddu $29, $29, 32\n\t"                                 \
+         "move %0, $2\n"                                          \
+         : /*out*/   "=r" (_res)                                  \
+         : /*in*/    "r" (&_argvec[0])                            \
+         : /*trash*/ "memory", __CALLER_SAVED_REGS                \
+      );                                                          \
+      lval = (__typeof__(lval)) (long)_res;                       \
+   } while (0)
+
+#endif /* PLAT_mips64_linux */
+
+/* ------------------------------------------------------------------ */
+/* ARCHITECTURE INDEPENDENT MACROS for CLIENT REQUESTS.               */
+/*                                                                    */
+/* ------------------------------------------------------------------ */
+
+/* Some request codes.  There are many more of these, but most are not
+   exposed to end-user view.  These are the public ones, all of the
+   form 0x1000 + small_number.
+
+   Core ones are in the range 0x00000000--0x0000ffff.  The non-public
+   ones start at 0x2000.
+*/
+
+/* These macros are used by tools -- they must be public, but don't
+   embed them into other programs. */
+#define VG_USERREQ_TOOL_BASE(a,b) \
+   ((unsigned int)(((a)&0xff) << 24 | ((b)&0xff) << 16))
+#define VG_IS_TOOL_USERREQ(a, b, v) \
+   (VG_USERREQ_TOOL_BASE(a,b) == ((v) & 0xffff0000))
+
+/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! 
+   This enum comprises an ABI exported by Valgrind to programs
+   which use client requests.  DO NOT CHANGE THE NUMERIC VALUES OF THESE
+   ENTRIES, NOR DELETE ANY -- add new ones at the end of the most
+   relevant group. */
+typedef
+   enum { VG_USERREQ__RUNNING_ON_VALGRIND  = 0x1001,
+          VG_USERREQ__DISCARD_TRANSLATIONS = 0x1002,
+
+          /* These allow any function to be called from the simulated
+             CPU but run on the real CPU.  Nb: the first arg passed to
+             the function is always the ThreadId of the running
+             thread!  So CLIENT_CALL0 actually requires a 1 arg
+             function, etc. */
+          VG_USERREQ__CLIENT_CALL0 = 0x1101,
+          VG_USERREQ__CLIENT_CALL1 = 0x1102,
+          VG_USERREQ__CLIENT_CALL2 = 0x1103,
+          VG_USERREQ__CLIENT_CALL3 = 0x1104,
+
+          /* Can be useful in regression testing suites -- eg. can
+             send Valgrind's output to /dev/null and still count
+             errors. */
+          VG_USERREQ__COUNT_ERRORS = 0x1201,
+
+          /* Allows the client program and/or gdbserver to execute a monitor
+             command. */
+          VG_USERREQ__GDB_MONITOR_COMMAND = 0x1202,
+
+          /* Allows the client program to change a dynamic command line
+             option.  */
+          VG_USERREQ__CLO_CHANGE = 0x1203,
+
+          /* These are useful and can be interpreted by any tool that
+             tracks malloc() et al, by using vg_replace_malloc.c. */
+          VG_USERREQ__MALLOCLIKE_BLOCK = 0x1301,
+          VG_USERREQ__RESIZEINPLACE_BLOCK = 0x130b,
+          VG_USERREQ__FREELIKE_BLOCK   = 0x1302,
+          /* Memory pool support. */
+          VG_USERREQ__CREATE_MEMPOOL   = 0x1303,
+          VG_USERREQ__DESTROY_MEMPOOL  = 0x1304,
+          VG_USERREQ__MEMPOOL_ALLOC    = 0x1305,
+          VG_USERREQ__MEMPOOL_FREE     = 0x1306,
+          VG_USERREQ__MEMPOOL_TRIM     = 0x1307,
+          VG_USERREQ__MOVE_MEMPOOL     = 0x1308,
+          VG_USERREQ__MEMPOOL_CHANGE   = 0x1309,
+          VG_USERREQ__MEMPOOL_EXISTS   = 0x130a,
+
+          /* Allow printfs to valgrind log. */
+          /* The first two pass the va_list argument by value, which
+             assumes it is the same size as or smaller than a UWord,
+             which generally isn't the case.  Hence are deprecated.
+             The second two pass the vargs by reference and so are
+             immune to this problem. */
+          /* both :: char* fmt, va_list vargs (DEPRECATED) */
+          VG_USERREQ__PRINTF           = 0x1401,
+          VG_USERREQ__PRINTF_BACKTRACE = 0x1402,
+          /* both :: char* fmt, va_list* vargs */
+          VG_USERREQ__PRINTF_VALIST_BY_REF = 0x1403,
+          VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF = 0x1404,
+
+          /* Stack support. */
+          VG_USERREQ__STACK_REGISTER   = 0x1501,
+          VG_USERREQ__STACK_DEREGISTER = 0x1502,
+          VG_USERREQ__STACK_CHANGE     = 0x1503,
+
+          /* Wine support */
+          VG_USERREQ__LOAD_PDB_DEBUGINFO = 0x1601,
+
+          /* Querying of debug info. */
+          VG_USERREQ__MAP_IP_TO_SRCLOC = 0x1701,
+
+          /* Disable/enable error reporting level.  Takes a single
+             Word arg which is the delta to this thread's error
+             disablement indicator.  Hence 1 disables or further
+             disables errors, and -1 moves back towards enablement.
+             Other values are not allowed. */
+          VG_USERREQ__CHANGE_ERR_DISABLEMENT = 0x1801,
+
+          /* Some requests used for Valgrind internal, such as
+             self-test or self-hosting. */
+          /* Initialise IR injection */
+          VG_USERREQ__VEX_INIT_FOR_IRI = 0x1901,
+          /* Used by Inner Valgrind to inform Outer Valgrind where to
+             find the list of inner guest threads */
+          VG_USERREQ__INNER_THREADS    = 0x1902
+   } Vg_ClientRequest;
+
+#if !defined(__GNUC__)
+#  define __extension__ /* */
+#endif
+
+
+/* Returns the number of Valgrinds this code is running under.  That
+   is, 0 if running natively, 1 if running under Valgrind, 2 if
+   running under Valgrind which is running under another Valgrind,
+   etc. */
+#define RUNNING_ON_VALGRIND                                           \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* if not */,         \
+                                    VG_USERREQ__RUNNING_ON_VALGRIND,  \
+                                    0, 0, 0, 0, 0)                    \
+
+
+/* Discard translation of code in the range [_qzz_addr .. _qzz_addr +
+   _qzz_len - 1].  Useful if you are debugging a JITter or some such,
+   since it provides a way to make sure valgrind will retranslate the
+   invalidated area.  Returns no value. */
+#define VALGRIND_DISCARD_TRANSLATIONS(_qzz_addr,_qzz_len)              \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DISCARD_TRANSLATIONS,  \
+                                    _qzz_addr, _qzz_len, 0, 0, 0)
+
+#define VALGRIND_INNER_THREADS(_qzz_addr)                               \
+   VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__INNER_THREADS,           \
+                                   _qzz_addr, 0, 0, 0, 0)
+
+
+/* These requests are for getting Valgrind itself to print something.
+   Possibly with a backtrace.  This is a really ugly hack.  The return value
+   is the number of characters printed, excluding the "**** " part at the
+   start and the backtrace (if present). */
+
+#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER)
+/* Modern GCC will optimize the static routine out if unused,
+   and unused attribute will shut down warnings about it.  */
+static int VALGRIND_PRINTF(const char *format, ...)
+   __attribute__((format(__printf__, 1, 2), __unused__));
+#endif
+static int
+#if defined(_MSC_VER)
+__inline
+#endif
+VALGRIND_PRINTF(const char *format, ...)
+{
+#if defined(NVALGRIND)
+   (void)format;
+   return 0;
+#else /* NVALGRIND */
+#if defined(_MSC_VER) || defined(__MINGW64__)
+   uintptr_t _qzz_res;
+#else
+   unsigned long _qzz_res;
+#endif
+   va_list vargs;
+   va_start(vargs, format);
+#if defined(_MSC_VER) || defined(__MINGW64__)
+   _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0,
+                              VG_USERREQ__PRINTF_VALIST_BY_REF,
+                              (uintptr_t)format,
+                              (uintptr_t)&vargs,
+                              0, 0, 0);
+#else
+   _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0,
+                              VG_USERREQ__PRINTF_VALIST_BY_REF,
+                              (unsigned long)format,
+                              (unsigned long)&vargs, 
+                              0, 0, 0);
+#endif
+   va_end(vargs);
+   return (int)_qzz_res;
+#endif /* NVALGRIND */
+}
+
+#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER)
+static int VALGRIND_PRINTF_BACKTRACE(const char *format, ...)
+   __attribute__((format(__printf__, 1, 2), __unused__));
+#endif
+static int
+#if defined(_MSC_VER)
+__inline
+#endif
+VALGRIND_PRINTF_BACKTRACE(const char *format, ...)
+{
+#if defined(NVALGRIND)
+   (void)format;
+   return 0;
+#else /* NVALGRIND */
+#if defined(_MSC_VER) || defined(__MINGW64__)
+   uintptr_t _qzz_res;
+#else
+   unsigned long _qzz_res;
+#endif
+   va_list vargs;
+   va_start(vargs, format);
+#if defined(_MSC_VER) || defined(__MINGW64__)
+   _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0,
+                              VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF,
+                              (uintptr_t)format,
+                              (uintptr_t)&vargs,
+                              0, 0, 0);
+#else
+   _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0,
+                              VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF,
+                              (unsigned long)format,
+                              (unsigned long)&vargs, 
+                              0, 0, 0);
+#endif
+   va_end(vargs);
+   return (int)_qzz_res;
+#endif /* NVALGRIND */
+}
+
+
+/* These requests allow control to move from the simulated CPU to the
+   real CPU, calling an arbitrary function.
+   
+   Note that the current ThreadId is inserted as the first argument.
+   So this call:
+
+     VALGRIND_NON_SIMD_CALL2(f, arg1, arg2)
+
+   requires f to have this signature:
+
+     Word f(Word tid, Word arg1, Word arg2)
+
+   where "Word" is a word-sized type.
+
+   Note that these client requests are not entirely reliable.  For example,
+   if you call a function with them that subsequently calls printf(),
+   there's a high chance Valgrind will crash.  Generally, your prospects of
+   these working are made higher if the called function does not refer to
+   any global variables, and does not refer to any libc or other functions
+   (printf et al).  Any kind of entanglement with libc or dynamic linking is
+   likely to have a bad outcome, for tricky reasons which we've grappled
+   with a lot in the past.
+*/
+#define VALGRIND_NON_SIMD_CALL0(_qyy_fn)                          \
+    VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */,       \
+                                    VG_USERREQ__CLIENT_CALL0,     \
+                                    _qyy_fn,                      \
+                                    0, 0, 0, 0)
+
+#define VALGRIND_NON_SIMD_CALL1(_qyy_fn, _qyy_arg1)                    \
+    VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */,            \
+                                    VG_USERREQ__CLIENT_CALL1,          \
+                                    _qyy_fn,                           \
+                                    _qyy_arg1, 0, 0, 0)
+
+#define VALGRIND_NON_SIMD_CALL2(_qyy_fn, _qyy_arg1, _qyy_arg2)         \
+    VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */,            \
+                                    VG_USERREQ__CLIENT_CALL2,          \
+                                    _qyy_fn,                           \
+                                    _qyy_arg1, _qyy_arg2, 0, 0)
+
+#define VALGRIND_NON_SIMD_CALL3(_qyy_fn, _qyy_arg1, _qyy_arg2, _qyy_arg3) \
+    VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */,             \
+                                    VG_USERREQ__CLIENT_CALL3,           \
+                                    _qyy_fn,                            \
+                                    _qyy_arg1, _qyy_arg2,               \
+                                    _qyy_arg3, 0)
+
+
+/* Counts the number of errors that have been recorded by a tool.  Nb:
+   the tool must record the errors with VG_(maybe_record_error)() or
+   VG_(unique_error)() for them to be counted. */
+#define VALGRIND_COUNT_ERRORS                                     \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(                    \
+                               0 /* default return */,            \
+                               VG_USERREQ__COUNT_ERRORS,          \
+                               0, 0, 0, 0, 0)
+
+/* Several Valgrind tools (Memcheck, Massif, Helgrind, DRD) rely on knowing
+   when heap blocks are allocated in order to give accurate results.  This
+   happens automatically for the standard allocator functions such as
+   malloc(), calloc(), realloc(), memalign(), new, new[], free(), delete,
+   delete[], etc.
+
+   But if your program uses a custom allocator, this doesn't automatically
+   happen, and Valgrind will not do as well.  For example, if you allocate
+   superblocks with mmap() and then allocates chunks of the superblocks, all
+   Valgrind's observations will be at the mmap() level and it won't know that
+   the chunks should be considered separate entities.  In Memcheck's case,
+   that means you probably won't get heap block overrun detection (because
+   there won't be redzones marked as unaddressable) and you definitely won't
+   get any leak detection.
+
+   The following client requests allow a custom allocator to be annotated so
+   that it can be handled accurately by Valgrind.
+
+   VALGRIND_MALLOCLIKE_BLOCK marks a region of memory as having been allocated
+   by a malloc()-like function.  For Memcheck (an illustrative case), this
+   does two things:
+
+   - It records that the block has been allocated.  This means any addresses
+     within the block mentioned in error messages will be
+     identified as belonging to the block.  It also means that if the block
+     isn't freed it will be detected by the leak checker.
+
+   - It marks the block as being addressable and undefined (if 'is_zeroed' is
+     not set), or addressable and defined (if 'is_zeroed' is set).  This
+     controls how accesses to the block by the program are handled.
+   
+   'addr' is the start of the usable block (ie. after any
+   redzone), 'sizeB' is its size.  'rzB' is the redzone size if the allocator
+   can apply redzones -- these are blocks of padding at the start and end of
+   each block.  Adding redzones is recommended as it makes it much more likely
+   Valgrind will spot block overruns.  `is_zeroed' indicates if the memory is
+   zeroed (or filled with another predictable value), as is the case for
+   calloc().
+   
+   VALGRIND_MALLOCLIKE_BLOCK should be put immediately after the point where a
+   heap block -- that will be used by the client program -- is allocated.
+   It's best to put it at the outermost level of the allocator if possible;
+   for example, if you have a function my_alloc() which calls
+   internal_alloc(), and the client request is put inside internal_alloc(),
+   stack traces relating to the heap block will contain entries for both
+   my_alloc() and internal_alloc(), which is probably not what you want.
+
+   For Memcheck users: if you use VALGRIND_MALLOCLIKE_BLOCK to carve out
+   custom blocks from within a heap block, B, that has been allocated with
+   malloc/calloc/new/etc, then block B will be *ignored* during leak-checking
+   -- the custom blocks will take precedence.
+
+   VALGRIND_FREELIKE_BLOCK is the partner to VALGRIND_MALLOCLIKE_BLOCK.  For
+   Memcheck, it does two things:
+
+   - It records that the block has been deallocated.  This assumes that the
+     block was annotated as having been allocated via
+     VALGRIND_MALLOCLIKE_BLOCK.  Otherwise, an error will be issued.
+
+   - It marks the block as being unaddressable.
+
+   VALGRIND_FREELIKE_BLOCK should be put immediately after the point where a
+   heap block is deallocated.
+
+   VALGRIND_RESIZEINPLACE_BLOCK informs a tool about reallocation. For
+   Memcheck, it does four things:
+
+   - It records that the size of a block has been changed.  This assumes that
+     the block was annotated as having been allocated via
+     VALGRIND_MALLOCLIKE_BLOCK.  Otherwise, an error will be issued.
+
+   - If the block shrunk, it marks the freed memory as being unaddressable.
+
+   - If the block grew, it marks the new area as undefined and defines a red
+     zone past the end of the new block.
+
+   - The V-bits of the overlap between the old and the new block are preserved.
+
+   VALGRIND_RESIZEINPLACE_BLOCK should be put after allocation of the new block
+   and before deallocation of the old block.
+
+   In many cases, these three client requests will not be enough to get your
+   allocator working well with Memcheck.  More specifically, if your allocator
+   writes to freed blocks in any way then a VALGRIND_MAKE_MEM_UNDEFINED call
+   will be necessary to mark the memory as addressable just before the zeroing
+   occurs, otherwise you'll get a lot of invalid write errors.  For example,
+   you'll need to do this if your allocator recycles freed blocks, but it
+   zeroes them before handing them back out (via VALGRIND_MALLOCLIKE_BLOCK).
+   Alternatively, if your allocator reuses freed blocks for allocator-internal
+   data structures, VALGRIND_MAKE_MEM_UNDEFINED calls will also be necessary.
+
+   Really, what's happening is a blurring of the lines between the client
+   program and the allocator... after VALGRIND_FREELIKE_BLOCK is called, the
+   memory should be considered unaddressable to the client program, but the
+   allocator knows more than the rest of the client program and so may be able
+   to safely access it.  Extra client requests are necessary for Valgrind to
+   understand the distinction between the allocator and the rest of the
+   program.
+
+   Ignored if addr == 0.
+*/
+#define VALGRIND_MALLOCLIKE_BLOCK(addr, sizeB, rzB, is_zeroed)          \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MALLOCLIKE_BLOCK,       \
+                                    addr, sizeB, rzB, is_zeroed, 0)
+
+/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details.
+   Ignored if addr == 0.
+*/
+#define VALGRIND_RESIZEINPLACE_BLOCK(addr, oldSizeB, newSizeB, rzB)     \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__RESIZEINPLACE_BLOCK,    \
+                                    addr, oldSizeB, newSizeB, rzB, 0)
+
+/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details.
+   Ignored if addr == 0.
+*/
+#define VALGRIND_FREELIKE_BLOCK(addr, rzB)                              \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__FREELIKE_BLOCK,         \
+                                    addr, rzB, 0, 0, 0)
+
+/* Create a memory pool. */
+#define VALGRIND_CREATE_MEMPOOL(pool, rzB, is_zeroed)             \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL,   \
+                                    pool, rzB, is_zeroed, 0, 0)
+
+/* Create a memory pool with some flags specifying extended behaviour.
+   When flags is zero, the behaviour is identical to VALGRIND_CREATE_MEMPOOL.
+   
+   The flag VALGRIND_MEMPOOL_METAPOOL specifies that the pieces of memory 
+   associated with the pool using VALGRIND_MEMPOOL_ALLOC  will be used
+   by the application as superblocks to dole out MALLOC_LIKE blocks using
+   VALGRIND_MALLOCLIKE_BLOCK. In other words, a meta pool is a "2 levels"
+   pool : first level is the blocks described by VALGRIND_MEMPOOL_ALLOC.
+   The second level blocks are described using VALGRIND_MALLOCLIKE_BLOCK.
+   Note that the association between the pool and the second level blocks
+   is implicit : second level blocks will be located inside first level
+   blocks. It is necessary to use the VALGRIND_MEMPOOL_METAPOOL flag
+   for such 2 levels pools, as otherwise valgrind will detect overlapping
+   memory blocks, and will abort execution (e.g. during leak search).
+
+   Such a meta pool can also be marked as an 'auto free' pool using the flag
+   VALGRIND_MEMPOOL_AUTO_FREE, which must be OR-ed together with the
+   VALGRIND_MEMPOOL_METAPOOL. For an 'auto free' pool, VALGRIND_MEMPOOL_FREE
+   will automatically free the second level blocks that are contained
+   inside the first level block freed with VALGRIND_MEMPOOL_FREE.
+   In other words, calling VALGRIND_MEMPOOL_FREE will cause implicit calls
+   to VALGRIND_FREELIKE_BLOCK for all the second level blocks included
+   in the first level block.
+   Note: it is an error to use the VALGRIND_MEMPOOL_AUTO_FREE flag
+   without the VALGRIND_MEMPOOL_METAPOOL flag.
+*/
+#define VALGRIND_MEMPOOL_AUTO_FREE  1
+#define VALGRIND_MEMPOOL_METAPOOL   2
+#define VALGRIND_CREATE_MEMPOOL_EXT(pool, rzB, is_zeroed, flags)        \
+   VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL,          \
+                                   pool, rzB, is_zeroed, flags, 0)
+
+/* Destroy a memory pool. */
+#define VALGRIND_DESTROY_MEMPOOL(pool)                            \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DESTROY_MEMPOOL,  \
+                                    pool, 0, 0, 0, 0)
+
+/* Associate a piece of memory with a memory pool. */
+#define VALGRIND_MEMPOOL_ALLOC(pool, addr, size)                  \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_ALLOC,    \
+                                    pool, addr, size, 0, 0)
+
+/* Disassociate a piece of memory from a memory pool. */
+#define VALGRIND_MEMPOOL_FREE(pool, addr)                         \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_FREE,     \
+                                    pool, addr, 0, 0, 0)
+
+/* Disassociate any pieces outside a particular range. */
+#define VALGRIND_MEMPOOL_TRIM(pool, addr, size)                   \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_TRIM,     \
+                                    pool, addr, size, 0, 0)
+
+/* Resize and/or move a piece associated with a memory pool. */
+#define VALGRIND_MOVE_MEMPOOL(poolA, poolB)                       \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MOVE_MEMPOOL,     \
+                                    poolA, poolB, 0, 0, 0)
+
+/* Resize and/or move a piece associated with a memory pool. */
+#define VALGRIND_MEMPOOL_CHANGE(pool, addrA, addrB, size)         \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_CHANGE,   \
+                                    pool, addrA, addrB, size, 0)
+
+/* Return 1 if a mempool exists, else 0. */
+#define VALGRIND_MEMPOOL_EXISTS(pool)                             \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0,                  \
+                               VG_USERREQ__MEMPOOL_EXISTS,        \
+                               pool, 0, 0, 0, 0)
+
+/* Mark a piece of memory as being a stack. Returns a stack id.
+   start is the lowest addressable stack byte, end is the highest
+   addressable stack byte. */
+#define VALGRIND_STACK_REGISTER(start, end)                       \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0,                  \
+                               VG_USERREQ__STACK_REGISTER,        \
+                               start, end, 0, 0, 0)
+
+/* Unmark the piece of memory associated with a stack id as being a
+   stack. */
+#define VALGRIND_STACK_DEREGISTER(id)                             \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_DEREGISTER, \
+                                    id, 0, 0, 0, 0)
+
+/* Change the start and end address of the stack id.
+   start is the new lowest addressable stack byte, end is the new highest
+   addressable stack byte. */
+#define VALGRIND_STACK_CHANGE(id, start, end)                     \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_CHANGE,     \
+                                    id, start, end, 0, 0)
+
+/* Load PDB debug info for Wine PE image_map. */
+#define VALGRIND_LOAD_PDB_DEBUGINFO(fd, ptr, total_size, delta)     \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__LOAD_PDB_DEBUGINFO, \
+                                    fd, ptr, total_size, delta, 0)
+
+/* Map a code address to a source file name and line number.  buf64
+   must point to a 64-byte buffer in the caller's address space.  The
+   result will be dumped in there and is guaranteed to be zero
+   terminated.  If no info is found, the first byte is set to zero. */
+#define VALGRIND_MAP_IP_TO_SRCLOC(addr, buf64)                    \
+    (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0,                  \
+                               VG_USERREQ__MAP_IP_TO_SRCLOC,      \
+                               addr, buf64, 0, 0, 0)
+
+/* Disable error reporting for this thread.  Behaves in a stack like
+   way, so you can safely call this multiple times provided that
+   VALGRIND_ENABLE_ERROR_REPORTING is called the same number of times
+   to re-enable reporting.  The first call of this macro disables
+   reporting.  Subsequent calls have no effect except to increase the
+   number of VALGRIND_ENABLE_ERROR_REPORTING calls needed to re-enable
+   reporting.  Child threads do not inherit this setting from their
+   parents -- they are always created with reporting enabled. */
+#define VALGRIND_DISABLE_ERROR_REPORTING                                \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \
+                                    1, 0, 0, 0, 0)
+
+/* Re-enable error reporting, as per comments on
+   VALGRIND_DISABLE_ERROR_REPORTING. */
+#define VALGRIND_ENABLE_ERROR_REPORTING                                 \
+    VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \
+                                    -1, 0, 0, 0, 0)
+
+/* Execute a monitor command from the client program.
+   If a connection is opened with GDB, the output will be sent
+   according to the output mode set for vgdb.
+   If no connection is opened, output will go to the log output.
+   Returns 1 if command not recognised, 0 otherwise. */
+#define VALGRIND_MONITOR_COMMAND(command)                               \
+   VALGRIND_DO_CLIENT_REQUEST_EXPR(0, VG_USERREQ__GDB_MONITOR_COMMAND, \
+                                   command, 0, 0, 0, 0)
+
+
+/* Change the value of a dynamic command line option.
+   Note that unknown or not dynamically changeable options
+   will cause a warning message to be output.  */
+#define VALGRIND_CLO_CHANGE(option)                           \
+   VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CLO_CHANGE, \
+                                   option, 0, 0, 0, 0)
+
+
+#undef PLAT_x86_darwin
+#undef PLAT_amd64_darwin
+#undef PLAT_x86_win32
+#undef PLAT_amd64_win64
+#undef PLAT_x86_linux
+#undef PLAT_amd64_linux
+#undef PLAT_ppc32_linux
+#undef PLAT_ppc64be_linux
+#undef PLAT_ppc64le_linux
+#undef PLAT_arm_linux
+#undef PLAT_s390x_linux
+#undef PLAT_mips32_linux
+#undef PLAT_mips64_linux
+#undef PLAT_nanomips_linux
+#undef PLAT_x86_solaris
+#undef PLAT_amd64_solaris
+
+#endif   /* __VALGRIND_H */
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4379771053523968bfa20c2fa17365944af07d6c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6213efa82fb0f1a3f81f71f867d399713f36e455
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataloader.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd4dc24d8ae3b17de19c4efd095ee4c91c5ded7a
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/dataset.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc36917e7d6b05c2599350abfc6c85280c42c8a2
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/distributed.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2c5290a7c7a08ecb92f36f9a0117b9b4ed67d3e
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15ba423f34555ecb60e9262926c847b002fa7e7c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e206ee2bce8e90082d8193decaf0639e7262e27d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/__pycache__/sampler.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef03e053211b67a0ed109bc69f4a6b058783d4cf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__init__.py
@@ -0,0 +1,54 @@
+# mypy: allow-untyped-defs
+r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py.
+
+A lot of multiprocessing is used in data loading, which only supports running
+functions defined in global environment (py2 can't serialize static methods).
+Therefore, for code tidiness we put these functions into different files in this
+folder.
+"""
+
+import atexit
+import sys
+
+# old private location of the ExceptionWrapper that some users rely on:
+from torch._utils import ExceptionWrapper
+
+
+IS_WINDOWS = sys.platform == "win32"
+
+
+MP_STATUS_CHECK_INTERVAL = 5.0
+r"""Interval (in seconds) to check status of processes to avoid hanging in
+    multiprocessing data loading. This is mainly used in getting data from
+    another process, in which case we need to periodically check whether the
+    sender is alive to prevent hanging."""
+
+
+python_exit_status = False
+r"""Whether Python is shutting down. This flag is guaranteed to be set before
+the Python core library resources are freed, but Python may already be exiting
+for some time when this is set.
+
+Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
+hook in Python 3.7 multiprocessing library:
+https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
+"""
+
+
+try:
+    import numpy
+
+    HAS_NUMPY = True
+except ModuleNotFoundError:
+    HAS_NUMPY = False
+
+
+def _set_python_exit_flag():
+    global python_exit_status
+    python_exit_status = True
+
+
+atexit.register(_set_python_exit_flag)
+
+
+from . import collate, fetch, pin_memory, signal_handling, worker
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02f16a6a7e0b26653aba8b0e9eb9554191bb8d05
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12a3972c204b37881415e658de8c420a003f4f9b
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9642cc754675432ec3068d4186831c93b2e04c1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8119bf22fc72ec95b4d59cf0bdcb79eb2053ddea
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c325a6be2da9ac8b43cbdf135294d88c4209677
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce9d20a4ebe323af190f40235b3f3c0b61030ab8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a4da0731c0e29234bf4461c065398da36f9b00
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py
@@ -0,0 +1,398 @@
+# mypy: allow-untyped-defs
+r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
+
+These methods are used to collate samples fetched from dataset into Tensor(s).
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+
+`default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
+"""
+
+import collections
+import contextlib
+import copy
+import re
+from typing import Callable, Optional, Union
+
+import torch
+
+
+np_str_obj_array_pattern = re.compile(r"[SaUO]")
+
+
+def default_convert(data):
+    r"""
+    Convert each NumPy array element into a :class:`torch.Tensor`.
+
+    If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
+    If the input is not an NumPy array, it is left unchanged.
+    This is used as the default function for collation when both `batch_sampler` and `batch_size`
+    are NOT defined in :class:`~torch.utils.data.DataLoader`.
+
+    The general input type to output type mapping is similar to that
+    of :func:`~torch.utils.data.default_collate`. See the description there for more details.
+
+    Args:
+        data: a single data point to be converted
+
+    Examples:
+        >>> # xdoctest: +SKIP
+        >>> # Example with `int`
+        >>> default_convert(0)
+        0
+        >>> # Example with NumPy array
+        >>> default_convert(np.array([0, 1]))
+        tensor([0, 1])
+        >>> # Example with NamedTuple
+        >>> Point = namedtuple('Point', ['x', 'y'])
+        >>> default_convert(Point(0, 0))
+        Point(x=0, y=0)
+        >>> default_convert(Point(np.array(0), np.array(0)))
+        Point(x=tensor(0), y=tensor(0))
+        >>> # Example with List
+        >>> default_convert([np.array([0, 1]), np.array([2, 3])])
+        [tensor([0, 1]), tensor([2, 3])]
+    """
+    elem_type = type(data)
+    if isinstance(data, torch.Tensor):
+        return data
+    elif (
+        elem_type.__module__ == "numpy"
+        and elem_type.__name__ != "str_"
+        and elem_type.__name__ != "string_"
+    ):
+        # array of string classes and object
+        if (
+            elem_type.__name__ == "ndarray"
+            and np_str_obj_array_pattern.search(data.dtype.str) is not None
+        ):
+            return data
+        return torch.as_tensor(data)
+    elif isinstance(data, collections.abc.Mapping):
+        try:
+            if isinstance(data, collections.abc.MutableMapping):
+                # The mapping type may have extra properties, so we can't just
+                # use `type(data)(...)` to create the new mapping.
+                # Create a clone and update it if the mapping type is mutable.
+                clone = copy.copy(data)
+                clone.update({key: default_convert(data[key]) for key in data})
+                return clone
+            else:
+                return elem_type({key: default_convert(data[key]) for key in data})
+        except TypeError:
+            # The mapping type may not support `copy()` / `update(mapping)`
+            # or `__init__(iterable)`.
+            return {key: default_convert(data[key]) for key in data}
+    elif isinstance(data, tuple) and hasattr(data, "_fields"):  # namedtuple
+        return elem_type(*(default_convert(d) for d in data))
+    elif isinstance(data, tuple):
+        return [default_convert(d) for d in data]  # Backwards compatibility.
+    elif isinstance(data, collections.abc.Sequence) and not isinstance(
+        data, (str, bytes)
+    ):
+        try:
+            if isinstance(data, collections.abc.MutableSequence):
+                # The sequence type may have extra properties, so we can't just
+                # use `type(data)(...)` to create the new sequence.
+                # Create a clone and update it if the sequence type is mutable.
+                clone = copy.copy(data)  # type: ignore[arg-type]
+                for i, d in enumerate(data):
+                    clone[i] = default_convert(d)
+                return clone
+            else:
+                return elem_type([default_convert(d) for d in data])
+        except TypeError:
+            # The sequence type may not support `copy()` / `__setitem__(index, item)`
+            # or `__init__(iterable)` (e.g., `range`).
+            return [default_convert(d) for d in data]
+    else:
+        return data
+
+
+default_collate_err_msg_format = (
+    "default_collate: batch must contain tensors, numpy arrays, numbers, "
+    "dicts or lists; found {}"
+)
+
+
+def collate(
+    batch,
+    *,
+    collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
+):
+    r"""
+    General collate function that handles collection type of element within each batch.
+
+    The function also opens function registry to deal with specific element types. `default_collate_fn_map`
+    provides default collate functions for tensors, numpy arrays, numbers and strings.
+
+    Args:
+        batch: a single batch to be collated
+        collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
+            If the element type isn't present in this dictionary,
+            this function will go through each key of the dictionary in the insertion order to
+            invoke the corresponding collate function if the element type is a subclass of the key.
+
+    Examples:
+        >>> def collate_tensor_fn(batch, *, collate_fn_map):
+        ...     # Extend this function to handle batch of tensors
+        ...     return torch.stack(batch, 0)
+        >>> def custom_collate(batch):
+        ...     collate_map = {torch.Tensor: collate_tensor_fn}
+        ...     return collate(batch, collate_fn_map=collate_map)
+        >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
+        >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
+
+    Note:
+        Each collate function requires a positional argument for batch and a keyword argument
+        for the dictionary of collate functions as `collate_fn_map`.
+    """
+    elem = batch[0]
+    elem_type = type(elem)
+
+    if collate_fn_map is not None:
+        if elem_type in collate_fn_map:
+            return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
+
+        for collate_type in collate_fn_map:
+            if isinstance(elem, collate_type):
+                return collate_fn_map[collate_type](
+                    batch, collate_fn_map=collate_fn_map
+                )
+
+    if isinstance(elem, collections.abc.Mapping):
+        try:
+            if isinstance(elem, collections.abc.MutableMapping):
+                # The mapping type may have extra properties, so we can't just
+                # use `type(data)(...)` to create the new mapping.
+                # Create a clone and update it if the mapping type is mutable.
+                clone = copy.copy(elem)
+                clone.update(
+                    {
+                        key: collate(
+                            [d[key] for d in batch], collate_fn_map=collate_fn_map
+                        )
+                        for key in elem
+                    }
+                )
+                return clone
+            else:
+                return elem_type(
+                    {
+                        key: collate(
+                            [d[key] for d in batch], collate_fn_map=collate_fn_map
+                        )
+                        for key in elem
+                    }
+                )
+        except TypeError:
+            # The mapping type may not support `copy()` / `update(mapping)`
+            # or `__init__(iterable)`.
+            return {
+                key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)
+                for key in elem
+            }
+    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
+        return elem_type(
+            *(
+                collate(samples, collate_fn_map=collate_fn_map)
+                for samples in zip(*batch)
+            )
+        )
+    elif isinstance(elem, collections.abc.Sequence):
+        # check to make sure that the elements in batch have consistent size
+        it = iter(batch)
+        elem_size = len(next(it))
+        if not all(len(elem) == elem_size for elem in it):
+            raise RuntimeError("each element in list of batch should be of equal size")
+        transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.
+
+        if isinstance(elem, tuple):
+            return [
+                collate(samples, collate_fn_map=collate_fn_map)
+                for samples in transposed
+            ]  # Backwards compatibility.
+        else:
+            try:
+                if isinstance(elem, collections.abc.MutableSequence):
+                    # The sequence type may have extra properties, so we can't just
+                    # use `type(data)(...)` to create the new sequence.
+                    # Create a clone and update it if the sequence type is mutable.
+                    clone = copy.copy(elem)  # type: ignore[arg-type]
+                    for i, samples in enumerate(transposed):
+                        clone[i] = collate(samples, collate_fn_map=collate_fn_map)
+                    return clone
+                else:
+                    return elem_type(
+                        [
+                            collate(samples, collate_fn_map=collate_fn_map)
+                            for samples in transposed
+                        ]
+                    )
+            except TypeError:
+                # The sequence type may not support `copy()` / `__setitem__(index, item)`
+                # or `__init__(iterable)` (e.g., `range`).
+                return [
+                    collate(samples, collate_fn_map=collate_fn_map)
+                    for samples in transposed
+                ]
+
+    raise TypeError(default_collate_err_msg_format.format(elem_type))
+
+
+def collate_tensor_fn(
+    batch,
+    *,
+    collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
+):
+    elem = batch[0]
+    out = None
+    if elem.is_nested:
+        raise RuntimeError(
+            "Batches of nested tensors are not currently supported by the default collate_fn; "
+            "please provide a custom collate_fn to handle them appropriately."
+        )
+    if elem.layout in {
+        torch.sparse_coo,
+        torch.sparse_csr,
+        torch.sparse_bsr,
+        torch.sparse_csc,
+        torch.sparse_bsc,
+    }:
+        raise RuntimeError(
+            "Batches of sparse tensors are not currently supported by the default collate_fn; "
+            "please provide a custom collate_fn to handle them appropriately."
+        )
+    if torch.utils.data.get_worker_info() is not None:
+        # If we're in a background process, concatenate directly into a
+        # shared memory tensor to avoid an extra copy
+        numel = sum(x.numel() for x in batch)
+        storage = elem._typed_storage()._new_shared(numel, device=elem.device)
+        out = elem.new(storage).resize_(len(batch), *list(elem.size()))
+    return torch.stack(batch, 0, out=out)
+
+
+def collate_numpy_array_fn(
+    batch,
+    *,
+    collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
+):
+    elem = batch[0]
+    # array of string classes and object
+    if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+        raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+    return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
+
+
+def collate_numpy_scalar_fn(
+    batch,
+    *,
+    collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
+):
+    return torch.as_tensor(batch)
+
+
+def collate_float_fn(
+    batch,
+    *,
+    collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
+):
+    return torch.tensor(batch, dtype=torch.float64)
+
+
+def collate_int_fn(
+    batch,
+    *,
+    collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
+):
+    return torch.tensor(batch)
+
+
+def collate_str_fn(
+    batch,
+    *,
+    collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None,
+):
+    return batch
+
+
+default_collate_fn_map: dict[Union[type, tuple[type, ...]], Callable] = {
+    torch.Tensor: collate_tensor_fn
+}
+with contextlib.suppress(ImportError):
+    import numpy as np
+
+    # For both ndarray and memmap (subclass of ndarray)
+    default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
+    # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
+    # Skip string scalars
+    default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
+default_collate_fn_map[float] = collate_float_fn
+default_collate_fn_map[int] = collate_int_fn
+default_collate_fn_map[str] = collate_str_fn
+default_collate_fn_map[bytes] = collate_str_fn
+
+
+def default_collate(batch):
+    r"""
+    Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
+
+    The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
+    Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
+    This is used as the default function for collation when
+    `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
+
+    Here is the general input type (based on the type of the element within the batch) to output type mapping:
+
+        * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
+        * NumPy Arrays -> :class:`torch.Tensor`
+        * `float` -> :class:`torch.Tensor`
+        * `int` -> :class:`torch.Tensor`
+        * `str` -> `str` (unchanged)
+        * `bytes` -> `bytes` (unchanged)
+        * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
+        * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
+          default_collate([V2_1, V2_2, ...]), ...]`
+        * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
+          default_collate([V2_1, V2_2, ...]), ...]`
+
+    Args:
+        batch: a single batch to be collated
+
+    Examples:
+        >>> # xdoctest: +SKIP
+        >>> # Example with a batch of `int`s:
+        >>> default_collate([0, 1, 2, 3])
+        tensor([0, 1, 2, 3])
+        >>> # Example with a batch of `str`s:
+        >>> default_collate(['a', 'b', 'c'])
+        ['a', 'b', 'c']
+        >>> # Example with `Map` inside the batch:
+        >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
+        {'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
+        >>> # Example with `NamedTuple` inside the batch:
+        >>> Point = namedtuple('Point', ['x', 'y'])
+        >>> default_collate([Point(0, 0), Point(1, 1)])
+        Point(x=tensor([0, 1]), y=tensor([0, 1]))
+        >>> # Example with `Tuple` inside the batch:
+        >>> default_collate([(0, 1), (2, 3)])
+        [tensor([0, 2]), tensor([1, 3])]
+        >>> # Example with `List` inside the batch:
+        >>> default_collate([[0, 1], [2, 3]])
+        [tensor([0, 2]), tensor([1, 3])]
+        >>> # Two options to extend `default_collate` to handle specific type
+        >>> # Option 1: Write custom collate function and invoke `default_collate`
+        >>> def custom_collate(batch):
+        ...     elem = batch[0]
+        ...     if isinstance(elem, CustomType):  # Some custom condition
+        ...         return ...
+        ...     else:  # Fall back to `default_collate`
+        ...         return default_collate(batch)
+        >>> # Option 2: In-place modify `default_collate_fn_map`
+        >>> def collate_customtype_fn(batch, *, collate_fn_map=None):
+        ...     return ...
+        >>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
+        >>> default_collate(batch)  # Handle `CustomType` automatically
+    """
+    return collate(batch, collate_fn_map=default_collate_fn_map)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fa6c49404f676ad3811080ff9631e49fb275513
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py
@@ -0,0 +1,55 @@
+# mypy: allow-untyped-defs
+r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset.
+
+This logic is shared in both single- and multi-processing data loading.
+"""
+
+
+class _BaseDatasetFetcher:
+    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
+        self.dataset = dataset
+        self.auto_collation = auto_collation
+        self.collate_fn = collate_fn
+        self.drop_last = drop_last
+
+    def fetch(self, possibly_batched_index):
+        raise NotImplementedError
+
+
+class _IterableDatasetFetcher(_BaseDatasetFetcher):
+    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
+        super().__init__(dataset, auto_collation, collate_fn, drop_last)
+        self.dataset_iter = iter(dataset)
+        self.ended = False
+
+    def fetch(self, possibly_batched_index):
+        if self.ended:
+            raise StopIteration
+
+        if self.auto_collation:
+            data = []
+            for _ in possibly_batched_index:
+                try:
+                    data.append(next(self.dataset_iter))
+                except StopIteration:
+                    self.ended = True
+                    break
+            if len(data) == 0 or (
+                self.drop_last and len(data) < len(possibly_batched_index)
+            ):
+                raise StopIteration
+        else:
+            data = next(self.dataset_iter)
+        return self.collate_fn(data)
+
+
+class _MapDatasetFetcher(_BaseDatasetFetcher):
+    def fetch(self, possibly_batched_index):
+        if self.auto_collation:
+            if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
+                data = self.dataset.__getitems__(possibly_batched_index)
+            else:
+                data = [self.dataset[idx] for idx in possibly_batched_index]
+        else:
+            data = self.dataset[possibly_batched_index]
+        return self.collate_fn(data)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..0efe39854fb0af6b10c76a264db329e647b0e801
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py
@@ -0,0 +1,110 @@
+# mypy: allow-untyped-defs
+r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory.
+
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+"""
+
+import collections
+import copy
+import queue
+
+import torch
+from torch._utils import ExceptionWrapper
+
+from . import MP_STATUS_CHECK_INTERVAL
+
+
+def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
+    # This setting is thread local, and prevents the copy in pin_memory from
+    # consuming all CPU cores.
+    torch.set_num_threads(1)
+
+    torch.multiprocessing._set_thread_name("pt_data_pin")
+
+    if device == "cuda":
+        torch.cuda.set_device(device_id)
+    elif device == "xpu":
+        torch.xpu.set_device(device_id)  # type: ignore[attr-defined]
+    elif device == torch._C._get_privateuse1_backend_name():
+        custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
+        custom_device_mod.set_device(device_id)
+    elif device is None:
+        torch.accelerator.set_device_index(device_id)
+
+    def do_one_step():
+        try:
+            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+        except queue.Empty:
+            return
+        idx, data = r
+        if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
+            try:
+                data = pin_memory(data, device)
+            except Exception:
+                data = ExceptionWrapper(
+                    where=f"in pin memory thread for device {device_id}"
+                )
+            r = (idx, data)
+        while not done_event.is_set():
+            try:
+                out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
+                break
+            except queue.Full:
+                continue
+
+    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+    # logic of this function.
+    while not done_event.is_set():
+        # Make sure that we don't preserve any object from one iteration
+        # to the next
+        do_one_step()
+
+
+def pin_memory(data, device=None):
+    if isinstance(data, torch.Tensor):
+        return data.pin_memory(device)
+    elif isinstance(data, (str, bytes)):
+        return data
+    elif isinstance(data, collections.abc.Mapping):
+        try:
+            if isinstance(data, collections.abc.MutableMapping):
+                # The sequence type may have extra properties, so we can't just
+                # use `type(data)(...)` to create the new sequence.
+                # Create a clone and update it if the sequence type is mutable.
+                clone = copy.copy(data)
+                clone.update(
+                    {k: pin_memory(sample, device) for k, sample in data.items()}
+                )
+                return clone
+            else:
+                return type(data)({k: pin_memory(sample, device) for k, sample in data.items()})  # type: ignore[call-arg]
+        except TypeError:
+            # The mapping type may not support `copy()` / `update(mapping)`
+            # or `__init__(iterable)`.
+            return {k: pin_memory(sample, device) for k, sample in data.items()}
+    elif isinstance(data, tuple) and hasattr(data, "_fields"):  # namedtuple
+        return type(data)(*(pin_memory(sample, device) for sample in data))
+    elif isinstance(data, tuple):
+        return [
+            pin_memory(sample, device) for sample in data
+        ]  # Backwards compatibility.
+    elif isinstance(data, collections.abc.Sequence):
+        try:
+            if isinstance(data, collections.abc.MutableSequence):
+                # The sequence type may have extra properties, so we can't just
+                # use `type(data)(...)` to create the new sequence.
+                # Create a clone and update it if the sequence type is mutable.
+                clone = copy.copy(data)  # type: ignore[arg-type]
+                for i, item in enumerate(data):
+                    clone[i] = pin_memory(item, device)
+                return clone
+            return type(data)([pin_memory(sample, device) for sample in data])  # type: ignore[call-arg]
+        except TypeError:
+            # The sequence type may not support `copy()` / `__setitem__(index, item)`
+            # or `__init__(iterable)` (e.g., `range`).
+            return [pin_memory(sample, device) for sample in data]
+    elif hasattr(data, "pin_memory"):
+        return data.pin_memory()
+    else:
+        return data
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/signal_handling.py b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/signal_handling.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d54f05e360e11e679345334988c8d416e58104
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/signal_handling.py
@@ -0,0 +1,79 @@
+# mypy: allow-untyped-defs
+r"""Signal handling for multiprocessing data loading.
+
+NOTE [ Signal handling in multiprocessing data loading ]
+
+In cases like DataLoader, if a worker process dies due to bus error/segfault
+or just hang, the main process will hang waiting for data. This is difficult
+to avoid on PyTorch side as it can be caused by limited shm, or other
+libraries users call in the workers. In this file and `DataLoader.cpp`, we make
+our best effort to provide some error message to users when such unfortunate
+events happen.
+
+When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
+defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
+via `_set_worker_pids`.
+
+When an error happens in a worker process, the main process received a SIGCHLD,
+and Python will eventually call the handler registered below
+(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
+call checks all registered worker pids and raise proper error message to
+prevent main process from hanging waiting for data from worker.
+
+Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
+`_set_worker_signal_handlers` is called to register critical signal handlers
+(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
+message to stderr before triggering the default handler. So a message will also
+be printed from the worker process when it is killed by such signals.
+
+See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
+this signal handling design and other mechanism we implement to make our
+multiprocessing data loading robust to errors.
+"""
+
+import signal
+import threading
+
+# Some of the following imported functions are not used in this file, but are to
+# be used `_utils.signal_handling.XXXXX`.
+from torch._C import (  # noqa: F401
+    _error_if_any_worker_fails,
+    _remove_worker_pids,
+    _set_worker_pids,
+    _set_worker_signal_handlers,
+)
+
+from . import IS_WINDOWS
+
+
+_SIGCHLD_handler_set = False
+r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
+handler needs to be set for all DataLoaders in a process."""
+
+
+def _set_SIGCHLD_handler():
+    # Windows doesn't support SIGCHLD handler
+    if IS_WINDOWS:
+        return
+    # can't set signal in child threads
+    if not isinstance(threading.current_thread(), threading._MainThread):  # type: ignore[attr-defined]
+        return
+    global _SIGCHLD_handler_set
+    if _SIGCHLD_handler_set:
+        return
+    previous_handler = signal.getsignal(signal.SIGCHLD)
+    if not callable(previous_handler):
+        # This doesn't catch default handler, but SIGCHLD default handler is a
+        # no-op.
+        previous_handler = None
+
+    def handler(signum, frame):
+        # This following call uses `waitid` with WNOHANG from C side. Therefore,
+        # Python can still get and update the process status successfully.
+        _error_if_any_worker_fails()
+        if previous_handler is not None:
+            assert callable(previous_handler)
+            previous_handler(signum, frame)
+
+    signal.signal(signal.SIGCHLD, handler)
+    _SIGCHLD_handler_set = True
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a275e2e86b6fff601bc5799f5f301faccf933c8f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py
@@ -0,0 +1,374 @@
+# mypy: allow-untyped-defs
+r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
+
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+"""
+
+import os
+import queue
+import random
+from dataclasses import dataclass
+from typing import Optional, TYPE_CHECKING, Union
+
+import torch
+from torch._utils import ExceptionWrapper
+
+from . import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling
+
+
+if TYPE_CHECKING:
+    from torch.utils.data import Dataset
+
+if IS_WINDOWS:
+    import ctypes
+    from ctypes.wintypes import BOOL, DWORD, HANDLE
+
+    # On Windows, the parent ID of the worker process remains unchanged when the manager process
+    # is gone, and the only way to check it through OS is to let the worker have a process handle
+    # of the manager and ask if the process status has changed.
+    class ManagerWatchdog:
+        def __init__(self) -> None:
+            self.manager_pid = os.getppid()
+
+            # mypy cannot detect this code is windows only
+            self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)  # type: ignore[attr-defined]
+            self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
+            self.kernel32.OpenProcess.restype = HANDLE
+            self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
+            self.kernel32.WaitForSingleObject.restype = DWORD
+
+            # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
+            SYNCHRONIZE = 0x00100000
+            self.manager_handle = self.kernel32.OpenProcess(
+                SYNCHRONIZE, 0, self.manager_pid
+            )
+
+            if not self.manager_handle:
+                raise ctypes.WinError(ctypes.get_last_error())  # type: ignore[attr-defined]
+
+            self.manager_dead = False
+
+        def is_alive(self):
+            if not self.manager_dead:
+                # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
+                self.manager_dead = (
+                    self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
+                )
+            return not self.manager_dead
+
+else:
+
+    class ManagerWatchdog:  # type: ignore[no-redef]
+        def __init__(self) -> None:
+            self.manager_pid = os.getppid()
+            self.manager_dead = False
+
+        def is_alive(self):
+            if not self.manager_dead:
+                self.manager_dead = os.getppid() != self.manager_pid
+            return not self.manager_dead
+
+
+_worker_info: Optional["WorkerInfo"] = None
+
+
+class WorkerInfo:
+    id: int
+    num_workers: int
+    seed: int
+    dataset: "Dataset"
+    __initialized = False
+
+    def __init__(self, **kwargs):
+        for k, v in kwargs.items():
+            setattr(self, k, v)
+        self.__keys = tuple(kwargs.keys())
+        self.__initialized = True
+
+    def __setattr__(self, key, val):
+        if self.__initialized:
+            raise RuntimeError(
+                f"Cannot assign attributes to {self.__class__.__name__} objects"
+            )
+        return super().__setattr__(key, val)
+
+    def __repr__(self):
+        items = [f"{k}={getattr(self, k)}" for k in self.__keys]
+        return f"{self.__class__.__name__}({', '.join(items)})"
+
+
+def get_worker_info() -> Optional[WorkerInfo]:
+    r"""Returns the information about the current
+    :class:`~torch.utils.data.DataLoader` iterator worker process.
+
+    When called in a worker, this returns an object guaranteed to have the
+    following attributes:
+
+    * :attr:`id`: the current worker id.
+    * :attr:`num_workers`: the total number of workers.
+    * :attr:`seed`: the random seed set for the current worker. This value is
+      determined by main process RNG and the worker id. See
+      :class:`~torch.utils.data.DataLoader`'s documentation for more details.
+    * :attr:`dataset`: the copy of the dataset object in **this** process. Note
+      that this will be a different object in a different process than the one
+      in the main process.
+
+    When called in the main process, this returns ``None``.
+
+    .. note::
+       When used in a :attr:`worker_init_fn` passed over to
+       :class:`~torch.utils.data.DataLoader`, this method can be useful to
+       set up each worker process differently, for instance, using ``worker_id``
+       to configure the ``dataset`` object to only read a specific fraction of a
+       sharded dataset, or use ``seed`` to seed other libraries used in dataset
+       code.
+    """
+    return _worker_info
+
+
+r"""Dummy class used to signal the end of an IterableDataset"""
+
+
+@dataclass(frozen=True)
+class _IterableDatasetStopIteration:
+    worker_id: int
+
+
+r"""Dummy class used to resume the fetching when worker reuse is enabled"""
+
+
+@dataclass(frozen=True)
+class _ResumeIteration:
+    seed: Optional[int] = None
+
+
+# The function `_generate_state` is adapted from `numpy.random.SeedSequence`
+# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
+# It's MIT licensed, here is the copyright:
+
+# Copyright (c) 2015 Melissa E. O'Neill
+# Copyright (c) 2019 NumPy Developers
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+# This function generates an array of int32 as the seed for
+# `numpy.random`, in order to prevent state collision due to same
+# seed and algorithm for `numpy.random` and `random` modules.
+# TODO: Implement `SeedSequence` like object for `torch.random`
+def _generate_state(base_seed, worker_id):
+    INIT_A = 0x43B0D7E5
+    MULT_A = 0x931E8875
+    INIT_B = 0x8B51F9DD
+    MULT_B = 0x58F38DED
+    MIX_MULT_L = 0xCA01F9DD
+    MIX_MULT_R = 0x4973F715
+    XSHIFT = 4 * 8 // 2
+    MASK32 = 0xFFFFFFFF
+
+    entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
+    pool = [0] * 4
+
+    hash_const_A = INIT_A
+
+    def hash(value):
+        nonlocal hash_const_A
+        value = (value ^ hash_const_A) & MASK32
+        hash_const_A = (hash_const_A * MULT_A) & MASK32
+        value = (value * hash_const_A) & MASK32
+        value = (value ^ (value >> XSHIFT)) & MASK32
+        return value
+
+    def mix(x, y):
+        result_x = (MIX_MULT_L * x) & MASK32
+        result_y = (MIX_MULT_R * y) & MASK32
+        result = (result_x - result_y) & MASK32
+        result = (result ^ (result >> XSHIFT)) & MASK32
+        return result
+
+    # Add in the entropy to the pool.
+    for i in range(len(pool)):
+        pool[i] = hash(entropy[i])
+
+    # Mix all bits together so late bits can affect earlier bits.
+    for i_src in range(len(pool)):
+        for i_dst in range(len(pool)):
+            if i_src != i_dst:
+                pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
+
+    hash_const_B = INIT_B
+    state = []
+    for i_dst in range(4):
+        data_val = pool[i_dst]
+        data_val = (data_val ^ hash_const_B) & MASK32
+        hash_const_B = (hash_const_B * MULT_B) & MASK32
+        data_val = (data_val * hash_const_B) & MASK32
+        data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
+        state.append(data_val)
+    return state
+
+
+def _worker_loop(
+    dataset_kind,
+    dataset,
+    index_queue,
+    data_queue,
+    done_event,
+    auto_collation,
+    collate_fn,
+    drop_last,
+    base_seed,
+    init_fn,
+    worker_id,
+    num_workers,
+    persistent_workers,
+    shared_seed,
+):
+    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+    # logic of this function.
+
+    try:
+        # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
+        # module's handlers are executed after Python returns from C low-level
+        # handlers, likely when the same fatal signal had already happened
+        # again.
+        # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
+        signal_handling._set_worker_signal_handlers()
+
+        torch.multiprocessing._set_thread_name("pt_data_worker")
+
+        torch.set_num_threads(1)
+        seed = base_seed + worker_id
+        random.seed(seed)
+        torch.manual_seed(seed)
+        if HAS_NUMPY:
+            np_seed = _generate_state(base_seed, worker_id)
+            import numpy as np
+
+            np.random.seed(np_seed)
+
+        from torch.utils.data import IterDataPipe
+        from torch.utils.data.graph_settings import apply_random_seed
+
+        shared_rng = torch.Generator()
+        if isinstance(dataset, IterDataPipe):
+            assert shared_seed is not None
+            shared_rng.manual_seed(shared_seed)
+            dataset = apply_random_seed(dataset, shared_rng)
+
+        global _worker_info
+        _worker_info = WorkerInfo(
+            id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
+        )
+
+        from torch.utils.data import _DatasetKind
+
+        init_exception = None
+
+        try:
+            if init_fn is not None:
+                init_fn(worker_id)
+
+            fetcher = _DatasetKind.create_fetcher(
+                dataset_kind, dataset, auto_collation, collate_fn, drop_last
+            )
+        except Exception:
+            init_exception = ExceptionWrapper(
+                where=f"in DataLoader worker process {worker_id}"
+            )
+
+        # When using Iterable mode, some worker can exit earlier than others due
+        # to the IterableDataset behaving differently for different workers.
+        # When such things happen, an `_IterableDatasetStopIteration` object is
+        # sent over to the main process with the ID of this worker, so that the
+        # main process won't send more tasks to this worker, and will send
+        # `None` to this worker to properly exit it.
+        #
+        # Note that we cannot set `done_event` from a worker as it is shared
+        # among all processes. Instead, we set the `iteration_end` flag to
+        # signify that the iterator is exhausted. When either `done_event` or
+        # `iteration_end` is set, we skip all processing step and just wait for
+        # `None`.
+        iteration_end = False
+
+        watchdog = ManagerWatchdog()
+
+        while watchdog.is_alive():
+            try:
+                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+            except queue.Empty:
+                continue
+            if isinstance(r, _ResumeIteration):
+                # Acknowledge the main process
+                data_queue.put((r, None))
+                iteration_end = False
+
+                if isinstance(dataset, IterDataPipe):
+                    assert r.seed is not None
+                    shared_rng.manual_seed(r.seed)
+                    dataset = apply_random_seed(dataset, shared_rng)
+
+                # Recreate the fetcher for worker-reuse policy
+                fetcher = _DatasetKind.create_fetcher(
+                    dataset_kind, dataset, auto_collation, collate_fn, drop_last
+                )
+                continue
+            elif r is None:
+                # Received the final signal
+                assert done_event.is_set() or iteration_end
+                break
+            elif done_event.is_set() or iteration_end:
+                # `done_event` is set. But I haven't received the final signal
+                # (None) yet. I will keep continuing until get it, and skip the
+                # processing steps.
+                continue
+            idx, index = r
+            data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
+            if init_exception is not None:
+                data = init_exception
+                init_exception = None
+            else:
+                try:
+                    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
+                except Exception as e:
+                    if (
+                        isinstance(e, StopIteration)
+                        and dataset_kind == _DatasetKind.Iterable
+                    ):
+                        data = _IterableDatasetStopIteration(worker_id)
+                        # Set `iteration_end`
+                        #   (1) to save future `next(...)` calls, and
+                        #   (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
+                        iteration_end = True
+                    else:
+                        # It is important that we don't store exc_info in a variable.
+                        # `ExceptionWrapper` does the correct thing.
+                        # See NOTE [ Python Traceback Reference Cycle Problem ]
+                        data = ExceptionWrapper(
+                            where=f"in DataLoader worker process {worker_id}"
+                        )
+            data_queue.put((idx, data))
+            del data, idx, index, r  # save memory
+    except KeyboardInterrupt:
+        # Main process will raise KeyboardInterrupt anyways.
+        pass
+    if done_event.is_set():
+        data_queue.cancel_join_thread()
+        data_queue.close()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac93de335b2d7379246de9cee658dd9eafe1d303
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__init__.py
@@ -0,0 +1 @@
+from torch.utils.data.datapipes import dataframe as dataframe, iter as iter, map as map
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db176a99052ff1819d8c99496f4d4baf1e94a067
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9b614d98cae13b174dd6dd815f38420dc94a4d9
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de248abaa98f9a0a2aa9efad89e307c6b22dfacc
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6eb4637b55e07e599f2932781583d45dab11e51c
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2808e46e0cac5183caa375b4297f6bfd55c3073
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_decorator.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_decorator.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e28a19d626643f62ae7988e15f5cce20720ce9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_decorator.py
@@ -0,0 +1,213 @@
+# mypy: allow-untyped-defs
+import inspect
+from functools import wraps
+from typing import Any, Callable, get_type_hints, Optional, Union
+
+from torch.utils.data.datapipes._typing import _DataPipeMeta
+from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
+
+
+######################################################
+# Functional API
+######################################################
+class functional_datapipe:
+    name: str
+
+    def __init__(self, name: str, enable_df_api_tracing=False) -> None:
+        """
+        Define a functional datapipe.
+
+        Args:
+            enable_df_api_tracing - if set, any returned DataPipe would accept
+            DataFrames API in tracing mode.
+        """
+        self.name = name
+        self.enable_df_api_tracing = enable_df_api_tracing
+
+    def __call__(self, cls):
+        if issubclass(cls, IterDataPipe):
+            if isinstance(cls, type):  # type: ignore[arg-type]
+                if not isinstance(cls, _DataPipeMeta):
+                    raise TypeError(
+                        "`functional_datapipe` can only decorate IterDataPipe"
+                    )
+            # with non_deterministic decorator
+            else:
+                if not isinstance(cls, non_deterministic) and not (
+                    hasattr(cls, "__self__")
+                    and isinstance(cls.__self__, non_deterministic)
+                ):
+                    raise TypeError(
+                        "`functional_datapipe` can only decorate IterDataPipe"
+                    )
+            IterDataPipe.register_datapipe_as_function(
+                self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing
+            )
+        elif issubclass(cls, MapDataPipe):
+            MapDataPipe.register_datapipe_as_function(self.name, cls)
+
+        return cls
+
+
+######################################################
+# Determinism
+######################################################
+_determinism: bool = False
+
+
+class guaranteed_datapipes_determinism:
+    prev: bool
+
+    def __init__(self) -> None:
+        global _determinism
+        self.prev = _determinism
+        _determinism = True
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        global _determinism
+        _determinism = self.prev
+
+
+class non_deterministic:
+    cls: Optional[type[IterDataPipe]] = None
+    # TODO: Lambda for picking
+    deterministic_fn: Callable[[], bool]
+
+    def __init__(self, arg: Union[type[IterDataPipe], Callable[[], bool]]) -> None:
+        # 1. Decorator doesn't have any argument
+        if isinstance(arg, type):  # type: ignore[arg-type]
+            if not issubclass(arg, IterDataPipe):  # type: ignore[arg-type]
+                raise TypeError(
+                    "Only `IterDataPipe` can be decorated with `non_deterministic`"
+                    f", but {arg.__name__} is found"
+                )
+            self.cls = arg  # type: ignore[assignment]
+        # 2. Decorator has an argument of a function
+        #    This class should behave differently given different inputs. Use this
+        #    function to verify the determinism for each instance.
+        #    When the function returns True, the instance is non-deterministic. Otherwise,
+        #    the instance is a deterministic DataPipe.
+        elif isinstance(arg, Callable):  # type:ignore[arg-type]
+            self.deterministic_fn = arg  # type: ignore[assignment, misc]
+        else:
+            raise TypeError(f"{arg} can not be decorated by non_deterministic")
+
+    def __call__(self, *args, **kwargs):
+        global _determinism
+        #  Decorate IterDataPipe
+        if self.cls is not None:
+            if _determinism:
+                raise TypeError(
+                    f"{self.cls.__name__} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. "
+                    "You can turn off determinism for this DataPipe if that is acceptable "
+                    "for your application"
+                )
+            return self.cls(*args, **kwargs)  # type: ignore[call-arg]
+
+        # Decorate with a functional argument
+        if not (
+            isinstance(args[0], type)
+            and issubclass(args[0], IterDataPipe)  # type: ignore[arg-type]
+        ):
+            raise TypeError(
+                f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found"
+            )
+        self.cls = args[0]
+        return self.deterministic_wrapper_fn
+
+    def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe:
+        res = self.deterministic_fn(*args, **kwargs)  # type: ignore[call-arg, misc]
+        if not isinstance(res, bool):
+            raise TypeError(
+                "deterministic_fn of `non_deterministic` decorator is required "
+                f"to return a boolean value, but {type(res)} is found"
+            )
+        global _determinism
+        if _determinism and res:
+            raise TypeError(
+                f"{self.cls.__name__} is non-deterministic with the inputs, but you set "  # type: ignore[union-attr]
+                "'guaranteed_datapipes_determinism'. You can turn off determinism "
+                "for this DataPipe if that is acceptable for your application"
+            )
+        return self.cls(*args, **kwargs)  # type: ignore[call-arg, misc]
+
+
+######################################################
+# Type validation
+######################################################
+# Validate each argument of DataPipe with hint as a subtype of the hint.
+def argument_validation(f):
+    signature = inspect.signature(f)
+    hints = get_type_hints(f)
+
+    @wraps(f)
+    def wrapper(*args, **kwargs):
+        bound = signature.bind(*args, **kwargs)
+        for argument_name, value in bound.arguments.items():
+            if argument_name in hints and isinstance(
+                hints[argument_name], _DataPipeMeta
+            ):
+                hint = hints[argument_name]
+                if not isinstance(value, IterDataPipe):
+                    raise TypeError(
+                        f"Expected argument '{argument_name}' as a IterDataPipe, but found {type(value)}"
+                    )
+                if not value.type.issubtype(hint.type):
+                    raise TypeError(
+                        f"Expected type of argument '{argument_name}' as a subtype of "
+                        f"hint {hint.type}, but found {value.type}"
+                    )
+
+        return f(*args, **kwargs)
+
+    return wrapper
+
+
+# Default value is True
+_runtime_validation_enabled: bool = True
+
+
+class runtime_validation_disabled:
+    prev: bool
+
+    def __init__(self) -> None:
+        global _runtime_validation_enabled
+        self.prev = _runtime_validation_enabled
+        _runtime_validation_enabled = False
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+        global _runtime_validation_enabled
+        _runtime_validation_enabled = self.prev
+
+
+# Runtime checking
+# Validate output data is subtype of return hint
+def runtime_validation(f):
+    # TODO:
+    # Can be extended to validate '__getitem__' and nonblocking
+    if f.__name__ != "__iter__":
+        raise TypeError(
+            f"Can not decorate function {f.__name__} with 'runtime_validation'"
+        )
+
+    @wraps(f)
+    def wrapper(self):
+        global _runtime_validation_enabled
+        if not _runtime_validation_enabled:
+            yield from f(self)
+        else:
+            it = f(self)
+            for d in it:
+                if not self.type.issubtype_of_instance(d):
+                    raise RuntimeError(
+                        f"Expected an instance as subtype of {self.type}, but found {d}({type(d)})"
+                    )
+                yield d
+
+    return wrapper
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_hook_iterator.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_hook_iterator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae42f75885c1da320e28fdfc87767a87d61e4271
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_hook_iterator.py
@@ -0,0 +1,279 @@
+# mypy: allow-untyped-defs
+import functools
+import inspect
+from enum import Enum
+
+import torch
+
+
+class _SnapshotState(Enum):
+    r"""
+    These are the snapshotting-related states that IterDataPipes can be in.
+
+    `NotStarted` - allows you to restore a snapshot and create an iterator with reset
+    `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe
+    `Iterating` - can restore, will reset if you create a new iterator
+    """
+
+    NotStarted = 0
+    Restored = 1
+    Iterating = 2
+
+
+def _simplify_obj_name(obj) -> str:
+    """Simplify the display strings of objects for the purpose of rendering within DataPipe error messages."""
+    if inspect.isfunction(obj):
+        return obj.__name__
+    else:
+        return repr(obj)
+
+
+def _strip_datapipe_from_name(name: str) -> str:
+    return name.replace("IterDataPipe", "").replace("MapDataPipe", "")
+
+
+def _generate_input_args_string(obj):
+    """Generate a string for the input arguments of an object."""
+    signature = inspect.signature(obj.__class__)
+    input_param_names = set(signature.parameters.keys())
+    result = []
+    for name, value in inspect.getmembers(obj):
+        if name in input_param_names:
+            result.append((name, _simplify_obj_name(value)))
+    return ", ".join([f"{name}={value}" for name, value in result])
+
+
+def _generate_iterdatapipe_msg(datapipe, simplify_dp_name: bool = False):
+    output_string = (
+        f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
+    )
+    if simplify_dp_name:
+        output_string = _strip_datapipe_from_name(output_string)
+    return output_string
+
+
+def _gen_invalid_iterdatapipe_msg(datapipe):
+    return (
+        "This iterator has been invalidated because another iterator has been created "
+        f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n"
+        "This may be caused multiple references to the same IterDataPipe. We recommend "
+        "using `.fork()` if that is necessary."
+    )
+
+
+_feedback_msg = (
+    "\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free "
+    "to comment on this issue: https://github.com/pytorch/data/issues/45."
+)
+
+
+def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None:
+    r"""
+    Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception.
+
+    In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well.
+    """
+    if next_method_exists:
+        # This is the case where `IterDataPipe` has both `__iter__` and `__next__`.
+        # The `_valid_iterator_id` should either be never set (`None`), or set by at most one
+        # iterator (`0`). Otherwise, it means there are multiple iterators.
+        if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0:
+            extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method"
+            raise RuntimeError(
+                _gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg
+            )
+    elif (
+        hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True
+    ):
+        if hasattr(datapipe, "_check_valid_iterator_id"):
+            if not datapipe._check_valid_iterator_id(iterator_id):
+                raise RuntimeError(
+                    "This iterator has been invalidated, because a new iterator has been created "
+                    f"from one of the ChildDataPipes of "
+                    f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}."
+                    + _feedback_msg
+                )
+        else:
+            raise RuntimeError(
+                "ChildDataPipe must have method `_check_valid_iterator_id`."
+            )
+    elif datapipe._valid_iterator_id != iterator_id:
+        raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg)
+
+
+def _set_datapipe_valid_iterator_id(datapipe):
+    """Given a DataPipe, updates its valid iterator ID and reset the DataPipe."""
+    if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
+        if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"):
+            datapipe._set_main_datapipe_valid_iterator_id()  # reset() is called within this method when appropriate
+        else:
+            raise RuntimeError(
+                "ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`."
+            )
+    else:
+        if datapipe._valid_iterator_id is None:
+            datapipe._valid_iterator_id = 0
+        else:
+            datapipe._valid_iterator_id += 1
+        datapipe.reset()
+    return datapipe._valid_iterator_id
+
+
+def hook_iterator(namespace):
+    r"""
+    Define a hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`.
+
+    This is done for the purpose of profiling and checking if an iterator is still valid.
+    """
+
+    def profiler_record_fn_context(datapipe):
+        if not hasattr(datapipe, "_profile_name"):
+            datapipe._profile_name = _generate_iterdatapipe_msg(
+                datapipe, simplify_dp_name=True
+            )
+        return torch.autograd.profiler.record_function(datapipe._profile_name)
+
+    class IteratorDecorator:
+        r"""
+        Wrap the iterator and modifying its `__next__` method.
+
+        This decorator is applied to DataPipes of which `__iter__` method is NOT a generator function.
+        Those `__iter__` method commonly returns `self` but not necessarily.
+        """
+
+        def __init__(self, iterator, datapipe, iterator_id, has_next_method):
+            self.iterator = iterator
+            self.datapipe = datapipe
+            self.iterator_id = iterator_id
+            self._profiler_enabled = torch.autograd._profiler_enabled()
+            # Check if `__iter__` returns `self` and `DataPipe` has `__next__`
+            self.self_and_has_next_method = (
+                self.iterator is self.datapipe and has_next_method
+            )
+
+        def __iter__(self):
+            return self
+
+        def _get_next(self):
+            """Return next with logic related to iterator validity, profiler, and incrementation of samples yielded."""
+            _check_iterator_valid(self.datapipe, self.iterator_id)
+            result = next(self.iterator)
+            if not self.self_and_has_next_method:
+                self.datapipe._number_of_samples_yielded += 1
+            return result
+
+        def __next__(self):
+            # TODO: Add try-except to in-place reduce traceback from the Exception
+            # See: https://github.com/pytorch/data/issues/284
+            if self._profiler_enabled:
+                with profiler_record_fn_context(self.datapipe):
+                    return self._get_next()
+            else:  # Decided against using `contextlib.nullcontext` for performance reasons
+                return self._get_next()
+
+        def __getattr__(self, name):
+            return getattr(self.iterator, name)
+
+    func = namespace["__iter__"]
+
+    # ``__iter__`` of IterDataPipe is a generator function
+    if inspect.isgeneratorfunction(func):
+
+        @functools.wraps(func)
+        def wrap_generator(*args, **kwargs):
+            gen = func(*args, **kwargs)
+            datapipe = args[0]
+            if datapipe._fast_forward_iterator:
+                it = datapipe._fast_forward_iterator
+                datapipe._fast_forward_iterator = None
+                datapipe._snapshot_state = _SnapshotState.Iterating
+                while True:
+                    try:
+                        yield next(it)
+                    except StopIteration:
+                        return
+            iterator_id = _set_datapipe_valid_iterator_id(
+                datapipe
+            )  # This ID is tied to each created iterator
+            _profiler_enabled = torch.autograd._profiler_enabled()
+            try:
+                if _profiler_enabled:
+                    with profiler_record_fn_context(datapipe):
+                        response = gen.send(None)
+                else:
+                    response = gen.send(None)
+
+                while True:
+                    datapipe._number_of_samples_yielded += 1
+                    request = yield response
+                    # Pass through here every time `__next__` is called
+                    if _profiler_enabled:
+                        with profiler_record_fn_context(datapipe):
+                            _check_iterator_valid(datapipe, iterator_id)
+                            response = gen.send(request)
+                    else:  # Decided against using `contextlib.nullcontext` for performance reasons
+                        _check_iterator_valid(datapipe, iterator_id)
+                        response = gen.send(request)
+            except StopIteration:
+                return
+            except Exception as e:
+                # TODO: Simplify the traceback message to skip over `response = gen.send(None)`
+                #       Part of https://github.com/pytorch/data/issues/284
+                datapipe = args[0]
+                msg = "thrown by __iter__ of"
+                single_iterator_msg = "single iterator per IterDataPipe constraint"
+                if hasattr(e.args, "__len__"):
+                    full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
+                    if len(e.args) == 0 or not isinstance(
+                        e.args[0], str
+                    ):  # If an exception message doesn't exist
+                        e.args = (f"\nThis exception is {full_msg}",)
+                    elif msg not in e.args[0] and single_iterator_msg not in e.args[0]:
+                        e.args = (
+                            e.args[0] + f"\nThis exception is {full_msg}",
+                        ) + e.args[1:]
+                raise
+
+        namespace["__iter__"] = wrap_generator
+    else:  # ``__iter__`` of IterDataPipe is NOT a generator function
+        # IterDataPipe is an iterator with both ``__iter__`` and ``__next__``
+        # And ``__iter__`` may or may not return `self`
+        if "__next__" in namespace:  # If `__next__` exists, put a wrapper around it
+            next_func = namespace["__next__"]
+
+            @functools.wraps(next_func)
+            def wrap_next(*args, **kwargs):
+                datapipe = args[0]
+                if torch.autograd._profiler_enabled():
+                    with profiler_record_fn_context(datapipe):
+                        result = next_func(*args, **kwargs)
+                else:
+                    result = next_func(*args, **kwargs)
+                datapipe._number_of_samples_yielded += 1
+                return result
+
+            namespace["__next__"] = wrap_next
+
+            # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
+            # the user will be violating the iterator protocol. Potential issue:
+            # 1. Valid iterator ID may not update or checked properly
+            # 2. The number of samples yielded will be miscounted
+
+        # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
+        @functools.wraps(func)
+        def wrap_iter(*args, **kwargs):
+            iter_ret = func(*args, **kwargs)
+            datapipe = args[0]
+            datapipe._snapshot_state = _SnapshotState.Iterating
+            if datapipe._fast_forward_iterator:
+                iter_ret = datapipe._fast_forward_iterator
+                datapipe._fast_forward_iterator = None
+                return iter_ret
+            iterator_id = _set_datapipe_valid_iterator_id(
+                datapipe
+            )  # This ID is tied to each created iterator
+            return IteratorDecorator(
+                iter_ret, datapipe, iterator_id, "__next__" in namespace
+            )
+
+        namespace["__iter__"] = wrap_iter
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df815358bd0e402ead283cba54157ca5d0aa383
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/_typing.py
@@ -0,0 +1,482 @@
+# mypy: allow-untyped-defs
+# Taking reference from official Python typing
+# https://github.com/python/cpython/blob/master/Lib/typing.py
+
+import collections
+import functools
+import numbers
+import sys
+
+# Please check [Note: TypeMeta and TypeAlias]
+# In case of metaclass conflict due to ABCMeta or _ProtocolMeta
+# For Python 3.9, only Protocol in typing uses metaclass
+from abc import ABCMeta
+from collections.abc import Iterator
+
+# TODO: Use TypeAlias when Python 3.6 is deprecated
+from typing import (  # type: ignore[attr-defined]
+    _eval_type,
+    _GenericAlias,
+    _tp_cache,
+    _type_check,
+    _type_repr,
+    Any,
+    ForwardRef,
+    Generic,
+    get_type_hints,
+    TypeVar,
+    Union,
+)
+
+from torch.utils.data.datapipes._hook_iterator import _SnapshotState, hook_iterator
+
+
+class GenericMeta(ABCMeta):  # type: ignore[no-redef]
+    pass
+
+
+class Integer(numbers.Integral):
+    pass
+
+
+class Boolean(numbers.Integral):
+    pass
+
+
+# Python 'type' object is not subscriptable
+# Tuple[int, List, dict] -> valid
+# tuple[int, list, dict] -> invalid
+# Map Python 'type' to abstract base class
+TYPE2ABC = {
+    bool: Boolean,
+    int: Integer,
+    float: numbers.Real,
+    complex: numbers.Complex,
+    dict: dict,
+    list: list,
+    set: set,
+    tuple: tuple,
+    None: type(None),
+}
+
+
+def issubtype(left, right, recursive=True):
+    r"""
+    Check if the left-side type is a subtype of the right-side type.
+
+    If any of type is a composite type like `Union` and `TypeVar` with
+    bounds, it would be expanded into a list of types and check all
+    of left-side types are subtypes of either one from right-side types.
+    """
+    left = TYPE2ABC.get(left, left)
+    right = TYPE2ABC.get(right, right)
+
+    if right is Any or left == right:
+        return True
+
+    if isinstance(right, _GenericAlias):
+        if getattr(right, "__origin__", None) is Generic:
+            return True
+
+    if right == type(None):
+        return False
+
+    # Right-side type
+    constraints = _decompose_type(right)
+
+    if len(constraints) == 0 or Any in constraints:
+        return True
+
+    if left is Any:
+        return False
+
+    # Left-side type
+    variants = _decompose_type(left)
+
+    # all() will return True for empty variants
+    if len(variants) == 0:
+        return False
+
+    return all(
+        _issubtype_with_constraints(variant, constraints, recursive)
+        for variant in variants
+    )
+
+
+def _decompose_type(t, to_list=True):
+    if isinstance(t, TypeVar):
+        if t.__bound__ is not None:
+            ts = [t.__bound__]
+        else:
+            # For T_co, __constraints__ is ()
+            ts = list(t.__constraints__)
+    elif hasattr(t, "__origin__") and t.__origin__ == Union:
+        ts = t.__args__
+    else:
+        if not to_list:
+            return None
+        ts = [t]
+    # Ignored: Generator has incompatible item type "object"; expected "Type[Any]"
+    ts = [TYPE2ABC.get(_t, _t) for _t in ts]  # type: ignore[misc]
+    return ts
+
+
+def _issubtype_with_constraints(variant, constraints, recursive=True):
+    r"""
+    Check if the variant is a subtype of either one from constraints.
+
+    For composite types like `Union` and `TypeVar` with bounds, they
+    would be expanded for testing.
+    """
+    if variant in constraints:
+        return True
+
+    # [Note: Subtype for Union and TypeVar]
+    # Python typing is able to flatten Union[Union[...]] or Union[TypeVar].
+    # But it couldn't flatten the following scenarios:
+    #   - Union[int, TypeVar[Union[...]]]
+    #   - TypeVar[TypeVar[...]]
+    # So, variant and each constraint may be a TypeVar or a Union.
+    # In these cases, all of inner types from the variant are required to be
+    # extraced and verified as a subtype of any constraint. And, all of
+    # inner types from any constraint being a TypeVar or a Union are
+    # also required to be extracted and verified if the variant belongs to
+    # any of them.
+
+    # Variant
+    vs = _decompose_type(variant, to_list=False)
+
+    # Variant is TypeVar or Union
+    if vs is not None:
+        return all(_issubtype_with_constraints(v, constraints, recursive) for v in vs)
+
+    # Variant is not TypeVar or Union
+    if hasattr(variant, "__origin__") and variant.__origin__ is not None:
+        v_origin = variant.__origin__
+        # In Python-3.9 typing library untyped generics do not have args
+        v_args = getattr(variant, "__args__", None)
+    else:
+        v_origin = variant
+        v_args = None
+
+    # Constraints
+    for constraint in constraints:
+        cs = _decompose_type(constraint, to_list=False)
+
+        # Constraint is TypeVar or Union
+        if cs is not None:
+            if _issubtype_with_constraints(variant, cs, recursive):
+                return True
+        # Constraint is not TypeVar or Union
+        else:
+            # __origin__ can be None for plain list, tuple, ... in Python 3.6
+            if hasattr(constraint, "__origin__") and constraint.__origin__ is not None:
+                c_origin = constraint.__origin__
+                if v_origin == c_origin:
+                    if not recursive:
+                        return True
+                    # In Python-3.9 typing library untyped generics do not have args
+                    c_args = getattr(constraint, "__args__", None)
+                    if c_args is None or len(c_args) == 0:
+                        return True
+                    if (
+                        v_args is not None
+                        and len(v_args) == len(c_args)
+                        and all(
+                            issubtype(v_arg, c_arg)
+                            for v_arg, c_arg in zip(v_args, c_args)
+                        )
+                    ):
+                        return True
+            # Tuple[int] -> Tuple
+            else:
+                if v_origin == constraint:
+                    return True
+
+    return False
+
+
+def issubinstance(data, data_type):
+    if not issubtype(type(data), data_type, recursive=False):
+        return False
+
+    # In Python-3.9 typing library __args__ attribute is not defined for untyped generics
+    dt_args = getattr(data_type, "__args__", None)
+    if isinstance(data, tuple):
+        if dt_args is None or len(dt_args) == 0:
+            return True
+        if len(dt_args) != len(data):
+            return False
+        return all(issubinstance(d, t) for d, t in zip(data, dt_args))
+    elif isinstance(data, (list, set)):
+        if dt_args is None or len(dt_args) == 0:
+            return True
+        t = dt_args[0]
+        return all(issubinstance(d, t) for d in data)
+    elif isinstance(data, dict):
+        if dt_args is None or len(dt_args) == 0:
+            return True
+        kt, vt = dt_args
+        return all(
+            issubinstance(k, kt) and issubinstance(v, vt) for k, v in data.items()
+        )
+
+    return True
+
+
+# [Note: TypeMeta and TypeAlias]
+# In order to keep compatibility for Python 3.6, use Meta for the typing.
+# TODO: When PyTorch drops the support for Python 3.6, it can be converted
+# into the Alias system and using `__class_getitem__` for DataPipe. The
+# typing system will gain benefit of performance and resolving metaclass
+# conflicts as elaborated in https://www.python.org/dev/peps/pep-0560/
+
+
+class _DataPipeType:
+    r"""Save type annotation in `param`."""
+
+    def __init__(self, param):
+        self.param = param
+
+    def __repr__(self):
+        return _type_repr(self.param)
+
+    def __eq__(self, other):
+        if isinstance(other, _DataPipeType):
+            return self.param == other.param
+        return NotImplemented
+
+    def __hash__(self):
+        return hash(self.param)
+
+    def issubtype(self, other):
+        if isinstance(other.param, _GenericAlias):
+            if getattr(other.param, "__origin__", None) is Generic:
+                return True
+        if isinstance(other, _DataPipeType):
+            return issubtype(self.param, other.param)
+        if isinstance(other, type):
+            return issubtype(self.param, other)
+        raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}")
+
+    def issubtype_of_instance(self, other):
+        return issubinstance(other, self.param)
+
+
+# Default type for DataPipe without annotation
+_T_co = TypeVar("_T_co", covariant=True)
+_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
+
+
+class _DataPipeMeta(GenericMeta):
+    r"""
+    Metaclass for `DataPipe`.
+
+    Add `type` attribute and `__init_subclass__` based on the type, and validate the return hint of `__iter__`.
+
+    Note that there is subclass `_IterDataPipeMeta` specifically for `IterDataPipe`.
+    """
+
+    type: _DataPipeType
+
+    def __new__(cls, name, bases, namespace, **kwargs):
+        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+        # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
+        cls.__origin__ = None
+        if "type" in namespace:
+            return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+        namespace["__type_class__"] = False
+        #  For plain derived class without annotation
+        for base in bases:
+            if isinstance(base, _DataPipeMeta):
+                return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+        namespace.update(
+            {"type": _DEFAULT_TYPE, "__init_subclass__": _dp_init_subclass}
+        )
+        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+    def __init__(self, name, bases, namespace, **kwargs):
+        super().__init__(name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+    # TODO: Fix isinstance bug
+    @_tp_cache
+    def _getitem_(self, params):
+        if params is None:
+            raise TypeError(f"{self.__name__}[t]: t can not be None")
+        if isinstance(params, str):
+            params = ForwardRef(params)
+        if not isinstance(params, tuple):
+            params = (params,)
+
+        msg = f"{self.__name__}[t]: t must be a type"
+        params = tuple(_type_check(p, msg) for p in params)
+
+        if isinstance(self.type.param, _GenericAlias):
+            orig = getattr(self.type.param, "__origin__", None)
+            if isinstance(orig, type) and orig is not Generic:
+                p = self.type.param[params]  # type: ignore[index]
+                t = _DataPipeType(p)
+                l = len(str(self.type)) + 2
+                name = self.__name__[:-l]
+                name = name + "[" + str(t) + "]"
+                bases = (self,) + self.__bases__
+                return self.__class__(
+                    name,
+                    bases,
+                    {
+                        "__init_subclass__": _dp_init_subclass,
+                        "type": t,
+                        "__type_class__": True,
+                    },
+                )
+
+        if len(params) > 1:
+            raise TypeError(
+                f"Too many parameters for {self} actual {len(params)}, expected 1"
+            )
+
+        t = _DataPipeType(params[0])
+
+        if not t.issubtype(self.type):
+            raise TypeError(
+                f"Can not subclass a DataPipe[{t}] from DataPipe[{self.type}]"
+            )
+
+        # Types are equal, fast path for inheritance
+        if self.type == t:
+            return self
+
+        name = self.__name__ + "[" + str(t) + "]"
+        bases = (self,) + self.__bases__
+
+        return self.__class__(
+            name,
+            bases,
+            {"__init_subclass__": _dp_init_subclass, "__type_class__": True, "type": t},
+        )
+
+    # TODO: Fix isinstance bug
+    def _eq_(self, other):
+        if not isinstance(other, _DataPipeMeta):
+            return NotImplemented
+        if self.__origin__ is None or other.__origin__ is None:  # type: ignore[has-type]
+            return self is other
+        return (
+            self.__origin__ == other.__origin__  # type: ignore[has-type]
+            and self.type == other.type
+        )
+
+    # TODO: Fix isinstance bug
+    def _hash_(self):
+        return hash((self.__name__, self.type))
+
+
+class _IterDataPipeMeta(_DataPipeMeta):
+    r"""
+    Metaclass for `IterDataPipe` and inherits from `_DataPipeMeta`.
+
+    Add various functions for behaviors specific to `IterDataPipe`.
+    """
+
+    def __new__(cls, name, bases, namespace, **kwargs):
+        if "reset" in namespace:
+            reset_func = namespace["reset"]
+
+            @functools.wraps(reset_func)
+            def conditional_reset(*args, **kwargs):
+                r"""
+                Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`.
+
+                This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call.
+                """
+                datapipe = args[0]
+                if datapipe._snapshot_state in (
+                    _SnapshotState.Iterating,
+                    _SnapshotState.NotStarted,
+                ):
+                    # Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have
+                    # already begun iterating.
+                    datapipe._number_of_samples_yielded = 0
+                    datapipe._fast_forward_iterator = None
+                    reset_func(*args, **kwargs)
+                datapipe._snapshot_state = _SnapshotState.Iterating
+
+            namespace["reset"] = conditional_reset
+
+        if "__iter__" in namespace:
+            hook_iterator(namespace)
+        return super().__new__(cls, name, bases, namespace, **kwargs)  # type: ignore[call-overload]
+
+
+def _dp_init_subclass(sub_cls, *args, **kwargs):
+    # Add function for datapipe instance to reinforce the type
+    sub_cls.reinforce_type = reinforce_type
+
+    # TODO:
+    # - add global switch for type checking at compile-time
+
+    # Ignore internal type class
+    if getattr(sub_cls, "__type_class__", False):
+        return
+
+    # Check if the string type is valid
+    if isinstance(sub_cls.type.param, ForwardRef):
+        base_globals = sys.modules[sub_cls.__module__].__dict__
+        try:
+            param = _eval_type(sub_cls.type.param, base_globals, locals())
+            sub_cls.type.param = param
+        except TypeError as e:
+            raise TypeError(
+                f"{sub_cls.type.param.__forward_arg__} is not supported by Python typing"
+            ) from e
+
+    if "__iter__" in sub_cls.__dict__:
+        iter_fn = sub_cls.__dict__["__iter__"]
+        hints = get_type_hints(iter_fn)
+        if "return" in hints:
+            return_hint = hints["return"]
+            # Plain Return Hint for Python 3.6
+            if return_hint == Iterator:
+                return
+            if not (
+                hasattr(return_hint, "__origin__")
+                and (
+                    return_hint.__origin__ == Iterator
+                    or return_hint.__origin__ == collections.abc.Iterator
+                )
+            ):
+                raise TypeError(
+                    "Expected 'Iterator' as the return annotation for `__iter__` of {}"
+                    ", but found {}".format(
+                        sub_cls.__name__, _type_repr(hints["return"])
+                    )
+                )
+            data_type = return_hint.__args__[0]
+            if not issubtype(data_type, sub_cls.type.param):
+                raise TypeError(
+                    f"Expected return type of '__iter__' as a subtype of {sub_cls.type},"
+                    f" but found {_type_repr(data_type)} for {sub_cls.__name__}"
+                )
+
+
+def reinforce_type(self, expected_type):
+    r"""
+    Reinforce the type for DataPipe instance.
+
+    And the 'expected_type' is required to be a subtype of the original type
+    hint to restrict the type requirement of DataPipe instance.
+    """
+    if isinstance(expected_type, tuple):
+        expected_type = tuple[expected_type]  # type: ignore[valid-type]
+    _type_check(expected_type, msg="'expected_type' must be a type")
+
+    if not issubtype(expected_type, self.type.param):
+        raise TypeError(
+            f"Expected 'expected_type' as subtype of {self.type}, but found {_type_repr(expected_type)}"
+        )
+
+    self.type = _DataPipeType(expected_type)
+    return self
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9feb5f113c0f37f820da9f5abeea14e6bed96c44
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__init__.py
@@ -0,0 +1,11 @@
+from torch.utils.data.datapipes.dataframe.dataframes import (
+    CaptureDataFrame,
+    DFIterDataPipe,
+)
+from torch.utils.data.datapipes.dataframe.datapipes import DataFramesAsTuplesPipe
+
+
+__all__ = ["CaptureDataFrame", "DFIterDataPipe", "DataFramesAsTuplesPipe"]
+
+# Please keep this list sorted
+assert __all__ == sorted(__all__)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aaa8cd6565235abcb927f31d5fe17c8da885c7fb
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccc026ee946092d62415f84bbb8c3bf33dcd4191
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ac01dd0e841ca748e8b673b51635ae4e53f89ac
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef08560d78484483b4c425d55ddcbf94099b4a20
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0c15cc1e2a6dbdab10efb1a3fee7675eee5cb88
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bbd2505b4b5fe3720d8a5eec9c0da65aa5a5cab
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py
@@ -0,0 +1,128 @@
+# mypy: allow-untyped-defs
+from typing import Any, Optional
+
+
+_pandas: Any = None
+_WITH_PANDAS: Optional[bool] = None
+
+
+def _try_import_pandas() -> bool:
+    try:
+        import pandas  # type: ignore[import]
+
+        global _pandas
+        _pandas = pandas
+        return True
+    except ImportError:
+        return False
+
+
+# pandas used only for prototyping, will be shortly replaced with TorchArrow
+def _with_pandas() -> bool:
+    global _WITH_PANDAS
+    if _WITH_PANDAS is None:
+        _WITH_PANDAS = _try_import_pandas()
+    return _WITH_PANDAS
+
+
+class PandasWrapper:
+    @classmethod
+    def create_dataframe(cls, data, columns):
+        if not _with_pandas():
+            raise RuntimeError("DataFrames prototype requires pandas to function")
+        return _pandas.DataFrame(data, columns=columns)  # type: ignore[union-attr]
+
+    @classmethod
+    def is_dataframe(cls, data):
+        if not _with_pandas():
+            return False
+        return isinstance(data, _pandas.core.frame.DataFrame)  # type: ignore[union-attr]
+
+    @classmethod
+    def is_column(cls, data):
+        if not _with_pandas():
+            return False
+        return isinstance(data, _pandas.core.series.Series)  # type: ignore[union-attr]
+
+    @classmethod
+    def iterate(cls, data):
+        if not _with_pandas():
+            raise RuntimeError("DataFrames prototype requires pandas to function")
+        yield from data.itertuples(index=False)
+
+    @classmethod
+    def concat(cls, buffer):
+        if not _with_pandas():
+            raise RuntimeError("DataFrames prototype requires pandas to function")
+        return _pandas.concat(buffer)  # type: ignore[union-attr]
+
+    @classmethod
+    def get_item(cls, data, idx):
+        if not _with_pandas():
+            raise RuntimeError("DataFrames prototype requires pandas to function")
+        return data[idx : idx + 1]
+
+    @classmethod
+    def get_len(cls, df):
+        if not _with_pandas():
+            raise RuntimeError("DataFrames prototype requires pandas to function")
+        return len(df.index)
+
+    @classmethod
+    def get_columns(cls, df):
+        if not _with_pandas():
+            raise RuntimeError("DataFrames prototype requires pandas to function")
+        return list(df.columns.values.tolist())
+
+
+# When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class)
+default_wrapper = PandasWrapper
+
+
+def get_df_wrapper():
+    return default_wrapper
+
+
+def set_df_wrapper(wrapper):
+    global default_wrapper
+    default_wrapper = wrapper
+
+
+def create_dataframe(data, columns=None):
+    wrapper = get_df_wrapper()
+    return wrapper.create_dataframe(data, columns)
+
+
+def is_dataframe(data):
+    wrapper = get_df_wrapper()
+    return wrapper.is_dataframe(data)
+
+
+def get_columns(data):
+    wrapper = get_df_wrapper()
+    return wrapper.get_columns(data)
+
+
+def is_column(data):
+    wrapper = get_df_wrapper()
+    return wrapper.is_column(data)
+
+
+def concat(buffer):
+    wrapper = get_df_wrapper()
+    return wrapper.concat(buffer)
+
+
+def iterate(data):
+    wrapper = get_df_wrapper()
+    return wrapper.iterate(data)
+
+
+def get_item(data, idx):
+    wrapper = get_df_wrapper()
+    return wrapper.get_item(data, idx)
+
+
+def get_len(df):
+    wrapper = get_df_wrapper()
+    return wrapper.get_len(df)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3958ec0c793f64a32c7144248e89b792f70a7c4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py
@@ -0,0 +1,457 @@
+# mypy: allow-untyped-defs
+from typing import Any, Optional
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
+from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
+
+
+# TODO(VitalyFedyunin): Add error when two different traces get combined
+
+__all__ = [
+    "Capture",
+    "CaptureA",
+    "CaptureAdd",
+    "CaptureCall",
+    "CaptureControl",
+    "CaptureDataFrame",
+    "CaptureDataFrameWithDataPipeOps",
+    "CaptureF",
+    "CaptureGetAttr",
+    "CaptureGetItem",
+    "CaptureInitial",
+    "CaptureLikeMock",
+    "CaptureMul",
+    "CaptureSetItem",
+    "CaptureSub",
+    "CaptureVariable",
+    "CaptureVariableAssign",
+    "DataFrameTracer",
+    "DataFrameTracedOps",
+    "disable_capture",
+    "get_val",
+]
+
+
+def disable_capture():
+    CaptureControl.disabled = True
+
+
+class CaptureControl:
+    disabled = False
+
+
+class DataFrameTracedOps(DFIterDataPipe):
+    def __init__(self, source_datapipe, output_var):
+        self.source_datapipe = source_datapipe
+        self.output_var = output_var
+
+    def __iter__(self):
+        for item in self.source_datapipe:
+            yield self.output_var.apply_ops(item)
+
+
+#  TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
+DATAPIPES_OPS = [
+    "_dataframes_as_tuples",
+    "groupby",
+    "_dataframes_filter",
+    "map",
+    "to_datapipe",
+    "shuffle",
+    "concat",
+    "batch",
+    "_dataframes_per_row",
+    "_dataframes_concat",
+    "_dataframes_shuffle",
+]
+
+UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sharding"]
+
+
+class Capture:
+    # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
+
+    def __init__(self, schema_df=None):
+        self.ctx = {"operations": [], "variables": [], "schema_df": schema_df}
+
+    def __str__(self):
+        return self._ops_str()
+
+    def _ops_str(self):
+        res = ""
+        for op in self.ctx["operations"]:
+            if len(res) > 0:
+                res += "\n"
+            res += str(op)
+        return res
+
+    def __getstate__(self):
+        # TODO(VitalyFedyunin): Currently can't pickle (why?)
+        self.ctx["schema_df"] = None
+        for var in self.ctx["variables"]:
+            var.calculated_value = None
+        state = {}
+        for item in self.__dict__:
+            state[item] = getattr(self, item)
+        return state
+
+    def __setstate__(self, state):
+        for k, v in state.items():
+            setattr(self, k, v)
+
+    def __getattr__(self, attrname):
+        if attrname == "kwarg" or attrname == "kwargs":
+            raise RuntimeError("no kwargs!")
+        if attrname in ["__deepcopy__"]:
+            raise AttributeError
+        result = CaptureGetAttr(self, attrname, ctx=self.ctx)
+        return result
+
+    def __getitem__(self, key):
+        return CaptureGetItem(self, key, ctx=self.ctx)
+
+    def __setitem__(self, key, value):
+        self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
+
+    def __add__(self, add_val):
+        res = CaptureAdd(self, add_val, ctx=self.ctx)
+        var = CaptureVariable(res, ctx=self.ctx)
+        self.ctx["operations"].append(
+            CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
+        )
+        return var
+
+    def __sub__(self, add_val):
+        res = CaptureSub(self, add_val, ctx=self.ctx)
+        var = CaptureVariable(res, ctx=self.ctx)
+        self.ctx["operations"].append(
+            CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
+        )
+        return var
+
+    def __mul__(self, add_val):
+        res = CaptureMul(self, add_val, ctx=self.ctx)
+        var = CaptureVariable(res, ctx=self.ctx)
+        t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
+        self.ctx["operations"].append(t)
+        return var
+
+    def _is_context_empty(self):
+        return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
+
+    def apply_ops_2(self, dataframe):
+        # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
+        self.ctx["variables"][0].calculated_value = dataframe
+        for op in self.ctx["operations"]:
+            op.execute()
+
+    @property
+    def columns(self):
+        self.apply_ops_2(self.ctx["schema_df"])
+        value = self.execute()
+        return value.columns
+
+    # TODO(VitalyFedyunin): Add tests
+    # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture
+
+    def __call__(self, *args, **kwargs):
+        # TODO: Check if args or kwargs have more than one different context
+        if self._is_context_empty():
+            # TODO: Allow CaptureA to take context from mock
+            for arg in args:
+                if isinstance(arg, Capture) and not arg._is_context_empty():
+                    self.ctx = arg.ctx
+                    break
+            if self._is_context_empty():
+                for k, v in kwargs.items():
+                    if isinstance(k, Capture) and not k._is_context_empty():
+                        self.ctx = k.ctx
+                        break
+                    if isinstance(v, Capture) and not v._is_context_empty():
+                        self.ctx = v.ctx
+                        break
+
+        res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
+        var = CaptureVariable(None, ctx=self.ctx)
+        t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
+        self.ctx["operations"].append(t)
+        return var
+
+
+class CaptureF(Capture):
+    def __init__(self, ctx=None, **kwargs):
+        if ctx is None:
+            self.ctx = {"operations": [], "variables": []}
+        else:
+            self.ctx = ctx
+        self.kwargs = kwargs
+
+
+class CaptureA(CaptureF):
+    def __str__(self):
+        return f"{self.kwargs['name']}"
+
+    def execute(self):
+        value = self.kwargs["real_attribute"]
+        return value
+
+
+class CaptureLikeMock:
+    def __init__(self, name):
+        import unittest.mock as mock
+
+        # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
+        get_target, attribute = mock._get_target(name)  # type: ignore[attr-defined]
+        self.get_target = get_target
+        self.attribute = attribute
+        self.name = name
+
+    def __enter__(self):
+        self.save = getattr(self.get_target(), self.attribute)
+        capt = CaptureA(name=self.name, real_attribute=self.save)
+        setattr(self.get_target(), self.attribute, capt)
+
+    def __exit__(self, *exc_info):
+        setattr(self.get_target(), self.attribute, self.save)
+
+
+class CaptureCall(Capture):
+    def __init__(self, callable, ctx=None, **kwargs):
+        if ctx is None:
+            self.ctx = {"operations": [], "variables": []}
+        else:
+            self.ctx = ctx
+        self.kwargs = kwargs
+        self.callable = callable
+
+    def __str__(self):
+        return "{callable}({args},{kwargs})".format(
+            callable=self.callable, **self.kwargs
+        )
+
+    def execute(self):
+        # TODO: VitalyFedyunin execute kwargs and maybe nested structures
+        executed_args = []
+        for arg in self.kwargs["args"]:
+            if isinstance(arg, Capture):
+                executed_args.append(arg.execute())
+            else:
+                executed_args.append(arg)
+        left = get_val(self.callable)
+        return left(*executed_args, **self.kwargs["kwargs"])
+
+
+class CaptureVariableAssign(CaptureF):
+    def __str__(self):
+        variable = self.kwargs["variable"]
+        value = self.kwargs["value"]
+        return f"{variable} = {value}"
+
+    def execute(self):
+        self.kwargs["variable"].calculated_value = self.kwargs["value"].execute()
+
+
+class CaptureVariable(Capture):
+    # TODO(VitalyFedyunin): This should be atomic and thread safe
+    names_idx = 0
+
+    def __init__(self, value, ctx):
+        if CaptureControl.disabled:
+            raise RuntimeError("Attempting to create capture variable with capture off")
+        self.ctx = ctx
+        self.value = value
+        self.name = f"var_{CaptureVariable.names_idx}"
+        CaptureVariable.names_idx += 1
+        self.ctx["variables"].append(self)
+
+    def __str__(self):
+        return self.name
+
+    def execute(self):
+        return self.calculated_value
+
+    def apply_ops(self, dataframe):
+        # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
+        self.ctx["variables"][0].calculated_value = dataframe
+        for op in self.ctx["operations"]:
+            op.execute()
+        return self.calculated_value
+
+
+class CaptureGetItem(Capture):
+    def __init__(self, left, key, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.key = key
+
+    def __str__(self):
+        return f"{self.left}[{get_val(self.key)}]"
+
+    def execute(self):
+        left = self.left.execute()
+        return left[self.key]
+
+
+class CaptureSetItem(Capture):
+    def __init__(self, left, key, value, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.key = key
+        self.value = value
+
+    def __str__(self):
+        return f"{self.left}[{get_val(self.key)}] = {self.value}"
+
+    def execute(self):
+        left = self.left.execute()
+        value = self.value.execute()
+        left[self.key] = value
+
+
+class CaptureAdd(Capture):
+    def __init__(self, left, right, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.right = right
+
+    def __str__(self):
+        return f"{self.left} + {self.right}"
+
+    def execute(self):
+        return get_val(self.left) + get_val(self.right)
+
+
+class CaptureMul(Capture):
+    def __init__(self, left, right, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.right = right
+
+    def __str__(self):
+        return f"{self.left} * {self.right}"
+
+    def execute(self):
+        return get_val(self.left) * get_val(self.right)
+
+
+class CaptureSub(Capture):
+    def __init__(self, left, right, ctx):
+        self.ctx = ctx
+        self.left = left
+        self.right = right
+
+    def __str__(self):
+        return f"{self.left} - {self.right}"
+
+    def execute(self):
+        return get_val(self.left) - get_val(self.right)
+
+
+class CaptureGetAttr(Capture):
+    def __init__(self, src, name, ctx):
+        self.ctx = ctx
+        self.src = src
+        self.name = name
+
+    def __str__(self):
+        return f"{self.src}.{self.name}"
+
+    def execute(self):
+        val = get_val(self.src)
+        return getattr(val, self.name)
+
+
+def get_val(capture):
+    if isinstance(capture, Capture):
+        return capture.execute()
+    elif isinstance(capture, str):
+        return f'"{capture}"'
+    else:
+        return capture
+
+
+class CaptureInitial(CaptureVariable):
+    def __init__(self, schema_df=None):
+        new_ctx: dict[str, list[Any]] = {
+            "operations": [],
+            "variables": [],
+            "schema_df": schema_df,
+        }
+        super().__init__(None, new_ctx)
+        self.name = f"input_{self.name}"
+
+
+class CaptureDataFrame(CaptureInitial):
+    pass
+
+
+class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
+    def as_datapipe(self):
+        return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
+
+    def raw_iterator(self):
+        return self.as_datapipe().__iter__()
+
+    def __iter__(self):
+        return iter(self._dataframes_as_tuples())
+
+    def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF):
+        dp = self._dataframes_per_row()._dataframes_concat(batch_size)
+        dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class)
+        dp._dp_contains_dataframe = True
+        return dp
+
+    def groupby(
+        self,
+        group_key_fn,
+        *,
+        buffer_size=10000,
+        group_size=None,
+        guaranteed_group_size=None,
+        drop_remaining=False,
+    ):
+        dp = self._dataframes_per_row()
+        dp = dp.as_datapipe().groupby(
+            group_key_fn,
+            buffer_size=buffer_size,
+            group_size=group_size,
+            guaranteed_group_size=guaranteed_group_size,
+            drop_remaining=drop_remaining,
+        )
+        return dp
+
+    def shuffle(self, *args, **kwargs):
+        return self._dataframes_shuffle(*args, **kwargs)
+
+    def filter(self, *args, **kwargs):
+        return self._dataframes_filter(*args, **kwargs)
+
+    def collate(self, *args, **kwargs):
+        raise RuntimeError("Can't collate unbatched DataFrames stream")
+
+    def __getattr__(self, attrname):  # ?
+        if attrname in UNIMPLEMENTED_ATTR:
+            raise AttributeError("Attempting to get ", attrname)
+        if attrname in DATAPIPES_OPS:
+            return (self.as_datapipe()).__getattr__(attrname)
+        return super().__getattr__(attrname)
+
+
+@functional_datapipe("trace_as_dataframe")
+class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):  # type: ignore[misc]
+    source_datapipe: Optional[Any] = None
+
+    # TODO(VitalyFedyunin): Must implement all special functions of datapipes
+
+    def set_shuffle_settings(self, *args, **kwargs):
+        pass
+
+    def is_shardable(self):
+        return False
+
+    def __init__(self, source_datapipe, schema_df=None):
+        self.source_datapipe = source_datapipe
+        if schema_df is None:
+            schema_df = next(iter(self.source_datapipe))
+        super().__init__(schema_df=schema_df)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b89d6437aab6aff00998ce2af88b308f7f78bc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py
@@ -0,0 +1,136 @@
+# mypy: allow-untyped-defs
+import random
+from typing import Any
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
+from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
+
+
+__all__ = [
+    "ConcatDataFramesPipe",
+    "DataFramesAsTuplesPipe",
+    "ExampleAggregateAsDataFrames",
+    "FilterDataFramesPipe",
+    "PerRowDataFramesPipe",
+    "ShuffleDataFramesPipe",
+]
+
+
+@functional_datapipe("_dataframes_as_tuples")
+class DataFramesAsTuplesPipe(IterDataPipe):
+    def __init__(self, source_datapipe):
+        self.source_datapipe = source_datapipe
+
+    def __iter__(self):
+        for df in self.source_datapipe:
+            # for record in df.to_records(index=False):
+            yield from df_wrapper.iterate(df)
+
+
+@functional_datapipe("_dataframes_per_row", enable_df_api_tracing=True)
+class PerRowDataFramesPipe(DFIterDataPipe):
+    def __init__(self, source_datapipe):
+        self.source_datapipe = source_datapipe
+
+    def __iter__(self):
+        for df in self.source_datapipe:
+            # TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup
+            for i in range(len(df)):
+                yield df[i : i + 1]
+
+
+@functional_datapipe("_dataframes_concat", enable_df_api_tracing=True)
+class ConcatDataFramesPipe(DFIterDataPipe):
+    def __init__(self, source_datapipe, batch=3):
+        self.source_datapipe = source_datapipe
+        self.n_batch = batch
+
+    def __iter__(self):
+        buffer = []
+        for df in self.source_datapipe:
+            buffer.append(df)
+            if len(buffer) == self.n_batch:
+                yield df_wrapper.concat(buffer)
+                buffer = []
+        if len(buffer):
+            yield df_wrapper.concat(buffer)
+
+
+@functional_datapipe("_dataframes_shuffle", enable_df_api_tracing=True)
+class ShuffleDataFramesPipe(DFIterDataPipe):
+    def __init__(self, source_datapipe):
+        self.source_datapipe = source_datapipe
+
+    def __iter__(self):
+        size = None
+        all_buffer: list[Any] = []
+        for df in self.source_datapipe:
+            if size is None:
+                size = df_wrapper.get_len(df)
+            all_buffer.extend(
+                df_wrapper.get_item(df, i) for i in range(df_wrapper.get_len(df))
+            )
+        random.shuffle(all_buffer)
+        buffer = []
+        for df in all_buffer:
+            buffer.append(df)
+            if len(buffer) == size:
+                yield df_wrapper.concat(buffer)
+                buffer = []
+        if len(buffer):
+            yield df_wrapper.concat(buffer)
+
+
+@functional_datapipe("_dataframes_filter", enable_df_api_tracing=True)
+class FilterDataFramesPipe(DFIterDataPipe):
+    def __init__(self, source_datapipe, filter_fn):
+        self.source_datapipe = source_datapipe
+        self.filter_fn = filter_fn
+
+    def __iter__(self):
+        size = None
+        all_buffer = []
+        filter_res = []
+        for df in self.source_datapipe:
+            if size is None:
+                size = len(df.index)
+            for i in range(len(df.index)):
+                all_buffer.append(df[i : i + 1])
+                filter_res.append(self.filter_fn(df.iloc[i]))
+
+        buffer = []
+        for df, res in zip(all_buffer, filter_res):
+            if res:
+                buffer.append(df)
+                if len(buffer) == size:
+                    yield df_wrapper.concat(buffer)
+                    buffer = []
+        if len(buffer):
+            yield df_wrapper.concat(buffer)
+
+
+@functional_datapipe("_to_dataframes_pipe", enable_df_api_tracing=True)
+class ExampleAggregateAsDataFrames(DFIterDataPipe):
+    def __init__(self, source_datapipe, dataframe_size=10, columns=None):
+        self.source_datapipe = source_datapipe
+        self.columns = columns
+        self.dataframe_size = dataframe_size
+
+    def _as_list(self, item):
+        try:
+            return list(item)
+        except (
+            Exception
+        ):  # TODO(VitalyFedyunin): Replace with better iterable exception
+            return [item]
+
+    def __iter__(self):
+        aggregate = []
+        for item in self.source_datapipe:
+            aggregate.append(self._as_list(item))
+            if len(aggregate) == self.dataframe_size:
+                yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
+                aggregate = []
+        if len(aggregate) > 0:
+            yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/structures.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/structures.py
new file mode 100644
index 0000000000000000000000000000000000000000..26b4c33db03cc584f223444c07730ef67f4495e7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/dataframe/structures.py
@@ -0,0 +1,22 @@
+from collections.abc import Iterator
+from typing import Any
+
+from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
+from torch.utils.data.datapipes.datapipe import DataChunk
+
+
+__all__ = ["DataChunkDF"]
+
+
+class DataChunkDF(DataChunk):
+    """DataChunkDF iterating over individual items inside of DataFrame containers, to access DataFrames user `raw_iterator`."""
+
+    def __iter__(self) -> Iterator[Any]:
+        for df in self.items:
+            yield from df_wrapper.iterate(df)
+
+    def __len__(self) -> int:
+        total_len = 0
+        for df in self.items:
+            total_len += df_wrapper.get_len(df)
+        return total_len
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3eeee0ebfdd539ea4351e57711034fc25084046
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.py
@@ -0,0 +1,416 @@
+import functools
+import pickle
+from collections.abc import Iterable, Iterator
+from typing import Callable, Optional, TypeVar
+
+from torch.utils._import_utils import import_dill
+from torch.utils.data.datapipes._hook_iterator import _SnapshotState
+from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
+from torch.utils.data.datapipes.utils.common import (
+    _deprecation_warning,
+    _iter_deprecated_functional_names,
+    _map_deprecated_functional_names,
+)
+from torch.utils.data.dataset import Dataset, IterableDataset
+
+
+dill = import_dill()
+HAS_DILL = dill is not None
+
+__all__ = [
+    "DataChunk",
+    "DFIterDataPipe",
+    "IterDataPipe",
+    "MapDataPipe",
+]
+
+
+_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
+
+UNTRACABLE_DATAFRAME_PIPES = [
+    "batch",  # As it returns DataChunks
+    "groupby",  # As it returns DataChunks
+    "_dataframes_as_tuples",  # As it unpacks DF
+    "trace_as_dataframe",  # As it used to mark DF for tracing
+]
+
+
+class DataChunk(list[_T]):
+    def __init__(self, items: Iterable[_T]) -> None:
+        items = list(items)
+        super().__init__(items)
+        self.items = items
+
+    def as_str(self, indent: str = "") -> str:
+        return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
+
+    def __iter__(self) -> Iterator[_T]:
+        yield from super().__iter__()
+
+    def raw_iterator(self) -> Iterator[_T]:
+        yield from self.items
+
+
+class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
+    r"""
+    Iterable-style DataPipe.
+
+    All DataPipes that represent an iterable of data samples should subclass this.
+    This style of DataPipes is particularly useful when data come from a stream, or
+    when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its
+    elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``.
+
+    All subclasses should overwrite :meth:`__iter__`, which would return an
+    iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its
+    method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should
+    override ``reset()`` if necessary. The common usages include resetting buffers, pointers,
+    and various state variables within the custom ``IterDataPipe``.
+
+    Note:
+        Only `one` iterator can be valid for each ``IterDataPipe`` at a time,
+        and the creation a second iterator will invalidate the first one. This constraint is necessary because
+        some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators.
+        The code example below presents details on how this constraint looks in practice.
+        If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_.
+
+    These DataPipes can be invoked in two ways, using the class constructor or applying their
+    functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes).
+    You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple
+    operations in succession.
+
+    .. _GitHub IterDataPipe Single Iterator Issue:
+        https://github.com/pytorch/data/issues/45
+
+    Note:
+        When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
+        item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader`
+        iterator. When :attr:`num_workers > 0`, each worker process will have a
+        different copy of the DataPipe object, so it is often desired to configure
+        each copy independently to avoid having duplicate data returned from the
+        workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
+        process, returns information about the worker. It can be used in either the
+        dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
+        :attr:`worker_init_fn` option to modify each copy's behavior.
+
+    Examples:
+        General Usage:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
+            >>> dp = IterableWrapper(range(10))
+            >>> map_dp_1 = Mapper(dp, lambda x: x + 1)  # Using class constructor
+            >>> map_dp_2 = dp.map(lambda x: x + 1)  # Using functional form (recommended)
+            >>> list(map_dp_1)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+            >>> list(map_dp_2)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+            >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
+            >>> list(filter_dp)
+            [2, 4, 6, 8, 10]
+        Single Iterator Constraint Example:
+            >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
+            >>> source_dp = IterableWrapper(range(10))
+            >>> it1 = iter(source_dp)
+            >>> list(it1)
+            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+            >>> it1 = iter(source_dp)
+            >>> it2 = iter(source_dp)  # The creation of a new iterator invalidates `it1`
+            >>> next(it2)
+            0
+            >>> next(it1)  # Further usage of `it1` will raise a `RunTimeError`
+    """
+
+    functions: dict[str, Callable] = {}
+    reduce_ex_hook: Optional[Callable] = None
+    getstate_hook: Optional[Callable] = None
+    str_hook: Optional[Callable] = None
+    repr_hook: Optional[Callable] = None
+    _valid_iterator_id: Optional[int] = None
+    _number_of_samples_yielded: int = 0
+    _snapshot_state: _SnapshotState = _SnapshotState.NotStarted
+    _fast_forward_iterator: Optional[Iterator] = None
+
+    def __iter__(self) -> Iterator[_T_co]:
+        return self
+
+    def __getattr__(self, attribute_name):
+        if attribute_name in IterDataPipe.functions:
+            if attribute_name in _iter_deprecated_functional_names:
+                kwargs = _iter_deprecated_functional_names[attribute_name]
+                _deprecation_warning(**kwargs)
+            f = IterDataPipe.functions[attribute_name]
+            function = functools.partial(f, self)
+            functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
+            return function
+        else:
+            raise AttributeError(
+                f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
+            )
+
+    @classmethod
+    def register_function(cls, function_name, function):
+        cls.functions[function_name] = function
+
+    @classmethod
+    def register_datapipe_as_function(
+        cls, function_name, cls_to_register, enable_df_api_tracing=False
+    ):
+        if function_name in cls.functions:
+            raise Exception(  # noqa: TRY002
+                f"Unable to add DataPipe function name {function_name} as it is already taken"
+            )
+
+        def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
+            result_pipe = cls(source_dp, *args, **kwargs)
+            if isinstance(result_pipe, IterDataPipe):
+                if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
+                    if function_name not in UNTRACABLE_DATAFRAME_PIPES:
+                        result_pipe = result_pipe.trace_as_dataframe()
+
+            return result_pipe
+
+        function = functools.partial(
+            class_function, cls_to_register, enable_df_api_tracing
+        )
+        functools.update_wrapper(
+            wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
+        )
+        cls.functions[function_name] = function
+
+    def __getstate__(self):
+        """
+        Serialize `lambda` functions when `dill` is available.
+
+        If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
+        `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
+        """
+        state = self.__dict__
+        if IterDataPipe.getstate_hook is not None:
+            return IterDataPipe.getstate_hook(state)
+        return state
+
+    def __reduce_ex__(self, *args, **kwargs):
+        if IterDataPipe.reduce_ex_hook is not None:
+            try:
+                return IterDataPipe.reduce_ex_hook(self)
+            except NotImplementedError:
+                pass
+        return super().__reduce_ex__(*args, **kwargs)
+
+    @classmethod
+    def set_getstate_hook(cls, hook_fn):
+        if IterDataPipe.getstate_hook is not None and hook_fn is not None:
+            raise RuntimeError("Attempt to override existing getstate_hook")
+        IterDataPipe.getstate_hook = hook_fn
+
+    @classmethod
+    def set_reduce_ex_hook(cls, hook_fn):
+        if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None:
+            raise RuntimeError("Attempt to override existing reduce_ex_hook")
+        IterDataPipe.reduce_ex_hook = hook_fn
+
+    def __repr__(self):
+        if self.repr_hook is not None:
+            return self.repr_hook(self)
+        # Instead of showing , return the class name
+        return str(self.__class__.__qualname__)
+
+    def __str__(self):
+        if self.str_hook is not None:
+            return self.str_hook(self)
+        # Instead of showing , return the class name
+        return str(self.__class__.__qualname__)
+
+    def __dir__(self):
+        # for auto-completion in a REPL (e.g. Jupyter notebook)
+        return list(super().__dir__()) + list(self.functions.keys())
+
+    def reset(self) -> None:
+        r"""
+        Reset the `IterDataPipe` to the initial state.
+
+        By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities,
+        they may want to override this method with implementations that
+        may clear the buffers and reset pointers of the DataPipe.
+        The `reset` method is always called when `__iter__` is called as part of `hook_iterator`.
+        """
+
+
+class DFIterDataPipe(IterDataPipe):
+    def _is_dfpipe(self):
+        return True
+
+
+class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
+    r"""
+    Map-style DataPipe.
+
+    All datasets that represent a map from keys to data samples should subclass this.
+    Subclasses should overwrite :meth:`__getitem__`, supporting fetching a
+    data sample for a given, unique key. Subclasses can also optionally overwrite
+    :meth:`__len__`, which is expected to return the size of the dataset by many
+    :class:`~torch.utils.data.Sampler` implementations and the default options
+    of :class:`~torch.utils.data.DataLoader`.
+
+    These DataPipes can be invoked in two ways, using the class constructor or applying their
+    functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes).
+
+    Note:
+        :class:`~torch.utils.data.DataLoader` by default constructs an index
+        sampler that yields integral indices. To make it work with a map-style
+        DataPipe with non-integral indices/keys, a custom sampler must be provided.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
+        >>> dp = SequenceWrapper(range(10))
+        >>> map_dp_1 = dp.map(lambda x: x + 1)  # Using functional form (recommended)
+        >>> list(map_dp_1)
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        >>> map_dp_2 = Mapper(dp, lambda x: x + 1)  # Using class constructor
+        >>> list(map_dp_2)
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        >>> batch_dp = map_dp_1.batch(batch_size=2)
+        >>> list(batch_dp)
+        [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+    """
+
+    functions: dict[str, Callable] = {}
+    reduce_ex_hook: Optional[Callable] = None
+    getstate_hook: Optional[Callable] = None
+    str_hook: Optional[Callable] = None
+    repr_hook: Optional[Callable] = None
+
+    def __getattr__(self, attribute_name):
+        if attribute_name in MapDataPipe.functions:
+            if attribute_name in _map_deprecated_functional_names:
+                kwargs = _map_deprecated_functional_names[attribute_name]
+                _deprecation_warning(**kwargs)
+            f = MapDataPipe.functions[attribute_name]
+            function = functools.partial(f, self)
+            functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
+            return function
+        else:
+            raise AttributeError(
+                f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
+            )
+
+    @classmethod
+    def register_function(cls, function_name, function):
+        cls.functions[function_name] = function
+
+    @classmethod
+    def register_datapipe_as_function(cls, function_name, cls_to_register):
+        if function_name in cls.functions:
+            raise Exception(  # noqa: TRY002
+                f"Unable to add DataPipe function name {function_name} as it is already taken"
+            )
+
+        def class_function(cls, source_dp, *args, **kwargs):
+            result_pipe = cls(source_dp, *args, **kwargs)
+            return result_pipe
+
+        function = functools.partial(class_function, cls_to_register)
+        functools.update_wrapper(
+            wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
+        )
+        cls.functions[function_name] = function
+
+    def __getstate__(self):
+        """
+        Serialize `lambda` functions when `dill` is available.
+
+        If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
+        `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
+        """
+        state = self.__dict__
+        if MapDataPipe.getstate_hook is not None:
+            return MapDataPipe.getstate_hook(state)
+        return state
+
+    def __reduce_ex__(self, *args, **kwargs):
+        if MapDataPipe.reduce_ex_hook is not None:
+            try:
+                return MapDataPipe.reduce_ex_hook(self)
+            except NotImplementedError:
+                pass
+        return super().__reduce_ex__(*args, **kwargs)
+
+    @classmethod
+    def set_getstate_hook(cls, hook_fn):
+        if MapDataPipe.getstate_hook is not None and hook_fn is not None:
+            raise RuntimeError("Attempt to override existing getstate_hook")
+        MapDataPipe.getstate_hook = hook_fn
+
+    @classmethod
+    def set_reduce_ex_hook(cls, hook_fn):
+        if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None:
+            raise RuntimeError("Attempt to override existing reduce_ex_hook")
+        MapDataPipe.reduce_ex_hook = hook_fn
+
+    def __repr__(self):
+        if self.repr_hook is not None:
+            return self.repr_hook(self)
+        # Instead of showing , return the class name
+        return str(self.__class__.__qualname__)
+
+    def __str__(self):
+        if self.str_hook is not None:
+            return self.str_hook(self)
+        # Instead of showing , return the class name
+        return str(self.__class__.__qualname__)
+
+    def __dir__(self):
+        # for auto-completion in a REPL (e.g. Jupyter notebook)
+        return list(super().__dir__()) + list(self.functions.keys())
+
+
+class _DataPipeSerializationWrapper:
+    def __init__(self, datapipe):
+        self._datapipe = datapipe
+
+    def __getstate__(self):
+        use_dill = False
+        try:
+            value = pickle.dumps(self._datapipe)
+        except Exception:
+            if HAS_DILL:
+                value = dill.dumps(self._datapipe)
+                use_dill = True
+            else:
+                raise
+        return (value, use_dill)
+
+    def __setstate__(self, state):
+        value, use_dill = state
+        if use_dill:
+            self._datapipe = dill.loads(value)
+        else:
+            self._datapipe = pickle.loads(value)
+
+    def __len__(self):
+        try:
+            return len(self._datapipe)
+        except Exception as e:
+            raise TypeError(
+                f"{type(self).__name__} instance doesn't have valid length"
+            ) from e
+
+
+class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
+    def __init__(self, datapipe: IterDataPipe[_T_co]):
+        super().__init__(datapipe)
+        self._datapipe_iter: Optional[Iterator[_T_co]] = None
+
+    def __iter__(self) -> "_IterDataPipeSerializationWrapper":
+        self._datapipe_iter = iter(self._datapipe)
+        return self
+
+    def __next__(self) -> _T_co:  # type: ignore[type-var]
+        assert self._datapipe_iter is not None
+        return next(self._datapipe_iter)
+
+
+class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
+    def __getitem__(self, idx):
+        return self._datapipe[idx]
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi
new file mode 100644
index 0000000000000000000000000000000000000000..23f2402a701c7b2d604a0a0788992f51ea75b4fa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/datapipe.pyi
@@ -0,0 +1,726 @@
+# @generated by torch/utils/data/datapipes/gen_pyi.py from datapipe.pyi.in
+# mypy: allow-untyped-defs
+# This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection
+# The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt
+# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
+# classes/objects here, even though we are not injecting extra code into them at the moment.
+
+from collections.abc import Iterable, Iterator
+from typing import Any, Callable, Literal, Optional, TypeVar, Union
+
+from torch.utils.data import Dataset, default_collate, IterableDataset
+from torch.utils.data.datapipes._hook_iterator import _SnapshotState
+from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
+
+_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
+UNTRACABLE_DATAFRAME_PIPES: Any
+
+class DataChunk(list[_T]):
+    items: list[_T]
+    def __init__(self, items: Iterable[_T]) -> None: ...
+    def as_str(self, indent: str = "") -> str: ...
+    def __iter__(self) -> Iterator[_T]: ...
+    def raw_iterator(self) -> Iterator[_T]: ...
+
+class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
+    functions: dict[str, Callable] = ...
+    reduce_ex_hook: Callable | None = ...
+    getstate_hook: Callable | None = ...
+    str_hook: Callable | None = ...
+    repr_hook: Callable | None = ...
+    def __getattr__(self, attribute_name: Any): ...
+    @classmethod
+    def register_function(cls, function_name: Any, function: Any) -> None: ...
+    @classmethod
+    def register_datapipe_as_function(
+        cls,
+        function_name: Any,
+        cls_to_register: Any,
+    ): ...
+    def __getstate__(self): ...
+    def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
+    @classmethod
+    def set_getstate_hook(cls, hook_fn: Any) -> None: ...
+    @classmethod
+    def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
+    # Functional form of 'BatcherMapDataPipe'
+    def batch(
+        self,
+        batch_size: int,
+        drop_last: bool = False,
+        wrapper_class: type[DataChunk] = DataChunk,
+    ) -> MapDataPipe:
+        r"""
+        Create mini-batches of data (functional name: ``batch``).
+
+        An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
+        or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
+
+        Args:
+            datapipe: Iterable DataPipe being batched
+            batch_size: The size of each batch
+            drop_last: Option to drop the last batch if it's not full
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper
+            >>> dp = SequenceWrapper(range(10))
+            >>> batch_dp = dp.batch(batch_size=2)
+            >>> list(batch_dp)
+            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
+        """
+    # Functional form of 'ConcaterMapDataPipe'
+    def concat(self, *datapipes: MapDataPipe) -> MapDataPipe:
+        r"""
+        Concatenate multiple Map DataPipes (functional name: ``concat``).
+
+        The new index of is the cumulative sum of source DataPipes.
+        For example, if there are 2 source DataPipes both with length 5,
+        index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
+        elements of the first DataPipe, and 5 to 9 would refer to elements
+        of the second DataPipe.
+
+        Args:
+            datapipes: Map DataPipes being concatenated
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper
+            >>> dp1 = SequenceWrapper(range(3))
+            >>> dp2 = SequenceWrapper(range(3))
+            >>> concat_dp = dp1.concat(dp2)
+            >>> list(concat_dp)
+            [0, 1, 2, 0, 1, 2]
+        """
+    # Functional form of 'MapperMapDataPipe'
+    def map(self, fn: Callable = ...) -> MapDataPipe:
+        r"""
+        Apply the input function over each item from the source DataPipe (functional name: ``map``).
+
+        The function can be any regular Python function or partial object. Lambda
+        function is not recommended as it is not supported by pickle.
+
+        Args:
+            datapipe: Source MapDataPipe
+            fn: Function being applied to each item
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
+            >>> def add_one(x):
+            ...     return x + 1
+            >>> dp = SequenceWrapper(range(10))
+            >>> map_dp_1 = dp.map(add_one)
+            >>> list(map_dp_1)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+            >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
+            >>> list(map_dp_2)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        """
+    # Functional form of 'ShufflerIterDataPipe'
+    def shuffle(self, *, indices: Optional[list] = None) -> IterDataPipe:
+        r"""
+        Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
+
+        When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
+        set up random seed are different based on :attr:`num_workers`.
+
+        For single-process mode (:attr:`num_workers == 0`), the random seed is set before
+        the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
+        mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
+        for each worker process.
+
+        Args:
+            datapipe: MapDataPipe being shuffled
+            indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper
+            >>> dp = SequenceWrapper(range(10))
+            >>> shuffle_dp = dp.shuffle().set_seed(0)
+            >>> list(shuffle_dp)
+            [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
+            >>> list(shuffle_dp)
+            [6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
+            >>> # Reset seed for Shuffler
+            >>> shuffle_dp = shuffle_dp.set_seed(0)
+            >>> list(shuffle_dp)
+            [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
+
+        Note:
+            Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
+            ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
+            the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
+            of data during data-processing.
+        """
+    # Functional form of 'ZipperMapDataPipe'
+    def zip(self, *datapipes: MapDataPipe[_T_co]) -> MapDataPipe:
+        r"""
+        Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
+
+        This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
+
+        Args:
+            *datapipes: Map DataPipes being aggregated
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.map import SequenceWrapper
+            >>> dp1 = SequenceWrapper(range(3))
+            >>> dp2 = SequenceWrapper(range(10, 13))
+            >>> zip_dp = dp1.zip(dp2)
+            >>> list(zip_dp)
+            [(0, 10), (1, 11), (2, 12)]
+        """
+
+class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
+    functions: dict[str, Callable] = ...
+    reduce_ex_hook: Optional[Callable] = ...
+    getstate_hook: Optional[Callable] = ...
+    str_hook: Optional[Callable] = ...
+    repr_hook: Optional[Callable] = ...
+    _number_of_samples_yielded: int = ...
+    _snapshot_state: _SnapshotState = _SnapshotState.Iterating  # noqa: PYI015
+    _fast_forward_iterator: Optional[Iterator] = ...
+    def __getattr__(self, attribute_name: Any): ...
+    @classmethod
+    def register_function(cls, function_name: Any, function: Any) -> None: ...
+    @classmethod
+    def register_datapipe_as_function(
+        cls,
+        function_name: Any,
+        cls_to_register: Any,
+        enable_df_api_tracing: bool = ...,
+    ): ...
+    def __getstate__(self): ...
+    def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
+    @classmethod
+    def set_getstate_hook(cls, hook_fn: Any) -> None: ...
+    @classmethod
+    def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
+    # Functional form of 'BatcherIterDataPipe'
+    def batch(
+        self,
+        batch_size: int,
+        drop_last: bool = False,
+        wrapper_class: type[DataChunk] = DataChunk,
+    ) -> IterDataPipe:
+        r"""
+        Creates mini-batches of data (functional name: ``batch``).
+
+        An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the
+        last batch if ``drop_last`` is set to ``False``.
+
+        Args:
+            datapipe: Iterable DataPipe being batched
+            batch_size: The size of each batch
+            drop_last: Option to drop the last batch if it's not full
+            wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding,
+                defaults to ``DataChunk``
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp = IterableWrapper(range(10))
+            >>> dp = dp.batch(batch_size=3, drop_last=True)
+            >>> list(dp)
+            [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
+        """
+    # Functional form of 'CollatorIterDataPipe'
+    def collate(
+        self,
+        conversion: Union[Callable[..., Any], dict[Union[str, Any], Union[Callable, Any]], None] = default_collate,
+        collate_fn: Optional[Callable] = None,
+    ) -> IterDataPipe:  # fmt: skip
+        r"""
+        Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
+
+        By default, it uses :func:`torch.utils.data.default_collate`.
+
+        .. note::
+            While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
+            default behavior and `functools.partial` to specify any additional arguments.
+
+        Args:
+            datapipe: Iterable DataPipe being collated
+            collate_fn: Customized collate function to collect and combine data or a batch of data.
+                Default function collates to Tensor(s) based on data type.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> # Convert integer data to float Tensor
+            >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
+            ...     def __init__(self, start, end):
+            ...         super(MyIterDataPipe).__init__()
+            ...         assert end > start, "this example code only works with end >= start"
+            ...         self.start = start
+            ...         self.end = end
+            ...
+            ...     def __iter__(self):
+            ...         return iter(range(self.start, self.end))
+            ...
+            ...     def __len__(self):
+            ...         return self.end - self.start
+            ...
+            >>> ds = MyIterDataPipe(start=3, end=7)
+            >>> print(list(ds))
+            [3, 4, 5, 6]
+            >>> def collate_fn(batch):
+            ...     return torch.tensor(batch, dtype=torch.float)
+            ...
+            >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
+            >>> print(list(collated_ds))
+            [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
+        """
+    # Functional form of 'ConcaterIterDataPipe'
+    def concat(self, *datapipes: IterDataPipe) -> IterDataPipe:
+        r"""
+        Concatenates multiple Iterable DataPipes (functional name: ``concat``).
+
+        The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones.
+
+        Args:
+            datapipes: Iterable DataPipes being concatenated
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> import random
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp1 = IterableWrapper(range(3))
+            >>> dp2 = IterableWrapper(range(5))
+            >>> list(dp1.concat(dp2))
+            [0, 1, 2, 0, 1, 2, 3, 4]
+        """
+    # Functional form of 'DemultiplexerIterDataPipe'
+    def demux(
+        self,
+        num_instances: int,
+        classifier_fn: Callable[[_T_co], Optional[int]],
+        drop_none: bool = False,
+        buffer_size: int = 1000,
+    ) -> list[IterDataPipe]:
+        r"""
+        Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``).
+
+        A list of the child DataPipes is returned from this operation.
+
+        Args:
+            datapipe: Iterable DataPipe being filtered
+            num_instances: number of instances of the DataPipe to create
+            classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None``
+            drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None``
+            buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
+                DataPipes while waiting for their values to be yielded.
+                Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
+
+        Examples:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> def odd_or_even(n):
+            ...     return n % 2
+            >>> source_dp = IterableWrapper(range(5))
+            >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
+            >>> list(dp1)
+            [0, 2, 4]
+            >>> list(dp2)
+            [1, 3]
+            >>> # It can also filter out any element that gets `None` from the `classifier_fn`
+            >>> def odd_or_even_no_zero(n):
+            ...     return n % 2 if n != 0 else None
+            >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
+            >>> list(dp1)
+            [2, 4]
+            >>> list(dp2)
+            [1, 3]
+        """
+    # Functional form of 'FilterIterDataPipe'
+    def filter(self, filter_fn: Callable, input_col=None) -> IterDataPipe:
+        r"""
+        Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).
+
+        Args:
+            datapipe: Iterable DataPipe being filtered
+            filter_fn: Customized function mapping an element to a boolean.
+            input_col: Index or indices of data which ``filter_fn`` is applied, such as:
+
+                - ``None`` as default to apply ``filter_fn`` to the data directly.
+                - Integer(s) is used for list/tuple.
+                - Key(s) is used for dict.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> def is_even(n):
+            ...     return n % 2 == 0
+            >>> dp = IterableWrapper(range(5))
+            >>> filter_dp = dp.filter(filter_fn=is_even)
+            >>> list(filter_dp)
+            [0, 2, 4]
+        """
+    # Functional form of 'ForkerIterDataPipe'
+    def fork(
+        self,
+        num_instances: int,
+        buffer_size: int = 1000,
+        copy: Optional[Literal["shallow", "deep"]] = None,
+    ) -> list[IterDataPipe]:
+        r"""
+        Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``).
+
+        Args:
+            datapipe: Iterable DataPipe being copied
+            num_instances: number of instances of the datapipe to create
+            buffer_size: this restricts how far ahead the leading child DataPipe
+               can read relative to the slowest child DataPipe.
+               Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
+            copy: copy strategy to use for items yielded by each branch. Supported
+                options are ``None`` for no copying, ``"shallow"`` for shallow object
+                copies, and ``"deep"`` for deep object copies. Defaults to ``None``.
+
+        Note:
+            All branches of the forked pipeline return the identical object unless
+            the copy parameter is supplied. If the object is mutable or contains
+            mutable objects, changing them in one branch will affect all others.
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> source_dp = IterableWrapper(range(5))
+            >>> dp1, dp2 = source_dp.fork(num_instances=2)
+            >>> list(dp1)
+            [0, 1, 2, 3, 4]
+            >>> list(dp2)
+            [0, 1, 2, 3, 4]
+        """
+    # Functional form of 'GrouperIterDataPipe'
+    def groupby(
+        self,
+        group_key_fn: Callable[[_T_co], Any],
+        *,
+        keep_key: bool = False,
+        buffer_size: int = 10000,
+        group_size: Optional[int] = None,
+        guaranteed_group_size: Optional[int] = None,
+        drop_remaining: bool = False,
+    ) -> IterDataPipe:
+        r"""
+        Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``.
+
+        (functional name: ``groupby``).
+
+        The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group
+        will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full,
+        the DataPipe will yield the largest batch with the same key, provided that its size is larger
+        than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``.
+
+        After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity
+        will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``.
+
+        Args:
+            datapipe: Iterable datapipe to be grouped
+            group_key_fn: Function used to generate group key from the data of the source datapipe
+            keep_key: Option to yield the matching key along with the items in a tuple,
+                resulting in `(key, [items])` otherwise returning [items]
+            buffer_size: The size of buffer for ungrouped data
+            group_size: The max size of each group, a batch is yielded as soon as it reaches this size
+            guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full
+            drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer
+                when the buffer is full
+
+        Example:
+            >>> import os
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> def group_fn(file):
+            ...     return os.path.basename(file).split(".")[0]
+            >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
+            >>> dp0 = source_dp.groupby(group_key_fn=group_fn)
+            >>> list(dp0)
+            [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
+            >>> # A group is yielded as soon as its size equals to `group_size`
+            >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
+            >>> list(dp1)
+            [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
+            >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
+            >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
+            >>> list(dp2)
+            [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
+        """
+    # Functional form of 'FileListerIterDataPipe'
+    def list_files(
+        self,
+        masks: Union[str, list[str]] = "",
+        *,
+        recursive: bool = False,
+        abspath: bool = False,
+        non_deterministic: bool = False,
+        length: int = -1,
+    ) -> IterDataPipe:
+        r"""
+        Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory.
+
+        Multiple root directories can be provided (functional name: ``list_files``).
+
+        Args:
+            root: Root directory or a sequence of root directories
+            masks: Unix style filter string or string list for filtering file name(s)
+            recursive: Whether to return pathname from nested directories or not
+            abspath: Whether to return relative pathname or absolute pathname
+            non_deterministic: Whether to return pathname in sorted order or not.
+                If ``False``, the results yielded from each root directory will be sorted
+            length: Nominal length of the datapipe
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import FileLister
+            >>> dp = FileLister(root=".", recursive=True)
+            >>> list(dp)
+            ['example.py', './data/data.tar']
+        """
+    # Functional form of 'MapperIterDataPipe'
+    def map(
+        self,
+        fn: Callable,
+        input_col=None,
+        output_col=None,
+    ) -> IterDataPipe:
+        r"""
+        Applies a function over each item from the source DataPipe (functional name: ``map``).
+
+        The function can be any regular Python function or partial object. Lambda
+        function is not recommended as it is not supported by pickle.
+
+        Args:
+            datapipe: Source Iterable DataPipe
+            fn: Function being applied over each item
+            input_col: Index or indices of data which ``fn`` is applied, such as:
+
+                - ``None`` as default to apply ``fn`` to the data directly.
+                - Integer(s) is used for list/tuple.
+                - Key(s) is used for dict.
+
+            output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
+                only when ``input_col`` is not ``None``
+
+                - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
+                  multiple indices, the left-most one is used, and other indices will be removed.
+                - Integer is used for list/tuple. ``-1`` represents to append result at the end.
+                - Key is used for dict. New key is acceptable.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
+            >>> def add_one(x):
+            ...     return x + 1
+            >>> dp = IterableWrapper(range(10))
+            >>> map_dp_1 = dp.map(add_one)  # Invocation via functional form is preferred
+            >>> list(map_dp_1)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+            >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
+            >>> # Use `functools.partial` or explicitly define the function instead
+            >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
+            >>> list(map_dp_2)
+            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        """
+    # Functional form of 'MultiplexerIterDataPipe'
+    def mux(self, *datapipes) -> IterDataPipe:
+        r"""
+        Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``).
+
+        As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
+        and so on. It ends when the shortest input DataPipe is exhausted.
+
+        Args:
+            datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
+            >>> list(dp1.mux(dp2, dp3))
+            [0, 10, 20, 1, 11, 21, 2, 12, 22]
+        """
+    # Functional form of 'FileOpenerIterDataPipe'
+    def open_files(
+        self,
+        mode: str = "r",
+        encoding: Optional[str] = None,
+        length: int = -1,
+    ) -> IterDataPipe:
+        r"""
+        Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``).
+
+        Args:
+            datapipe: Iterable datapipe that provides pathnames
+            mode: An optional string that specifies the mode in which
+                the file is opened by ``open()``. It defaults to ``r``, other options are
+                ``b`` for reading in binary mode and ``t`` for text mode.
+            encoding: An optional string that specifies the encoding of the
+                underlying file. It defaults to ``None`` to match the default encoding of ``open``.
+            length: Nominal length of the datapipe
+
+        Note:
+            The opened file handles will be closed by Python's GC periodically. Users can choose
+            to close them explicitly.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
+            >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
+            >>> dp = FileOpener(dp)
+            >>> dp = StreamReader(dp)
+            >>> list(dp)
+            [('./abc.txt', 'abc')]
+        """
+    # Functional form of 'StreamReaderIterDataPipe'
+    def read_from_stream(self, chunk: Optional[int] = None) -> IterDataPipe:
+        r"""
+        Given IO streams and their label names, yield bytes with label name as tuple.
+
+        (functional name: ``read_from_stream``).
+
+        Args:
+            datapipe: Iterable DataPipe provides label/URL and byte stream
+            chunk: Number of bytes to be read from stream per iteration.
+                If ``None``, all bytes will be read until the EOF.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
+            >>> from io import StringIO
+            >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
+            >>> list(StreamReader(dp, chunk=1))
+            [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
+        """
+    # Functional form of 'RoutedDecoderIterDataPipe'
+    def routed_decode(
+        self,
+        *handlers: Callable,
+        key_fn: Callable = ...,
+    ) -> IterDataPipe:
+        r"""
+        Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple.
+
+        (functional name: ``routed_decode``)
+
+        Args:
+            datapipe: Iterable datapipe that provides pathname and binary stream in tuples
+            handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
+                handlers will be set as default. If multiple handles are provided, the priority
+                order follows the order of handlers (the first handler has the top priority)
+            key_fn: Function for decoder to extract key from pathname to dispatch handlers.
+                Default is set to extract file extension from pathname
+
+        Note:
+            When ``key_fn`` is specified returning anything other than extension, the default
+            handler will not work and users need to specify custom handler. Custom handler
+            could use regex to determine the eligibility to handle data.
+        """
+    # Functional form of 'ShardingFilterIterDataPipe'
+    def sharding_filter(self, sharding_group_filter=None) -> IterDataPipe:
+        r"""
+        Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``).
+
+        After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
+        original DataPipe, where `n` equals to the number of instances.
+
+        Args:
+            source_datapipe: Iterable DataPipe that will be sharded
+        """
+    # Functional form of 'ShufflerIterDataPipe'
+    def shuffle(
+        self,
+        *,
+        buffer_size: int = 10000,
+        unbatch_level: int = 0,
+    ) -> IterDataPipe:
+        r"""
+        Shuffle the input DataPipe with a buffer (functional name: ``shuffle``).
+
+        The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then,
+        each item will be yielded from the buffer by reservoir sampling via iterator.
+
+        ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the
+        datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
+        ``buffer_size`` is required to be greater than or equal to the size of datapipe.
+
+        When it is used with :class:`torch.utils.data.DataLoader`, the methods to
+        set up random seed are different based on :attr:`num_workers`.
+
+        For single-process mode (:attr:`num_workers == 0`), the random seed is set before
+        the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
+        mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
+        for each worker process.
+
+        Args:
+            datapipe: The IterDataPipe being shuffled
+            buffer_size: The buffer size for shuffling (default to ``10000``)
+            unbatch_level: Specifies if it is necessary to unbatch source data before
+                applying the shuffle
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp = IterableWrapper(range(10))
+            >>> shuffle_dp = dp.shuffle()
+            >>> list(shuffle_dp)
+            [0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
+        """
+    # Functional form of 'UnBatcherIterDataPipe'
+    def unbatch(self, unbatch_level: int = 1) -> IterDataPipe:
+        r"""
+        Undos batching of data (functional name: ``unbatch``).
+
+        In other words, it flattens the data up to the specified level within a batched DataPipe.
+
+        Args:
+            datapipe: Iterable DataPipe being un-batched
+            unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``,
+                it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe.
+
+        Example:
+            >>> # xdoctest: +SKIP
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]])
+            >>> dp1 = source_dp.unbatch()
+            >>> list(dp1)
+            [[0, 1], [2], [3, 4], [5], [6]]
+            >>> dp2 = source_dp.unbatch(unbatch_level=2)
+            >>> list(dp2)
+            [0, 1, 2, 3, 4, 5, 6]
+        """
+    # Functional form of 'ZipperIterDataPipe'
+    def zip(self, *datapipes: IterDataPipe) -> IterDataPipe:
+        r"""
+        Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
+
+        The output is stopped as soon as the shortest input DataPipe is exhausted.
+
+        Args:
+            *datapipes: Iterable DataPipes being aggregated
+
+        Example:
+            >>> # xdoctest: +REQUIRES(module:torchdata)
+            >>> from torchdata.datapipes.iter import IterableWrapper
+            >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
+            >>> list(dp1.zip(dp2, dp3))
+            [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
+        """
+
+class DFIterDataPipe(IterDataPipe):
+    def _is_dfpipe(self): ...
+    def __iter__(self): ...
+
+class _DataPipeSerializationWrapper:
+    def __init__(self, datapipe): ...
+    def __getstate__(self): ...
+    def __setstate__(self, state): ...
+    def __len__(self): ...
+
+class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
+    def __iter__(self): ...
+
+class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
+    def __getitem__(self, idx): ...
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/gen_pyi.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/gen_pyi.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce38547986b68cef78e7b9547bf22298a70214a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/gen_pyi.py
@@ -0,0 +1,336 @@
+# mypy: allow-untyped-defs
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Union
+from typing_extensions import deprecated
+
+
+try:
+    from torchgen.api.python import format_function_signature
+    from torchgen.utils import FileManager as FileManager
+except ImportError:
+    import sys
+
+    REPO_ROOT = Path(__file__).absolute().parents[4]
+    sys.path.insert(0, str(REPO_ROOT))
+
+    from torchgen.api.python import format_function_signature
+    from torchgen.utils import FileManager
+
+    if len(sys.path) > 0 and sys.path[0] == str(REPO_ROOT):
+        del sys.path[0]
+
+
+__all__: list[str] = []  # not intended to expose any symbols
+
+
+def __dir__() -> list[str]:
+    return []  # appease public API test
+
+
+@deprecated(
+    "`torch.utils.data.datapipes.gen_pyi.materialize_lines` is deprecated and will be removed in the future.",
+    category=FutureWarning,
+)
+def materialize_lines(lines: list[str], indentation: int) -> str:
+    output = ""
+    new_line_with_indent = "\n" + " " * indentation
+    for i, line in enumerate(lines):
+        if i != 0:
+            output += new_line_with_indent
+        output += line.replace("\n", new_line_with_indent)
+    return output
+
+
+@deprecated(
+    "`torch.utils.data.datapipes.gen_pyi.gen_from_template` is deprecated and will be removed in the future.",
+    category=FutureWarning,
+)
+def gen_from_template(
+    dir: str,
+    template_name: str,
+    output_name: str,
+    replacements: list[tuple[str, Any, int]],
+):
+    template_path = os.path.join(dir, template_name)
+    output_path = os.path.join(dir, output_name)
+
+    with open(template_path, encoding="utf-8") as f:
+        content = f.read()
+    for placeholder, lines, indentation in replacements:
+        with open(output_path, "w", encoding="utf-8") as f:
+            content = content.replace(
+                placeholder, materialize_lines(lines, indentation)
+            )
+            f.write(content)
+
+
+def find_file_paths(dir_paths: list[str], files_to_exclude: set[str]) -> set[str]:
+    """
+    When given a path to a directory, returns the paths to the relevant files within it.
+
+    This function does NOT recursive traverse to subdirectories.
+    """
+    paths: set[str] = set()
+    for dir_path in dir_paths:
+        all_files = os.listdir(dir_path)
+        python_files = {fname for fname in all_files if ".py" == fname[-3:]}
+        filter_files = {
+            fname for fname in python_files if fname not in files_to_exclude
+        }
+        paths.update({os.path.join(dir_path, fname) for fname in filter_files})
+    return paths
+
+
+def extract_method_name(line: str) -> str:
+    """Extract method name from decorator in the form of "@functional_datapipe({method_name})"."""
+    if '("' in line:
+        start_token, end_token = '("', '")'
+    elif "('" in line:
+        start_token, end_token = "('", "')"
+    else:
+        raise RuntimeError(
+            f"Unable to find appropriate method name within line:\n{line}"
+        )
+    start, end = line.find(start_token) + len(start_token), line.find(end_token)
+    return line[start:end]
+
+
+def extract_class_name(line: str) -> str:
+    """Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):"."""
+    start_token = "class "
+    end_token = "("
+    start, end = line.find(start_token) + len(start_token), line.find(end_token)
+    return line[start:end]
+
+
+def parse_datapipe_file(
+    file_path: str,
+) -> tuple[dict[str, list[str]], dict[str, str], set[str], dict[str, list[str]]]:
+    """Given a path to file, parses the file and returns a dictionary of method names to function signatures."""
+    method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
+    doc_string_dict = defaultdict(list)
+    with open(file_path, encoding="utf-8") as f:
+        open_paren_count = 0
+        method_name, class_name, signature = "", "", ""
+        skip = False
+        for line in f:
+            if line.count('"""') % 2 == 1:
+                skip = not skip
+            if skip or '"""' in line:  # Saving docstrings
+                doc_string_dict[method_name].append(line)
+                continue
+            if "@functional_datapipe" in line:
+                method_name = extract_method_name(line)
+                doc_string_dict[method_name] = []
+                continue
+            if method_name and "class " in line:
+                class_name = extract_class_name(line)
+                continue
+            if method_name and ("def __init__(" in line or "def __new__(" in line):
+                if "def __new__(" in line:
+                    special_output_type.add(method_name)
+                open_paren_count += 1
+                start = line.find("(") + len("(")
+                line = line[start:]
+            if open_paren_count > 0:
+                open_paren_count += line.count("(")
+                open_paren_count -= line.count(")")
+                if open_paren_count == 0:
+                    end = line.rfind(")")
+                    signature += line[:end]
+                    method_to_signature[method_name] = process_signature(signature)
+                    method_to_class_name[method_name] = class_name
+                    method_name, class_name, signature = "", "", ""
+                elif open_paren_count < 0:
+                    raise RuntimeError(
+                        "open parenthesis count < 0. This shouldn't be possible."
+                    )
+                else:
+                    signature += line.strip()
+    return (
+        method_to_signature,
+        method_to_class_name,
+        special_output_type,
+        doc_string_dict,
+    )
+
+
+def parse_datapipe_files(
+    file_paths: set[str],
+) -> tuple[dict[str, list[str]], dict[str, str], set[str], dict[str, list[str]]]:
+    methods_and_signatures = {}
+    methods_and_class_names = {}
+    methods_with_special_output_types = set()
+    methods_and_doc_strings = {}
+    for path in file_paths:
+        (
+            method_to_signature,
+            method_to_class_name,
+            methods_needing_special_output_types,
+            doc_string_dict,
+        ) = parse_datapipe_file(path)
+        methods_and_signatures.update(method_to_signature)
+        methods_and_class_names.update(method_to_class_name)
+        methods_with_special_output_types.update(methods_needing_special_output_types)
+        methods_and_doc_strings.update(doc_string_dict)
+    return (
+        methods_and_signatures,
+        methods_and_class_names,
+        methods_with_special_output_types,
+        methods_and_doc_strings,
+    )
+
+
+def split_outside_bracket(line: str, delimiter: str = ",") -> list[str]:
+    """Given a line of text, split it on comma unless the comma is within a bracket '[]'."""
+    bracket_count = 0
+    curr_token = ""
+    res = []
+    for char in line:
+        if char == "[":
+            bracket_count += 1
+        elif char == "]":
+            bracket_count -= 1
+        elif char == delimiter and bracket_count == 0:
+            res.append(curr_token)
+            curr_token = ""
+            continue
+        curr_token += char
+    res.append(curr_token)
+    return res
+
+
+def process_signature(line: str) -> list[str]:
+    """
+    Clean up a given raw function signature.
+
+    This includes removing the self-referential datapipe argument, default
+    arguments of input functions, newlines, and spaces.
+    """
+    tokens: list[str] = split_outside_bracket(line)
+    for i, token in enumerate(tokens):
+        tokens[i] = token.strip(" ")
+        if token == "cls":
+            tokens[i] = "self"
+        elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"):
+            # Remove the datapipe after 'self' or 'cls' unless it has '*'
+            tokens[i] = ""
+        elif "Callable =" in token:  # Remove default argument if it is a function
+            head = token.rpartition("=")[0]
+            tokens[i] = head.strip(" ") + " = ..."
+    tokens = [t for t in tokens if t != ""]
+    return tokens
+
+
+def get_method_definitions(
+    file_path: Union[str, list[str]],
+    files_to_exclude: set[str],
+    deprecated_files: set[str],
+    default_output_type: str,
+    method_to_special_output_type: dict[str, str],
+    root: str = "",
+) -> list[str]:
+    """
+    #.pyi generation for functional DataPipes Process.
+
+    # 1. Find files that we want to process (exclude the ones who don't)
+    # 2. Parse method name and signature
+    # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
+    """
+    if root == "":
+        root = str(Path(__file__).parent.resolve())
+    file_path = [file_path] if isinstance(file_path, str) else file_path
+    file_path = [os.path.join(root, path) for path in file_path]
+    file_paths = find_file_paths(
+        file_path, files_to_exclude=files_to_exclude.union(deprecated_files)
+    )
+    (
+        methods_and_signatures,
+        methods_and_class_names,
+        methods_w_special_output_types,
+        methods_and_doc_strings,
+    ) = parse_datapipe_files(file_paths)
+
+    for fn_name in method_to_special_output_type:
+        if fn_name not in methods_w_special_output_types:
+            methods_w_special_output_types.add(fn_name)
+
+    method_definitions = []
+    for method_name, arguments in methods_and_signatures.items():
+        class_name = methods_and_class_names[method_name]
+        if method_name in methods_w_special_output_types:
+            output_type = method_to_special_output_type[method_name]
+        else:
+            output_type = default_output_type
+        doc_string = "".join(methods_and_doc_strings[method_name])
+        if doc_string == "":
+            doc_string = " ..."
+        else:
+            doc_string = "\n" + doc_string
+        definition = format_function_signature(method_name, arguments, output_type)
+        method_definitions.append(
+            f"# Functional form of '{class_name}'\n"
+            + definition.removesuffix("...").rstrip()  # remove "..."
+            + doc_string,
+        )
+    method_definitions.sort(
+        key=lambda s: s.split("\n")[1]
+    )  # sorting based on method_name
+
+    return method_definitions
+
+
+# Defined outside of main() so they can be imported by TorchData
+iterDP_file_path: str = "iter"
+iterDP_files_to_exclude: set[str] = {"__init__.py", "utils.py"}
+iterDP_deprecated_files: set[str] = set()
+iterDP_method_to_special_output_type: dict[str, str] = {
+    "demux": "list[IterDataPipe]",
+    "fork": "list[IterDataPipe]",
+}
+
+mapDP_file_path: str = "map"
+mapDP_files_to_exclude: set[str] = {"__init__.py", "utils.py"}
+mapDP_deprecated_files: set[str] = set()
+mapDP_method_to_special_output_type: dict[str, str] = {"shuffle": "IterDataPipe"}
+
+
+def main() -> None:
+    """
+    # Inject file into template datapipe.pyi.in.
+
+    TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
+          interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
+    """
+    iter_method_definitions = get_method_definitions(
+        iterDP_file_path,
+        iterDP_files_to_exclude,
+        iterDP_deprecated_files,
+        "IterDataPipe",
+        iterDP_method_to_special_output_type,
+    )
+
+    map_method_definitions = get_method_definitions(
+        mapDP_file_path,
+        mapDP_files_to_exclude,
+        mapDP_deprecated_files,
+        "MapDataPipe",
+        mapDP_method_to_special_output_type,
+    )
+
+    path = Path(__file__).absolute().parent
+    fm = FileManager(install_dir=path, template_dir=path, dry_run=False)
+    fm.write_with_template(
+        "datapipe.pyi",
+        "datapipe.pyi.in",
+        lambda: {
+            "IterDataPipeMethods": iter_method_definitions,
+            "MapDataPipeMethods": map_method_definitions,
+        },
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..37d1664753b151f34d6d0461bb597803f3ffd40e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__init__.py
@@ -0,0 +1,65 @@
+from torch.utils.data.datapipes.iter.callable import (
+    CollatorIterDataPipe as Collator,
+    MapperIterDataPipe as Mapper,
+)
+from torch.utils.data.datapipes.iter.combinatorics import (
+    SamplerIterDataPipe as Sampler,
+    ShufflerIterDataPipe as Shuffler,
+)
+from torch.utils.data.datapipes.iter.combining import (
+    ConcaterIterDataPipe as Concater,
+    DemultiplexerIterDataPipe as Demultiplexer,
+    ForkerIterDataPipe as Forker,
+    MultiplexerIterDataPipe as Multiplexer,
+    ZipperIterDataPipe as Zipper,
+)
+from torch.utils.data.datapipes.iter.filelister import (
+    FileListerIterDataPipe as FileLister,
+)
+from torch.utils.data.datapipes.iter.fileopener import (
+    FileOpenerIterDataPipe as FileOpener,
+)
+from torch.utils.data.datapipes.iter.grouping import (
+    BatcherIterDataPipe as Batcher,
+    GrouperIterDataPipe as Grouper,
+    UnBatcherIterDataPipe as UnBatcher,
+)
+from torch.utils.data.datapipes.iter.routeddecoder import (
+    RoutedDecoderIterDataPipe as RoutedDecoder,
+)
+from torch.utils.data.datapipes.iter.selecting import FilterIterDataPipe as Filter
+from torch.utils.data.datapipes.iter.sharding import (
+    ShardingFilterIterDataPipe as ShardingFilter,
+)
+from torch.utils.data.datapipes.iter.streamreader import (
+    StreamReaderIterDataPipe as StreamReader,
+)
+from torch.utils.data.datapipes.iter.utils import (
+    IterableWrapperIterDataPipe as IterableWrapper,
+)
+
+
+__all__ = [
+    "Batcher",
+    "Collator",
+    "Concater",
+    "Demultiplexer",
+    "FileLister",
+    "FileOpener",
+    "Filter",
+    "Forker",
+    "Grouper",
+    "IterableWrapper",
+    "Mapper",
+    "Multiplexer",
+    "RoutedDecoder",
+    "Sampler",
+    "ShardingFilter",
+    "Shuffler",
+    "StreamReader",
+    "UnBatcher",
+    "Zipper",
+]
+
+# Please keep this list sorted
+assert __all__ == sorted(__all__)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac5fc86733d069a23dfd4640a139cc8c2b80159d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b31fae20bef54b391b01341217745b9b097c761e
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca0e057d2272450b8027252a62819a04095fa5ae
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cecc7147f8770d79293a23ccc9b68664e6e87feb
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8e0e2accae6e49f8bdbcb80f66b53acc7339296d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8328ccf63cbb4675a673069b0ad89dd97f23b60
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef21fa8feead0cca74f5a97e77653183d2937698
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dcbc9ee10267d6c447af553bfb699a120637570
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc3d4ad99baef2f5bfaa38c932f91007a22a6f95
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26457d65af185d2a9d7454acb7c0ed8d6fd763b6
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ace919b602f283790680cdecaf810eb3ee9e2a2
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e31ffd9046d01205676602b89b948968b936b76
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/callable.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/callable.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0e4191fd20a2d0e67903f92517daeb5c56f568b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/callable.py
@@ -0,0 +1,242 @@
+# mypy: allow-untyped-defs
+import functools
+from collections import namedtuple
+from collections.abc import Iterator, Sized
+from typing import Any, Callable, Optional, TypeVar, Union
+
+from torch.utils.data._utils.collate import default_collate
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+from torch.utils.data.datapipes.utils.common import (
+    _check_unpickable_fn,
+    validate_input_col,
+)
+
+
+__all__ = [
+    "CollatorIterDataPipe",
+    "MapperIterDataPipe",
+]
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+@functional_datapipe("map")
+class MapperIterDataPipe(IterDataPipe[_T_co]):
+    r"""
+    Applies a function over each item from the source DataPipe (functional name: ``map``).
+
+    The function can be any regular Python function or partial object. Lambda
+    function is not recommended as it is not supported by pickle.
+
+    Args:
+        datapipe: Source Iterable DataPipe
+        fn: Function being applied over each item
+        input_col: Index or indices of data which ``fn`` is applied, such as:
+
+            - ``None`` as default to apply ``fn`` to the data directly.
+            - Integer(s) is used for list/tuple.
+            - Key(s) is used for dict.
+
+        output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
+            only when ``input_col`` is not ``None``
+
+            - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
+              multiple indices, the left-most one is used, and other indices will be removed.
+            - Integer is used for list/tuple. ``-1`` represents to append result at the end.
+            - Key is used for dict. New key is acceptable.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
+        >>> def add_one(x):
+        ...     return x + 1
+        >>> dp = IterableWrapper(range(10))
+        >>> map_dp_1 = dp.map(add_one)  # Invocation via functional form is preferred
+        >>> list(map_dp_1)
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
+        >>> # Use `functools.partial` or explicitly define the function instead
+        >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
+        >>> list(map_dp_2)
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+    """
+
+    datapipe: IterDataPipe
+    fn: Callable
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe,
+        fn: Callable,
+        input_col=None,
+        output_col=None,
+    ) -> None:
+        super().__init__()
+        self.datapipe = datapipe
+
+        _check_unpickable_fn(fn)
+        self.fn = fn  # type: ignore[assignment]
+
+        self.input_col = input_col
+        if input_col is None and output_col is not None:
+            raise ValueError("`output_col` must be None when `input_col` is None.")
+        if isinstance(output_col, (list, tuple)):
+            if len(output_col) > 1:
+                raise ValueError("`output_col` must be a single-element list or tuple")
+            output_col = output_col[0]
+        self.output_col = output_col
+        validate_input_col(fn, input_col)
+
+    def _apply_fn(self, data):
+        if self.input_col is None and self.output_col is None:
+            return self.fn(data)
+
+        if self.input_col is None:
+            res = self.fn(data)
+        elif isinstance(self.input_col, (list, tuple)):
+            args = tuple(data[col] for col in self.input_col)
+            res = self.fn(*args)
+        else:
+            res = self.fn(data[self.input_col])
+
+        # Copy tuple to list and run in-place modification because tuple is immutable.
+        if isinstance(data, tuple):
+            t_flag = True
+            data = list(data)
+        else:
+            t_flag = False
+
+        if self.output_col is None:
+            if isinstance(self.input_col, (list, tuple)):
+                data[self.input_col[0]] = res
+                for idx in sorted(self.input_col[1:], reverse=True):
+                    del data[idx]
+            else:
+                data[self.input_col] = res
+        else:
+            if self.output_col == -1:
+                data.append(res)
+            else:
+                data[self.output_col] = res
+
+        # Convert list back to tuple
+        return tuple(data) if t_flag else data
+
+    def __iter__(self) -> Iterator[_T_co]:
+        for data in self.datapipe:
+            yield self._apply_fn(data)
+
+    def __len__(self) -> int:
+        if isinstance(self.datapipe, Sized):
+            return len(self.datapipe)
+        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
+
+
+def _collate_helper(conversion, item):
+    # TODO(VitalyFedyunin): Verify that item is any sort of batch
+    if len(item.items) > 1:
+        # TODO(VitalyFedyunin): Compact all batch dataframes into one
+        raise RuntimeError("Only supports one DataFrame per batch")
+    df = item[0]
+    columns_name = df_wrapper.get_columns(df)
+    tuple_names: list = []
+    tuple_values: list = []
+
+    for name in conversion.keys():
+        if name not in columns_name:
+            raise RuntimeError("Conversion keys missmatch")
+
+    for name in columns_name:
+        if name in conversion:
+            if not callable(conversion[name]):
+                raise RuntimeError(
+                    "Collate (DF)DataPipe requires callable as dict values"
+                )
+            collation_fn = conversion[name]
+        else:
+            # TODO(VitalyFedyunin): Add default collation into df_wrapper
+            try:
+                import torcharrow.pytorch as tap  # type: ignore[import]
+
+                collation_fn = tap.rec.Default()
+            except Exception as e:
+                raise RuntimeError(
+                    "unable to import default collation function from the TorchArrow"
+                ) from e
+
+        tuple_names.append(str(name))
+        value = collation_fn(df[name])
+        tuple_values.append(value)
+
+    # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here
+    # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty
+    tpl_cls = namedtuple("CollateResult", tuple_names)  # type: ignore[misc]
+    tuple = tpl_cls(*tuple_values)
+    return tuple
+
+
+@functional_datapipe("collate")
+class CollatorIterDataPipe(MapperIterDataPipe):
+    r"""
+    Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
+
+    By default, it uses :func:`torch.utils.data.default_collate`.
+
+    .. note::
+        While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
+        default behavior and `functools.partial` to specify any additional arguments.
+
+    Args:
+        datapipe: Iterable DataPipe being collated
+        collate_fn: Customized collate function to collect and combine data or a batch of data.
+            Default function collates to Tensor(s) based on data type.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> # Convert integer data to float Tensor
+        >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
+        ...     def __init__(self, start, end):
+        ...         super(MyIterDataPipe).__init__()
+        ...         assert end > start, "this example code only works with end >= start"
+        ...         self.start = start
+        ...         self.end = end
+        ...
+        ...     def __iter__(self):
+        ...         return iter(range(self.start, self.end))
+        ...
+        ...     def __len__(self):
+        ...         return self.end - self.start
+        ...
+        >>> ds = MyIterDataPipe(start=3, end=7)
+        >>> print(list(ds))
+        [3, 4, 5, 6]
+        >>> def collate_fn(batch):
+        ...     return torch.tensor(batch, dtype=torch.float)
+        ...
+        >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
+        >>> print(list(collated_ds))
+        [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
+    """
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe,
+        conversion: Union[
+            Callable[..., Any], dict[Union[str, Any], Union[Callable, Any]], None
+        ] = default_collate,
+        collate_fn: Optional[Callable] = None,
+    ) -> None:
+        # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]`
+        # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]`
+        if collate_fn is not None:
+            super().__init__(datapipe, fn=collate_fn)
+        else:
+            if callable(conversion):
+                super().__init__(datapipe, fn=conversion)
+            else:
+                # TODO(VitalyFedyunin): Validate passed dictionary
+                collate_fn = functools.partial(_collate_helper, conversion)
+                super().__init__(datapipe, fn=collate_fn)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combinatorics.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combinatorics.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c602ce4eeda0cf54643556f798c9ec9cd03a14d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combinatorics.py
@@ -0,0 +1,190 @@
+# mypy: allow-untyped-defs
+import random
+from collections.abc import Iterator, Sized
+from typing import Optional, TypeVar
+
+import torch
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+from torch.utils.data.sampler import Sampler, SequentialSampler
+
+
+__all__ = [
+    "SamplerIterDataPipe",
+    "ShufflerIterDataPipe",
+]
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+class SamplerIterDataPipe(IterDataPipe[_T_co]):
+    r"""
+    Generate sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`).
+
+    Args:
+        datapipe: IterDataPipe to sample from
+        sampler: Sampler class to generate sample elements from input DataPipe.
+            Default is :class:`SequentialSampler` for IterDataPipe
+    """
+
+    datapipe: IterDataPipe
+    sampler: Sampler
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe,
+        sampler: type[Sampler] = SequentialSampler,
+        sampler_args: Optional[tuple] = None,
+        sampler_kwargs: Optional[dict] = None,
+    ) -> None:
+        assert isinstance(
+            datapipe, Sized
+        ), "Sampler class requires input datapipe implemented `__len__`"
+        super().__init__()
+        self.datapipe = datapipe
+        self.sampler_args = () if sampler_args is None else sampler_args
+        self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
+        # https://github.com/python/mypy/pull/9629 will solve
+        self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs)  # type: ignore[misc]
+
+    def __iter__(self) -> Iterator[_T_co]:
+        return iter(self.sampler)
+
+    def __len__(self) -> int:
+        # Dataset has been tested as `Sized`
+        if isinstance(self.sampler, Sized):
+            return len(self.sampler)
+        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
+
+
+@functional_datapipe("shuffle")
+class ShufflerIterDataPipe(IterDataPipe[_T_co]):
+    r"""
+    Shuffle the input DataPipe with a buffer (functional name: ``shuffle``).
+
+    The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then,
+    each item will be yielded from the buffer by reservoir sampling via iterator.
+
+    ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the
+    datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
+    ``buffer_size`` is required to be greater than or equal to the size of datapipe.
+
+    When it is used with :class:`torch.utils.data.DataLoader`, the methods to
+    set up random seed are different based on :attr:`num_workers`.
+
+    For single-process mode (:attr:`num_workers == 0`), the random seed is set before
+    the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
+    mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
+    for each worker process.
+
+    Args:
+        datapipe: The IterDataPipe being shuffled
+        buffer_size: The buffer size for shuffling (default to ``10000``)
+        unbatch_level: Specifies if it is necessary to unbatch source data before
+            applying the shuffle
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> dp = IterableWrapper(range(10))
+        >>> shuffle_dp = dp.shuffle()
+        >>> list(shuffle_dp)
+        [0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
+    """
+
+    datapipe: IterDataPipe[_T_co]
+    buffer_size: int
+    _buffer: list[_T_co]
+    _enabled: bool
+    _seed: Optional[int]
+    _rng: random.Random
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe[_T_co],
+        *,
+        buffer_size: int = 10000,
+        unbatch_level: int = 0,
+    ) -> None:
+        super().__init__()
+        # TODO: Performance optimization
+        #       buffer can be a fixed size and remove expensive `append()` and `len()` operations
+        self._buffer: list[_T_co] = []
+        assert buffer_size > 0, "buffer_size should be larger than 0"
+        if unbatch_level == 0:
+            self.datapipe = datapipe
+        else:
+            self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level)
+        self.buffer_size = buffer_size
+        self._enabled = True
+        self._seed = None
+        self._rng = random.Random()
+
+    def set_shuffle(self, shuffle=True):
+        self._enabled = shuffle
+        return self
+
+    def set_seed(self, seed: int):
+        self._seed = seed
+        return self
+
+    def __iter__(self) -> Iterator[_T_co]:
+        if not self._enabled:
+            yield from self.datapipe
+        else:
+            for x in self.datapipe:
+                if len(self._buffer) == self.buffer_size:
+                    idx = self._rng.randint(0, len(self._buffer) - 1)
+                    val, self._buffer[idx] = self._buffer[idx], x
+                    yield val
+                else:
+                    self._buffer.append(x)
+            while self._buffer:
+                idx = self._rng.randint(0, len(self._buffer) - 1)
+                yield self._buffer.pop(idx)
+
+    def __len__(self) -> int:
+        if isinstance(self.datapipe, Sized):
+            return len(self.datapipe)
+        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
+
+    def reset(self) -> None:
+        self._buffer = []
+        if self._enabled:
+            if self._seed is None:
+                self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
+            self._rng.seed(self._seed)
+            self._seed = None
+
+    def __getstate__(self):
+        state = (
+            self.datapipe,
+            self.buffer_size,
+            self._enabled,
+            self._seed,
+            self._buffer,
+            self._rng.getstate(),
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        )
+        if IterDataPipe.getstate_hook is not None:
+            return IterDataPipe.getstate_hook(state)
+        return state
+
+    def __setstate__(self, state):
+        (
+            self.datapipe,
+            self.buffer_size,
+            self._enabled,
+            self._seed,
+            self._buffer,
+            rng_state,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        ) = state
+        self._rng = random.Random()
+        self._rng.setstate(rng_state)
+
+    def __del__(self):
+        self._buffer.clear()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combining.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combining.py
new file mode 100644
index 0000000000000000000000000000000000000000..deaca079c68c0fc138505e839e7cff2c78c38dcc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/combining.py
@@ -0,0 +1,696 @@
+# mypy: allow-untyped-defs
+import copy as copymodule
+import warnings
+from abc import ABC, abstractmethod
+from collections import deque
+from collections.abc import Iterator, Sized
+from typing import Any, Callable, Literal, Optional, TypeVar
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes._hook_iterator import _SnapshotState
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, StreamWrapper
+
+
+__all__ = [
+    "ConcaterIterDataPipe",
+    "DemultiplexerIterDataPipe",
+    "ForkerIterDataPipe",
+    "MultiplexerIterDataPipe",
+    "ZipperIterDataPipe",
+]
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+@functional_datapipe("concat")
+class ConcaterIterDataPipe(IterDataPipe):
+    r"""
+    Concatenates multiple Iterable DataPipes (functional name: ``concat``).
+
+    The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones.
+
+    Args:
+        datapipes: Iterable DataPipes being concatenated
+
+    Example:
+        >>> # xdoctest: +REQUIRES(module:torchdata)
+        >>> import random
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> dp1 = IterableWrapper(range(3))
+        >>> dp2 = IterableWrapper(range(5))
+        >>> list(dp1.concat(dp2))
+        [0, 1, 2, 0, 1, 2, 3, 4]
+    """
+
+    datapipes: tuple[IterDataPipe]
+
+    def __init__(self, *datapipes: IterDataPipe):
+        if len(datapipes) == 0:
+            raise ValueError("Expected at least one DataPipe, but got nothing")
+        if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
+            raise TypeError("Expected all inputs to be `IterDataPipe`")
+        self.datapipes = datapipes  # type: ignore[assignment]
+
+    def __iter__(self) -> Iterator:
+        for dp in self.datapipes:
+            yield from dp
+
+    def __len__(self) -> int:
+        if all(isinstance(dp, Sized) for dp in self.datapipes):
+            return sum(len(dp) for dp in self.datapipes)
+        else:
+            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
+
+
+@functional_datapipe("fork")
+class ForkerIterDataPipe(IterDataPipe):
+    r"""
+    Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``).
+
+    Args:
+        datapipe: Iterable DataPipe being copied
+        num_instances: number of instances of the datapipe to create
+        buffer_size: this restricts how far ahead the leading child DataPipe
+           can read relative to the slowest child DataPipe.
+           Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
+        copy: copy strategy to use for items yielded by each branch. Supported
+            options are ``None`` for no copying, ``"shallow"`` for shallow object
+            copies, and ``"deep"`` for deep object copies. Defaults to ``None``.
+
+    Note:
+        All branches of the forked pipeline return the identical object unless
+        the copy parameter is supplied. If the object is mutable or contains
+        mutable objects, changing them in one branch will affect all others.
+
+    Example:
+        >>> # xdoctest: +REQUIRES(module:torchdata)
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> source_dp = IterableWrapper(range(5))
+        >>> dp1, dp2 = source_dp.fork(num_instances=2)
+        >>> list(dp1)
+        [0, 1, 2, 3, 4]
+        >>> list(dp2)
+        [0, 1, 2, 3, 4]
+    """
+
+    def __new__(
+        cls,
+        datapipe: IterDataPipe,
+        num_instances: int,
+        buffer_size: int = 1000,
+        copy: Optional[Literal["shallow", "deep"]] = None,
+    ):
+        if num_instances < 1:
+            raise ValueError(
+                f"Expected `num_instances` larger than 0, but {num_instances} is found"
+            )
+        if num_instances == 1:
+            return datapipe
+        container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size, copy)  # type: ignore[abstract]
+        return [_ChildDataPipe(container, i) for i in range(num_instances)]
+
+
+class _ContainerTemplate(ABC):
+    r"""Abstract class for container ``DataPipes``. The followings are three required methods."""
+
+    @abstractmethod
+    def get_next_element_by_instance(self, instance_id: int):
+        ...
+
+    @abstractmethod
+    def is_every_instance_exhausted(self) -> bool:
+        ...
+
+    @abstractmethod
+    def reset(self) -> None:
+        ...
+
+    @abstractmethod
+    def get_length_by_instance(self, instance_id: int):
+        r"""Raise TypeError if it's not supposed to be implemented to support `list(datapipe)`."""
+
+
+def _no_op(x):
+    return x
+
+
+class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
+    r"""
+    Container to hold instance-specific information on behalf of ForkerIterDataPipe.
+
+    It tracks the state of its child DataPipes, maintains the buffer, and yields the next value
+    as requested by the child DataPipes.
+    """
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe,
+        num_instances: int,
+        buffer_size: int = 1000,
+        copy: Optional[Literal["shallow", "deep"]] = None,
+    ):
+        self.main_datapipe = datapipe
+        self._datapipe_iterator: Optional[Iterator[Any]] = None
+        self.num_instances = num_instances
+        self.buffer: deque = deque()
+        self.buffer_size = buffer_size
+        if self.buffer_size < 0:
+            warnings.warn(
+                "Unlimited buffer size is set for `fork`, "
+                "please be aware of OOM at random places",
+                UserWarning,
+            )
+        if copy is None:
+            self.copy_fn = _no_op
+        elif copy == "shallow":
+            self.copy_fn = copymodule.copy
+        elif copy == "deep":
+            self.copy_fn = copymodule.deepcopy
+        else:
+            raise ValueError(
+                f"Unknown copy method `{copy}` requested, choose one of None, `shallow` or `deep`."
+            )
+
+        self.child_pointers: list[int] = [
+            0
+        ] * num_instances  # Indicate the indices of the next element to get
+        self.slowest_ptr = 0  # The index to read by the slowest child
+        self.leading_ptr = 0  # The index to read by the fastest child
+        self.end_ptr: Optional[int] = None  # The index to stop child
+        self._child_stop: list[bool] = [True for _ in range(num_instances)]
+
+    def __len__(self):
+        return len(self.main_datapipe)
+
+    def get_next_element_by_instance(self, instance_id: int):
+        if self._datapipe_iterator is None and self._child_stop[instance_id]:
+            self._datapipe_iterator = iter(self.main_datapipe)
+            self._snapshot_state = _SnapshotState.Iterating
+            for i in range(self.num_instances):
+                self._child_stop[i] = False
+        try:
+            while not self._child_stop[instance_id]:
+                self.child_pointers[instance_id] += 1
+                if (
+                    self.end_ptr is not None
+                    and self.child_pointers[instance_id] == self.end_ptr
+                ):
+                    self._child_stop[instance_id] = True
+                    break
+                # Use buffer
+                if self.buffer and self.child_pointers[instance_id] <= self.leading_ptr:
+                    idx = self.child_pointers[instance_id] - self.slowest_ptr - 1
+                    return_val = self.buffer[idx]
+                else:  # Retrieve one element from main datapipe
+                    self.leading_ptr = self.child_pointers[instance_id]
+                    try:
+                        return_val = next(self._datapipe_iterator)  # type: ignore[arg-type]
+                        self.buffer.append(return_val)
+                    except StopIteration:
+                        self._child_stop[instance_id] = True
+                        self._datapipe_iterator = None
+                        self.end_ptr = self.leading_ptr
+                        continue
+                if self.child_pointers[instance_id] == self.slowest_ptr + 1:
+                    new_min = min(
+                        self.child_pointers
+                    )  # Can optimize by avoiding the call to min()
+                    if self.slowest_ptr < new_min:
+                        self.slowest_ptr = new_min
+                        self.buffer.popleft()
+                if (
+                    self.buffer_size >= 0
+                    and self.leading_ptr > self.buffer_size + self.slowest_ptr
+                ):
+                    raise BufferError(
+                        "ForkerIterDataPipe buffer overflow,"
+                        + f"buffer size {self.buffer_size} is insufficient."
+                    )
+
+                yield self.copy_fn(return_val)  # type: ignore[possibly-undefined]
+        finally:
+            self._child_stop[instance_id] = True
+            # Cleanup _datapipe_iterator for the case that fork exits earlier
+            if all(self._child_stop):
+                self._datapipe_iterator = None
+                self._cleanup()
+
+    def is_every_instance_exhausted(self) -> bool:
+        return self.end_ptr is not None and all(self._child_stop)
+
+    def get_length_by_instance(self, instance_id: int) -> int:
+        return len(self.main_datapipe)
+
+    def reset(self) -> None:
+        self._datapipe_iterator = None
+        self.buffer = deque()
+        self.child_pointers = [0] * self.num_instances
+        self.slowest_ptr = 0
+        self.leading_ptr = 0
+        self.end_ptr = None
+        self._child_stop = [True for _ in range(self.num_instances)]
+
+    def __getstate__(self):
+        state = (
+            self.main_datapipe,
+            self.num_instances,
+            self.buffer_size,
+            self.copy_fn,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        )
+        if IterDataPipe.getstate_hook is not None:
+            return IterDataPipe.getstate_hook(state)
+        return state
+
+    def __setstate__(self, state):
+        (
+            self.main_datapipe,
+            self.num_instances,
+            self.buffer_size,
+            self.copy_fn,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        ) = state
+        self._datapipe_iterator = None
+        self.buffer = deque()
+        self.child_pointers = [0] * self.num_instances
+        self.slowest_ptr = 0
+        self.leading_ptr = 0
+        self.end_ptr = None
+        self._child_stop = [True for _ in range(self.num_instances)]
+
+    def _cleanup(self):
+        while self.buffer:
+            d = self.buffer.popleft()
+            StreamWrapper.close_streams(d)
+
+    def __del__(self):
+        self._cleanup()
+
+
+class _ChildDataPipe(IterDataPipe):
+    r"""
+    Iterable Datapipe that is a child of a main DataPipe.
+
+    The instance of this class will pass its instance_id to get the next value from its main DataPipe.
+
+    Note:
+        ChildDataPipe, like all other IterDataPipe, follows the single iterator per IterDataPipe constraint.
+        Since ChildDataPipes share a common buffer, when an iterator is created for one of the ChildDataPipes,
+        the previous iterators  for all ChildDataPipes must be invalidated, with the exception when a ChildDataPipe
+        hasn't had an iterator created from it since the last invalidation. See the example below.
+
+    Example:
+        >>> # xdoctest: +REQUIRES(module:torchdata)
+        >>> # Singler Iterator per IteraDataPipe Invalidation
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> source_dp = IterableWrapper(range(10))
+        >>> cdp1, cdp2 = source_dp.fork(num_instances=2)
+        >>> it1, it2 = iter(cdp1), iter(cdp2)
+        >>> it3 = iter(cdp1)
+        >>> # The line above invalidates `it1` and `it2`, and resets `ForkerIterDataPipe`.
+        >>> it4 = iter(cdp2)
+        >>> # The line above doesn't invalidate `it3`, because an iterator for `cdp2` hasn't been created since
+        >>> # the last invalidation.
+
+    Args:
+        main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)'
+        instance_id: integer identifier of this instance
+    """
+
+    _is_child_datapipe: bool = True
+
+    def __init__(self, main_datapipe: IterDataPipe, instance_id: int):
+        assert isinstance(main_datapipe, _ContainerTemplate)
+
+        self.main_datapipe: IterDataPipe = main_datapipe
+        self.instance_id = instance_id
+
+    def __iter__(self):
+        # Note that the logic behind setting iterator ID and `reset` are handled within `hook_iterator`
+        # We want to separate the code for reset and yield, so that 'reset' executes before __next__ is called
+        return self.main_datapipe.get_next_element_by_instance(self.instance_id)
+
+    def __len__(self):
+        return self.main_datapipe.get_length_by_instance(self.instance_id)
+
+    # This method is called by `hook_iterator` in `_typing.py`.
+    def _set_main_datapipe_valid_iterator_id(self) -> int:
+        r"""
+        Update the valid iterator ID for both this DataPipe object and `main_datapipe`.
+
+        `main_datapipe.reset()` is called when the ID is incremented to a new generation.
+        """
+        # 1. First time any child iterator is created
+        if self.main_datapipe._valid_iterator_id is None:
+            self.main_datapipe._valid_iterator_id = 0  # type: ignore[attr-defined]
+        # 2. This instance was already in the same generation as `main_datapipe`,
+        #    we need to increment the ID further by 1
+        elif self.main_datapipe._valid_iterator_id == self._valid_iterator_id:  # type: ignore[has-type]
+            self.main_datapipe._valid_iterator_id += 1  # type: ignore[attr-defined]
+            # Whenever a new generation of iterator is created, the `main_datapipe` must reset
+            if not self.main_datapipe.is_every_instance_exhausted():
+                warnings.warn(
+                    "Some child DataPipes are not exhausted when __iter__ is called. We are resetting "
+                    "the buffer and each child DataPipe will read from the start again.",
+                    UserWarning,
+                )
+            self.main_datapipe.reset()
+        # 3. Otherwise, the iterator is behind the others, so it will just need to catch up by setting
+        #    the instance's iterator to match that of `main_datapipe`
+        self._valid_iterator_id = self.main_datapipe._valid_iterator_id
+        return self._valid_iterator_id
+
+    # This method is called by `hook_iterator` in `_typing.py`.
+    def _check_valid_iterator_id(self, iterator_id) -> bool:
+        r"""Check the valid iterator ID against that of DataPipe object and that of `main_datapipe`."""
+        return (
+            iterator_id == self._valid_iterator_id
+            and iterator_id == self.main_datapipe._valid_iterator_id
+        )
+
+
+@functional_datapipe("demux")
+class DemultiplexerIterDataPipe(IterDataPipe):
+    r"""
+    Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``).
+
+    A list of the child DataPipes is returned from this operation.
+
+    Args:
+        datapipe: Iterable DataPipe being filtered
+        num_instances: number of instances of the DataPipe to create
+        classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None``
+        drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None``
+        buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
+            DataPipes while waiting for their values to be yielded.
+            Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
+
+    Examples:
+        >>> # xdoctest: +REQUIRES(module:torchdata)
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> def odd_or_even(n):
+        ...     return n % 2
+        >>> source_dp = IterableWrapper(range(5))
+        >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
+        >>> list(dp1)
+        [0, 2, 4]
+        >>> list(dp2)
+        [1, 3]
+        >>> # It can also filter out any element that gets `None` from the `classifier_fn`
+        >>> def odd_or_even_no_zero(n):
+        ...     return n % 2 if n != 0 else None
+        >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
+        >>> list(dp1)
+        [2, 4]
+        >>> list(dp2)
+        [1, 3]
+    """
+
+    def __new__(
+        cls,
+        datapipe: IterDataPipe,
+        num_instances: int,
+        classifier_fn: Callable[[_T_co], Optional[int]],
+        drop_none: bool = False,
+        buffer_size: int = 1000,
+    ):
+        if num_instances < 1:
+            raise ValueError(
+                f"Expected `num_instances` larger than 0, but {num_instances} is found"
+            )
+
+        _check_unpickable_fn(classifier_fn)
+
+        # When num_instances == 1, demux can be replaced by filter,
+        # but keep it as Demultiplexer for the sake of consistency
+        # like throwing Error when classification result is out of o range
+        container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size)  # type: ignore[abstract]
+        return [_ChildDataPipe(container, i) for i in range(num_instances)]
+
+
+class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
+    r"""
+    Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe.
+
+    It tracks the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value
+    as requested by the child DataPipes.
+    """
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe[_T_co],
+        num_instances: int,
+        classifier_fn: Callable[[_T_co], Optional[int]],
+        drop_none: bool,
+        buffer_size: int,
+    ):
+        self.main_datapipe = datapipe
+        self._datapipe_iterator: Optional[Iterator[Any]] = None
+        self.num_instances = num_instances
+        self.buffer_size = buffer_size
+        if self.buffer_size < 0:
+            warnings.warn(
+                "Unlimited buffer size is set for `demux`, "
+                "please be aware of OOM at random places",
+                UserWarning,
+            )
+        self.current_buffer_usage = 0
+        self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)]
+        self.classifier_fn = classifier_fn
+        self.drop_none = drop_none
+        self.main_datapipe_exhausted = False
+        self._child_stop: list[bool] = [True for _ in range(num_instances)]
+
+    def _find_next(self, instance_id: int) -> _T_co:  # type: ignore[type-var]
+        while True:
+            if self.main_datapipe_exhausted or self._child_stop[instance_id]:
+                raise StopIteration
+            if self._datapipe_iterator is None:
+                raise ValueError(
+                    "_datapipe_iterator has not been set, likely because this private method is called directly "
+                    "without invoking get_next_element_by_instance() first."
+                )
+            value = next(self._datapipe_iterator)
+            classification = self.classifier_fn(value)
+            if classification is None and self.drop_none:
+                StreamWrapper.close_streams(value)
+                continue
+            if (
+                classification is None
+                or classification >= self.num_instances
+                or classification < 0
+            ):
+                raise ValueError(
+                    f"Output of the classification fn should be between 0 and {self.num_instances - 1}. "
+                    + f"{classification} is returned."
+                )
+            if classification == instance_id:
+                return value
+            self.child_buffers[classification].append(value)
+            self.current_buffer_usage += 1
+            if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size:
+                raise BufferError(
+                    f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient."
+                )
+
+    def get_next_element_by_instance(self, instance_id: int):
+        if self._datapipe_iterator is None and self._child_stop[instance_id]:
+            self._datapipe_iterator = iter(self.main_datapipe)
+            self._snapshot_state = (
+                _SnapshotState.Iterating
+            )  # This is necessary for the DataPipe to reset properly.
+            self.main_datapipe_exhausted = False
+            for i in range(self.num_instances):
+                self._child_stop[i] = False
+
+        try:
+            while not self._child_stop[instance_id]:
+                if self.child_buffers[instance_id]:
+                    self.current_buffer_usage -= 1
+                    yield self.child_buffers[instance_id].popleft()
+                else:
+                    try:
+                        yield self._find_next(instance_id)
+                    except StopIteration:
+                        self._child_stop[instance_id] = True
+                        self.main_datapipe_exhausted = True
+                        self._datapipe_iterator = None
+        finally:
+            self._child_stop[instance_id] = True
+            # Cleanup _datapipe_iterator for the case that demux exits earlier
+            if all(self._child_stop):
+                self._datapipe_iterator = None
+            if self.child_buffers[instance_id]:
+                self._cleanup(instance_id)
+
+    def is_every_instance_exhausted(self) -> bool:
+        return self.main_datapipe_exhausted and all(self._child_stop)
+
+    def get_length_by_instance(self, instance_id: int) -> int:
+        raise TypeError
+
+    def reset(self) -> None:
+        self._datapipe_iterator = None
+        self.current_buffer_usage = 0
+        self.child_buffers = [deque() for _ in range(self.num_instances)]
+        self._child_stop = [True for _ in range(self.num_instances)]
+        self.main_datapipe_exhausted = False
+
+    def __getstate__(self):
+        state = (
+            self.main_datapipe,
+            self.num_instances,
+            self.buffer_size,
+            self.classifier_fn,
+            self.drop_none,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        )
+        if IterDataPipe.getstate_hook is not None:
+            return IterDataPipe.getstate_hook(state)
+        return state
+
+    def __setstate__(self, state):
+        (
+            self.main_datapipe,
+            self.num_instances,
+            self.buffer_size,
+            self.classifier_fn,
+            self.drop_none,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        ) = state
+        self._datapipe_iterator = None
+        self.current_buffer_usage = 0
+        self.child_buffers = [deque() for _ in range(self.num_instances)]
+        self._child_stop = [True for _ in range(self.num_instances)]
+        self.main_datapipe_exhausted = False
+
+    def _cleanup(self, instance_id: Optional[int] = None):
+        ids = (
+            range(self.num_instances)
+            if instance_id is None
+            else [
+                instance_id,
+            ]
+        )
+        for i in ids:
+            q = self.child_buffers[i]
+            while q:
+                d = q.popleft()
+                StreamWrapper.close_streams(d)
+
+    def __del__(self):
+        self._cleanup()
+
+
+@functional_datapipe("mux")
+class MultiplexerIterDataPipe(IterDataPipe):
+    r"""
+    Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``).
+
+    As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
+    and so on. It ends when the shortest input DataPipe is exhausted.
+
+    Args:
+        datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted
+
+    Example:
+        >>> # xdoctest: +REQUIRES(module:torchdata)
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
+        >>> list(dp1.mux(dp2, dp3))
+        [0, 10, 20, 1, 11, 21, 2, 12, 22]
+    """
+
+    def __init__(self, *datapipes):
+        self.datapipes = datapipes
+        self.buffer: list = (
+            []
+        )  # Store values to be yielded only when every iterator provides one
+
+    def __iter__(self):
+        iterators = [iter(x) for x in self.datapipes]
+        while len(iterators):
+            for it in iterators:
+                try:
+                    value = next(it)
+                    self.buffer.append(value)
+                except StopIteration:
+                    self.buffer.clear()
+                    return
+            yield from self.buffer
+            self.buffer.clear()
+
+    def __len__(self):
+        if all(isinstance(dp, Sized) for dp in self.datapipes):
+            return min(len(dp) for dp in self.datapipes) * len(self.datapipes)
+        else:
+            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
+
+    def reset(self) -> None:
+        self.buffer = []
+
+    def __getstate__(self):
+        state = (
+            self.datapipes,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        )
+        if IterDataPipe.getstate_hook is not None:
+            return IterDataPipe.getstate_hook(state)
+        return state
+
+    def __setstate__(self, state):
+        (
+            self.datapipes,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        ) = state
+        self.buffer = []
+
+    def __del__(self):
+        self.buffer.clear()
+
+
+@functional_datapipe("zip")
+class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
+    r"""
+    Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
+
+    The output is stopped as soon as the shortest input DataPipe is exhausted.
+
+    Args:
+        *datapipes: Iterable DataPipes being aggregated
+
+    Example:
+        >>> # xdoctest: +REQUIRES(module:torchdata)
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
+        >>> list(dp1.zip(dp2, dp3))
+        [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
+    """
+
+    datapipes: tuple[IterDataPipe]
+
+    def __init__(self, *datapipes: IterDataPipe):
+        if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
+            raise TypeError(
+                "All inputs are required to be `IterDataPipe` for `ZipIterDataPipe`."
+            )
+        super().__init__()
+        self.datapipes = datapipes  # type: ignore[assignment]
+
+    def __iter__(self) -> Iterator[tuple[_T_co]]:
+        iterators = [iter(datapipe) for datapipe in self.datapipes]
+        yield from zip(*iterators)
+
+    def __len__(self) -> int:
+        if all(isinstance(dp, Sized) for dp in self.datapipes):
+            return min(len(dp) for dp in self.datapipes)
+        else:
+            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/filelister.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/filelister.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b3d16bed2a66736b9874d14a0696c1dee32b23f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/filelister.py
@@ -0,0 +1,68 @@
+from collections.abc import Iterator, Sequence
+from typing import Union
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+from torch.utils.data.datapipes.iter.utils import IterableWrapperIterDataPipe
+from torch.utils.data.datapipes.utils.common import get_file_pathnames_from_root
+
+
+__all__ = ["FileListerIterDataPipe"]
+
+
+@functional_datapipe("list_files")
+class FileListerIterDataPipe(IterDataPipe[str]):
+    r"""
+    Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory.
+
+    Multiple root directories can be provided (functional name: ``list_files``).
+
+    Args:
+        root: Root directory or a sequence of root directories
+        masks: Unix style filter string or string list for filtering file name(s)
+        recursive: Whether to return pathname from nested directories or not
+        abspath: Whether to return relative pathname or absolute pathname
+        non_deterministic: Whether to return pathname in sorted order or not.
+            If ``False``, the results yielded from each root directory will be sorted
+        length: Nominal length of the datapipe
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import FileLister
+        >>> dp = FileLister(root=".", recursive=True)
+        >>> list(dp)
+        ['example.py', './data/data.tar']
+    """
+
+    def __init__(
+        self,
+        root: Union[str, Sequence[str], IterDataPipe] = ".",
+        masks: Union[str, list[str]] = "",
+        *,
+        recursive: bool = False,
+        abspath: bool = False,
+        non_deterministic: bool = False,
+        length: int = -1,
+    ) -> None:
+        super().__init__()
+        if isinstance(root, str):
+            root = [root]
+        if not isinstance(root, IterDataPipe):
+            root = IterableWrapperIterDataPipe(root)
+        self.datapipe: IterDataPipe = root
+        self.masks: Union[str, list[str]] = masks
+        self.recursive: bool = recursive
+        self.abspath: bool = abspath
+        self.non_deterministic: bool = non_deterministic
+        self.length: int = length
+
+    def __iter__(self) -> Iterator[str]:
+        for path in self.datapipe:
+            yield from get_file_pathnames_from_root(
+                path, self.masks, self.recursive, self.abspath, self.non_deterministic
+            )
+
+    def __len__(self) -> int:
+        if self.length == -1:
+            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
+        return self.length
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/fileopener.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/fileopener.py
new file mode 100644
index 0000000000000000000000000000000000000000..2542c89773bdd956f32af325a297cf62a88cce6b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/fileopener.py
@@ -0,0 +1,77 @@
+# mypy: allow-untyped-defs
+from collections.abc import Iterable
+from io import IOBase
+from typing import Optional
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames
+
+
+__all__ = [
+    "FileOpenerIterDataPipe",
+]
+
+
+@functional_datapipe("open_files")
+class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]):
+    r"""
+    Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``).
+
+    Args:
+        datapipe: Iterable datapipe that provides pathnames
+        mode: An optional string that specifies the mode in which
+            the file is opened by ``open()``. It defaults to ``r``, other options are
+            ``b`` for reading in binary mode and ``t`` for text mode.
+        encoding: An optional string that specifies the encoding of the
+            underlying file. It defaults to ``None`` to match the default encoding of ``open``.
+        length: Nominal length of the datapipe
+
+    Note:
+        The opened file handles will be closed by Python's GC periodically. Users can choose
+        to close them explicitly.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
+        >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
+        >>> dp = FileOpener(dp)
+        >>> dp = StreamReader(dp)
+        >>> list(dp)
+        [('./abc.txt', 'abc')]
+    """
+
+    def __init__(
+        self,
+        datapipe: Iterable[str],
+        mode: str = "r",
+        encoding: Optional[str] = None,
+        length: int = -1,
+    ):
+        super().__init__()
+        self.datapipe: Iterable = datapipe
+        self.mode: str = mode
+        self.encoding: Optional[str] = encoding
+
+        if self.mode not in ("b", "t", "rb", "rt", "r"):
+            raise ValueError(f"Invalid mode {mode}")
+        # TODO: enforce typing for each instance based on mode, otherwise
+        #       `argument_validation` with this DataPipe may be potentially broken
+
+        if "b" in mode and encoding is not None:
+            raise ValueError("binary mode doesn't take an encoding argument")
+
+        self.length: int = length
+
+    # Remove annotation due to 'IOBase' is a general type and true type
+    # is determined at runtime based on mode. Some `DataPipe` requiring
+    # a subtype would cause mypy error.
+    def __iter__(self):
+        yield from get_file_binaries_from_pathnames(
+            self.datapipe, self.mode, self.encoding
+        )
+
+    def __len__(self):
+        if self.length == -1:
+            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
+        return self.length
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/grouping.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/grouping.py
new file mode 100644
index 0000000000000000000000000000000000000000..08d124fdc6087e277a274578527ad65d8142da0a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/grouping.py
@@ -0,0 +1,322 @@
+# mypy: allow-untyped-defs
+import warnings
+from collections import defaultdict
+from collections.abc import Iterator, Sized
+from typing import Any, Callable, Optional, TypeVar
+
+import torch.utils.data.datapipes.iter.sharding
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe
+from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
+
+
+__all__ = [
+    "BatcherIterDataPipe",
+    "GrouperIterDataPipe",
+    "UnBatcherIterDataPipe",
+]
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+def __getattr__(name: str):
+    if name in ["SHARDING_PRIORITIES", "ShardingFilterIterDataPipe"]:
+        warnings.warn(
+            f"`{name}` from `torch.utils.data.datapipes.iter.grouping` is going to be removed in PyTorch 2.1"
+            f"Please use `{name}` from the `torch.utils.data.datapipes.iter.sharding`",
+            category=FutureWarning,
+            stacklevel=2,
+        )
+
+        return getattr(torch.utils.data.datapipes.iter.sharding, name)
+
+    raise AttributeError(f"module {__name__} has no attribute {name}")
+
+
+@functional_datapipe("batch")
+class BatcherIterDataPipe(IterDataPipe[DataChunk]):
+    r"""
+    Creates mini-batches of data (functional name: ``batch``).
+
+    An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the
+    last batch if ``drop_last`` is set to ``False``.
+
+    Args:
+        datapipe: Iterable DataPipe being batched
+        batch_size: The size of each batch
+        drop_last: Option to drop the last batch if it's not full
+        wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding,
+            defaults to ``DataChunk``
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> dp = IterableWrapper(range(10))
+        >>> dp = dp.batch(batch_size=3, drop_last=True)
+        >>> list(dp)
+        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
+    """
+
+    datapipe: IterDataPipe
+    batch_size: int
+    drop_last: bool
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe,
+        batch_size: int,
+        drop_last: bool = False,
+        wrapper_class: type[DataChunk] = DataChunk,
+    ) -> None:
+        assert batch_size > 0, "Batch size is required to be larger than 0!"
+        super().__init__()
+        self.datapipe = datapipe
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.wrapper_class = wrapper_class
+
+    def __iter__(self) -> Iterator[DataChunk]:
+        batch: list = []
+        for x in self.datapipe:
+            batch.append(x)
+            if len(batch) == self.batch_size:
+                yield self.wrapper_class(batch)
+                batch = []
+        if len(batch) > 0:
+            if not self.drop_last:
+                yield self.wrapper_class(batch)
+
+    def __len__(self) -> int:
+        if isinstance(self.datapipe, Sized):
+            if self.drop_last:
+                return len(self.datapipe) // self.batch_size
+            else:
+                return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
+        else:
+            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
+
+
+@functional_datapipe("unbatch")
+class UnBatcherIterDataPipe(IterDataPipe):
+    r"""
+    Undos batching of data (functional name: ``unbatch``).
+
+    In other words, it flattens the data up to the specified level within a batched DataPipe.
+
+    Args:
+        datapipe: Iterable DataPipe being un-batched
+        unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``,
+            it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]])
+        >>> dp1 = source_dp.unbatch()
+        >>> list(dp1)
+        [[0, 1], [2], [3, 4], [5], [6]]
+        >>> dp2 = source_dp.unbatch(unbatch_level=2)
+        >>> list(dp2)
+        [0, 1, 2, 3, 4, 5, 6]
+    """
+
+    def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1):
+        self.datapipe = datapipe
+        self.unbatch_level = unbatch_level
+
+    def __iter__(self):
+        for element in self.datapipe:
+            yield from self._dive(element, unbatch_level=self.unbatch_level)
+
+    def _dive(self, element, unbatch_level):
+        if unbatch_level < -1:
+            raise ValueError("unbatch_level must be -1 or >= 0")
+        if unbatch_level == -1:
+            if isinstance(element, (list, DataChunk)):
+                for item in element:
+                    yield from self._dive(item, unbatch_level=-1)
+            else:
+                yield element
+        elif unbatch_level == 0:
+            yield element
+        else:
+            if isinstance(element, (list, DataChunk)):
+                for item in element:
+                    yield from self._dive(item, unbatch_level=unbatch_level - 1)
+            else:
+                raise IndexError(
+                    f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe"
+                )
+
+
+@functional_datapipe("groupby")
+class GrouperIterDataPipe(IterDataPipe[DataChunk]):
+    r"""
+    Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``.
+
+    (functional name: ``groupby``).
+
+    The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group
+    will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full,
+    the DataPipe will yield the largest batch with the same key, provided that its size is larger
+    than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``.
+
+    After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity
+    will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``.
+
+    Args:
+        datapipe: Iterable datapipe to be grouped
+        group_key_fn: Function used to generate group key from the data of the source datapipe
+        keep_key: Option to yield the matching key along with the items in a tuple,
+            resulting in `(key, [items])` otherwise returning [items]
+        buffer_size: The size of buffer for ungrouped data
+        group_size: The max size of each group, a batch is yielded as soon as it reaches this size
+        guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full
+        drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer
+            when the buffer is full
+
+    Example:
+        >>> import os
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> def group_fn(file):
+        ...     return os.path.basename(file).split(".")[0]
+        >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
+        >>> dp0 = source_dp.groupby(group_key_fn=group_fn)
+        >>> list(dp0)
+        [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
+        >>> # A group is yielded as soon as its size equals to `group_size`
+        >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
+        >>> list(dp1)
+        [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
+        >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
+        >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
+        >>> list(dp2)
+        [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
+    """
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe[_T_co],
+        group_key_fn: Callable[[_T_co], Any],
+        *,
+        keep_key: bool = False,
+        buffer_size: int = 10000,
+        group_size: Optional[int] = None,
+        guaranteed_group_size: Optional[int] = None,
+        drop_remaining: bool = False,
+    ):
+        _check_unpickable_fn(group_key_fn)
+        self.datapipe = datapipe
+        self.group_key_fn = group_key_fn
+
+        self.keep_key = keep_key
+        self.max_buffer_size = buffer_size
+        self.buffer_elements: defaultdict[Any, list] = defaultdict(list)
+        self.curr_buffer_size = 0
+        self.group_size = group_size
+        self.guaranteed_group_size = None
+        if group_size is not None and buffer_size is not None:
+            assert 0 < group_size <= buffer_size
+            self.guaranteed_group_size = group_size
+        if guaranteed_group_size is not None:
+            assert group_size is not None and 0 < guaranteed_group_size <= group_size
+            self.guaranteed_group_size = guaranteed_group_size
+        self.drop_remaining = drop_remaining
+        self.wrapper_class = DataChunk
+
+    def _remove_biggest_key(self):
+        biggest_key = None
+        biggest_size = 0
+        result_to_yield = None
+        for findkey in self.buffer_elements.keys():
+            if len(self.buffer_elements[findkey]) > biggest_size:
+                biggest_size = len(self.buffer_elements[findkey])
+                biggest_key = findkey
+
+        if (
+            self.guaranteed_group_size is not None
+            and biggest_size < self.guaranteed_group_size
+            and not self.drop_remaining
+        ):
+            raise RuntimeError(
+                "Failed to group items", str(self.buffer_elements[biggest_key])
+            )
+
+        if (
+            self.guaranteed_group_size is None
+            or biggest_size >= self.guaranteed_group_size
+        ):
+            result_to_yield = self.buffer_elements[biggest_key]
+
+        self.curr_buffer_size -= biggest_size
+        del self.buffer_elements[biggest_key]
+
+        return result_to_yield
+
+    def __iter__(self):
+        for x in self.datapipe:
+            key = self.group_key_fn(x)
+
+            self.buffer_elements[key].append(x)
+            self.curr_buffer_size += 1
+
+            if self.group_size is not None and self.group_size == len(
+                self.buffer_elements[key]
+            ):
+                result: DataChunk[Any] = self.wrapper_class(self.buffer_elements[key])
+                yield (key, result) if self.keep_key else result
+                self.curr_buffer_size -= len(self.buffer_elements[key])
+                del self.buffer_elements[key]
+
+            if self.curr_buffer_size == self.max_buffer_size:
+                result_to_yield = self._remove_biggest_key()
+                if result_to_yield is not None:
+                    result = self.wrapper_class(result_to_yield)
+                    yield (key, result) if self.keep_key else result
+
+        for key in tuple(self.buffer_elements.keys()):
+            result = self.wrapper_class(self.buffer_elements.pop(key))
+            self.curr_buffer_size -= len(result)
+            yield (key, result) if self.keep_key else result
+
+    def reset(self) -> None:
+        self.curr_buffer_size = 0
+        self.buffer_elements = defaultdict(list)
+
+    def __getstate__(self):
+        state = (
+            self.datapipe,
+            self.group_key_fn,
+            self.keep_key,
+            self.max_buffer_size,
+            self.group_size,
+            self.guaranteed_group_size,
+            self.drop_remaining,
+            self.wrapper_class,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        )
+        if IterDataPipe.getstate_hook is not None:
+            return IterDataPipe.getstate_hook(state)
+        return state
+
+    def __setstate__(self, state):
+        (
+            self.datapipe,
+            self.group_key_fn,
+            self.keep_key,
+            self.max_buffer_size,
+            self.group_size,
+            self.guaranteed_group_size,
+            self.drop_remaining,
+            self.wrapper_class,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        ) = state
+        self.curr_buffer_size = 0
+        self.buffer_elements = defaultdict(list)
+
+    def __del__(self):
+        self.buffer_elements.clear()
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..611b4870a493a2bde5ba4233a8e5d78a1c21ba79
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py
@@ -0,0 +1,70 @@
+from collections.abc import Iterable, Iterator, Sized
+from io import BufferedIOBase
+from typing import Any, Callable
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+from torch.utils.data.datapipes.utils.common import _deprecation_warning
+from torch.utils.data.datapipes.utils.decoder import (
+    basichandlers as decoder_basichandlers,
+    Decoder,
+    extension_extract_fn,
+    imagehandler as decoder_imagehandler,
+)
+
+
+__all__ = ["RoutedDecoderIterDataPipe"]
+
+
+@functional_datapipe("routed_decode")
+class RoutedDecoderIterDataPipe(IterDataPipe[tuple[str, Any]]):
+    r"""
+    Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple.
+
+    (functional name: ``routed_decode``)
+
+    Args:
+        datapipe: Iterable datapipe that provides pathname and binary stream in tuples
+        handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
+            handlers will be set as default. If multiple handles are provided, the priority
+            order follows the order of handlers (the first handler has the top priority)
+        key_fn: Function for decoder to extract key from pathname to dispatch handlers.
+            Default is set to extract file extension from pathname
+
+    Note:
+        When ``key_fn`` is specified returning anything other than extension, the default
+        handler will not work and users need to specify custom handler. Custom handler
+        could use regex to determine the eligibility to handle data.
+    """
+
+    def __init__(
+        self,
+        datapipe: Iterable[tuple[str, BufferedIOBase]],
+        *handlers: Callable,
+        key_fn: Callable = extension_extract_fn,
+    ) -> None:
+        super().__init__()
+        self.datapipe: Iterable[tuple[str, BufferedIOBase]] = datapipe
+        if not handlers:
+            handlers = (decoder_basichandlers, decoder_imagehandler("torch"))
+        self.decoder = Decoder(*handlers, key_fn=key_fn)
+        _deprecation_warning(
+            type(self).__name__,
+            deprecation_version="1.12",
+            removal_version="1.13",
+            old_functional_name="routed_decode",
+        )
+
+    def add_handler(self, *handler: Callable) -> None:
+        self.decoder.add_handler(*handler)
+
+    def __iter__(self) -> Iterator[tuple[str, Any]]:
+        for data in self.datapipe:
+            pathname = data[0]
+            result = self.decoder(data)
+            yield (pathname, result[pathname])
+
+    def __len__(self) -> int:
+        if isinstance(self.datapipe, Sized):
+            return len(self.datapipe)
+        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/selecting.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/selecting.py
new file mode 100644
index 0000000000000000000000000000000000000000..97dcb2d6c49105205ab00159538b3a0c4f1ed01b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/selecting.py
@@ -0,0 +1,102 @@
+# mypy: allow-untyped-defs
+from collections.abc import Iterator
+from typing import Callable, TypeVar
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+from torch.utils.data.datapipes.utils.common import (
+    _check_unpickable_fn,
+    StreamWrapper,
+    validate_input_col,
+)
+
+
+__all__ = ["FilterIterDataPipe"]
+
+
+_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+@functional_datapipe("filter")
+class FilterIterDataPipe(IterDataPipe[_T_co]):
+    r"""
+    Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).
+
+    Args:
+        datapipe: Iterable DataPipe being filtered
+        filter_fn: Customized function mapping an element to a boolean.
+        input_col: Index or indices of data which ``filter_fn`` is applied, such as:
+
+            - ``None`` as default to apply ``filter_fn`` to the data directly.
+            - Integer(s) is used for list/tuple.
+            - Key(s) is used for dict.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> def is_even(n):
+        ...     return n % 2 == 0
+        >>> dp = IterableWrapper(range(5))
+        >>> filter_dp = dp.filter(filter_fn=is_even)
+        >>> list(filter_dp)
+        [0, 2, 4]
+    """
+
+    datapipe: IterDataPipe[_T_co]
+    filter_fn: Callable
+
+    def __init__(
+        self,
+        datapipe: IterDataPipe[_T_co],
+        filter_fn: Callable,
+        input_col=None,
+    ) -> None:
+        super().__init__()
+        self.datapipe = datapipe
+
+        _check_unpickable_fn(filter_fn)
+        self.filter_fn = filter_fn  # type: ignore[assignment]
+
+        self.input_col = input_col
+        validate_input_col(filter_fn, input_col)
+
+    def _apply_filter_fn(self, data) -> bool:
+        if self.input_col is None:
+            return self.filter_fn(data)
+        elif isinstance(self.input_col, (list, tuple)):
+            args = tuple(data[col] for col in self.input_col)
+            return self.filter_fn(*args)
+        else:
+            return self.filter_fn(data[self.input_col])
+
+    def __iter__(self) -> Iterator[_T_co]:
+        for data in self.datapipe:
+            condition, filtered = self._returnIfTrue(data)
+            if condition:
+                yield filtered
+            else:
+                StreamWrapper.close_streams(data)
+
+    def _returnIfTrue(self, data: _T) -> tuple[bool, _T]:
+        condition = self._apply_filter_fn(data)
+
+        if df_wrapper.is_column(condition):
+            # We are operating on DataFrames filter here
+            result = []
+            for idx, mask in enumerate(df_wrapper.iterate(condition)):
+                if mask:
+                    result.append(df_wrapper.get_item(data, idx))
+            if len(result):
+                return True, df_wrapper.concat(result)
+            else:
+                return False, None  # type: ignore[return-value]
+
+        if not isinstance(condition, bool):
+            raise ValueError(
+                "Boolean output is required for `filter_fn` of FilterIterDataPipe, got",
+                type(condition),
+            )
+
+        return condition, data
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/sharding.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/sharding.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e381c87a4a5826cd0e26a8c3054e4bb5c0b5798
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/sharding.py
@@ -0,0 +1,101 @@
+# mypy: allow-untyped-defs
+from collections.abc import Sized
+from enum import IntEnum
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+
+
+__all__ = [
+    "SHARDING_PRIORITIES",
+    "ShardingFilterIterDataPipe",
+]
+
+
+class SHARDING_PRIORITIES(IntEnum):
+    DEFAULT = 1
+    DISTRIBUTED = 2
+    MULTIPROCESSING = 3
+
+
+class _ShardingIterDataPipe(IterDataPipe):
+    def apply_sharding(
+        self,
+        num_of_instances: int,
+        instance_id: int,
+        sharding_group: SHARDING_PRIORITIES,
+    ):
+        raise NotImplementedError
+
+
+@functional_datapipe("sharding_filter")
+class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
+    r"""
+    Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``).
+
+    After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
+    original DataPipe, where `n` equals to the number of instances.
+
+    Args:
+        source_datapipe: Iterable DataPipe that will be sharded
+    """
+
+    def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None):
+        self.source_datapipe = source_datapipe
+        self.sharding_group_filter = sharding_group_filter
+        self.groups: dict[int, tuple[int, int]] = {}
+        self.num_of_instances = 1
+        self.instance_id = 0
+        self._update_num_of_instances()
+
+    def apply_sharding(
+        self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT
+    ):
+        if instance_id >= num_of_instances:
+            raise ValueError(
+                f"instance_id({instance_id}) should be smaller than num_of_instances({num_of_instances})"
+            )
+        if sharding_group == SHARDING_PRIORITIES.DEFAULT:
+            if len(self.groups) and SHARDING_PRIORITIES.DEFAULT not in self.groups:
+                raise RuntimeError(
+                    "ShardingFilter cannot mix DEFAULT and non DEFAULT groups"
+                )
+        else:
+            if SHARDING_PRIORITIES.DEFAULT in self.groups:
+                raise RuntimeError(
+                    "ShardingFilter cannot mix DEFAULT and non DEFAULT groups"
+                )
+        self.groups[sharding_group] = (num_of_instances, instance_id)
+        self._update_num_of_instances()
+
+    def _update_num_of_instances(self):
+        sorted_sharding_groups = [
+            self.groups[key]
+            for key in sorted(self.groups.keys())
+            if self.sharding_group_filter is None or key == self.sharding_group_filter
+        ]
+
+        sorted_sharding_groups.reverse()
+
+        self.num_of_instances = 1
+        self.instance_id = 0
+
+        for group_num_of_instances, group_instance_id in sorted_sharding_groups:
+            self.instance_id += self.num_of_instances * group_instance_id
+            self.num_of_instances *= group_num_of_instances
+
+    def __iter__(self):
+        for i, item in enumerate(self.source_datapipe):
+            if i % self.num_of_instances == self.instance_id:
+                yield item
+
+    def __len__(self):
+        if isinstance(self.source_datapipe, Sized):
+            return len(self.source_datapipe) // self.num_of_instances + (
+                1
+                if (
+                    self.instance_id < len(self.source_datapipe) % self.num_of_instances
+                )
+                else 0
+            )
+        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/streamreader.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/streamreader.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c3af4f12a81f87eee10738ca63c01ead8b06c72
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/streamreader.py
@@ -0,0 +1,46 @@
+from collections.abc import Iterator
+from io import IOBase
+from typing import Optional
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+
+
+__all__ = ["StreamReaderIterDataPipe"]
+
+
+@functional_datapipe("read_from_stream")
+class StreamReaderIterDataPipe(IterDataPipe[tuple[str, bytes]]):
+    r"""
+    Given IO streams and their label names, yield bytes with label name as tuple.
+
+    (functional name: ``read_from_stream``).
+
+    Args:
+        datapipe: Iterable DataPipe provides label/URL and byte stream
+        chunk: Number of bytes to be read from stream per iteration.
+            If ``None``, all bytes will be read until the EOF.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
+        >>> from io import StringIO
+        >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
+        >>> list(StreamReader(dp, chunk=1))
+        [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
+    """
+
+    def __init__(
+        self, datapipe: IterDataPipe[tuple[str, IOBase]], chunk: Optional[int] = None
+    ):
+        self.datapipe = datapipe
+        self.chunk = chunk
+
+    def __iter__(self) -> Iterator[tuple[str, bytes]]:
+        for furl, stream in self.datapipe:
+            while True:
+                d = stream.read(self.chunk)
+                if not d:
+                    stream.close()
+                    break
+                yield (furl, d)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/utils.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0728746af23ddd7e628dbcdd0218af44aceb1ec9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/iter/utils.py
@@ -0,0 +1,54 @@
+# mypy: allow-untyped-defs
+import copy
+import warnings
+
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+
+
+__all__ = ["IterableWrapperIterDataPipe"]
+
+
+class IterableWrapperIterDataPipe(IterDataPipe):
+    r"""
+    Wraps an iterable object to create an IterDataPipe.
+
+    Args:
+        iterable: Iterable object to be wrapped into an IterDataPipe
+        deepcopy: Option to deepcopy input iterable object for each
+            iterator. The copy is made when the first element is read in ``iter()``.
+
+    .. note::
+        If ``deepcopy`` is explicitly set to ``False``, users should ensure
+        that the data pipeline doesn't contain any in-place operations over
+        the iterable instance to prevent data inconsistency across iterations.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.iter import IterableWrapper
+        >>> dp = IterableWrapper(range(10))
+        >>> list(dp)
+        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+    """
+
+    def __init__(self, iterable, deepcopy=True):
+        self.iterable = iterable
+        self.deepcopy = deepcopy
+
+    def __iter__(self):
+        source_data = self.iterable
+        if self.deepcopy:
+            try:
+                source_data = copy.deepcopy(self.iterable)
+            # For the case that data cannot be deep-copied,
+            # all in-place operations will affect iterable variable.
+            # When this DataPipe is iterated second time, it will
+            # yield modified items.
+            except TypeError:
+                warnings.warn(
+                    "The input iterable can not be deepcopied, "
+                    "please be aware of in-place modification would affect source data."
+                )
+        yield from source_data
+
+    def __len__(self):
+        return len(self.iterable)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fa8932dd6fcafa2d807c591755852e74d7fdc51
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__init__.py
@@ -0,0 +1,19 @@
+# Functional DataPipe
+from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper
+from torch.utils.data.datapipes.map.combinatorics import (
+    ShufflerIterDataPipe as Shuffler,
+)
+from torch.utils.data.datapipes.map.combining import (
+    ConcaterMapDataPipe as Concater,
+    ZipperMapDataPipe as Zipper,
+)
+from torch.utils.data.datapipes.map.grouping import BatcherMapDataPipe as Batcher
+from torch.utils.data.datapipes.map.utils import (
+    SequenceWrapperMapDataPipe as SequenceWrapper,
+)
+
+
+__all__ = ["Batcher", "Concater", "Mapper", "SequenceWrapper", "Shuffler", "Zipper"]
+
+# Please keep this list sorted
+assert __all__ == sorted(__all__)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35414bc369edcb8a96e02dc034ae62201cbd313d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0279ef8aeb67553a10e0842bb66866ce8f33e9d
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f61cb973ed22e87b9ae9a0b09435eca148b1a1ee
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d56f7a2fabdd51f47fd0acd51b97bb8c82b39e5
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0061dd86a1c6388f5fd8d318cef1d42b33982625
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..222655f19b283b0ae5b5ea36d4a2336d0427aa0f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/callable.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/callable.py
new file mode 100644
index 0000000000000000000000000000000000000000..cee08b7a8c8d1578befd1fb5c75647293a9c3db6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/callable.py
@@ -0,0 +1,65 @@
+# mypy: allow-untyped-defs
+from typing import Callable, TypeVar
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import MapDataPipe
+from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
+
+
+__all__ = ["MapperMapDataPipe", "default_fn"]
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+# Default function to return each item directly
+# In order to keep datapipe picklable, eliminates the usage
+# of python lambda function
+def default_fn(data):
+    return data
+
+
+@functional_datapipe("map")
+class MapperMapDataPipe(MapDataPipe[_T_co]):
+    r"""
+    Apply the input function over each item from the source DataPipe (functional name: ``map``).
+
+    The function can be any regular Python function or partial object. Lambda
+    function is not recommended as it is not supported by pickle.
+
+    Args:
+        datapipe: Source MapDataPipe
+        fn: Function being applied to each item
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
+        >>> def add_one(x):
+        ...     return x + 1
+        >>> dp = SequenceWrapper(range(10))
+        >>> map_dp_1 = dp.map(add_one)
+        >>> list(map_dp_1)
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
+        >>> list(map_dp_2)
+        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+    """
+
+    datapipe: MapDataPipe
+    fn: Callable
+
+    def __init__(
+        self,
+        datapipe: MapDataPipe,
+        fn: Callable = default_fn,
+    ) -> None:
+        super().__init__()
+        self.datapipe = datapipe
+        _check_unpickable_fn(fn)
+        self.fn = fn  # type: ignore[assignment]
+
+    def __len__(self) -> int:
+        return len(self.datapipe)
+
+    def __getitem__(self, index) -> _T_co:
+        return self.fn(self.datapipe[index])
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combinatorics.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combinatorics.py
new file mode 100644
index 0000000000000000000000000000000000000000..619d0e5c7a0e8a13f40f16f3a219e56ce7adfd31
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combinatorics.py
@@ -0,0 +1,130 @@
+# mypy: allow-untyped-defs
+import random
+from collections.abc import Iterator
+from typing import Optional, TypeVar
+
+import torch
+from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
+
+
+__all__ = ["ShufflerIterDataPipe"]
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+# @functional_datapipe('shuffle')
+class ShufflerIterDataPipe(IterDataPipe[_T_co]):
+    r"""
+    Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
+
+    When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
+    set up random seed are different based on :attr:`num_workers`.
+
+    For single-process mode (:attr:`num_workers == 0`), the random seed is set before
+    the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
+    mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
+    for each worker process.
+
+    Args:
+        datapipe: MapDataPipe being shuffled
+        indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.map import SequenceWrapper
+        >>> dp = SequenceWrapper(range(10))
+        >>> shuffle_dp = dp.shuffle().set_seed(0)
+        >>> list(shuffle_dp)
+        [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
+        >>> list(shuffle_dp)
+        [6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
+        >>> # Reset seed for Shuffler
+        >>> shuffle_dp = shuffle_dp.set_seed(0)
+        >>> list(shuffle_dp)
+        [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
+
+    Note:
+        Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
+        ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
+        the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
+        of data during data-processing.
+    """
+
+    datapipe: MapDataPipe[_T_co]
+    _enabled: bool
+    _seed: Optional[int]
+    _rng: random.Random
+
+    def __init__(
+        self,
+        datapipe: MapDataPipe[_T_co],
+        *,
+        indices: Optional[list] = None,
+    ) -> None:
+        super().__init__()
+        self.datapipe = datapipe
+        self.indices = list(range(len(datapipe))) if indices is None else indices
+        self._enabled = True
+        self._seed = None
+        self._rng = random.Random()
+        self._shuffled_indices: list = self.indices
+
+    def set_shuffle(self, shuffle=True):
+        self._enabled = shuffle
+        return self
+
+    def set_seed(self, seed: int):
+        self._seed = seed
+        return self
+
+    def __iter__(self) -> Iterator[_T_co]:
+        if not self._enabled:
+            for idx in self.indices:
+                yield self.datapipe[idx]
+        else:
+            while self._shuffled_indices:
+                idx = self._shuffled_indices.pop()
+                yield self.datapipe[idx]
+
+    def reset(self) -> None:
+        if self._enabled and self._seed is None:
+            self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
+        self._rng.seed(self._seed)
+        self._seed = None
+        self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
+
+    def __len__(self) -> int:
+        return len(self.datapipe)
+
+    def __getstate__(self):
+        state = (
+            self.datapipe,
+            self.indices,
+            self._enabled,
+            self._seed,
+            self._rng.getstate(),
+            self._shuffled_indices,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        )
+        if IterDataPipe.getstate_hook is not None:
+            return IterDataPipe.getstate_hook(state)
+        return state
+
+    def __setstate__(self, state):
+        (
+            self.datapipe,
+            self.indices,
+            self._enabled,
+            self._seed,
+            rng_state,
+            self._shuffled_indices,
+            self._valid_iterator_id,
+            self._number_of_samples_yielded,
+        ) = state
+        self._rng = random.Random()
+        self._rng.setstate(rng_state)
+
+
+MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combining.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combining.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f9ef142a7c2822d24145cc93fbb75f7ef5a71d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/combining.py
@@ -0,0 +1,105 @@
+# mypy: allow-untyped-defs
+from collections.abc import Sized
+from typing import TypeVar
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import MapDataPipe
+
+
+__all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"]
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+@functional_datapipe("concat")
+class ConcaterMapDataPipe(MapDataPipe):
+    r"""
+    Concatenate multiple Map DataPipes (functional name: ``concat``).
+
+    The new index of is the cumulative sum of source DataPipes.
+    For example, if there are 2 source DataPipes both with length 5,
+    index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
+    elements of the first DataPipe, and 5 to 9 would refer to elements
+    of the second DataPipe.
+
+    Args:
+        datapipes: Map DataPipes being concatenated
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.map import SequenceWrapper
+        >>> dp1 = SequenceWrapper(range(3))
+        >>> dp2 = SequenceWrapper(range(3))
+        >>> concat_dp = dp1.concat(dp2)
+        >>> list(concat_dp)
+        [0, 1, 2, 0, 1, 2]
+    """
+
+    datapipes: tuple[MapDataPipe]
+
+    def __init__(self, *datapipes: MapDataPipe):
+        if len(datapipes) == 0:
+            raise ValueError("Expected at least one DataPipe, but got nothing")
+        if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
+            raise TypeError("Expected all inputs to be `MapDataPipe`")
+        if not all(isinstance(dp, Sized) for dp in datapipes):
+            raise TypeError("Expected all inputs to be `Sized`")
+        self.datapipes = datapipes  # type: ignore[assignment]
+
+    def __getitem__(self, index) -> _T_co:  # type: ignore[type-var]
+        offset = 0
+        for dp in self.datapipes:
+            if index - offset < len(dp):
+                return dp[index - offset]
+            else:
+                offset += len(dp)
+        raise IndexError(f"Index {index} is out of range.")
+
+    def __len__(self) -> int:
+        return sum(len(dp) for dp in self.datapipes)
+
+
+@functional_datapipe("zip")
+class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]):
+    r"""
+    Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
+
+    This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
+
+    Args:
+        *datapipes: Map DataPipes being aggregated
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.map import SequenceWrapper
+        >>> dp1 = SequenceWrapper(range(3))
+        >>> dp2 = SequenceWrapper(range(10, 13))
+        >>> zip_dp = dp1.zip(dp2)
+        >>> list(zip_dp)
+        [(0, 10), (1, 11), (2, 12)]
+    """
+
+    datapipes: tuple[MapDataPipe[_T_co], ...]
+
+    def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None:
+        if len(datapipes) == 0:
+            raise ValueError("Expected at least one DataPipe, but got nothing")
+        if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
+            raise TypeError("Expected all inputs to be `MapDataPipe`")
+        if not all(isinstance(dp, Sized) for dp in datapipes):
+            raise TypeError("Expected all inputs to be `Sized`")
+        self.datapipes = datapipes
+
+    def __getitem__(self, index) -> tuple[_T_co, ...]:
+        res = []
+        for dp in self.datapipes:
+            try:
+                res.append(dp[index])
+            except IndexError as e:
+                raise IndexError(
+                    f"Index {index} is out of range for one of the input MapDataPipes {dp}."
+                ) from e
+        return tuple(res)
+
+    def __len__(self) -> int:
+        return min(len(dp) for dp in self.datapipes)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/grouping.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/grouping.py
new file mode 100644
index 0000000000000000000000000000000000000000..e77f96730e5adb521466c782ce49deba05316c59
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/grouping.py
@@ -0,0 +1,74 @@
+# mypy: allow-untyped-defs
+from collections.abc import Sized
+from typing import TypeVar
+
+from torch.utils.data.datapipes._decorator import functional_datapipe
+from torch.utils.data.datapipes.datapipe import DataChunk, MapDataPipe
+
+
+__all__ = ["BatcherMapDataPipe"]
+
+
+_T = TypeVar("_T")
+
+
+@functional_datapipe("batch")
+class BatcherMapDataPipe(MapDataPipe[DataChunk]):
+    r"""
+    Create mini-batches of data (functional name: ``batch``).
+
+    An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
+    or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
+
+    Args:
+        datapipe: Iterable DataPipe being batched
+        batch_size: The size of each batch
+        drop_last: Option to drop the last batch if it's not full
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.map import SequenceWrapper
+        >>> dp = SequenceWrapper(range(10))
+        >>> batch_dp = dp.batch(batch_size=2)
+        >>> list(batch_dp)
+        [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
+    """
+
+    datapipe: MapDataPipe
+    batch_size: int
+    drop_last: bool
+
+    def __init__(
+        self,
+        datapipe: MapDataPipe[_T],
+        batch_size: int,
+        drop_last: bool = False,
+        wrapper_class: type[DataChunk] = DataChunk,
+    ) -> None:
+        assert batch_size > 0, "Batch size is required to be larger than 0!"
+        super().__init__()
+        self.datapipe = datapipe
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.wrapper_class = wrapper_class
+
+    def __getitem__(self, index) -> DataChunk:
+        batch: list = []
+        indices = range(index * self.batch_size, (index + 1) * self.batch_size)
+        try:
+            batch.extend(self.datapipe[i] for i in indices)
+            return self.wrapper_class(batch)
+        except IndexError as e:
+            if not self.drop_last and len(batch) > 0:
+                return self.wrapper_class(batch)
+            else:
+                raise IndexError(f"Index {index} is out of bound.") from e
+
+    def __len__(self) -> int:
+        if isinstance(self.datapipe, Sized):
+            if self.drop_last:
+                return len(self.datapipe) // self.batch_size
+            else:
+                return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
+        else:
+            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/utils.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e8b8ff56fe440f341f1b1d402f5aaf5446c445
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/map/utils.py
@@ -0,0 +1,53 @@
+# mypy: allow-untyped-defs
+import copy
+import warnings
+
+from torch.utils.data.datapipes.datapipe import MapDataPipe
+
+
+__all__ = ["SequenceWrapperMapDataPipe"]
+
+
+class SequenceWrapperMapDataPipe(MapDataPipe):
+    r"""
+    Wraps a sequence object into a MapDataPipe.
+
+    Args:
+        sequence: Sequence object to be wrapped into an MapDataPipe
+        deepcopy: Option to deepcopy input sequence object
+
+    .. note::
+      If ``deepcopy`` is set to False explicitly, users should ensure
+      that data pipeline doesn't contain any in-place operations over
+      the iterable instance, in order to prevent data inconsistency
+      across iterations.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torchdata.datapipes.map import SequenceWrapper
+        >>> dp = SequenceWrapper(range(10))
+        >>> list(dp)
+        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+        >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
+        >>> dp['a']
+        100
+    """
+
+    def __init__(self, sequence, deepcopy=True):
+        if deepcopy:
+            try:
+                self.sequence = copy.deepcopy(sequence)
+            except TypeError:
+                warnings.warn(
+                    "The input sequence can not be deepcopied, "
+                    "please be aware of in-place modification would affect source data"
+                )
+                self.sequence = sequence
+        else:
+            self.sequence = sequence
+
+    def __getitem__(self, index):
+        return self.sequence[index]
+
+    def __len__(self):
+        return len(self.sequence)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06ac54b0eb7da2e5843cd17c5cbecaa6177217d8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dcd639d83b99e1ff396b5fca947a3e692b8876ad
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32c62ac3234018c1abb9d7cfc15f7afc7e1d1c23
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/common.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddf3eecdd949fac30eb9f2cecb4a22f39ce005df
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/common.py
@@ -0,0 +1,412 @@
+# mypy: allow-untyped-defs
+import fnmatch
+import functools
+import inspect
+import os
+import warnings
+from collections.abc import Iterable
+from io import IOBase
+from typing import Any, Callable, Optional, Union
+
+from torch.utils._import_utils import dill_available
+
+
+__all__ = [
+    "validate_input_col",
+    "StreamWrapper",
+    "get_file_binaries_from_pathnames",
+    "get_file_pathnames_from_root",
+    "match_masks",
+    "validate_pathname_binary_tuple",
+]
+
+
+# BC for torchdata
+DILL_AVAILABLE = dill_available()
+
+
+def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]):
+    """
+    Check that function used in a callable datapipe works with the input column.
+
+    This simply ensures that the number of positional arguments matches the size
+    of the input column. The function must not contain any non-default
+    keyword-only arguments.
+
+    Examples:
+        >>> # xdoctest: +SKIP("Failing on some CI machines")
+        >>> def f(a, b, *, c=1):
+        >>>     return a + b + c
+        >>> def f_def(a, b=1, *, c=1):
+        >>>     return a + b + c
+        >>> assert validate_input_col(f, [1, 2])
+        >>> assert validate_input_col(f_def, 1)
+        >>> assert validate_input_col(f_def, [1, 2])
+
+    Notes:
+        If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments,
+        for example, f(a, *args), the validator will accept any size of input column
+        greater than or equal to the number of positional arguments.
+        (in this case, 1).
+
+    Args:
+        fn: The function to check.
+        input_col: The input column to check.
+
+    Raises:
+        ValueError: If the function is not compatible with the input column.
+    """
+    try:
+        sig = inspect.signature(fn)
+    except (
+        ValueError
+    ):  # Signature cannot be inspected, likely it is a built-in fn or written in C
+        return
+    if isinstance(input_col, (list, tuple)):
+        input_col_size = len(input_col)
+    else:
+        input_col_size = 1
+
+    pos = []
+    var_positional = False
+    non_default_kw_only = []
+
+    for p in sig.parameters.values():
+        if p.kind in (
+            inspect.Parameter.POSITIONAL_ONLY,
+            inspect.Parameter.POSITIONAL_OR_KEYWORD,
+        ):
+            pos.append(p)
+        elif p.kind is inspect.Parameter.VAR_POSITIONAL:
+            var_positional = True
+        elif p.kind is inspect.Parameter.KEYWORD_ONLY:
+            if p.default is p.empty:
+                non_default_kw_only.append(p)
+        else:
+            continue
+
+    if isinstance(fn, functools.partial):
+        fn_name = getattr(fn.func, "__name__", repr(fn.func))
+    else:
+        fn_name = getattr(fn, "__name__", repr(fn))
+
+    if len(non_default_kw_only) > 0:
+        raise ValueError(
+            f"The function {fn_name} takes {len(non_default_kw_only)} "
+            f"non-default keyword-only parameters, which is not allowed."
+        )
+
+    if len(sig.parameters) < input_col_size:
+        if not var_positional:
+            raise ValueError(
+                f"The function {fn_name} takes {len(sig.parameters)} "
+                f"parameters, but {input_col_size} are required."
+            )
+    else:
+        if len(pos) > input_col_size:
+            if any(p.default is p.empty for p in pos[input_col_size:]):
+                raise ValueError(
+                    f"The function {fn_name} takes {len(pos)} "
+                    f"positional parameters, but {input_col_size} are required."
+                )
+        elif len(pos) < input_col_size:
+            if not var_positional:
+                raise ValueError(
+                    f"The function {fn_name} takes {len(pos)} "
+                    f"positional parameters, but {input_col_size} are required."
+                )
+
+
+def _is_local_fn(fn):
+    # Functions or Methods
+    if hasattr(fn, "__code__"):
+        return fn.__code__.co_flags & inspect.CO_NESTED
+    # Callable Objects
+    else:
+        if hasattr(fn, "__qualname__"):
+            return "" in fn.__qualname__
+        fn_type = type(fn)
+        if hasattr(fn_type, "__qualname__"):
+            return "" in fn_type.__qualname__
+    return False
+
+
+def _check_unpickable_fn(fn: Callable):
+    """
+    Check function is pickable or not.
+
+    If it is a lambda or local function, a UserWarning will be raised. If it's not a callable function, a TypeError will be raised.
+    """
+    if not callable(fn):
+        raise TypeError(f"A callable function is expected, but {type(fn)} is provided.")
+
+    # Extract function from partial object
+    # Nested partial function is automatically expanded as a single partial object
+    if isinstance(fn, functools.partial):
+        fn = fn.func
+
+    # Local function
+    if _is_local_fn(fn) and not dill_available():
+        warnings.warn(
+            "Local function is not supported by pickle, please use "
+            "regular python function or functools.partial instead."
+        )
+        return
+
+    # Lambda function
+    if hasattr(fn, "__name__") and fn.__name__ == "" and not dill_available():
+        warnings.warn(
+            "Lambda function is not supported by pickle, please use "
+            "regular python function or functools.partial instead."
+        )
+        return
+
+
+def match_masks(name: str, masks: Union[str, list[str]]) -> bool:
+    # empty mask matches any input name
+    if not masks:
+        return True
+
+    if isinstance(masks, str):
+        return fnmatch.fnmatch(name, masks)
+
+    for mask in masks:
+        if fnmatch.fnmatch(name, mask):
+            return True
+    return False
+
+
+def get_file_pathnames_from_root(
+    root: str,
+    masks: Union[str, list[str]],
+    recursive: bool = False,
+    abspath: bool = False,
+    non_deterministic: bool = False,
+) -> Iterable[str]:
+    # print out an error message and raise the error out
+    def onerror(err: OSError):
+        warnings.warn(err.filename + " : " + err.strerror)
+        raise err
+
+    if os.path.isfile(root):
+        path = root
+        if abspath:
+            path = os.path.abspath(path)
+        fname = os.path.basename(path)
+        if match_masks(fname, masks):
+            yield path
+    else:
+        for path, dirs, files in os.walk(root, onerror=onerror):
+            if abspath:
+                path = os.path.abspath(path)
+            if not non_deterministic:
+                files.sort()
+            for f in files:
+                if match_masks(f, masks):
+                    yield os.path.join(path, f)
+            if not recursive:
+                break
+            if not non_deterministic:
+                # Note that this is in-place modifying the internal list from `os.walk`
+                # This only works because `os.walk` doesn't shallow copy before turn
+                # https://github.com/python/cpython/blob/f4c03484da59049eb62a9bf7777b963e2267d187/Lib/os.py#L407
+                dirs.sort()
+
+
+def get_file_binaries_from_pathnames(
+    pathnames: Iterable, mode: str, encoding: Optional[str] = None
+):
+    if not isinstance(pathnames, Iterable):
+        pathnames = [
+            pathnames,
+        ]
+
+    if mode in ("b", "t"):
+        mode = "r" + mode
+
+    for pathname in pathnames:
+        if not isinstance(pathname, str):
+            raise TypeError(
+                f"Expected string type for pathname, but got {type(pathname)}"
+            )
+        yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding))
+
+
+def validate_pathname_binary_tuple(data: tuple[str, IOBase]):
+    if not isinstance(data, tuple):
+        raise TypeError(
+            f"pathname binary data should be tuple type, but it is type {type(data)}"
+        )
+    if len(data) != 2:
+        raise TypeError(
+            f"pathname binary stream tuple length should be 2, but got {len(data)}"
+        )
+    if not isinstance(data[0], str):
+        raise TypeError(
+            f"pathname within the tuple should have string type pathname, but it is type {type(data[0])}"
+        )
+    if not isinstance(data[1], IOBase) and not isinstance(data[1], StreamWrapper):
+        raise TypeError(
+            f"binary stream within the tuple should have IOBase or"
+            f"its subclasses as type, but it is type {type(data[1])}"
+        )
+
+
+# Deprecated function names and its corresponding DataPipe type and kwargs for the `_deprecation_warning` function
+_iter_deprecated_functional_names: dict[str, dict] = {}
+_map_deprecated_functional_names: dict[str, dict] = {}
+
+
+def _deprecation_warning(
+    old_class_name: str,
+    *,
+    deprecation_version: str,
+    removal_version: str,
+    old_functional_name: str = "",
+    old_argument_name: str = "",
+    new_class_name: str = "",
+    new_functional_name: str = "",
+    new_argument_name: str = "",
+    deprecate_functional_name_only: bool = False,
+) -> None:
+    if new_functional_name and not old_functional_name:
+        raise ValueError(
+            "Old functional API needs to be specified for the deprecation warning."
+        )
+    if new_argument_name and not old_argument_name:
+        raise ValueError(
+            "Old argument name needs to be specified for the deprecation warning."
+        )
+
+    if old_functional_name and old_argument_name:
+        raise ValueError(
+            "Deprecating warning for functional API and argument should be separated."
+        )
+
+    msg = f"`{old_class_name}()`"
+    if deprecate_functional_name_only and old_functional_name:
+        msg = f"{msg}'s functional API `.{old_functional_name}()` is"
+    elif old_functional_name:
+        msg = f"{msg} and its functional API `.{old_functional_name}()` are"
+    elif old_argument_name:
+        msg = f"The argument `{old_argument_name}` of {msg} is"
+    else:
+        msg = f"{msg} is"
+    msg = (
+        f"{msg} deprecated since {deprecation_version} and will be removed in {removal_version}."
+        f"\nSee https://github.com/pytorch/data/issues/163 for details."
+    )
+
+    if new_class_name or new_functional_name:
+        msg = f"{msg}\nPlease use"
+        if new_class_name:
+            msg = f"{msg} `{new_class_name}()`"
+        if new_class_name and new_functional_name:
+            msg = f"{msg} or"
+        if new_functional_name:
+            msg = f"{msg} `.{new_functional_name}()`"
+        msg = f"{msg} instead."
+
+    if new_argument_name:
+        msg = f"{msg}\nPlease use `{old_class_name}({new_argument_name}=)` instead."
+
+    warnings.warn(msg, FutureWarning)
+
+
+class StreamWrapper:
+    """
+    StreamWrapper is introduced to wrap file handler generated by DataPipe operation like `FileOpener`.
+
+    StreamWrapper would guarantee the wrapped file handler is closed when it's out of scope.
+    """
+
+    session_streams: dict[Any, int] = {}
+    debug_unclosed_streams: bool = False
+
+    def __init__(self, file_obj, parent_stream=None, name=None):
+        self.file_obj = file_obj
+        self.child_counter = 0
+        self.parent_stream = parent_stream
+        self.close_on_last_child = False
+        self.name = name
+        self.closed = False
+        if parent_stream is not None:
+            if not isinstance(parent_stream, StreamWrapper):
+                raise RuntimeError(
+                    f"Parent stream should be StreamWrapper, {type(parent_stream)} was given"
+                )
+            parent_stream.child_counter += 1
+            self.parent_stream = parent_stream
+        if StreamWrapper.debug_unclosed_streams:
+            StreamWrapper.session_streams[self] = 1
+
+    @classmethod
+    def close_streams(cls, v, depth=0):
+        """Traverse structure and attempts to close all found StreamWrappers on best effort basis."""
+        if depth > 10:
+            return
+        if isinstance(v, StreamWrapper):
+            v.close()
+        else:
+            # Traverse only simple structures
+            if isinstance(v, dict):
+                for vv in v.values():
+                    cls.close_streams(vv, depth=depth + 1)
+            elif isinstance(v, (list, tuple)):
+                for vv in v:
+                    cls.close_streams(vv, depth=depth + 1)
+
+    def __getattr__(self, name):
+        file_obj = self.__dict__["file_obj"]
+        return getattr(file_obj, name)
+
+    def close(self, *args, **kwargs):
+        if self.closed:
+            return
+        if StreamWrapper.debug_unclosed_streams:
+            del StreamWrapper.session_streams[self]
+        if hasattr(self, "parent_stream") and self.parent_stream is not None:
+            self.parent_stream.child_counter -= 1
+            if (
+                not self.parent_stream.child_counter
+                and self.parent_stream.close_on_last_child
+            ):
+                self.parent_stream.close()
+        try:
+            self.file_obj.close(*args, **kwargs)
+        except AttributeError:
+            pass
+        self.closed = True
+
+    def autoclose(self):
+        """Automatically close stream when all child streams are closed or if there are none."""
+        self.close_on_last_child = True
+        if self.child_counter == 0:
+            self.close()
+
+    def __dir__(self):
+        attrs = list(self.__dict__.keys()) + list(StreamWrapper.__dict__.keys())
+        attrs += dir(self.file_obj)
+        return list(set(attrs))
+
+    def __del__(self):
+        if not self.closed:
+            self.close()
+
+    def __iter__(self):
+        yield from self.file_obj
+
+    def __next__(self):
+        return next(self.file_obj)
+
+    def __repr__(self):
+        if self.name is None:
+            return f"StreamWrapper<{self.file_obj!r}>"
+        else:
+            return f"StreamWrapper<{self.name},{self.file_obj!r}>"
+
+    def __getstate__(self):
+        return self.file_obj
+
+    def __setstate__(self, obj):
+        self.file_obj = obj
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/decoder.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee5bee8f15280542e245ff91c297d15282a1942b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/decoder.py
@@ -0,0 +1,378 @@
+# mypy: allow-untyped-defs
+# This file takes partial of the implementation from NVIDIA's webdataset at here:
+# https://github.com/tmbdev/webdataset/blob/master/webdataset/autodecode.py
+
+import io
+import json
+import os.path
+import pickle
+import tempfile
+
+import torch
+from torch.utils.data.datapipes.utils.common import StreamWrapper
+
+
+__all__ = [
+    "Decoder",
+    "ImageHandler",
+    "MatHandler",
+    "audiohandler",
+    "basichandlers",
+    "extension_extract_fn",
+    "handle_extension",
+    "imagehandler",
+    "mathandler",
+    "videohandler",
+]
+
+
+################################################################
+# handle basic datatypes
+################################################################
+def basichandlers(extension: str, data):
+    """Transforms raw data (byte stream) into python objects.
+
+    Looks at the extension and loads the data into a python object supporting
+    the corresponding extension.
+
+    Args:
+        extension (str): The file extension
+        data (byte stream): Data to load into a python object.
+
+    Returns:
+        object: The data loaded into a corresponding python object
+            supporting the extension.
+
+    Example:
+        >>> import pickle
+        >>> data = pickle.dumps('some data')
+        >>> new_data = basichandlers('pickle', data)
+        >>> new_data
+        some data
+
+    The transformation of data for extensions are:
+        - txt, text, transcript: utf-8 decoded data of str format
+        - cls, cls2, class, count, index, inx, id: int
+        - json, jsn: json loaded data
+        - pickle, pyd: pickle loaded data
+        - pt: torch loaded data
+    """
+
+    if extension in "txt text transcript":
+        return data.decode("utf-8")
+
+    if extension in "cls cls2 class count index inx id".split():
+        try:
+            return int(data)
+        except ValueError:
+            return None
+
+    if extension in "json jsn":
+        return json.loads(data)
+
+    if extension in "pyd pickle".split():
+        return pickle.loads(data)
+
+    if extension in "pt".split():
+        stream = io.BytesIO(data)
+        return torch.load(stream)
+
+    # if extension in "ten tb".split():
+    #     from . import tenbin
+    #     return tenbin.decode_buffer(data)
+
+    # if extension in "mp msgpack msg".split():
+    #     import msgpack
+    #     return msgpack.unpackb(data)
+
+    return None
+
+
+################################################################
+# handle images
+################################################################
+imagespecs = {
+    "l8": ("numpy", "uint8", "l"),
+    "rgb8": ("numpy", "uint8", "rgb"),
+    "rgba8": ("numpy", "uint8", "rgba"),
+    "l": ("numpy", "float", "l"),
+    "rgb": ("numpy", "float", "rgb"),
+    "rgba": ("numpy", "float", "rgba"),
+    "torchl8": ("torch", "uint8", "l"),
+    "torchrgb8": ("torch", "uint8", "rgb"),
+    "torchrgba8": ("torch", "uint8", "rgba"),
+    "torchl": ("torch", "float", "l"),
+    "torchrgb": ("torch", "float", "rgb"),
+    "torch": ("torch", "float", "rgb"),
+    "torchrgba": ("torch", "float", "rgba"),
+    "pill": ("pil", None, "l"),
+    "pil": ("pil", None, "rgb"),
+    "pilrgb": ("pil", None, "rgb"),
+    "pilrgba": ("pil", None, "rgba"),
+}
+
+
+def handle_extension(extensions, f):
+    """
+    Return a decoder handler function for the list of extensions.
+
+    Extensions can be a space separated list of extensions.
+    Extensions can contain dots, in which case the corresponding number
+    of extension components must be present in the key given to f.
+    Comparisons are case insensitive.
+    Examples:
+    handle_extension("jpg jpeg", my_decode_jpg)  # invoked for any file.jpg
+    handle_extension("seg.jpg", special_case_jpg)  # invoked only for file.seg.jpg
+    """
+    extensions = extensions.lower().split()
+
+    def g(key, data):
+        extension = key.lower().split(".")
+
+        for target in extensions:
+            target = target.split(".")
+            if len(target) > len(extension):
+                continue
+
+            if extension[-len(target) :] == target:
+                return f(data)
+            return None
+
+    return g
+
+
+class ImageHandler:
+    """
+    Decode image data using the given `imagespec`.
+
+    The `imagespec` specifies whether the image is decoded
+    to numpy/torch/pi, decoded to uint8/float, and decoded
+    to l/rgb/rgba:
+
+    - l8: numpy uint8 l
+    - rgb8: numpy uint8 rgb
+    - rgba8: numpy uint8 rgba
+    - l: numpy float l
+    - rgb: numpy float rgb
+    - rgba: numpy float rgba
+    - torchl8: torch uint8 l
+    - torchrgb8: torch uint8 rgb
+    - torchrgba8: torch uint8 rgba
+    - torchl: torch float l
+    - torchrgb: torch float rgb
+    - torch: torch float rgb
+    - torchrgba: torch float rgba
+    - pill: pil None l
+    - pil: pil None rgb
+    - pilrgb: pil None rgb
+    - pilrgba: pil None rgba
+    """
+
+    def __init__(self, imagespec):
+        assert imagespec in list(
+            imagespecs.keys()
+        ), f"unknown image specification: {imagespec}"
+        self.imagespec = imagespec.lower()
+
+    def __call__(self, extension, data):
+        if extension.lower() not in "jpg jpeg png ppm pgm pbm pnm".split():
+            return None
+
+        try:
+            import numpy as np
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError(
+                "Package `numpy` is required to be installed for default image decoder."
+                "Please use `pip install numpy` to install the package"
+            ) from e
+
+        try:
+            import PIL.Image
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError(
+                "Package `PIL` is required to be installed for default image decoder."
+                "Please use `pip install Pillow` to install the package"
+            ) from e
+
+        imagespec = self.imagespec
+        atype, etype, mode = imagespecs[imagespec]
+
+        with io.BytesIO(data) as stream:
+            img = PIL.Image.open(stream)
+            img.load()
+            img = img.convert(mode.upper())
+            if atype == "pil":
+                return img
+            elif atype == "numpy":
+                result = np.asarray(img)
+                assert (
+                    result.dtype == np.uint8
+                ), f"numpy image array should be type uint8, but got {result.dtype}"
+                if etype == "uint8":
+                    return result
+                else:
+                    return result.astype("f") / 255.0
+            elif atype == "torch":
+                result = np.asarray(img)
+                assert (
+                    result.dtype == np.uint8
+                ), f"numpy image array should be type uint8, but got {result.dtype}"
+
+                if etype == "uint8":
+                    result = np.array(result.transpose(2, 0, 1))
+                    return torch.tensor(result)
+                else:
+                    result = np.array(result.transpose(2, 0, 1))
+                    return torch.tensor(result) / 255.0
+            return None
+
+
+def imagehandler(imagespec):
+    return ImageHandler(imagespec)
+
+
+################################################################
+# torch video
+################################################################
+def videohandler(extension, data):
+    if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split():
+        return None
+
+    try:
+        import torchvision.io
+    except ImportError as e:
+        raise ModuleNotFoundError(
+            "Package `torchvision` is required to be installed for default video file loader."
+            "Please use `pip install torchvision`"
+            "to install the package"
+        ) from e
+
+    with tempfile.TemporaryDirectory() as dirname:
+        fname = os.path.join(dirname, f"file.{extension}")
+        with open(fname, "wb") as stream:
+            stream.write(data)
+            return torchvision.io.read_video(fname)
+
+
+################################################################
+# torchaudio
+################################################################
+def audiohandler(extension, data):
+    if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]:
+        return None
+
+    try:
+        import torchaudio  # type: ignore[import]
+    except ImportError as e:
+        raise ModuleNotFoundError(
+            "Package `torchaudio` is required to be installed for default audio file loader."
+            "Please use `pip install torchaudio`"
+            "to install the package"
+        ) from e
+
+    with tempfile.TemporaryDirectory() as dirname:
+        fname = os.path.join(dirname, f"file.{extension}")
+        with open(fname, "wb") as stream:
+            stream.write(data)
+            return torchaudio.load(fname)
+
+
+################################################################
+# mat
+################################################################
+class MatHandler:
+    def __init__(self, **loadmat_kwargs) -> None:
+        try:
+            import scipy.io as sio
+        except ImportError as e:
+            raise ModuleNotFoundError(
+                "Package `scipy` is required to be installed for mat file."
+                "Please use `pip install scipy`"
+                "to install the package"
+            ) from e
+        self.sio = sio
+        self.loadmat_kwargs = loadmat_kwargs
+
+    def __call__(self, extension, data):
+        if extension != "mat":
+            return None
+        with io.BytesIO(data) as stream:
+            return self.sio.loadmat(stream, **self.loadmat_kwargs)
+
+
+def mathandler(**loadmat_kwargs):
+    return MatHandler(**loadmat_kwargs)
+
+
+################################################################
+# a sample decoder
+################################################################
+# Extract extension from pathname
+def extension_extract_fn(pathname):
+    ext = os.path.splitext(pathname)[1]
+    # Remove dot
+    if ext:
+        ext = ext[1:]
+    return ext
+
+
+class Decoder:
+    """
+    Decode key/data sets using a list of handlers.
+
+    For each key/data item, this iterates through the list of
+    handlers until some handler returns something other than None.
+    """
+
+    def __init__(self, *handler, key_fn=extension_extract_fn):
+        self.handlers = list(handler) if handler else []
+        self.key_fn = key_fn
+
+    # Insert new handler from the beginning of handlers list to make sure the new
+    # handler having the highest priority
+    def add_handler(self, *handler):
+        if not handler:
+            return
+        self.handlers = list(handler) + self.handlers
+
+    @staticmethod
+    def _is_stream_handle(data):
+        obj_to_check = data.file_obj if isinstance(data, StreamWrapper) else data
+        return isinstance(obj_to_check, (io.BufferedIOBase, io.RawIOBase))
+
+    def decode1(self, key, data):
+        if not data:
+            return data
+
+        # if data is a stream handle, we need to read all the content before decoding
+        if Decoder._is_stream_handle(data):
+            ds = data
+            # The behavior of .read can differ between streams (e.g. HTTPResponse), hence this is used instead
+            data = b"".join(data)
+            ds.close()
+
+        for f in self.handlers:
+            result = f(key, data)
+            if result is not None:
+                return result
+        return data
+
+    def decode(self, data):
+        result = {}
+        # single data tuple(pathname, data stream)
+        if isinstance(data, tuple):
+            data = [data]
+
+        if data is not None:
+            for k, v in data:
+                # TODO: xinyu, figure out why Nvidia do this?
+                if k[0] == "_":
+                    if isinstance(v, bytes):
+                        v = v.decode("utf-8")
+                        result[k] = v
+                        continue
+                result[k] = self.decode1(self.key_fn(k), v)
+        return result
+
+    def __call__(self, data):
+        return self.decode(data)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/snapshot.py b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/snapshot.py
new file mode 100644
index 0000000000000000000000000000000000000000..d120025a934ec2456306c59dd70a6f5d348e16b4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/data/datapipes/utils/snapshot.py
@@ -0,0 +1,64 @@
+# mypy: allow-untyped-defs
+from torch.utils.data.datapipes._hook_iterator import _SnapshotState
+from torch.utils.data.datapipes.datapipe import IterDataPipe
+from torch.utils.data.graph_settings import apply_random_seed
+
+
+# TODO: Caveats
+#   1. Caller (either the ReadingService or DataLoader) must pass in the initial RNG
+#   2. `in_batch_shuffle` and `bucketbatch` are not compatible with this because they currently
+#      lack the option to `set_seed`.
+def _simple_graph_snapshot_restoration(
+    datapipe: IterDataPipe, n_iterations: int, rng=None
+) -> None:
+    r"""
+    Fast-forward the given DataPipe and its parents by ``n_iterations``, re-doing computations to restore a snapshot.
+
+    For instance, applying this function to the final DataPipe of a graph will restore the snapshot
+    (via fast-forward) every DataPipe within the graph.
+
+    After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input
+    to this function to forward the DataPipe.
+
+    A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration
+    attempts.
+
+    Note:
+        This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding
+        methods (custom ones if necessary) are recommended.
+
+    Args:
+        datapipe: IterDataPipe to be fast-forwarded
+        n_iterations: number of iterations to fast-forward
+        rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator
+            should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``.
+    """
+    if datapipe._snapshot_state == _SnapshotState.Restored:
+        raise RuntimeError(
+            "Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph "
+            "if your graph has not been restored."
+        )
+
+    # For this snapshot restoration function, we want the DataPipe to be at its initial state prior to
+    # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`,
+    # the first reset will not actually reset.
+    datapipe.reset()  # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`.
+    apply_random_seed(datapipe, rng)
+
+    remainder = n_iterations
+    it = iter(datapipe)  # This always reset the DataPipe if it hasn't already.
+    while remainder > 0:
+        try:
+            next(it)
+            remainder -= 1
+        except StopIteration as e:
+            raise RuntimeError(
+                f"Fast-forward {datapipe} by {n_iterations} iterations "
+                "exceeds the number of samples available."
+            ) from e
+    datapipe._fast_forward_iterator = it
+    # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere.
+
+    # This will prevent the DataPipe from resetting in the `iter()` call
+    # If another DataPipe is consuming it, it won't have to start over again
+    datapipe._snapshot_state = _SnapshotState.Restored
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/jit/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/jit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/jit/log_extract.py b/.venv/lib/python3.12/site-packages/torch/utils/jit/log_extract.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5804e710bae53b3a1e7d0dd94785d00de41ab78
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/jit/log_extract.py
@@ -0,0 +1,113 @@
+# mypy: allow-untyped-defs
+from contextlib import contextmanager
+from typing import Any, cast
+import random
+import torch
+import time
+from torch.utils.benchmark import Timer
+
+def extract_ir(filename: str) -> list[str]:
+    BEGIN = ""
+    END = ""
+    pfx = None
+    graphs = []
+    with open(filename) as f:
+        split_strs = f.read().split(BEGIN)
+        for i, split_str in enumerate(split_strs):
+            if i == 0:
+                continue
+            end_loc = split_str.find(END)
+            if end_loc == -1:
+                continue
+            s = split_str[:end_loc]
+            pfx = split_strs[i - 1].splitlines()[-1]
+            lines = [x[len(pfx):] for x in s.splitlines(keepends=True)]
+            graphs.append(''.join(lines))
+
+    return graphs
+
+
+def make_tensor_from_type(inp_type: torch._C.TensorType):
+    size = inp_type.sizes()
+    stride = inp_type.strides()
+    device = inp_type.device()
+    dtype = inp_type.dtype()
+    assert size is not None
+    assert stride is not None
+    assert device is not None
+    assert dtype is not None
+    return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype)
+
+def load_graph_and_inputs(ir: str) -> tuple[Any, list[Any]]:
+    graph = torch._C.parse_ir(ir, parse_tensor_constants=True)
+    graph.makeMultiOutputIntoTuple()
+    inputs = []
+    for inp in graph.inputs():
+        if isinstance(inp.type(), torch._C.FloatType):
+            inputs.append(random.uniform(.1, 100))
+        elif isinstance(inp.type(), torch._C.IntType):
+            inputs.append(random.randint(1, 100))
+        elif isinstance(inp.type(), torch._C.TensorType):
+            tensorType = cast(torch._C.TensorType, inp.type())
+            inputs.append(make_tensor_from_type(tensorType))
+        elif isinstance(inp.type(), torch._C.BoolType):
+            inputs.append(random.randint(0, 1) == 1)
+        else:
+            raise NotImplementedError(f"A default value is not implemented for type {inp.type()}")
+
+    func = torch._C._create_function_from_graph("forward", graph)
+    torch._C._jit_pass_erase_shape_information(func.graph)
+    return (func, inputs)
+
+def time_cuda(fn, inputs, test_runs):
+    t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs})
+    times = t.blocked_autorange()
+    return times.median * 1000  # time in ms
+
+def time_cpu(fn, inputs, test_runs):
+    s = time.perf_counter()
+    for _ in range(test_runs):
+        fn(*inputs)
+    e = time.perf_counter()
+    return (e - s) / test_runs * 1000  # time in ms
+
+def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
+    graph, _ = load_graph_and_inputs(ir)
+    for _ in range(warmup_runs):
+        graph(*inputs)
+
+    is_cpu = None
+    for input in inputs:
+        if isinstance(input, torch.Tensor):
+            is_cpu = input.device.type == "cpu"
+            break
+    assert is_cpu is not None
+
+    out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
+    return out
+
+@contextmanager
+def no_fuser(*args, **kwargs):
+    old_optimize = torch._C._get_graph_executor_optimize(False)
+    try:
+        yield
+    finally:
+        torch._C._get_graph_executor_optimize(old_optimize)
+
+def run_baseline_no_fusion(ir, inputs) -> float:
+    with no_fuser():
+        return run_test(ir, inputs)
+
+
+def run_nnc(ir, inputs, dynamic) -> float:
+    try:
+        strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)]
+        old_strat = torch.jit.set_fusion_strategy(strat)
+        with torch.jit.fuser("fuser1"):
+            return run_test(ir, inputs)
+    finally:
+        torch.jit.set_fusion_strategy(old_strat)
+
+def run_nvfuser(ir, inputs) -> float:
+    with torch.jit.fuser("fuser2"):
+        return run_test(ir, inputs)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b65acbca6204896a928aec5d33c054e6e645502f
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1351f4b7aedabb8716c2ab8ec01d7b55b9706be8
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/serialization/__pycache__/config.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9b2ac5edd05e16ef51e75f2ca68864b65da5d58
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/__init__.py
@@ -0,0 +1,19 @@
+import tensorboard
+from torch._vendor.packaging.version import Version
+
+if not hasattr(tensorboard, "__version__") or Version(
+    tensorboard.__version__
+) < Version("1.15"):
+    raise ImportError("TensorBoard logging requires TensorBoard version 1.15 or above")
+
+del Version
+del tensorboard
+
+from .writer import FileWriter, SummaryWriter
+from tensorboard.summary.writer.record_writer import RecordWriter
+
+__all__ = [
+    "FileWriter",
+    "RecordWriter",
+    "SummaryWriter",
+]
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_convert_np.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_convert_np.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e20ec6337c3662fdd74836d280cedae2678b66a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_convert_np.py
@@ -0,0 +1,33 @@
+"""This module converts objects into numpy array."""
+
+import numpy as np
+
+import torch
+
+
+def make_np(x: torch.Tensor) -> np.ndarray:
+    """
+    Convert an object into numpy array.
+
+    Args:
+      x: An instance of torch tensor
+
+    Returns:
+        numpy.array: Numpy array
+    """
+    if isinstance(x, np.ndarray):
+        return x
+    if np.isscalar(x):
+        return np.array([x])
+    if isinstance(x, torch.Tensor):
+        return _prepare_pytorch(x)
+    raise NotImplementedError(
+        f"Got {type(x)}, but numpy array or torch tensor are expected."
+    )
+
+
+def _prepare_pytorch(x: torch.Tensor) -> np.ndarray:
+    if x.dtype == torch.bfloat16:
+        x = x.to(torch.float16)
+    x = x.detach().cpu().numpy()
+    return x
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_embedding.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..44cb6c41b017f170c432f307907949e37de497c2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_embedding.py
@@ -0,0 +1,86 @@
+# mypy: allow-untyped-defs
+import math
+import numpy as np
+from ._convert_np import make_np
+from ._utils import make_grid
+from tensorboard.compat import tf
+from tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
+
+
+_HAS_GFILE_JOIN = hasattr(tf.io.gfile, "join")
+
+
+def _gfile_join(a, b):
+    # The join API is different between tensorboard's TF stub and TF:
+    # https://github.com/tensorflow/tensorboard/issues/6080
+    # We need to try both because `tf` may point to either the stub or the real TF.
+    if _HAS_GFILE_JOIN:
+        return tf.io.gfile.join(a, b)
+    else:
+        fs = tf.io.gfile.get_filesystem(a)
+        return fs.join(a, b)
+
+
+def make_tsv(metadata, save_path, metadata_header=None):
+    if not metadata_header:
+        metadata = [str(x) for x in metadata]
+    else:
+        assert len(metadata_header) == len(
+            metadata[0]
+        ), "len of header must be equal to the number of columns in metadata"
+        metadata = ["\t".join(str(e) for e in l) for l in [metadata_header] + metadata]
+
+    metadata_bytes = tf.compat.as_bytes("\n".join(metadata) + "\n")
+    with tf.io.gfile.GFile(_gfile_join(save_path, "metadata.tsv"), "wb") as f:
+        f.write(metadata_bytes)
+
+
+# https://github.com/tensorflow/tensorboard/issues/44 image label will be squared
+def make_sprite(label_img, save_path):
+    from PIL import Image
+    from io import BytesIO
+
+    # this ensures the sprite image has correct dimension as described in
+    # https://www.tensorflow.org/get_started/embedding_viz
+    nrow = int(math.ceil((label_img.size(0)) ** 0.5))
+    arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow)
+
+    # augment images so that #images equals nrow*nrow
+    arranged_augment_square_HWC = np.zeros(
+        (arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3)
+    )
+    arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0)  # chw -> hwc
+    arranged_augment_square_HWC[: arranged_img_HWC.shape[0], :, :] = arranged_img_HWC
+    im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255)))
+
+    with BytesIO() as buf:
+        im.save(buf, format="PNG")
+        im_bytes = buf.getvalue()
+
+    with tf.io.gfile.GFile(_gfile_join(save_path, "sprite.png"), "wb") as f:
+        f.write(im_bytes)
+
+
+def get_embedding_info(metadata, label_img, subdir, global_step, tag):
+    info = EmbeddingInfo()
+    info.tensor_name = f"{tag}:{str(global_step).zfill(5)}"
+    info.tensor_path = _gfile_join(subdir, "tensors.tsv")
+    if metadata is not None:
+        info.metadata_path = _gfile_join(subdir, "metadata.tsv")
+    if label_img is not None:
+        info.sprite.image_path = _gfile_join(subdir, "sprite.png")
+        info.sprite.single_image_dim.extend([label_img.size(3), label_img.size(2)])
+    return info
+
+
+def write_pbtxt(save_path, contents):
+    config_path = _gfile_join(save_path, "projector_config.pbtxt")
+    with tf.io.gfile.GFile(config_path, "wb") as f:
+        f.write(tf.compat.as_bytes(contents))
+
+
+def make_mat(matlist, save_path):
+    with tf.io.gfile.GFile(_gfile_join(save_path, "tensors.tsv"), "wb") as f:
+        for x in matlist:
+            x = [str(i.item()) for i in x]
+            f.write(tf.compat.as_bytes("\t".join(x) + "\n"))
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_onnx_graph.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_onnx_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b7381737b3e77a34b32f28a48833a1fedd61104
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_onnx_graph.py
@@ -0,0 +1,61 @@
+# mypy: allow-untyped-defs
+from tensorboard.compat.proto.graph_pb2 import GraphDef
+from tensorboard.compat.proto.node_def_pb2 import NodeDef
+from tensorboard.compat.proto.versions_pb2 import VersionDef
+from tensorboard.compat.proto.attr_value_pb2 import AttrValue
+from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
+
+
+def load_onnx_graph(fname):
+    import onnx
+
+    m = onnx.load(fname)  # type: ignore[attr-defined]
+    g = m.graph
+    return parse(g)
+
+
+def parse(graph):
+    nodes = []
+    import itertools
+
+    nodes_proto = list(itertools.chain(graph.input, graph.output))
+
+    for node in nodes_proto:
+        print(node.name)
+        shapeproto = TensorShapeProto(
+            dim=[
+                TensorShapeProto.Dim(size=d.dim_value)
+                for d in node.type.tensor_type.shape.dim
+            ]
+        )
+        nodes.append(
+            NodeDef(
+                name=node.name.encode(encoding="utf_8"),
+                op="Variable",
+                input=[],
+                attr={
+                    "dtype": AttrValue(type=node.type.tensor_type.elem_type),
+                    "shape": AttrValue(shape=shapeproto),
+                },
+            )
+        )
+
+    for node in graph.node:
+        _attr = [" = ".join([str(f[1]) for f in s.ListFields()]) for s in node.attribute]
+        attr = ", ".join(_attr).encode(encoding="utf_8")
+        print(node.output[0])
+        nodes.append(
+            NodeDef(
+                name=node.output[0].encode(encoding="utf_8"),
+                op=node.op_type,
+                input=node.input,
+                attr={"parameters": AttrValue(s=attr)},
+            )
+        )
+
+    # two pass token replacement, appends opname to object id
+    mapping = {}
+    for node in nodes:
+        mapping[node.name] = node.op + "_" + node.name
+
+    return GraphDef(node=nodes, versions=VersionDef(producer=22))
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_proto_graph.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_proto_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..30140a22cff673458b0e2b2fb5c43416d318a5ca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_proto_graph.py
@@ -0,0 +1,54 @@
+# mypy: allow-untyped-defs
+from typing import Optional
+from tensorboard.compat.proto.node_def_pb2 import NodeDef
+from tensorboard.compat.proto.attr_value_pb2 import AttrValue
+from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
+
+
+def attr_value_proto(dtype, shape, s):
+    """Create a dict of objects matching a NodeDef's attr field.
+
+    Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
+    specifically designed for a NodeDef. The values have been reverse engineered from
+    standard TensorBoard logged data.
+    """
+    attr = {}
+    if s is not None:
+        attr["attr"] = AttrValue(s=s.encode(encoding="utf_8"))
+    if shape is not None:
+        shapeproto = tensor_shape_proto(shape)
+        attr["_output_shapes"] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))
+    return attr
+
+
+def tensor_shape_proto(outputsize):
+    """Create an object matching a tensor_shape field.
+
+    Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto .
+    """
+    return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize])
+
+
+def node_proto(
+    name,
+    op="UnSpecified",
+    input=None,
+    dtype=None,
+    shape: Optional[tuple] = None,
+    outputsize=None,
+    attributes="",
+):
+    """Create an object matching a NodeDef.
+
+    Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto .
+    """
+    if input is None:
+        input = []
+    if not isinstance(input, list):
+        input = [input]
+    return NodeDef(
+        name=name.encode(encoding="utf_8"),
+        op=op,
+        input=input,
+        attr=attr_value_proto(dtype, outputsize, attributes),
+    )
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_pytorch_graph.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_pytorch_graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..85427162fc770acc6f63bfd0ded3a354ab2eb3da
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_pytorch_graph.py
@@ -0,0 +1,376 @@
+# mypy: allow-untyped-defs
+from collections import OrderedDict
+import contextlib
+from typing import Any
+
+from tensorboard.compat.proto.config_pb2 import RunMetadata
+from tensorboard.compat.proto.graph_pb2 import GraphDef
+from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
+from tensorboard.compat.proto.versions_pb2 import VersionDef
+
+import torch
+from ._proto_graph import node_proto
+
+methods_OP = [
+    "attributeNames",
+    "hasMultipleOutputs",
+    "hasUses",
+    "inputs",
+    "kind",
+    "outputs",
+    "outputsSize",
+    "scopeName",
+]
+# Some additional methods to explure for methods_IO are
+#
+#   'unique' (type int)
+#   'type' (type >)
+#
+# But the below are sufficient for now.
+methods_IO = ["node", "offset", "debugName"]
+
+GETATTR_KIND = "prim::GetAttr"
+CLASSTYPE_KIND = "ClassType"
+
+
+class NodeBase:
+    def __init__(
+        self,
+        debugName=None,
+        inputs=None,
+        scope=None,
+        tensor_size=None,
+        op_type="UnSpecified",
+        attributes="",
+    ):
+        # TODO; Specify a __slots__ for this class or potentially
+        # used namedtuple instead
+        self.debugName = debugName
+        self.inputs = inputs
+        self.tensor_size = tensor_size
+        self.kind = op_type
+        self.attributes = attributes
+        self.scope = scope
+
+    def __repr__(self):
+        repr = []
+        repr.append(str(type(self)))
+        repr.extend(
+            m + ": " + str(getattr(self, m)) + str(type(getattr(self, m)))
+            for m in dir(self)
+            if "__" not in m
+        )
+        return "\n".join(repr) + "\n\n"
+
+
+class NodePy(NodeBase):
+    def __init__(self, node_cpp, valid_methods):
+        super().__init__(node_cpp)
+        valid_methods = valid_methods[:]
+        self.inputs = []
+
+        for m in valid_methods:
+            if m == "inputs" or m == "outputs":
+                list_of_node = list(getattr(node_cpp, m)())
+                io_unique_names = []
+                io_tensor_sizes = []
+                for n in list_of_node:
+                    io_unique_names.append(n.debugName())
+                    if n.isCompleteTensor():
+                        io_tensor_sizes.append(n.type().sizes())
+                    else:
+                        io_tensor_sizes.append(None)
+
+                setattr(self, m, io_unique_names)
+                setattr(self, m + "tensor_size", io_tensor_sizes)
+
+            else:
+                setattr(self, m, getattr(node_cpp, m)())
+
+
+class NodePyIO(NodePy):
+    def __init__(self, node_cpp, input_or_output=None):
+        super().__init__(node_cpp, methods_IO)
+        try:
+            tensor_size = node_cpp.type().sizes()
+        except RuntimeError:
+            tensor_size = [
+                1,
+            ]  # fail when constant model is used.
+        self.tensor_size = tensor_size
+        # Kind attribute string is purely descriptive and will be shown
+        # in detailed information for the node in TensorBoard's graph plugin.
+        #
+        # NodePyOP nodes get this from their kind() method.
+        self.kind = "Parameter"
+        if input_or_output:
+            self.input_or_output = input_or_output
+            self.kind = "IO Node"
+
+
+class NodePyOP(NodePy):
+    def __init__(self, node_cpp):
+        super().__init__(node_cpp, methods_OP)
+        # Replace single quote which causes strange behavior in TensorBoard
+        # TODO: See if we can remove this in the future
+        self.attributes = str(
+            {k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()}
+        ).replace("'", " ")
+        self.kind = node_cpp.kind()
+
+
+class GraphPy:
+    """Helper class to convert torch.nn.Module to GraphDef proto and visualization with TensorBoard.
+
+    GraphDef generation operates in two passes:
+
+    In the first pass, all nodes are read and saved to two lists.
+    One list is for input/output nodes (nodes_io), which only have inbound
+    or outbound connections, but not both. Another list is for internal
+    operator nodes (nodes_op). The first pass also saves all scope name
+    appeared in the nodes in scope_name_appeared list for later processing.
+
+    In the second pass, scope names are fully applied to all nodes.
+    debugNameToScopedName is a mapping from a node's ID to its fully qualified
+    scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have
+    totally correct scope output, so this is nontrivial. The function
+    populate_namespace_from_OP_to_IO and find_common_root are used to
+    assign scope name to a node based on the connection between nodes
+    in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name
+    and scope_name_appeared.
+    """
+
+    def __init__(self):
+        self.nodes_op = []
+        self.nodes_io = OrderedDict()
+        self.unique_name_to_scoped_name = {}
+        self.shallowest_scope_name = "default"
+        self.scope_name_appeared = []
+
+    def append(self, x):
+        if isinstance(x, NodePyIO):
+            self.nodes_io[x.debugName] = x
+        if isinstance(x, NodePyOP):
+            self.nodes_op.append(x)
+
+    def printall(self):
+        print("all nodes")
+        for node in self.nodes_op:
+            print(node)
+        for key in self.nodes_io:
+            print(self.nodes_io[key])
+
+    def find_common_root(self):
+        for fullscope in self.scope_name_appeared:
+            if fullscope:
+                self.shallowest_scope_name = fullscope.split("/")[0]
+
+    def populate_namespace_from_OP_to_IO(self):
+        for node in self.nodes_op:
+            for node_output, outputSize in zip(node.outputs, node.outputstensor_size):
+                self.scope_name_appeared.append(node.scopeName)
+                self.nodes_io[node_output] = NodeBase(
+                    node_output,
+                    node.inputs,
+                    node.scopeName,
+                    outputSize,
+                    op_type=node.kind,
+                    attributes=node.attributes,
+                )
+
+        self.find_common_root()
+
+        for node in self.nodes_op:
+            for input_node_id in node.inputs:
+                self.unique_name_to_scoped_name[input_node_id] = (
+                    node.scopeName + "/" + input_node_id
+                )
+
+        for key, node in self.nodes_io.items():
+            if type(node) == NodeBase:
+                self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
+            if hasattr(node, "input_or_output"):
+                self.unique_name_to_scoped_name[key] = (
+                    node.input_or_output + "/" + node.debugName
+                )
+
+            if hasattr(node, "scope") and node.scope is not None:
+                self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
+                if node.scope == "" and self.shallowest_scope_name:
+                    self.unique_name_to_scoped_name[node.debugName] = (
+                        self.shallowest_scope_name + "/" + node.debugName
+                    )
+
+        # replace name
+        for key, node in self.nodes_io.items():
+            self.nodes_io[key].inputs = [
+                self.unique_name_to_scoped_name[node_input_id]
+                for node_input_id in node.inputs
+            ]
+            if node.debugName in self.unique_name_to_scoped_name:
+                self.nodes_io[key].debugName = self.unique_name_to_scoped_name[
+                    node.debugName
+                ]
+
+    def to_proto(self):
+        """Convert graph representation of GraphPy object to TensorBoard required format."""
+        # TODO: compute correct memory usage and CPU time once
+        # PyTorch supports it
+        nodes = [
+            node_proto(
+                v.debugName,
+                input=v.inputs,
+                outputsize=v.tensor_size,
+                op=v.kind,
+                attributes=v.attributes,
+            )
+            for v in self.nodes_io.values()
+        ]
+        return nodes
+
+
+def parse(graph, trace, args=None, omit_useless_nodes=True):
+    """Parse an optimized PyTorch model graph and produces a list of nodes and node stats.
+
+    Useful for eventual conversion to TensorBoard protobuf format.
+
+    Args:
+      graph (PyTorch module): The model graph to be parsed.
+      trace (PyTorch JIT TracedModule): The model trace to be parsed.
+      args (tuple): input tensor[s] for the model.
+      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
+    """
+    nodes_py = GraphPy()
+    for node in graph.inputs():
+        if omit_useless_nodes:
+            if (
+                len(node.uses()) == 0
+            ):  # number of user of the node (= number of outputs/ fanout)
+                continue
+
+        if node.type().kind() != CLASSTYPE_KIND:
+            nodes_py.append(NodePyIO(node, "input"))
+
+    attr_to_scope: dict[Any, str] = {}
+    for node in graph.nodes():
+        if node.kind() == GETATTR_KIND:
+            attr_name = node.s("name")
+            attr_key = node.output().debugName()
+            parent = node.input().node()
+            if (
+                parent.kind() == GETATTR_KIND
+            ):  # If the parent node is not the top-level "self" node
+                parent_attr_key = parent.output().debugName()
+                parent_scope = attr_to_scope[parent_attr_key]
+                attr_scope = parent_scope.split("/")[-1]
+                attr_to_scope[attr_key] = f"{parent_scope}/{attr_scope}.{attr_name}"
+            else:
+                attr_to_scope[attr_key] = f"__module.{attr_name}"
+            # We don't need classtype nodes; scope will provide this information
+            if node.output().type().kind() != CLASSTYPE_KIND:
+                node_py = NodePyOP(node)
+                node_py.scopeName = attr_to_scope[attr_key]  # type: ignore[attr-defined]
+                nodes_py.append(node_py)
+        else:
+            nodes_py.append(NodePyOP(node))
+
+    for i, node in enumerate(graph.outputs()):  # Create sink nodes for output ops
+        node_pyio = NodePyIO(node, "output")
+        node_pyio.debugName = f"output.{i + 1}"
+        node_pyio.inputs = [node.debugName()]
+        nodes_py.append(node_pyio)
+
+    def parse_traced_name(module):
+        if isinstance(module, torch.jit.TracedModule):
+            module_name = module._name
+        else:
+            module_name = getattr(module, "original_name", "Module")
+        return module_name
+
+    alias_to_name = {}
+    base_name = parse_traced_name(trace)
+    for name, module in trace.named_modules(prefix="__module"):
+        mod_name = parse_traced_name(module)
+        attr_name = name.split(".")[-1]
+        alias_to_name[name] = f"{mod_name}[{attr_name}]"
+
+    for node in nodes_py.nodes_op:
+        module_aliases = node.scopeName.split("/")
+        replacements = [
+            alias_to_name[alias] if alias in alias_to_name else alias.split(".")[-1]
+            for alias in module_aliases
+        ]
+        node.scopeName = base_name
+        if any(replacements):
+            node.scopeName += "/" + "/".join(replacements)
+
+    nodes_py.populate_namespace_from_OP_to_IO()
+    return nodes_py.to_proto()
+
+
+def graph(model, args, verbose=False, use_strict_trace=True):
+    """
+    Process a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard.
+
+    Args:
+      model (PyTorch module): The model to be parsed.
+      args (tuple): input tensor[s] for the model.
+      verbose (bool): Whether to print out verbose information while
+        processing.
+      use_strict_trace (bool): Whether to pass keyword argument `strict` to
+        `torch.jit.trace`. Pass False when you want the tracer to
+        record your mutable container types (list, dict)
+    """
+    with _set_model_to_eval(model):
+        try:
+            trace = torch.jit.trace(model, args, strict=use_strict_trace)
+            graph = trace.graph
+            torch._C._jit_pass_inline(graph)
+        except RuntimeError as e:
+            print(e)
+            print("Error occurs, No graph saved")
+            raise e
+
+    if verbose:
+        print(graph)
+    list_of_nodes = parse(graph, trace, args)
+    # We are hardcoding that this was run on CPU even though it might have actually
+    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
+    # on actual execution.
+    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
+    # and pass it correctly to TensorBoard.
+    #
+    # Definition of StepStats and DeviceStepStats can be found at
+    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/proto.ts
+    # and
+    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
+    stepstats = RunMetadata(
+        step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])
+    )
+    return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
+    # The producer version has been reverse engineered from standard
+    # TensorBoard logged data.
+
+
+@contextlib.contextmanager
+def _set_model_to_eval(model):
+    """Context manager to temporarily set the training mode of ``model`` to eval."""
+    if not isinstance(model, torch.jit.ScriptFunction):
+        originally_training = model.training
+        model.train(False)
+        try:
+            yield
+        finally:
+            model.train(originally_training)
+    else:
+        # Do nothing for ScriptFunction
+        try:
+            yield
+        finally:
+            pass
+
+
+def _node_get(node: torch._C.Node, key: str):
+    """Get attributes of a node which is polymorphic over return type."""
+    sel = node.kindOf(key)
+    return getattr(node, sel)(key)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_utils.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8acaf1696cb1f253fb4307e91e1fa088278334f0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/_utils.py
@@ -0,0 +1,127 @@
+# mypy: allow-untyped-defs
+import numpy as np
+import numpy.typing as npt
+
+
+# Functions for converting
+def figure_to_image(figures, close=True):
+    """Render matplotlib figure to numpy format.
+
+    Note that this requires the ``matplotlib`` package.
+
+    Args:
+        figures (matplotlib.pyplot.figure or list of figures): figure or a list of figures
+        close (bool): Flag to automatically close the figure
+
+    Returns:
+        numpy.array: image in [CHW] order
+    """
+    import matplotlib.pyplot as plt
+    import matplotlib.backends.backend_agg as plt_backend_agg
+
+    def render_to_rgb(figure):
+        canvas = plt_backend_agg.FigureCanvasAgg(figure)
+        canvas.draw()
+        data: npt.NDArray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
+        w, h = figure.canvas.get_width_height()
+        image_hwc = data.reshape([h, w, 4])[:, :, 0:3]
+        image_chw = np.moveaxis(image_hwc, source=2, destination=0)
+        if close:
+            plt.close(figure)
+        return image_chw
+
+    if isinstance(figures, list):
+        images = [render_to_rgb(figure) for figure in figures]
+        return np.stack(images)
+    else:
+        image = render_to_rgb(figures)
+        return image
+
+
+def _prepare_video(V):
+    """
+    Convert a 5D tensor into 4D tensor.
+
+    Convesrion is done from [batchsize, time(frame), channel(color), height, width]  (5D tensor)
+    to [time(frame), new_width, new_height, channel] (4D tensor).
+
+    A batch of images are spreaded to a grid, which forms a frame.
+    e.g. Video with batchsize 16 will have a 4x4 grid.
+    """
+    b, t, c, h, w = V.shape
+
+    if V.dtype == np.uint8:
+        V = np.float32(V) / 255.0
+
+    def is_power2(num):
+        return num != 0 and ((num & (num - 1)) == 0)
+
+    # pad to nearest power of 2, all at once
+    if not is_power2(V.shape[0]):
+        len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
+        V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
+
+    n_rows = 2 ** ((b.bit_length() - 1) // 2)
+    n_cols = V.shape[0] // n_rows
+
+    V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))
+    V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3))
+    V = np.reshape(V, newshape=(t, n_rows * h, n_cols * w, c))
+
+    return V
+
+
+def make_grid(I, ncols=8):
+    # I: N1HW or N3HW
+    assert isinstance(I, np.ndarray), "plugin error, should pass numpy array here"
+    if I.shape[1] == 1:
+        I = np.concatenate([I, I, I], 1)
+    assert I.ndim == 4 and I.shape[1] == 3
+    nimg = I.shape[0]
+    H = I.shape[2]
+    W = I.shape[3]
+    ncols = min(nimg, ncols)
+    nrows = int(np.ceil(float(nimg) / ncols))
+    canvas = np.zeros((3, H * nrows, W * ncols), dtype=I.dtype)
+    i = 0
+    for y in range(nrows):
+        for x in range(ncols):
+            if i >= nimg:
+                break
+            canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = I[i]
+            i = i + 1
+    return canvas
+
+    # if modality == 'IMG':
+    #     if x.dtype == np.uint8:
+    #         x = x.astype(np.float32) / 255.0
+
+
+def convert_to_HWC(tensor, input_format):  # tensor: numpy array
+    assert len(set(input_format)) == len(
+        input_format
+    ), f"You can not use the same dimension shordhand twice.         input_format: {input_format}"
+    assert len(tensor.shape) == len(
+        input_format
+    ), f"size of input tensor and input format are different. \
+        tensor shape: {tensor.shape}, input_format: {input_format}"
+    input_format = input_format.upper()
+
+    if len(input_format) == 4:
+        index = [input_format.find(c) for c in "NCHW"]
+        tensor_NCHW = tensor.transpose(index)
+        tensor_CHW = make_grid(tensor_NCHW)
+        return tensor_CHW.transpose(1, 2, 0)
+
+    if len(input_format) == 3:
+        index = [input_format.find(c) for c in "HWC"]
+        tensor_HWC = tensor.transpose(index)
+        if tensor_HWC.shape[2] == 1:
+            tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2)
+        return tensor_HWC
+
+    if len(input_format) == 2:
+        index = [input_format.find(c) for c in "HW"]
+        tensor = tensor.transpose(index)
+        tensor = np.stack([tensor, tensor, tensor], 2)
+        return tensor
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/summary.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fca4d9b7e66c7f6a3802e79a1980ecc2e77097c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/summary.py
@@ -0,0 +1,982 @@
+# mypy: allow-untyped-defs
+import json
+import logging
+import os
+import struct
+
+from typing import Any, Optional
+
+import torch
+import numpy as np
+
+from google.protobuf import struct_pb2
+
+from tensorboard.compat.proto.summary_pb2 import (
+    HistogramProto,
+    Summary,
+    SummaryMetadata,
+)
+from tensorboard.compat.proto.tensor_pb2 import TensorProto
+from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
+from tensorboard.plugins.custom_scalar import layout_pb2
+from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
+from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
+
+from ._convert_np import make_np
+from ._utils import _prepare_video, convert_to_HWC
+
+__all__ = [
+    "half_to_int",
+    "int_to_half",
+    "hparams",
+    "scalar",
+    "histogram_raw",
+    "histogram",
+    "make_histogram",
+    "image",
+    "image_boxes",
+    "draw_boxes",
+    "make_image",
+    "video",
+    "make_video",
+    "audio",
+    "custom_scalars",
+    "text",
+    "tensor_proto",
+    "pr_curve_raw",
+    "pr_curve",
+    "compute_curve",
+    "mesh",
+]
+
+logger = logging.getLogger(__name__)
+
+def half_to_int(f: float) -> int:
+    """Casts a half-precision float value into an integer.
+
+    Converts a half precision floating point value, such as `torch.half` or
+    `torch.bfloat16`, into an integer value which can be written into the
+    half_val field of a TensorProto for storage.
+
+    To undo the effects of this conversion, use int_to_half().
+
+    """
+    buf = struct.pack("f", f)
+    return struct.unpack("i", buf)[0]
+
+def int_to_half(i: int) -> float:
+    """Casts an integer value to a half-precision float.
+
+    Converts an integer value obtained from half_to_int back into a floating
+    point value.
+
+    """
+    buf = struct.pack("i", i)
+    return struct.unpack("f", buf)[0]
+
+def _tensor_to_half_val(t: torch.Tensor) -> list[int]:
+    return [half_to_int(x) for x in t.flatten().tolist()]
+
+def _tensor_to_complex_val(t: torch.Tensor) -> list[float]:
+    return torch.view_as_real(t).flatten().tolist()
+
+def _tensor_to_list(t: torch.Tensor) -> list[Any]:
+    return t.flatten().tolist()
+
+# type maps: torch.Tensor type -> (protobuf type, protobuf val field)
+_TENSOR_TYPE_MAP = {
+    torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
+    torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
+    torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
+    torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
+    torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
+    torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
+    torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
+    torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
+    torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
+    torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
+    torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
+    torch.short: ("DT_INT16", "int_val", _tensor_to_list),
+    torch.int: ("DT_INT32", "int_val", _tensor_to_list),
+    torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
+    torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
+    torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
+    torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
+    torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
+    torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
+    torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
+    torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
+    torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
+    torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
+    torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
+    torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
+    torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
+}
+
+
+def _calc_scale_factor(tensor):
+    converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
+    return 1 if converted.dtype == np.uint8 else 255
+
+
+def _draw_single_box(
+    image,
+    xmin,
+    ymin,
+    xmax,
+    ymax,
+    display_str,
+    color="black",
+    color_text="black",
+    thickness=2,
+):
+    from PIL import ImageDraw, ImageFont
+
+    font = ImageFont.load_default()
+    draw = ImageDraw.Draw(image)
+    (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
+    draw.line(
+        [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
+        width=thickness,
+        fill=color,
+    )
+    if display_str:
+        text_bottom = bottom
+        # Reverse list and print from bottom to top.
+        _left, _top, _right, _bottom = font.getbbox(display_str)
+        text_width, text_height = _right - _left, _bottom - _top
+        margin = np.ceil(0.05 * text_height)
+        draw.rectangle(
+            [
+                (left, text_bottom - text_height - 2 * margin),
+                (left + text_width, text_bottom),
+            ],
+            fill=color,
+        )
+        draw.text(
+            (left + margin, text_bottom - text_height - margin),
+            display_str,
+            fill=color_text,
+            font=font,
+        )
+    return image
+
+
+def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
+    """Output three `Summary` protocol buffers needed by hparams plugin.
+
+    `Experiment` keeps the metadata of an experiment, such as the name of the
+      hyperparameters and the name of the metrics.
+    `SessionStartInfo` keeps key-value pairs of the hyperparameters
+    `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
+
+    Args:
+      hparam_dict: A dictionary that contains names of the hyperparameters
+        and their values.
+      metric_dict: A dictionary that contains names of the metrics
+        and their values.
+      hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
+        contains names of the hyperparameters and all discrete values they can hold
+
+    Returns:
+      The `Summary` protobufs for Experiment, SessionStartInfo and
+        SessionEndInfo
+    """
+    import torch
+    from tensorboard.plugins.hparams.api_pb2 import (
+        DataType,
+        Experiment,
+        HParamInfo,
+        MetricInfo,
+        MetricName,
+        Status,
+    )
+    from tensorboard.plugins.hparams.metadata import (
+        EXPERIMENT_TAG,
+        PLUGIN_DATA_VERSION,
+        PLUGIN_NAME,
+        SESSION_END_INFO_TAG,
+        SESSION_START_INFO_TAG,
+    )
+    from tensorboard.plugins.hparams.plugin_data_pb2 import (
+        HParamsPluginData,
+        SessionEndInfo,
+        SessionStartInfo,
+    )
+
+    # TODO: expose other parameters in the future.
+    # hp = HParamInfo(name='lr',display_name='learning rate',
+    # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
+    # max_value=100))
+    # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
+    # description='', dataset_type=DatasetType.DATASET_VALIDATION)
+    # exp = Experiment(name='123', description='456', time_created_secs=100.0,
+    # hparam_infos=[hp], metric_infos=[mt], user='tw')
+
+    if not isinstance(hparam_dict, dict):
+        logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
+        raise TypeError(
+            "parameter: hparam_dict should be a dictionary, nothing logged."
+        )
+    if not isinstance(metric_dict, dict):
+        logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
+        raise TypeError(
+            "parameter: metric_dict should be a dictionary, nothing logged."
+        )
+
+    hparam_domain_discrete = hparam_domain_discrete or {}
+    if not isinstance(hparam_domain_discrete, dict):
+        raise TypeError(
+            "parameter: hparam_domain_discrete should be a dictionary, nothing logged."
+        )
+    for k, v in hparam_domain_discrete.items():
+        if (
+            k not in hparam_dict
+            or not isinstance(v, list)
+            or not all(isinstance(d, type(hparam_dict[k])) for d in v)
+        ):
+            raise TypeError(
+                f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
+            )
+    hps = []
+
+    ssi = SessionStartInfo()
+    for k, v in hparam_dict.items():
+        if v is None:
+            continue
+        if isinstance(v, (int, float)):
+            ssi.hparams[k].number_value = v
+
+            if k in hparam_domain_discrete:
+                domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
+                    values=[
+                        struct_pb2.Value(number_value=d)
+                        for d in hparam_domain_discrete[k]
+                    ]
+                )
+            else:
+                domain_discrete = None
+
+            hps.append(
+                HParamInfo(
+                    name=k,
+                    type=DataType.Value("DATA_TYPE_FLOAT64"),
+                    domain_discrete=domain_discrete,
+                )
+            )
+            continue
+
+        if isinstance(v, str):
+            ssi.hparams[k].string_value = v
+
+            if k in hparam_domain_discrete:
+                domain_discrete = struct_pb2.ListValue(
+                    values=[
+                        struct_pb2.Value(string_value=d)
+                        for d in hparam_domain_discrete[k]
+                    ]
+                )
+            else:
+                domain_discrete = None
+
+            hps.append(
+                HParamInfo(
+                    name=k,
+                    type=DataType.Value("DATA_TYPE_STRING"),
+                    domain_discrete=domain_discrete,
+                )
+            )
+            continue
+
+        if isinstance(v, bool):
+            ssi.hparams[k].bool_value = v
+
+            if k in hparam_domain_discrete:
+                domain_discrete = struct_pb2.ListValue(
+                    values=[
+                        struct_pb2.Value(bool_value=d)
+                        for d in hparam_domain_discrete[k]
+                    ]
+                )
+            else:
+                domain_discrete = None
+
+            hps.append(
+                HParamInfo(
+                    name=k,
+                    type=DataType.Value("DATA_TYPE_BOOL"),
+                    domain_discrete=domain_discrete,
+                )
+            )
+            continue
+
+        if isinstance(v, torch.Tensor):
+            v = make_np(v)[0]
+            ssi.hparams[k].number_value = v
+            hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
+            continue
+        raise ValueError(
+            "value should be one of int, float, str, bool, or torch.Tensor"
+        )
+
+    content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
+    smd = SummaryMetadata(
+        plugin_data=SummaryMetadata.PluginData(
+            plugin_name=PLUGIN_NAME, content=content.SerializeToString()
+        )
+    )
+    ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
+
+    mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
+
+    exp = Experiment(hparam_infos=hps, metric_infos=mts)
+
+    content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
+    smd = SummaryMetadata(
+        plugin_data=SummaryMetadata.PluginData(
+            plugin_name=PLUGIN_NAME, content=content.SerializeToString()
+        )
+    )
+    exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
+
+    sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
+    content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
+    smd = SummaryMetadata(
+        plugin_data=SummaryMetadata.PluginData(
+            plugin_name=PLUGIN_NAME, content=content.SerializeToString()
+        )
+    )
+    sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
+
+    return exp, ssi, sei
+
+
+def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
+    """Output a `Summary` protocol buffer containing a single scalar value.
+
+    The generated Summary has a Tensor.proto containing the input Tensor.
+    Args:
+      name: A name for the generated node. Will also serve as the series name in
+        TensorBoard.
+      tensor: A real numeric Tensor containing a single value.
+      collections: Optional list of graph collections keys. The new summary op is
+        added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
+      new_style: Whether to use new style (tensor field) or old style (simple_value
+        field). New style could lead to faster data loading.
+    Returns:
+      A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
+    Raises:
+      ValueError: If tensor has the wrong shape or type.
+    """
+    tensor = make_np(tensor).squeeze()
+    assert (
+        tensor.ndim == 0
+    ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
+    # python float is double precision in numpy
+    scalar = float(tensor)
+    if new_style:
+        tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
+        if double_precision:
+            tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
+
+        plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
+        smd = SummaryMetadata(plugin_data=plugin_data)
+        return Summary(
+            value=[
+                Summary.Value(
+                    tag=name,
+                    tensor=tensor_proto,
+                    metadata=smd,
+                )
+            ]
+        )
+    else:
+        return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
+
+
+def tensor_proto(tag, tensor):
+    """Outputs a `Summary` protocol buffer containing the full tensor.
+    The generated Summary has a Tensor.proto containing the input Tensor.
+    Args:
+      tag: A name for the generated node. Will also serve as the series name in
+        TensorBoard.
+      tensor: Tensor to be converted to protobuf
+    Returns:
+      A tensor protobuf in a `Summary` protobuf.
+    Raises:
+      ValueError: If tensor is too big to be converted to protobuf, or
+                     tensor data type is not supported
+    """
+    if tensor.numel() * tensor.itemsize >= (1 << 31):
+        raise ValueError(
+            "tensor is bigger than protocol buffer's hard limit of 2GB in size"
+        )
+
+    if tensor.dtype in _TENSOR_TYPE_MAP:
+        dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
+        tensor_proto = TensorProto(
+            **{
+                "dtype": dtype,
+                "tensor_shape": TensorShapeProto(
+                    dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
+                ),
+                field_name: conversion_fn(tensor),
+            },
+        )
+    else:
+        raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
+
+    plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
+    smd = SummaryMetadata(plugin_data=plugin_data)
+    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
+
+
+def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
+    # pylint: disable=line-too-long
+    """Output a `Summary` protocol buffer with a histogram.
+
+    The generated
+    [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
+    has one summary value containing a histogram for `values`.
+    Args:
+      name: A name for the generated node. Will also serve as a series name in
+        TensorBoard.
+      min: A float or int min value
+      max: A float or int max value
+      num: Int number of values
+      sum: Float or int sum of all values
+      sum_squares: Float or int sum of squares for all values
+      bucket_limits: A numeric `Tensor` with upper value per bucket
+      bucket_counts: A numeric `Tensor` with number of values per bucket
+    Returns:
+      A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+      buffer.
+    """
+    hist = HistogramProto(
+        min=min,
+        max=max,
+        num=num,
+        sum=sum,
+        sum_squares=sum_squares,
+        bucket_limit=bucket_limits,
+        bucket=bucket_counts,
+    )
+    return Summary(value=[Summary.Value(tag=name, histo=hist)])
+
+
+def histogram(name, values, bins, max_bins=None):
+    # pylint: disable=line-too-long
+    """Output a `Summary` protocol buffer with a histogram.
+
+    The generated
+    [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
+    has one summary value containing a histogram for `values`.
+    This op reports an `InvalidArgument` error if any value is not finite.
+    Args:
+      name: A name for the generated node. Will also serve as a series name in
+        TensorBoard.
+      values: A real numeric `Tensor`. Any shape. Values to use to
+        build the histogram.
+    Returns:
+      A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+      buffer.
+    """
+    values = make_np(values)
+    hist = make_histogram(values.astype(float), bins, max_bins)
+    return Summary(value=[Summary.Value(tag=name, histo=hist)])
+
+
+def make_histogram(values, bins, max_bins=None):
+    """Convert values into a histogram proto using logic from histogram.cc."""
+    if values.size == 0:
+        raise ValueError("The input has no element.")
+    values = values.reshape(-1)
+    counts, limits = np.histogram(values, bins=bins)
+    num_bins = len(counts)
+    if max_bins is not None and num_bins > max_bins:
+        subsampling = num_bins // max_bins
+        subsampling_remainder = num_bins % subsampling
+        if subsampling_remainder != 0:
+            counts = np.pad(
+                counts,
+                pad_width=[[0, subsampling - subsampling_remainder]],
+                mode="constant",
+                constant_values=0,
+            )
+        counts = counts.reshape(-1, subsampling).sum(axis=-1)
+        new_limits = np.empty((counts.size + 1,), limits.dtype)
+        new_limits[:-1] = limits[:-1:subsampling]
+        new_limits[-1] = limits[-1]
+        limits = new_limits
+
+    # Find the first and the last bin defining the support of the histogram:
+
+    cum_counts = np.cumsum(np.greater(counts, 0))
+    start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
+    start = int(start)
+    end = int(end) + 1
+    del cum_counts
+
+    # TensorBoard only includes the right bin limits. To still have the leftmost limit
+    # included, we include an empty bin left.
+    # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
+    # first nonzero-count bin:
+    counts = (
+        counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
+    )
+    limits = limits[start : end + 1]
+
+    if counts.size == 0 or limits.size == 0:
+        raise ValueError("The histogram is empty, please file a bug report.")
+
+    sum_sq = values.dot(values)
+    return HistogramProto(
+        min=values.min(),
+        max=values.max(),
+        num=len(values),
+        sum=values.sum(),
+        sum_squares=sum_sq,
+        bucket_limit=limits.tolist(),
+        bucket=counts.tolist(),
+    )
+
+
+def image(tag, tensor, rescale=1, dataformats="NCHW"):
+    """Output a `Summary` protocol buffer with images.
+
+    The summary has up to `max_images` summary values containing images. The
+    images are built from `tensor` which must be 3-D with shape `[height, width,
+    channels]` and where `channels` can be:
+    *  1: `tensor` is interpreted as Grayscale.
+    *  3: `tensor` is interpreted as RGB.
+    *  4: `tensor` is interpreted as RGBA.
+    The `name` in the outputted Summary.Value protobufs is generated based on the
+    name, with a suffix depending on the max_outputs setting:
+    *  If `max_outputs` is 1, the summary value tag is '*name*/image'.
+    *  If `max_outputs` is greater than 1, the summary value tags are
+       generated sequentially as '*name*/image/0', '*name*/image/1', etc.
+    Args:
+      tag: A name for the generated node. Will also serve as a series name in
+        TensorBoard.
+      tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
+        channels]` where `channels` is 1, 3, or 4.
+        'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
+        The image() function will scale the image values to [0, 255] by applying
+        a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
+        will be clipped.
+    Returns:
+      A scalar `Tensor` of type `string`. The serialized `Summary` protocol
+      buffer.
+    """
+    tensor = make_np(tensor)
+    tensor = convert_to_HWC(tensor, dataformats)
+    # Do not assume that user passes in values in [0, 255], use data type to detect
+    scale_factor = _calc_scale_factor(tensor)
+    tensor = tensor.astype(np.float32)
+    tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
+    image = make_image(tensor, rescale=rescale)
+    return Summary(value=[Summary.Value(tag=tag, image=image)])
+
+
+def image_boxes(
+    tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
+):
+    """Output a `Summary` protocol buffer with images."""
+    tensor_image = make_np(tensor_image)
+    tensor_image = convert_to_HWC(tensor_image, dataformats)
+    tensor_boxes = make_np(tensor_boxes)
+    tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
+    image = make_image(
+        tensor_image.clip(0, 255).astype(np.uint8),
+        rescale=rescale,
+        rois=tensor_boxes,
+        labels=labels,
+    )
+    return Summary(value=[Summary.Value(tag=tag, image=image)])
+
+
+def draw_boxes(disp_image, boxes, labels=None):
+    # xyxy format
+    num_boxes = boxes.shape[0]
+    list_gt = range(num_boxes)
+    for i in list_gt:
+        disp_image = _draw_single_box(
+            disp_image,
+            boxes[i, 0],
+            boxes[i, 1],
+            boxes[i, 2],
+            boxes[i, 3],
+            display_str=None if labels is None else labels[i],
+            color="Red",
+        )
+    return disp_image
+
+
+def make_image(tensor, rescale=1, rois=None, labels=None):
+    """Convert a numpy representation of an image to Image protobuf."""
+    from PIL import Image
+
+    height, width, channel = tensor.shape
+    scaled_height = int(height * rescale)
+    scaled_width = int(width * rescale)
+    image = Image.fromarray(tensor)
+    if rois is not None:
+        image = draw_boxes(image, rois, labels=labels)
+    ANTIALIAS = Image.Resampling.LANCZOS
+    image = image.resize((scaled_width, scaled_height), ANTIALIAS)
+    import io
+
+    output = io.BytesIO()
+    image.save(output, format="PNG")
+    image_string = output.getvalue()
+    output.close()
+    return Summary.Image(
+        height=height,
+        width=width,
+        colorspace=channel,
+        encoded_image_string=image_string,
+    )
+
+
+def video(tag, tensor, fps=4):
+    tensor = make_np(tensor)
+    tensor = _prepare_video(tensor)
+    # If user passes in uint8, then we don't need to rescale by 255
+    scale_factor = _calc_scale_factor(tensor)
+    tensor = tensor.astype(np.float32)
+    tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
+    video = make_video(tensor, fps)
+    return Summary(value=[Summary.Value(tag=tag, image=video)])
+
+
+def make_video(tensor, fps):
+    try:
+        import moviepy  # noqa: F401
+    except ImportError:
+        print("add_video needs package moviepy")
+        return
+    try:
+        from moviepy import editor as mpy
+    except ImportError:
+        print(
+            "moviepy is installed, but can't import moviepy.editor.",
+            "Some packages could be missing [imageio, requests]",
+        )
+        return
+    import tempfile
+
+    _t, h, w, c = tensor.shape
+
+    # encode sequence of images into gif string
+    clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
+
+    filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
+    try:  # newer version of moviepy use logger instead of progress_bar argument.
+        clip.write_gif(filename, verbose=False, logger=None)
+    except TypeError:
+        try:  # older version of moviepy does not support progress_bar argument.
+            clip.write_gif(filename, verbose=False, progress_bar=False)
+        except TypeError:
+            clip.write_gif(filename, verbose=False)
+
+    with open(filename, "rb") as f:
+        tensor_string = f.read()
+
+    try:
+        os.remove(filename)
+    except OSError:
+        logger.warning("The temporary file used by moviepy cannot be deleted.")
+
+    return Summary.Image(
+        height=h, width=w, colorspace=c, encoded_image_string=tensor_string
+    )
+
+
+def audio(tag, tensor, sample_rate=44100):
+    array = make_np(tensor)
+    array = array.squeeze()
+    if abs(array).max() > 1:
+        print("warning: audio amplitude out of range, auto clipped.")
+        array = array.clip(-1, 1)
+    assert array.ndim == 1, "input tensor should be 1 dimensional."
+    array = (array * np.iinfo(np.int16).max).astype(" 127:  # weird, value > 127 breaks protobuf
+        num_thresholds = 127
+    data = np.stack((tp, fp, tn, fn, precision, recall))
+    pr_curve_plugin_data = PrCurvePluginData(
+        version=0, num_thresholds=num_thresholds
+    ).SerializeToString()
+    plugin_data = SummaryMetadata.PluginData(
+        plugin_name="pr_curves", content=pr_curve_plugin_data
+    )
+    smd = SummaryMetadata(plugin_data=plugin_data)
+    tensor = TensorProto(
+        dtype="DT_FLOAT",
+        float_val=data.reshape(-1).tolist(),
+        tensor_shape=TensorShapeProto(
+            dim=[
+                TensorShapeProto.Dim(size=data.shape[0]),
+                TensorShapeProto.Dim(size=data.shape[1]),
+            ]
+        ),
+    )
+    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
+
+
+def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
+    # weird, value > 127 breaks protobuf
+    num_thresholds = min(num_thresholds, 127)
+    data = compute_curve(
+        labels, predictions, num_thresholds=num_thresholds, weights=weights
+    )
+    pr_curve_plugin_data = PrCurvePluginData(
+        version=0, num_thresholds=num_thresholds
+    ).SerializeToString()
+    plugin_data = SummaryMetadata.PluginData(
+        plugin_name="pr_curves", content=pr_curve_plugin_data
+    )
+    smd = SummaryMetadata(plugin_data=plugin_data)
+    tensor = TensorProto(
+        dtype="DT_FLOAT",
+        float_val=data.reshape(-1).tolist(),
+        tensor_shape=TensorShapeProto(
+            dim=[
+                TensorShapeProto.Dim(size=data.shape[0]),
+                TensorShapeProto.Dim(size=data.shape[1]),
+            ]
+        ),
+    )
+    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
+
+
+# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
+def compute_curve(labels, predictions, num_thresholds=None, weights=None):
+    _MINIMUM_COUNT = 1e-7
+
+    if weights is None:
+        weights = 1.0
+
+    # Compute bins of true positives and false positives.
+    bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
+    float_labels = labels.astype(np.float64)
+    histogram_range = (0, num_thresholds - 1)
+    tp_buckets, _ = np.histogram(
+        bucket_indices,
+        bins=num_thresholds,
+        range=histogram_range,
+        weights=float_labels * weights,
+    )
+    fp_buckets, _ = np.histogram(
+        bucket_indices,
+        bins=num_thresholds,
+        range=histogram_range,
+        weights=(1.0 - float_labels) * weights,
+    )
+
+    # Obtain the reverse cumulative sum.
+    tp = np.cumsum(tp_buckets[::-1])[::-1]
+    fp = np.cumsum(fp_buckets[::-1])[::-1]
+    tn = fp[0] - fp
+    fn = tp[0] - tp
+    precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
+    recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
+    return np.stack((tp, fp, tn, fn, precision, recall))
+
+
+def _get_tensor_summary(
+    name, display_name, description, tensor, content_type, components, json_config
+):
+    """Create a tensor summary with summary metadata.
+
+    Args:
+      name: Uniquely identifiable name of the summary op. Could be replaced by
+        combination of name and type to make it unique even outside of this
+        summary.
+      display_name: Will be used as the display name in TensorBoard.
+        Defaults to `name`.
+      description: A longform readable description of the summary data. Markdown
+        is supported.
+      tensor: Tensor to display in summary.
+      content_type: Type of content inside the Tensor.
+      components: Bitmask representing present parts (vertices, colors, etc.) that
+        belong to the summary.
+      json_config: A string, JSON-serialized dictionary of ThreeJS classes
+        configuration.
+
+    Returns:
+      Tensor summary with metadata.
+    """
+    import torch
+    from tensorboard.plugins.mesh import metadata
+
+    tensor = torch.as_tensor(tensor)
+
+    tensor_metadata = metadata.create_summary_metadata(
+        name,
+        display_name,
+        content_type,
+        components,
+        tensor.shape,
+        description,
+        json_config=json_config,
+    )
+
+    tensor = TensorProto(
+        dtype="DT_FLOAT",
+        float_val=tensor.reshape(-1).tolist(),
+        tensor_shape=TensorShapeProto(
+            dim=[
+                TensorShapeProto.Dim(size=tensor.shape[0]),
+                TensorShapeProto.Dim(size=tensor.shape[1]),
+                TensorShapeProto.Dim(size=tensor.shape[2]),
+            ]
+        ),
+    )
+
+    tensor_summary = Summary.Value(
+        tag=metadata.get_instance_name(name, content_type),
+        tensor=tensor,
+        metadata=tensor_metadata,
+    )
+
+    return tensor_summary
+
+
+def _get_json_config(config_dict):
+    """Parse and returns JSON string from python dictionary."""
+    json_config = "{}"
+    if config_dict is not None:
+        json_config = json.dumps(config_dict, sort_keys=True)
+    return json_config
+
+
+# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
+def mesh(
+    tag, vertices, colors, faces, config_dict, display_name=None, description=None
+):
+    """Output a merged `Summary` protocol buffer with a mesh/point cloud.
+
+    Args:
+      tag: A name for this summary operation.
+      vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
+        coordinates of vertices.
+      faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
+        vertices within each triangle.
+      colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
+        vertex.
+      display_name: If set, will be used as the display name in TensorBoard.
+        Defaults to `name`.
+      description: A longform readable description of the summary data. Markdown
+        is supported.
+      config_dict: Dictionary with ThreeJS classes names and configuration.
+
+    Returns:
+      Merged summary for mesh/point cloud representation.
+    """
+    from tensorboard.plugins.mesh import metadata
+    from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
+
+    json_config = _get_json_config(config_dict)
+
+    summaries = []
+    tensors = [
+        (vertices, MeshPluginData.VERTEX),
+        (faces, MeshPluginData.FACE),
+        (colors, MeshPluginData.COLOR),
+    ]
+    tensors = [tensor for tensor in tensors if tensor[0] is not None]
+    components = metadata.get_components_bitmask(
+        [content_type for (tensor, content_type) in tensors]
+    )
+
+    for tensor, content_type in tensors:
+        summaries.append(
+            _get_tensor_summary(
+                tag,
+                display_name,
+                description,
+                tensor,
+                content_type,
+                components,
+                json_config,
+            )
+        )
+
+    return Summary(value=summaries)
diff --git a/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/writer.py b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..129281cb8ac3eccf6491567c77113c88ccf76b92
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/torch/utils/tensorboard/writer.py
@@ -0,0 +1,1208 @@
+# mypy: allow-untyped-defs
+"""Provide an API for writing protocol buffers to event files to be consumed by TensorBoard for visualization."""
+
+import os
+import time
+from typing import Optional, TYPE_CHECKING, Union
+
+import torch
+
+if TYPE_CHECKING:
+    from matplotlib.figure import Figure
+from tensorboard.compat import tf
+from tensorboard.compat.proto import event_pb2
+from tensorboard.compat.proto.event_pb2 import Event, SessionLog
+from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
+from tensorboard.summary.writer.event_file_writer import EventFileWriter
+
+from ._convert_np import make_np
+from ._embedding import get_embedding_info, make_mat, make_sprite, make_tsv, write_pbtxt
+from ._onnx_graph import load_onnx_graph
+from ._pytorch_graph import graph
+from ._utils import figure_to_image
+from .summary import (
+    audio,
+    custom_scalars,
+    histogram,
+    histogram_raw,
+    hparams,
+    image,
+    image_boxes,
+    mesh,
+    pr_curve,
+    pr_curve_raw,
+    scalar,
+    tensor_proto,
+    text,
+    video,
+)
+
+__all__ = ["FileWriter", "SummaryWriter"]
+
+
+class FileWriter:
+    """Writes protocol buffers to event files to be consumed by TensorBoard.
+
+    The `FileWriter` class provides a mechanism to create an event file in a
+    given directory and add summaries and events to it. The class updates the
+    file contents asynchronously. This allows a training program to call methods
+    to add data to the file directly from the training loop, without slowing down
+    training.
+    """
+
+    def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=""):
+        """Create a `FileWriter` and an event file.
+
+        On construction the writer creates a new event file in `log_dir`.
+        The other arguments to the constructor control the asynchronous writes to
+        the event file.
+
+        Args:
+          log_dir: A string. Directory where event file will be written.
+          max_queue: Integer. Size of the queue for pending events and
+            summaries before one of the 'add' calls forces a flush to disk.
+            Default is ten items.
+          flush_secs: Number. How often, in seconds, to flush the
+            pending events and summaries to disk. Default is every two minutes.
+          filename_suffix: A string. Suffix added to all event filenames
+            in the log_dir directory. More details on filename construction in
+            tensorboard.summary.writer.event_file_writer.EventFileWriter.
+        """
+        # Sometimes PosixPath is passed in and we need to coerce it to
+        # a string in all cases
+        # TODO: See if we can remove this in the future if we are
+        # actually the ones passing in a PosixPath
+        log_dir = str(log_dir)
+        self.event_writer = EventFileWriter(
+            log_dir, max_queue, flush_secs, filename_suffix
+        )
+
+    def get_logdir(self):
+        """Return the directory where event file will be written."""
+        return self.event_writer.get_logdir()
+
+    def add_event(self, event, step=None, walltime=None):
+        """Add an event to the event file.
+
+        Args:
+          event: An `Event` protocol buffer.
+          step: Number. Optional global step value for training process
+            to record with the event.
+          walltime: float. Optional walltime to override the default (current)
+            walltime (from time.time()) seconds after epoch
+        """
+        event.wall_time = time.time() if walltime is None else walltime
+        if step is not None:
+            # Make sure step is converted from numpy or other formats
+            # since protobuf might not convert depending on version
+            event.step = int(step)
+        self.event_writer.add_event(event)
+
+    def add_summary(self, summary, global_step=None, walltime=None):
+        """Add a `Summary` protocol buffer to the event file.
+
+        This method wraps the provided summary in an `Event` protocol buffer
+        and adds it to the event file.
+
+        Args:
+          summary: A `Summary` protocol buffer.
+          global_step: Number. Optional global step value for training process
+            to record with the summary.
+          walltime: float. Optional walltime to override the default (current)
+            walltime (from time.time()) seconds after epoch
+        """
+        event = event_pb2.Event(summary=summary)
+        self.add_event(event, global_step, walltime)
+
+    def add_graph(self, graph_profile, walltime=None):
+        """Add a `Graph` and step stats protocol buffer to the event file.
+
+        Args:
+          graph_profile: A `Graph` and step stats protocol buffer.
+          walltime: float. Optional walltime to override the default (current)
+            walltime (from time.time()) seconds after epoch
+        """
+        graph = graph_profile[0]
+        stepstats = graph_profile[1]
+        event = event_pb2.Event(graph_def=graph.SerializeToString())
+        self.add_event(event, None, walltime)
+
+        trm = event_pb2.TaggedRunMetadata(
+            tag="step1", run_metadata=stepstats.SerializeToString()
+        )
+        event = event_pb2.Event(tagged_run_metadata=trm)
+        self.add_event(event, None, walltime)
+
+    def add_onnx_graph(self, graph, walltime=None):
+        """Add a `Graph` protocol buffer to the event file.
+
+        Args:
+          graph: A `Graph` protocol buffer.
+          walltime: float. Optional walltime to override the default (current)
+            _get_file_writerfrom time.time())
+        """
+        event = event_pb2.Event(graph_def=graph.SerializeToString())
+        self.add_event(event, None, walltime)
+
+    def flush(self):
+        """Flushes the event file to disk.
+
+        Call this method to make sure that all pending events have been written to
+        disk.
+        """
+        self.event_writer.flush()
+
+    def close(self):
+        """Flushes the event file to disk and close the file.
+
+        Call this method when you do not need the summary writer anymore.
+        """
+        self.event_writer.close()
+
+    def reopen(self):
+        """Reopens the EventFileWriter.
+
+        Can be called after `close()` to add more events in the same directory.
+        The events will go into a new events file.
+        Does nothing if the EventFileWriter was not closed.
+        """
+        self.event_writer.reopen()
+
+
+class SummaryWriter:
+    """Writes entries directly to event files in the log_dir to be consumed by TensorBoard.
+
+    The `SummaryWriter` class provides a high-level API to create an event file
+    in a given directory and add summaries and events to it. The class updates the
+    file contents asynchronously. This allows a training program to call methods
+    to add data to the file directly from the training loop, without slowing down
+    training.
+    """
+
+    def __init__(
+        self,
+        log_dir=None,
+        comment="",
+        purge_step=None,
+        max_queue=10,
+        flush_secs=120,
+        filename_suffix="",
+    ):
+        """Create a `SummaryWriter` that will write out events and summaries to the event file.
+
+        Args:
+            log_dir (str): Save directory location. Default is
+              runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run.
+              Use hierarchical folder structure to compare
+              between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc.
+              for each new experiment to compare across them.
+            comment (str): Comment log_dir suffix appended to the default
+              ``log_dir``. If ``log_dir`` is assigned, this argument has no effect.
+            purge_step (int):
+              When logging crashes at step :math:`T+X` and restarts at step :math:`T`,
+              any events whose global_step larger or equal to :math:`T` will be
+              purged and hidden from TensorBoard.
+              Note that crashed and resumed experiments should have the same ``log_dir``.
+            max_queue (int): Size of the queue for pending events and
+              summaries before one of the 'add' calls forces a flush to disk.
+              Default is ten items.
+            flush_secs (int): How often, in seconds, to flush the
+              pending events and summaries to disk. Default is every two minutes.
+            filename_suffix (str): Suffix added to all event filenames in
+              the log_dir directory. More details on filename construction in
+              tensorboard.summary.writer.event_file_writer.EventFileWriter.
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+
+            # create a summary writer with automatically generated folder name.
+            writer = SummaryWriter()
+            # folder location: runs/May04_22-14-54_s-MacBook-Pro.local/
+
+            # create a summary writer using the specified folder name.
+            writer = SummaryWriter("my_experiment")
+            # folder location: my_experiment
+
+            # create a summary writer with comment appended.
+            writer = SummaryWriter(comment="LR_0.1_BATCH_16")
+            # folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/
+
+        """
+        torch._C._log_api_usage_once("tensorboard.create.summarywriter")
+        if not log_dir:
+            import socket
+            from datetime import datetime
+
+            current_time = datetime.now().strftime("%b%d_%H-%M-%S")
+            log_dir = os.path.join(
+                "runs", current_time + "_" + socket.gethostname() + comment
+            )
+        self.log_dir = log_dir
+        self.purge_step = purge_step
+        self.max_queue = max_queue
+        self.flush_secs = flush_secs
+        self.filename_suffix = filename_suffix
+
+        # Initialize the file writers, but they can be cleared out on close
+        # and recreated later as needed.
+        self.file_writer = self.all_writers = None
+        self._get_file_writer()
+
+        # Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard
+        v = 1e-12
+        buckets = []
+        neg_buckets = []
+        while v < 1e20:
+            buckets.append(v)
+            neg_buckets.append(-v)
+            v *= 1.1
+        self.default_bins = neg_buckets[::-1] + [0] + buckets
+
+    def _get_file_writer(self):
+        """Return the default FileWriter instance. Recreates it if closed."""
+        if self.all_writers is None or self.file_writer is None:
+            self.file_writer = FileWriter(
+                self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
+            )
+            self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
+            if self.purge_step is not None:
+                most_recent_step = self.purge_step
+                self.file_writer.add_event(
+                    Event(step=most_recent_step, file_version="brain.Event:2")
+                )
+                self.file_writer.add_event(
+                    Event(
+                        step=most_recent_step,
+                        session_log=SessionLog(status=SessionLog.START),
+                    )
+                )
+                self.purge_step = None
+        return self.file_writer
+
+    def get_logdir(self):
+        """Return the directory where event files will be written."""
+        return self.log_dir
+
+    def add_hparams(
+        self,
+        hparam_dict,
+        metric_dict,
+        hparam_domain_discrete=None,
+        run_name=None,
+        global_step=None,
+    ):
+        """Add a set of hyperparameters to be compared in TensorBoard.
+
+        Args:
+            hparam_dict (dict): Each key-value pair in the dictionary is the
+              name of the hyper parameter and it's corresponding value.
+              The type of the value can be one of `bool`, `string`, `float`,
+              `int`, or `None`.
+            metric_dict (dict): Each key-value pair in the dictionary is the
+              name of the metric and it's corresponding value. Note that the key used
+              here should be unique in the tensorboard record. Otherwise the value
+              you added by ``add_scalar`` will be displayed in hparam plugin. In most
+              cases, this is unwanted.
+            hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
+              contains names of the hyperparameters and all discrete values they can hold
+            run_name (str): Name of the run, to be included as part of the logdir.
+              If unspecified, will use current timestamp.
+            global_step (int): Global step value to record
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            with SummaryWriter() as w:
+                for i in range(5):
+                    w.add_hparams({'lr': 0.1*i, 'bsize': i},
+                                  {'hparam/accuracy': 10*i, 'hparam/loss': 10*i})
+
+        Expected result:
+
+        .. image:: _static/img/tensorboard/add_hparam.png
+           :scale: 50 %
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_hparams")
+        if type(hparam_dict) is not dict or type(metric_dict) is not dict:
+            raise TypeError("hparam_dict and metric_dict should be dictionary.")
+        exp, ssi, sei = hparams(hparam_dict, metric_dict, hparam_domain_discrete)
+
+        if not run_name:
+            run_name = str(time.time())
+        logdir = os.path.join(self._get_file_writer().get_logdir(), run_name)
+        with SummaryWriter(log_dir=logdir) as w_hp:
+            w_hp.file_writer.add_summary(exp, global_step)
+            w_hp.file_writer.add_summary(ssi, global_step)
+            w_hp.file_writer.add_summary(sei, global_step)
+            for k, v in metric_dict.items():
+                w_hp.add_scalar(k, v, global_step)
+
+    def add_scalar(
+        self,
+        tag,
+        scalar_value,
+        global_step=None,
+        walltime=None,
+        new_style=False,
+        double_precision=False,
+    ):
+        """Add scalar data to summary.
+
+        Args:
+            tag (str): Data identifier
+            scalar_value (float or string/blobname): Value to save
+            global_step (int): Global step value to record
+            walltime (float): Optional override default walltime (time.time())
+              with seconds after epoch of event
+            new_style (boolean): Whether to use new style (tensor field) or old
+              style (simple_value field). New style could lead to faster data loading.
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            writer = SummaryWriter()
+            x = range(100)
+            for i in x:
+                writer.add_scalar('y=2x', i * 2, i)
+            writer.close()
+
+        Expected result:
+
+        .. image:: _static/img/tensorboard/add_scalar.png
+           :scale: 50 %
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_scalar")
+
+        summary = scalar(
+            tag, scalar_value, new_style=new_style, double_precision=double_precision
+        )
+        self._get_file_writer().add_summary(summary, global_step, walltime)
+
+    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
+        """Add many scalar data to summary.
+
+        Args:
+            main_tag (str): The parent name for the tags
+            tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
+            global_step (int): Global step value to record
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            writer = SummaryWriter()
+            r = 5
+            for i in range(100):
+                writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
+                                                'xcosx':i*np.cos(i/r),
+                                                'tanx': np.tan(i/r)}, i)
+            writer.close()
+            # This call adds three values to the same scalar plot with the tag
+            # 'run_14h' in TensorBoard's scalar section.
+
+        Expected result:
+
+        .. image:: _static/img/tensorboard/add_scalars.png
+           :scale: 50 %
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_scalars")
+        walltime = time.time() if walltime is None else walltime
+        fw_logdir = self._get_file_writer().get_logdir()
+        for tag, scalar_value in tag_scalar_dict.items():
+            fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag
+            assert self.all_writers is not None
+            if fw_tag in self.all_writers.keys():
+                fw = self.all_writers[fw_tag]
+            else:
+                fw = FileWriter(
+                    fw_tag, self.max_queue, self.flush_secs, self.filename_suffix
+                )
+                self.all_writers[fw_tag] = fw
+            fw.add_summary(scalar(main_tag, scalar_value), global_step, walltime)
+
+    def add_tensor(
+        self,
+        tag,
+        tensor,
+        global_step=None,
+        walltime=None,
+    ):
+        """Add tensor data to summary.
+
+        Args:
+            tag (str): Data identifier
+            tensor (torch.Tensor): tensor to save
+            global_step (int): Global step value to record
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            writer = SummaryWriter()
+            x = torch.tensor([1,2,3])
+            writer.add_scalar('x', x)
+            writer.close()
+
+        Expected result:
+            Summary::tensor::float_val [1,2,3]
+                   ::tensor::shape [3]
+                   ::tag 'x'
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_tensor")
+
+        summary = tensor_proto(tag, tensor)
+        self._get_file_writer().add_summary(summary, global_step, walltime)
+
+    def add_histogram(
+        self,
+        tag,
+        values,
+        global_step=None,
+        bins="tensorflow",
+        walltime=None,
+        max_bins=None,
+    ):
+        """Add histogram to summary.
+
+        Args:
+            tag (str): Data identifier
+            values (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram
+            global_step (int): Global step value to record
+            bins (str): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find
+              other options in: https://numpy.org/doc/stable/reference/generated/numpy.histogram.html
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            import numpy as np
+            writer = SummaryWriter()
+            for i in range(10):
+                x = np.random.random(1000)
+                writer.add_histogram('distribution centers', x + i, i)
+            writer.close()
+
+        Expected result:
+
+        .. image:: _static/img/tensorboard/add_histogram.png
+           :scale: 50 %
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_histogram")
+        if isinstance(bins, str) and bins == "tensorflow":
+            bins = self.default_bins
+        self._get_file_writer().add_summary(
+            histogram(tag, values, bins, max_bins=max_bins), global_step, walltime
+        )
+
+    def add_histogram_raw(
+        self,
+        tag,
+        min,
+        max,
+        num,
+        sum,
+        sum_squares,
+        bucket_limits,
+        bucket_counts,
+        global_step=None,
+        walltime=None,
+    ):
+        """Add histogram with raw data.
+
+        Args:
+            tag (str): Data identifier
+            min (float or int): Min value
+            max (float or int): Max value
+            num (int): Number of values
+            sum (float or int): Sum of all values
+            sum_squares (float or int): Sum of squares for all values
+            bucket_limits (torch.Tensor, numpy.ndarray): Upper value per bucket.
+              The number of elements of it should be the same as `bucket_counts`.
+            bucket_counts (torch.Tensor, numpy.ndarray): Number of values per bucket
+            global_step (int): Global step value to record
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+            see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            import numpy as np
+            writer = SummaryWriter()
+            dummy_data = []
+            for idx, value in enumerate(range(50)):
+                dummy_data += [idx + 0.001] * value
+
+            bins = list(range(50+2))
+            bins = np.array(bins)
+            values = np.array(dummy_data).astype(float).reshape(-1)
+            counts, limits = np.histogram(values, bins=bins)
+            sum_sq = values.dot(values)
+            writer.add_histogram_raw(
+                tag='histogram_with_raw_data',
+                min=values.min(),
+                max=values.max(),
+                num=len(values),
+                sum=values.sum(),
+                sum_squares=sum_sq,
+                bucket_limits=limits[1:].tolist(),
+                bucket_counts=counts.tolist(),
+                global_step=0)
+            writer.close()
+
+        Expected result:
+
+        .. image:: _static/img/tensorboard/add_histogram_raw.png
+           :scale: 50 %
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_histogram_raw")
+        if len(bucket_limits) != len(bucket_counts):
+            raise ValueError(
+                "len(bucket_limits) != len(bucket_counts), see the document."
+            )
+        self._get_file_writer().add_summary(
+            histogram_raw(
+                tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts
+            ),
+            global_step,
+            walltime,
+        )
+
+    def add_image(
+        self, tag, img_tensor, global_step=None, walltime=None, dataformats="CHW"
+    ):
+        """Add image data to summary.
+
+        Note that this requires the ``pillow`` package.
+
+        Args:
+            tag (str): Data identifier
+            img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
+            global_step (int): Global step value to record
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+            dataformats (str): Image data format specification of the form
+              CHW, HWC, HW, WH, etc.
+        Shape:
+            img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
+            convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
+            Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
+            corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            import numpy as np
+            img = np.zeros((3, 100, 100))
+            img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
+            img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
+
+            img_HWC = np.zeros((100, 100, 3))
+            img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
+            img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
+
+            writer = SummaryWriter()
+            writer.add_image('my_image', img, 0)
+
+            # If you have non-default dimension setting, set the dataformats argument.
+            writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
+            writer.close()
+
+        Expected result:
+
+        .. image:: _static/img/tensorboard/add_image.png
+           :scale: 50 %
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_image")
+        self._get_file_writer().add_summary(
+            image(tag, img_tensor, dataformats=dataformats), global_step, walltime
+        )
+
+    def add_images(
+        self, tag, img_tensor, global_step=None, walltime=None, dataformats="NCHW"
+    ):
+        """Add batched image data to summary.
+
+        Note that this requires the ``pillow`` package.
+
+        Args:
+            tag (str): Data identifier
+            img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
+            global_step (int): Global step value to record
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+            dataformats (str): Image data format specification of the form
+              NCHW, NHWC, CHW, HWC, HW, WH, etc.
+        Shape:
+            img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be
+            accepted. e.g. NCHW or NHWC.
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            import numpy as np
+
+            img_batch = np.zeros((16, 3, 100, 100))
+            for i in range(16):
+                img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i
+                img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i
+
+            writer = SummaryWriter()
+            writer.add_images('my_image_batch', img_batch, 0)
+            writer.close()
+
+        Expected result:
+
+        .. image:: _static/img/tensorboard/add_images.png
+           :scale: 30 %
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_images")
+        self._get_file_writer().add_summary(
+            image(tag, img_tensor, dataformats=dataformats), global_step, walltime
+        )
+
+    def add_image_with_boxes(
+        self,
+        tag,
+        img_tensor,
+        box_tensor,
+        global_step=None,
+        walltime=None,
+        rescale=1,
+        dataformats="CHW",
+        labels=None,
+    ):
+        """Add image and draw bounding boxes on the image.
+
+        Args:
+            tag (str): Data identifier
+            img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
+            box_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Box data (for detected objects)
+              box should be represented as [x1, y1, x2, y2].
+            global_step (int): Global step value to record
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+            rescale (float): Optional scale override
+            dataformats (str): Image data format specification of the form
+              NCHW, NHWC, CHW, HWC, HW, WH, etc.
+            labels (list of string): The label to be shown for each bounding box.
+        Shape:
+            img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformats`` argument.
+            e.g. CHW or HWC
+
+            box_tensor: (torch.Tensor, numpy.ndarray, or string/blobname): NX4,  where N is the number of
+            boxes and each 4 elements in a row represents (xmin, ymin, xmax, ymax).
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes")
+        if labels is not None:
+            if isinstance(labels, str):
+                labels = [labels]
+            if len(labels) != box_tensor.shape[0]:
+                labels = None
+        self._get_file_writer().add_summary(
+            image_boxes(
+                tag,
+                img_tensor,
+                box_tensor,
+                rescale=rescale,
+                dataformats=dataformats,
+                labels=labels,
+            ),
+            global_step,
+            walltime,
+        )
+
+    def add_figure(
+        self,
+        tag: str,
+        figure: Union["Figure", list["Figure"]],
+        global_step: Optional[int] = None,
+        close: bool = True,
+        walltime: Optional[float] = None,
+    ) -> None:
+        """Render matplotlib figure into an image and add it to summary.
+
+        Note that this requires the ``matplotlib`` package.
+
+        Args:
+            tag: Data identifier
+            figure: Figure or a list of figures
+            global_step: Global step value to record
+            close: Flag to automatically close the figure
+            walltime: Optional override default walltime (time.time())
+              seconds after epoch of event
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_figure")
+        if isinstance(figure, list):
+            self.add_image(
+                tag,
+                figure_to_image(figure, close),
+                global_step,
+                walltime,
+                dataformats="NCHW",
+            )
+        else:
+            self.add_image(
+                tag,
+                figure_to_image(figure, close),
+                global_step,
+                walltime,
+                dataformats="CHW",
+            )
+
+    def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
+        """Add video data to summary.
+
+        Note that this requires the ``moviepy`` package.
+
+        Args:
+            tag (str): Data identifier
+            vid_tensor (torch.Tensor): Video data
+            global_step (int): Global step value to record
+            fps (float or int): Frames per second
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+        Shape:
+            vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`.
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_video")
+        self._get_file_writer().add_summary(
+            video(tag, vid_tensor, fps), global_step, walltime
+        )
+
+    def add_audio(
+        self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None
+    ):
+        """Add audio data to summary.
+
+        Args:
+            tag (str): Data identifier
+            snd_tensor (torch.Tensor): Sound data
+            global_step (int): Global step value to record
+            sample_rate (int): sample rate in Hz
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+        Shape:
+            snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_audio")
+        self._get_file_writer().add_summary(
+            audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime
+        )
+
+    def add_text(self, tag, text_string, global_step=None, walltime=None):
+        """Add text data to summary.
+
+        Args:
+            tag (str): Data identifier
+            text_string (str): String to save
+            global_step (int): Global step value to record
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+        Examples::
+
+            writer.add_text('lstm', 'This is an lstm', 0)
+            writer.add_text('rnn', 'This is an rnn', 10)
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_text")
+        self._get_file_writer().add_summary(
+            text(tag, text_string), global_step, walltime
+        )
+
+    def add_onnx_graph(self, prototxt):
+        torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph")
+        self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt))
+
+    def add_graph(
+        self, model, input_to_model=None, verbose=False, use_strict_trace=True
+    ):
+        """Add graph data to summary.
+
+        Args:
+            model (torch.nn.Module): Model to draw.
+            input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of
+                variables to be fed.
+            verbose (bool): Whether to print graph structure in console.
+            use_strict_trace (bool): Whether to pass keyword argument `strict` to
+                `torch.jit.trace`. Pass False when you want the tracer to
+                record your mutable container types (list, dict)
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_graph")
+        # A valid PyTorch model should have a 'forward' method
+        self._get_file_writer().add_graph(
+            graph(model, input_to_model, verbose, use_strict_trace)
+        )
+
+    @staticmethod
+    def _encode(rawstr):
+        # I'd use urllib but, I'm unsure about the differences from python3 to python2, etc.
+        retval = rawstr
+        retval = retval.replace("%", f"%{ord('%'):02x}")
+        retval = retval.replace("/", f"%{ord('/'):02x}")
+        retval = retval.replace("\\", "%%%02x" % (ord("\\")))  # noqa: UP031
+        return retval
+
+    def add_embedding(
+        self,
+        mat,
+        metadata=None,
+        label_img=None,
+        global_step=None,
+        tag="default",
+        metadata_header=None,
+    ):
+        """Add embedding projector data to summary.
+
+        Args:
+            mat (torch.Tensor or numpy.ndarray): A matrix which each row is the feature vector of the data point
+            metadata (list): A list of labels, each element will be converted to string
+            label_img (torch.Tensor): Images correspond to each data point
+            global_step (int): Global step value to record
+            tag (str): Name for the embedding
+            metadata_header (list): A list of headers for multi-column metadata. If given, each metadata must be
+                a list with values corresponding to headers.
+        Shape:
+            mat: :math:`(N, D)`, where N is number of data and D is feature dimension
+
+            label_img: :math:`(N, C, H, W)`
+
+        Examples::
+
+            import keyword
+            import torch
+            meta = []
+            while len(meta)<100:
+                meta = meta+keyword.kwlist # get some strings
+            meta = meta[:100]
+
+            for i, v in enumerate(meta):
+                meta[i] = v+str(i)
+
+            label_img = torch.rand(100, 3, 10, 32)
+            for i in range(100):
+                label_img[i]*=i/100.0
+
+            writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
+            writer.add_embedding(torch.randn(100, 5), label_img=label_img)
+            writer.add_embedding(torch.randn(100, 5), metadata=meta)
+
+        .. note::
+            Categorical (i.e. non-numeric) metadata cannot have more than 50 unique values if they are to be used for
+            coloring in the embedding projector.
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_embedding")
+        mat = make_np(mat)
+        if global_step is None:
+            global_step = 0
+            # clear pbtxt?
+
+        # Maybe we should encode the tag so slashes don't trip us up?
+        # I don't think this will mess us up, but better safe than sorry.
+        subdir = f"{str(global_step).zfill(5)}/{self._encode(tag)}"
+        save_path = os.path.join(self._get_file_writer().get_logdir(), subdir)
+
+        fs = tf.io.gfile
+        if fs.exists(save_path):
+            if fs.isdir(save_path):
+                print(
+                    "warning: Embedding dir exists, did you set global_step for add_embedding()?"
+                )
+            else:
+                raise NotADirectoryError(
+                    f"Path: `{save_path}` exists, but is a file. Cannot proceed."
+                )
+        else:
+            fs.makedirs(save_path)
+
+        if metadata is not None:
+            assert mat.shape[0] == len(
+                metadata
+            ), "#labels should equal with #data points"
+            make_tsv(metadata, save_path, metadata_header=metadata_header)
+
+        if label_img is not None:
+            assert (
+                mat.shape[0] == label_img.shape[0]
+            ), "#images should equal with #data points"
+            make_sprite(label_img, save_path)
+
+        assert (
+            mat.ndim == 2
+        ), "mat should be 2D, where mat.size(0) is the number of data points"
+        make_mat(mat, save_path)
+
+        # Filesystem doesn't necessarily have append semantics, so we store an
+        # internal buffer to append to and re-write whole file after each
+        # embedding is added
+        if not hasattr(self, "_projector_config"):
+            self._projector_config = ProjectorConfig()
+        embedding_info = get_embedding_info(
+            metadata, label_img, subdir, global_step, tag
+        )
+        self._projector_config.embeddings.extend([embedding_info])
+
+        from google.protobuf import text_format
+
+        config_pbtxt = text_format.MessageToString(self._projector_config)
+        write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt)
+
+    def add_pr_curve(
+        self,
+        tag,
+        labels,
+        predictions,
+        global_step=None,
+        num_thresholds=127,
+        weights=None,
+        walltime=None,
+    ):
+        """Add precision recall curve.
+
+        Plotting a precision-recall curve lets you understand your model's
+        performance under different threshold settings. With this function,
+        you provide the ground truth labeling (T/F) and prediction confidence
+        (usually the output of your model) for each target. The TensorBoard UI
+        will let you choose the threshold interactively.
+
+        Args:
+            tag (str): Data identifier
+            labels (torch.Tensor, numpy.ndarray, or string/blobname):
+              Ground truth data. Binary label for each element.
+            predictions (torch.Tensor, numpy.ndarray, or string/blobname):
+              The probability that an element be classified as true.
+              Value should be in [0, 1]
+            global_step (int): Global step value to record
+            num_thresholds (int): Number of thresholds used to draw the curve.
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            import numpy as np
+            labels = np.random.randint(2, size=100)  # binary label
+            predictions = np.random.rand(100)
+            writer = SummaryWriter()
+            writer.add_pr_curve('pr_curve', labels, predictions, 0)
+            writer.close()
+
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve")
+        labels, predictions = make_np(labels), make_np(predictions)
+        self._get_file_writer().add_summary(
+            pr_curve(tag, labels, predictions, num_thresholds, weights),
+            global_step,
+            walltime,
+        )
+
+    def add_pr_curve_raw(
+        self,
+        tag,
+        true_positive_counts,
+        false_positive_counts,
+        true_negative_counts,
+        false_negative_counts,
+        precision,
+        recall,
+        global_step=None,
+        num_thresholds=127,
+        weights=None,
+        walltime=None,
+    ):
+        """Add precision recall curve with raw data.
+
+        Args:
+            tag (str): Data identifier
+            true_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): true positive counts
+            false_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): false positive counts
+            true_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): true negative counts
+            false_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): false negative counts
+            precision (torch.Tensor, numpy.ndarray, or string/blobname): precision
+            recall (torch.Tensor, numpy.ndarray, or string/blobname): recall
+            global_step (int): Global step value to record
+            num_thresholds (int): Number of thresholds used to draw the curve.
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+            see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve_raw")
+        self._get_file_writer().add_summary(
+            pr_curve_raw(
+                tag,
+                true_positive_counts,
+                false_positive_counts,
+                true_negative_counts,
+                false_negative_counts,
+                precision,
+                recall,
+                num_thresholds,
+                weights,
+            ),
+            global_step,
+            walltime,
+        )
+
+    def add_custom_scalars_multilinechart(
+        self, tags, category="default", title="untitled"
+    ):
+        """Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*.
+
+        Args:
+            tags (list): list of tags that have been used in ``add_scalar()``
+
+        Examples::
+
+            writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330'])
+        """
+        torch._C._log_api_usage_once(
+            "tensorboard.logging.add_custom_scalars_multilinechart"
+        )
+        layout = {category: {title: ["Multiline", tags]}}
+        self._get_file_writer().add_summary(custom_scalars(layout))
+
+    def add_custom_scalars_marginchart(
+        self, tags, category="default", title="untitled"
+    ):
+        """Shorthand for creating marginchart.
+
+        Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*,
+        which should have exactly 3 elements.
+
+        Args:
+            tags (list): list of tags that have been used in ``add_scalar()``
+
+        Examples::
+
+            writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006'])
+        """
+        torch._C._log_api_usage_once(
+            "tensorboard.logging.add_custom_scalars_marginchart"
+        )
+        assert len(tags) == 3
+        layout = {category: {title: ["Margin", tags]}}
+        self._get_file_writer().add_summary(custom_scalars(layout))
+
+    def add_custom_scalars(self, layout):
+        """Create special chart by collecting charts tags in 'scalars'.
+
+        NOTE: This function can only be called once for each SummaryWriter() object.
+
+        Because it only provides metadata to tensorboard, the function can be called before or after the training loop.
+
+        Args:
+            layout (dict): {categoryName: *charts*}, where *charts* is also a dictionary
+              {chartName: *ListOfProperties*}. The first element in *ListOfProperties* is the chart's type
+              (one of **Multiline** or **Margin**) and the second element should be a list containing the tags
+              you have used in add_scalar function, which will be collected into the new chart.
+
+        Examples::
+
+            layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]},
+                         'USA':{ 'dow':['Margin',   ['dow/aaa', 'dow/bbb', 'dow/ccc']],
+                              'nasdaq':['Margin',   ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}}
+
+            writer.add_custom_scalars(layout)
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars")
+        self._get_file_writer().add_summary(custom_scalars(layout))
+
+    def add_mesh(
+        self,
+        tag,
+        vertices,
+        colors=None,
+        faces=None,
+        config_dict=None,
+        global_step=None,
+        walltime=None,
+    ):
+        """Add meshes or 3D point clouds to TensorBoard.
+
+        The visualization is based on Three.js,
+        so it allows users to interact with the rendered object. Besides the basic definitions
+        such as vertices, faces, users can further provide camera parameter, lighting condition, etc.
+        Please check https://threejs.org/docs/index.html#manual/en/introduction/Creating-a-scene for
+        advanced usage.
+
+        Args:
+            tag (str): Data identifier
+            vertices (torch.Tensor): List of the 3D coordinates of vertices.
+            colors (torch.Tensor): Colors for each vertex
+            faces (torch.Tensor): Indices of vertices within each triangle. (Optional)
+            config_dict: Dictionary with ThreeJS classes names and configuration.
+            global_step (int): Global step value to record
+            walltime (float): Optional override default walltime (time.time())
+              seconds after epoch of event
+
+        Shape:
+            vertices: :math:`(B, N, 3)`. (batch, number_of_vertices, channels)
+
+            colors: :math:`(B, N, 3)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`.
+
+            faces: :math:`(B, N, 3)`. The values should lie in [0, number_of_vertices] for type `uint8`.
+
+        Examples::
+
+            from torch.utils.tensorboard import SummaryWriter
+            vertices_tensor = torch.as_tensor([
+                [1, 1, 1],
+                [-1, -1, 1],
+                [1, -1, -1],
+                [-1, 1, -1],
+            ], dtype=torch.float).unsqueeze(0)
+            colors_tensor = torch.as_tensor([
+                [255, 0, 0],
+                [0, 255, 0],
+                [0, 0, 255],
+                [255, 0, 255],
+            ], dtype=torch.int).unsqueeze(0)
+            faces_tensor = torch.as_tensor([
+                [0, 2, 3],
+                [0, 3, 1],
+                [0, 1, 2],
+                [1, 3, 2],
+            ], dtype=torch.int).unsqueeze(0)
+
+            writer = SummaryWriter()
+            writer.add_mesh('my_mesh', vertices=vertices_tensor, colors=colors_tensor, faces=faces_tensor)
+
+            writer.close()
+        """
+        torch._C._log_api_usage_once("tensorboard.logging.add_mesh")
+        self._get_file_writer().add_summary(
+            mesh(tag, vertices, colors, faces, config_dict), global_step, walltime
+        )
+
+    def flush(self):
+        """Flushes the event file to disk.
+
+        Call this method to make sure that all pending events have been written to
+        disk.
+        """
+        if self.all_writers is None:
+            return
+        for writer in self.all_writers.values():
+            writer.flush()
+
+    def close(self):
+        if self.all_writers is None:
+            return  # ignore double close
+        for writer in self.all_writers.values():
+            writer.flush()
+            writer.close()
+        self.file_writer = self.all_writers = None
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
diff --git a/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3ab5449f06136db152786a2c236febd1483d537
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/__init__.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8740535f8991ce0d2bc94c8cfb9383f817417c71
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/_utils.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b83d7b8592f4a6130ca4862ac67cc5d486cf6176
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/memory.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..021a54f2e15c90b6671565841afb8ef791b469f7
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/random.cpython-312.pyc differ
diff --git a/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..583787231a8ddfa212d5b067bf379a7c46c227b1
Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/xpu/__pycache__/streams.cpython-312.pyc differ
diff --git a/arch/__pycache__/__init__.cpython-312.pyc b/arch/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0ae1d1c19f70fbb8998605fe264f229ef519802
Binary files /dev/null and b/arch/__pycache__/__init__.cpython-312.pyc differ
diff --git a/arch/__pycache__/adapter.cpython-312.pyc b/arch/__pycache__/adapter.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b61f334a1a5fad8c53210ae4f36e37beac86897
Binary files /dev/null and b/arch/__pycache__/adapter.cpython-312.pyc differ
diff --git a/arch/__pycache__/data_loader.cpython-312.pyc b/arch/__pycache__/data_loader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3955e1642520ec5bfb5c31150ac834c04abea369
Binary files /dev/null and b/arch/__pycache__/data_loader.cpython-312.pyc differ
diff --git a/arch/__pycache__/model_loader.cpython-312.pyc b/arch/__pycache__/model_loader.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2eadf9aa701f35aaf5c30be09903aaf8a6f2d5d
Binary files /dev/null and b/arch/__pycache__/model_loader.cpython-312.pyc differ
diff --git a/arch/__pycache__/pipeline.cpython-312.pyc b/arch/__pycache__/pipeline.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4e2ef34b5dabd750bb9433ce07f4c662c5d59814
Binary files /dev/null and b/arch/__pycache__/pipeline.cpython-312.pyc differ
diff --git a/arch/__pycache__/text_encoder.cpython-312.pyc b/arch/__pycache__/text_encoder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c20ad7009cbc19d34b9b387737f4eee881f199d
Binary files /dev/null and b/arch/__pycache__/text_encoder.cpython-312.pyc differ
diff --git a/arch/__pycache__/training.cpython-312.pyc b/arch/__pycache__/training.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5d077441b3c2cc1f5c7d46a2ae0ea3ee4fd8f528
Binary files /dev/null and b/arch/__pycache__/training.cpython-312.pyc differ
diff --git a/illustrious_generated/07f25b42b8d7.png b/illustrious_generated/07f25b42b8d7.png
new file mode 100644
index 0000000000000000000000000000000000000000..35ff9a179b1efc8b0507af9013079ec19dbca5f1
--- /dev/null
+++ b/illustrious_generated/07f25b42b8d7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:929c4479354ebfad7c228a014f6cc6af6adc54b3dbc9e061a81a3ca122ee994e
+size 5163673
diff --git a/illustrious_generated/0a4d2753b7c7.png b/illustrious_generated/0a4d2753b7c7.png
new file mode 100644
index 0000000000000000000000000000000000000000..217feda08ddc07e48c1bf8f08208dea035f4c15f
--- /dev/null
+++ b/illustrious_generated/0a4d2753b7c7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57a7bce4374b978fbed7614ba941c0b28f142bb25d810957089e4f06b7a358ac
+size 1678400
diff --git a/illustrious_generated/0d1561551252.png b/illustrious_generated/0d1561551252.png
new file mode 100644
index 0000000000000000000000000000000000000000..02283726954e317060ba665aafd6cb1b73185786
--- /dev/null
+++ b/illustrious_generated/0d1561551252.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5a3ea266ab00cf4e40ea635ae80c09299bdab5ba0880a587ff05ab039e463808
+size 2144741
diff --git a/illustrious_generated/0e42a3a7d296.png b/illustrious_generated/0e42a3a7d296.png
new file mode 100644
index 0000000000000000000000000000000000000000..c2fa22e75a70ab31dff05c00217b7b08984fc72f
--- /dev/null
+++ b/illustrious_generated/0e42a3a7d296.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f1b19b9922c41a91ab21f21878ad7be91113896b50edbce600d3900481721c8
+size 913930
diff --git a/illustrious_generated/1371e80e6f16.png b/illustrious_generated/1371e80e6f16.png
new file mode 100644
index 0000000000000000000000000000000000000000..5af6028f5f03e7aed26591f8884d6167897c4568
--- /dev/null
+++ b/illustrious_generated/1371e80e6f16.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d55e8dd34bbbd21d100d7e5218304819d971b713ce9b8a1b63a9aaa86c5aaa18
+size 567333
diff --git a/illustrious_generated/13d5cc66d420.png b/illustrious_generated/13d5cc66d420.png
new file mode 100644
index 0000000000000000000000000000000000000000..05972c6c8e42ba477a2ab72dc466829dcf7fa65e
--- /dev/null
+++ b/illustrious_generated/13d5cc66d420.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42b2779264186ed856d7ca2b588288895722e1e2e69e794b256e5cdec138f921
+size 1702828
diff --git a/illustrious_generated/149a82673fe9.png b/illustrious_generated/149a82673fe9.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c277e9f7ae36c43edd8ad7c442c6a8959cc9f05
--- /dev/null
+++ b/illustrious_generated/149a82673fe9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a712abd49ecdab24c3d7defa43aadc8bee67c4c5c0e6c0f4df7ee474460fa8a1
+size 1104972
diff --git a/illustrious_generated/14cde88e8c02.png b/illustrious_generated/14cde88e8c02.png
new file mode 100644
index 0000000000000000000000000000000000000000..9051f5aa3173d739a516964bfae8d2916a0902e2
--- /dev/null
+++ b/illustrious_generated/14cde88e8c02.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0eceb87a24dd6a02e5ea34e93496636cec3793249a0f10ea5edd8744d9670b5e
+size 1988528
diff --git a/illustrious_generated/163359d65694.png b/illustrious_generated/163359d65694.png
new file mode 100644
index 0000000000000000000000000000000000000000..5c26ed2ccac6d891e31298dbaf5d64f71395a29d
--- /dev/null
+++ b/illustrious_generated/163359d65694.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:007b360efef28a411cde89398f63617ba19555560814d58fbe994f774c0f87e5
+size 493033
diff --git a/illustrious_generated/1755674c7bc1.png b/illustrious_generated/1755674c7bc1.png
new file mode 100644
index 0000000000000000000000000000000000000000..827bc977b5d03669799fb80dbe3a7a0d4f900d17
--- /dev/null
+++ b/illustrious_generated/1755674c7bc1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:48231b0e0452e3db55fbd9698fc371aeaeb3a20fb0f6410b10704c462f9fb8d7
+size 1028767
diff --git a/illustrious_generated/176f37a6cd7f.png b/illustrious_generated/176f37a6cd7f.png
new file mode 100644
index 0000000000000000000000000000000000000000..18d536dee098e1a19db139d9ab0626e581bc3d3b
--- /dev/null
+++ b/illustrious_generated/176f37a6cd7f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:794f6d93fb5ca8cba1d9844619d772683efc3c1ab36c833f5e31761756f2e37c
+size 741162
diff --git a/illustrious_generated/18f1f90221cc.png b/illustrious_generated/18f1f90221cc.png
new file mode 100644
index 0000000000000000000000000000000000000000..b8301fc62b18d8abc19ba6268cf6efca542e36fd
--- /dev/null
+++ b/illustrious_generated/18f1f90221cc.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16f7532e948798059eafc88854e0340d0d546f35f03ee847d0cb8ea19558c70c
+size 4006372
diff --git a/illustrious_generated/1970ccc17731.png b/illustrious_generated/1970ccc17731.png
new file mode 100644
index 0000000000000000000000000000000000000000..3579a46280a845e5a9a9049e2cc8379365674957
--- /dev/null
+++ b/illustrious_generated/1970ccc17731.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ba975bbc5fb9e6af6b28a030d321890b0e6735d2317d4ed46c63577284595f5
+size 715917
diff --git a/illustrious_generated/1bb92008067a.png b/illustrious_generated/1bb92008067a.png
new file mode 100644
index 0000000000000000000000000000000000000000..69f2cd51c8e3e1c091b0715509ef636e35fbf405
--- /dev/null
+++ b/illustrious_generated/1bb92008067a.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:31d7f411392c030ce340ac1cdbbd4bc21438addf11e920e67ab437d05e91d063
+size 3623453
diff --git a/illustrious_generated/1ca2c247cfef.png b/illustrious_generated/1ca2c247cfef.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e0a458bd1766e59feabf7a9e72b001efb5b8526
--- /dev/null
+++ b/illustrious_generated/1ca2c247cfef.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45f9042d7f6d9470b732e162238244bcf08a526cf664bc27805d08d073703464
+size 2014643
diff --git a/illustrious_generated/1dfed5dbd442.png b/illustrious_generated/1dfed5dbd442.png
new file mode 100644
index 0000000000000000000000000000000000000000..96e56f90e8ed7e4a65f0b5c267ca715e928a9c72
--- /dev/null
+++ b/illustrious_generated/1dfed5dbd442.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b15839116f43dbccd86a94fc5eb4830ed3c16e8f65971e88820bf9d2b34e93aa
+size 967162
diff --git a/illustrious_generated/2008ee2e6f66.png b/illustrious_generated/2008ee2e6f66.png
new file mode 100644
index 0000000000000000000000000000000000000000..b8396a940a99560da51121595ad9ef83d4b04965
--- /dev/null
+++ b/illustrious_generated/2008ee2e6f66.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3751de66c70ad71710270b742f89f788d7c28cc9a858590d0cc49d4fa53c08a
+size 402641
diff --git a/illustrious_generated/2433395d6d43.png b/illustrious_generated/2433395d6d43.png
new file mode 100644
index 0000000000000000000000000000000000000000..5fa6a9c7b116eded78396e9a55fde718e8d47ccf
--- /dev/null
+++ b/illustrious_generated/2433395d6d43.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83a895e76b095edbe0924cddffd46e73f45ae06aed8f3b65d5d694f41fd18a66
+size 3898935
diff --git a/illustrious_generated/243fd5bd7679.png b/illustrious_generated/243fd5bd7679.png
new file mode 100644
index 0000000000000000000000000000000000000000..384cee8f6526b987af8b5ab236037f4d10e7ca8c
--- /dev/null
+++ b/illustrious_generated/243fd5bd7679.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bc3f51f6c203aaa455afdb87db761616f19006e1a14905a108c423d4706f85ec
+size 1270276
diff --git a/illustrious_generated/259e79d12cba.png b/illustrious_generated/259e79d12cba.png
new file mode 100644
index 0000000000000000000000000000000000000000..edf98955cabe1366503ae2a720e8abea56b1cb9f
--- /dev/null
+++ b/illustrious_generated/259e79d12cba.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2c4c27b357af7ed8b85592958082c3ed8621ec9220decc94c7c1f9eb6af3d09
+size 2377446
diff --git a/illustrious_generated/25df61f584ba.png b/illustrious_generated/25df61f584ba.png
new file mode 100644
index 0000000000000000000000000000000000000000..e86ba95ca2709c55ad66a9b0790c2e444a18f8f7
--- /dev/null
+++ b/illustrious_generated/25df61f584ba.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66c4cd8abb5cd49d44a9aba7ebfd70d235eddbad4f583d19560485f1dcddd0ac
+size 1366271
diff --git a/illustrious_generated/29259007fbbf.png b/illustrious_generated/29259007fbbf.png
new file mode 100644
index 0000000000000000000000000000000000000000..641fa5c0821cd95d289ad92300b88e25761c8341
--- /dev/null
+++ b/illustrious_generated/29259007fbbf.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1fde3ae780d4218ea7371074683b0924df1ba0759a037f3c38bd78b63c739255
+size 2149474
diff --git a/illustrious_generated/29c45b6176ff.png b/illustrious_generated/29c45b6176ff.png
new file mode 100644
index 0000000000000000000000000000000000000000..e8f677e39d41555f35f7e0c239df0b67350fdd07
--- /dev/null
+++ b/illustrious_generated/29c45b6176ff.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:72979de05d455ec3f64e3963b03cc686e186c3f8e0e276337d933d432f992ba8
+size 613217
diff --git a/illustrious_generated/2c3ca96d94bc.png b/illustrious_generated/2c3ca96d94bc.png
new file mode 100644
index 0000000000000000000000000000000000000000..b9c3ce4e8888ddc8ed9a671dcd63b47a80dc4cca
--- /dev/null
+++ b/illustrious_generated/2c3ca96d94bc.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a710bb8b49e020fa61af8c6581e1f08e8910d4ab5604b26000888083ee3ae7a0
+size 1204390
diff --git a/illustrious_generated/2f054f26c97c.png b/illustrious_generated/2f054f26c97c.png
new file mode 100644
index 0000000000000000000000000000000000000000..7acf7c91c8d056d8f5631d0da10478eec2e71829
--- /dev/null
+++ b/illustrious_generated/2f054f26c97c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf16f96bfe6b779d25cdd5c9d9c1a096bfda4de98e33565c4a2fcb0563fde50a
+size 2989378
diff --git a/illustrious_generated/30f0d4a3b6b6.png b/illustrious_generated/30f0d4a3b6b6.png
new file mode 100644
index 0000000000000000000000000000000000000000..632e3bf69dd0d9d60155532a43aba3376a1ec9d8
--- /dev/null
+++ b/illustrious_generated/30f0d4a3b6b6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a04c46aff712b07cc3a007b291fa683b981acea2f9e51fafbb31e30e6b992b36
+size 531600
diff --git a/illustrious_generated/312f15890343.png b/illustrious_generated/312f15890343.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7e9263b6d821b828ff5141e10d9d4e74205fa9b
--- /dev/null
+++ b/illustrious_generated/312f15890343.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:434ada965cb8d7c3e2ae4c096442e5090532fbbf0f2675e729ec2e94c169bcce
+size 1728788
diff --git a/illustrious_generated/33aeb710e683.png b/illustrious_generated/33aeb710e683.png
new file mode 100644
index 0000000000000000000000000000000000000000..4597fcd101076a51c751a65a7c738bd2e6a93db0
--- /dev/null
+++ b/illustrious_generated/33aeb710e683.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96f91612efc6f4fccb40d547f96f7ac25378312b62232401470b9b77502d2ea7
+size 1592227
diff --git a/illustrious_generated/3503bc2ba8ab.png b/illustrious_generated/3503bc2ba8ab.png
new file mode 100644
index 0000000000000000000000000000000000000000..18e70dc0657a1c9570891498837228448d3fffb1
--- /dev/null
+++ b/illustrious_generated/3503bc2ba8ab.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20aea12e30a75625196f481d1d4b380e9b0dda38f9594eec0359be8546d5bf48
+size 1182453
diff --git a/illustrious_generated/37c414606b7d.png b/illustrious_generated/37c414606b7d.png
new file mode 100644
index 0000000000000000000000000000000000000000..d31962dff9045a3957813677ad087b71488934be
--- /dev/null
+++ b/illustrious_generated/37c414606b7d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5c6fb153684b194bab8f13d33acbc72bb75afbfc5f9cc43069281783a9bf536
+size 2632019
diff --git a/illustrious_generated/38cdd5aae9bb.png b/illustrious_generated/38cdd5aae9bb.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b4b5ae7dd854921549f4ed2b430c4a0d802e3aa
--- /dev/null
+++ b/illustrious_generated/38cdd5aae9bb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b08be916a5731d051021dde4714fce68e885b715fdbedb5d5eb7e82998f8dfa0
+size 370904
diff --git a/illustrious_generated/3cbd0d9a36a7.png b/illustrious_generated/3cbd0d9a36a7.png
new file mode 100644
index 0000000000000000000000000000000000000000..7c2c7223a26d3a9ffc450678831bf7b4402d75d5
--- /dev/null
+++ b/illustrious_generated/3cbd0d9a36a7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2d4dc924eec87e172d19238b78a17dff36d1f6c51b33d17e4dffc0325e63ed8
+size 2208280
diff --git a/illustrious_generated/3dbcf74b53f1.png b/illustrious_generated/3dbcf74b53f1.png
new file mode 100644
index 0000000000000000000000000000000000000000..be49e8d438e1e801bc72239a9050c6178fe389eb
--- /dev/null
+++ b/illustrious_generated/3dbcf74b53f1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:756dc7ee4d0202b4535bc192a3be60e4b7337a11f123a1e52b4d020ad632dfd7
+size 1799185
diff --git a/illustrious_generated/4014bca6d153.png b/illustrious_generated/4014bca6d153.png
new file mode 100644
index 0000000000000000000000000000000000000000..2eb210b2979ec7ddbc77953080b4abdfc385c04e
--- /dev/null
+++ b/illustrious_generated/4014bca6d153.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86c97c8b6c8ce720d4dc6b91b4d6b1f1fa64253d6aa3afff06143278f7a37890
+size 379356
diff --git a/illustrious_generated/4536122bec42.png b/illustrious_generated/4536122bec42.png
new file mode 100644
index 0000000000000000000000000000000000000000..2a6f262e7cd10eee30da30bf698634073060e792
--- /dev/null
+++ b/illustrious_generated/4536122bec42.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7bbb182dcdb0ad9333de62158a4a239d1436b277e9740b7c94053e76d3c3b606
+size 4849046
diff --git a/illustrious_generated/4632a958366a.png b/illustrious_generated/4632a958366a.png
new file mode 100644
index 0000000000000000000000000000000000000000..e17587f4882d733422600f715b1be86299cdf15e
--- /dev/null
+++ b/illustrious_generated/4632a958366a.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:760ae9b6954ea7713ddfac23d253b4a1b46aa74e222895d21c52c2f0cc92e44b
+size 767323
diff --git a/illustrious_generated/46d7a6324536.png b/illustrious_generated/46d7a6324536.png
new file mode 100644
index 0000000000000000000000000000000000000000..fbbe63197b12a13e0694953f625ffd78b0732c9b
--- /dev/null
+++ b/illustrious_generated/46d7a6324536.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:68f3320d13bc6b09e1d6b55269d5a8456cc28da3e8416d67cc5dfc6c369bf2ed
+size 1137442
diff --git a/illustrious_generated/476617799f8d.png b/illustrious_generated/476617799f8d.png
new file mode 100644
index 0000000000000000000000000000000000000000..a8f124ef26f9ca6d6773ac57075e2ec8fd46dc0e
--- /dev/null
+++ b/illustrious_generated/476617799f8d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a4643c01149a8559815fe5886120bdc53a44b7877ffabdde499d6ced6ec1608f
+size 3640292
diff --git a/illustrious_generated/4769b078890d.png b/illustrious_generated/4769b078890d.png
new file mode 100644
index 0000000000000000000000000000000000000000..756e98c5a2edadbacabc03e976ee5c02c9e9bd71
--- /dev/null
+++ b/illustrious_generated/4769b078890d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:13778f98fc98a1b6d95ee36788b735713b1edd1bd669f0d166cb315eb77f68da
+size 2794444
diff --git a/illustrious_generated/47aa86a18b5a.png b/illustrious_generated/47aa86a18b5a.png
new file mode 100644
index 0000000000000000000000000000000000000000..782690692b868bb15dbb413a4e6e22035671c7ee
--- /dev/null
+++ b/illustrious_generated/47aa86a18b5a.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19ba63f411eca8f3b4ed41882bff7c354ce99ff643a1c236ba4232ef97324907
+size 470905
diff --git a/illustrious_generated/484afd596fd0.png b/illustrious_generated/484afd596fd0.png
new file mode 100644
index 0000000000000000000000000000000000000000..b87716512ef8966416890e56e31623e67b0fb04b
--- /dev/null
+++ b/illustrious_generated/484afd596fd0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:91d682d41f34fb7942c0689c89aaa49fcd8372064da3d47da3f6a40c72c132bd
+size 3360059
diff --git a/illustrious_generated/48fcc54ce758.png b/illustrious_generated/48fcc54ce758.png
new file mode 100644
index 0000000000000000000000000000000000000000..ffed118c8e499814ca14323a9e4a75407f44c5dc
--- /dev/null
+++ b/illustrious_generated/48fcc54ce758.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f27b607fc1838292bfa0bee66bcdb3e0d26e6a88859f782569f7784300306e7
+size 3292623
diff --git a/illustrious_generated/49222c07a03f.png b/illustrious_generated/49222c07a03f.png
new file mode 100644
index 0000000000000000000000000000000000000000..41adc68ab55a61755b526c43707c9903eced2ad0
--- /dev/null
+++ b/illustrious_generated/49222c07a03f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d1b85312287d0406ff083d07361c203ce6d2302de56307c8d94974e4eb03d686
+size 1875077
diff --git a/illustrious_generated/49c90b238c77.png b/illustrious_generated/49c90b238c77.png
new file mode 100644
index 0000000000000000000000000000000000000000..90664f0c079d25b9deda7892328c583352f1a4a2
--- /dev/null
+++ b/illustrious_generated/49c90b238c77.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f6f4f6da080b69be0ad12a24725f64754483a375393d50a8c7416ebf32e5477d
+size 2897500
diff --git a/illustrious_generated/4a574003bbc9.png b/illustrious_generated/4a574003bbc9.png
new file mode 100644
index 0000000000000000000000000000000000000000..06b6122c75cac79efe3339065262ee636d803faf
--- /dev/null
+++ b/illustrious_generated/4a574003bbc9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4aef06542711ed175bb26139a78a32bd651c3736c995d9499cbcdba96e46f35
+size 1116905
diff --git a/illustrious_generated/4af1ebb4488c.png b/illustrious_generated/4af1ebb4488c.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d445e43766ccea66d1d0d6f4f1ee7eb056bc24f
--- /dev/null
+++ b/illustrious_generated/4af1ebb4488c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:80119e608588c708de988526b3e10b03e5b1893aa0f35a99f15c795ac303e284
+size 600329
diff --git a/illustrious_generated/4af582043c66.png b/illustrious_generated/4af582043c66.png
new file mode 100644
index 0000000000000000000000000000000000000000..5cb01c1fd7997cf63bba0df9b726726b2914b965
--- /dev/null
+++ b/illustrious_generated/4af582043c66.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c096cab69a703ae59d3cfb79316cda96e3ef3297bddec57f550677a34760f387
+size 2806125
diff --git a/illustrious_generated/4c0f1f03dcd2.png b/illustrious_generated/4c0f1f03dcd2.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ba1607708b41033bd78fdf91ad2085225c44a8a
--- /dev/null
+++ b/illustrious_generated/4c0f1f03dcd2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09a1cdc3ca1f2da5d970ca27d4eea193754e5e0cb5481589bde7bdb7e34ce938
+size 1919048
diff --git a/illustrious_generated/4c39593e352f.png b/illustrious_generated/4c39593e352f.png
new file mode 100644
index 0000000000000000000000000000000000000000..f60c7d623f3b5d8d0e981c3ebf64e376367d0b27
--- /dev/null
+++ b/illustrious_generated/4c39593e352f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eabbbf6bf2a3d7d6879fe06a197bedfddb54cf95566cb78318f4d99a020d93bb
+size 1033221
diff --git a/illustrious_generated/4c7a209b8887.png b/illustrious_generated/4c7a209b8887.png
new file mode 100644
index 0000000000000000000000000000000000000000..c933d8aac134fd479168183ba82d469baf168c7f
--- /dev/null
+++ b/illustrious_generated/4c7a209b8887.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:22787878fb0448e39c98bcb9d0e23dd7694df84230f9fa7711b4bdcd5f836383
+size 2264642
diff --git a/illustrious_generated/4d7d6abe842c.png b/illustrious_generated/4d7d6abe842c.png
new file mode 100644
index 0000000000000000000000000000000000000000..bdada797176b2dc498bace124b5c824f136aa94c
--- /dev/null
+++ b/illustrious_generated/4d7d6abe842c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2764f791008304abd553fcbc9116448692eb51dc0826df45fc4709004b31726
+size 1026868
diff --git a/illustrious_generated/4f7953e3c61d.png b/illustrious_generated/4f7953e3c61d.png
new file mode 100644
index 0000000000000000000000000000000000000000..323317097780fb82dfd50c8b373e7861ac54fe02
--- /dev/null
+++ b/illustrious_generated/4f7953e3c61d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eafe98e12b1d6d40a409c576cf196640f13de37c907fe41b78fd786bc01047b4
+size 3940252
diff --git a/illustrious_generated/5028ea7283a1.png b/illustrious_generated/5028ea7283a1.png
new file mode 100644
index 0000000000000000000000000000000000000000..268b0635df7827ab59af2889833df96ae0a5683d
--- /dev/null
+++ b/illustrious_generated/5028ea7283a1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f2c01c7f3ed9cf9b5561d2b4822ad75eaf880a139029ba1494d187bc0f7d6c50
+size 673587
diff --git a/illustrious_generated/532b5ad4f3b4.png b/illustrious_generated/532b5ad4f3b4.png
new file mode 100644
index 0000000000000000000000000000000000000000..2dbfe3335d5eb83b161ecf68c1e0e42c0fbb7cda
--- /dev/null
+++ b/illustrious_generated/532b5ad4f3b4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:189e667da2293106fd180c8ed40e44919b8c6e31fecd2e6610226fbd55ef470a
+size 3185718
diff --git a/illustrious_generated/53d7fbefadd0.png b/illustrious_generated/53d7fbefadd0.png
new file mode 100644
index 0000000000000000000000000000000000000000..3dd314783b870945438788fbf5b5d94c5cbee878
--- /dev/null
+++ b/illustrious_generated/53d7fbefadd0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd71ebfbed9184650866d820548f268c951a758a1f860ae8c25c099e13e5bfaf
+size 1168543
diff --git a/illustrious_generated/548d5f09851f.png b/illustrious_generated/548d5f09851f.png
new file mode 100644
index 0000000000000000000000000000000000000000..e24f08a1ac794c3166ca92669d3c57b4e7435804
--- /dev/null
+++ b/illustrious_generated/548d5f09851f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92c545640b022ca759715d4430652420af79e1101371934686b3cae5f28af90b
+size 3743336
diff --git a/illustrious_generated/56b23fd84d9e.png b/illustrious_generated/56b23fd84d9e.png
new file mode 100644
index 0000000000000000000000000000000000000000..db20eea179512b13ec16c14d32538b4747db53cc
--- /dev/null
+++ b/illustrious_generated/56b23fd84d9e.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b706d1ea798caf11ec07cf16733d75ad6d4194ab8a8e5efb5857946c4c04b80f
+size 1611390
diff --git a/illustrious_generated/56e66190e7ce.png b/illustrious_generated/56e66190e7ce.png
new file mode 100644
index 0000000000000000000000000000000000000000..9894fc9adbe7d7c6cd297f948e56e02aef6325d1
--- /dev/null
+++ b/illustrious_generated/56e66190e7ce.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b12bb2d62d567a58517e15aaee5db7c2014ba296c39708ed11b4d3591d16ff10
+size 579650
diff --git a/illustrious_generated/57264915c81a.png b/illustrious_generated/57264915c81a.png
new file mode 100644
index 0000000000000000000000000000000000000000..e3cab9b1f1e212ffa097f10f3f56610d31ff4acd
--- /dev/null
+++ b/illustrious_generated/57264915c81a.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:487c90b325374e7901b3b6e22e52e1368936158255f6c7a6b78b412ce3e71d36
+size 3412149
diff --git a/illustrious_generated/5803e88c45f3.png b/illustrious_generated/5803e88c45f3.png
new file mode 100644
index 0000000000000000000000000000000000000000..61dc659bf4735d0336a1abf085bde1d59d24be6b
--- /dev/null
+++ b/illustrious_generated/5803e88c45f3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16ba5c300eb8ee5c4f0e17f141af2191ae4408f72eeda4b22091e61639f2ade1
+size 2771883
diff --git a/illustrious_generated/58b9009a8433.png b/illustrious_generated/58b9009a8433.png
new file mode 100644
index 0000000000000000000000000000000000000000..e8bd39deb435400d413db563aacac1bfb0435dc8
--- /dev/null
+++ b/illustrious_generated/58b9009a8433.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:daecc5026f308b49c7823fbb3c6b030c0f001935b57ae2a150b914913ae3357b
+size 2449214
diff --git a/illustrious_generated/593a5b377fb6.png b/illustrious_generated/593a5b377fb6.png
new file mode 100644
index 0000000000000000000000000000000000000000..54a4e4e7662d24d18be70fa43224992ba508a520
--- /dev/null
+++ b/illustrious_generated/593a5b377fb6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7b8f047f157778d92bf9f5fe49cc6d3c8b3b7de6d3771037bb5c749dd61f586
+size 491353
diff --git a/illustrious_generated/5ab3056565d7.png b/illustrious_generated/5ab3056565d7.png
new file mode 100644
index 0000000000000000000000000000000000000000..9091bf4523a140be97c4cb4b94fd65016e5e8706
--- /dev/null
+++ b/illustrious_generated/5ab3056565d7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6fdd124c7cc0bb1adaf4aa8c96d3e515f6dee54fb0d64e840f96eba375e3188
+size 1032199
diff --git a/illustrious_generated/5e8555d59ee2.png b/illustrious_generated/5e8555d59ee2.png
new file mode 100644
index 0000000000000000000000000000000000000000..3aea217675546a0650bec797d81fe340b305aab9
--- /dev/null
+++ b/illustrious_generated/5e8555d59ee2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8c79890da330e4419d6d3916d272e0aa78f9f79ecade0def8a634a4ae574d82
+size 1137210
diff --git a/illustrious_generated/5fa4fb280d76.png b/illustrious_generated/5fa4fb280d76.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b8dd1182a0e97edc7e435b3054fba5b385f3f62
--- /dev/null
+++ b/illustrious_generated/5fa4fb280d76.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd08bb459cf1c7e13ba3f94096b0b9598d262c5679e25d35408c9e6a4abb7980
+size 1554242
diff --git a/illustrious_generated/6035404a25b1.png b/illustrious_generated/6035404a25b1.png
new file mode 100644
index 0000000000000000000000000000000000000000..b7133888f033c6849e8fcca2439e60eb53cefce8
--- /dev/null
+++ b/illustrious_generated/6035404a25b1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2084e21707798a50898ef5656f31a89de0497c631b48364ef0c66b7a73bc0d6c
+size 1653872
diff --git a/illustrious_generated/6061d4cce84a.png b/illustrious_generated/6061d4cce84a.png
new file mode 100644
index 0000000000000000000000000000000000000000..3bd711e3d7d3a7973d56f706bf33237d334f92dd
--- /dev/null
+++ b/illustrious_generated/6061d4cce84a.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2e01ca162239cf2c9bad73ae58b133980e78eb868aa784647f8802e7581b065
+size 2824721
diff --git a/illustrious_generated/60f785243e60.png b/illustrious_generated/60f785243e60.png
new file mode 100644
index 0000000000000000000000000000000000000000..4013ebcabdaede753a1457d29aacec689237086a
--- /dev/null
+++ b/illustrious_generated/60f785243e60.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4369eb8701791557d6eb1722b2273db8402e7f6b7e63b1de5d6f87103f518fc5
+size 1664574
diff --git a/illustrious_generated/617ce74ffafa.png b/illustrious_generated/617ce74ffafa.png
new file mode 100644
index 0000000000000000000000000000000000000000..250034bfa693f1cd60bb2da89dbc57ec19f9dfa1
--- /dev/null
+++ b/illustrious_generated/617ce74ffafa.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb0e23ad342f8acac2a39d451c8dfc22c19737208a2d1a2b6396c6bafe579a77
+size 2579644
diff --git a/illustrious_generated/62fb4f26c8d0.png b/illustrious_generated/62fb4f26c8d0.png
new file mode 100644
index 0000000000000000000000000000000000000000..31561c51a09a8e68e5137f49de894187b6df1c01
--- /dev/null
+++ b/illustrious_generated/62fb4f26c8d0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:feee14a7b1bac6dca9d2ba7e5c65db8e282cf263349173b6d06a929d9b692c8a
+size 482859
diff --git a/illustrious_generated/639cf9531d64.png b/illustrious_generated/639cf9531d64.png
new file mode 100644
index 0000000000000000000000000000000000000000..48c411be28e402186dc63c960d6c407b4431ec96
--- /dev/null
+++ b/illustrious_generated/639cf9531d64.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:125630b9052094c530456fb3a814a5ce70898ded323b1bba581c7f301f80ee60
+size 2227210
diff --git a/illustrious_generated/64d5e838ff32.png b/illustrious_generated/64d5e838ff32.png
new file mode 100644
index 0000000000000000000000000000000000000000..f70d7e1401390db944c7b42e24b44b6fd461bf3b
--- /dev/null
+++ b/illustrious_generated/64d5e838ff32.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9d350db3bd5ce8e3eb93877ef839170a02929754d06f5c57f154d75260a1861
+size 539209
diff --git a/illustrious_generated/666f1a84cb80.png b/illustrious_generated/666f1a84cb80.png
new file mode 100644
index 0000000000000000000000000000000000000000..45af884ceda1e5f7c26ea37e8220c547b57f2889
--- /dev/null
+++ b/illustrious_generated/666f1a84cb80.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb2a8633d190c2520ca9914240f87b5f49fa001a8b30f1041b743a2160c42613
+size 3319739
diff --git a/illustrious_generated/690e7719d4bf.png b/illustrious_generated/690e7719d4bf.png
new file mode 100644
index 0000000000000000000000000000000000000000..853828a8779afe2d347b7a8062d3e1d6d1fd3962
--- /dev/null
+++ b/illustrious_generated/690e7719d4bf.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e317e3a7ee8e3d9f2b0240a6d15a7211d8d0a2a24bd81ddfd3450c82648867b3
+size 668069
diff --git a/illustrious_generated/6994acf3db4f.png b/illustrious_generated/6994acf3db4f.png
new file mode 100644
index 0000000000000000000000000000000000000000..84727706016b94b0c53dc5792b301feba5553dd4
--- /dev/null
+++ b/illustrious_generated/6994acf3db4f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51bcae5f6c9da8b5069fc2cfba8259838b1e1e8b86c7952f86fb12d91aa229b0
+size 1861147
diff --git a/illustrious_generated/699b6be34d1c.png b/illustrious_generated/699b6be34d1c.png
new file mode 100644
index 0000000000000000000000000000000000000000..0709b5fc2c633c6b53a817ebb393289ef459bffe
--- /dev/null
+++ b/illustrious_generated/699b6be34d1c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92101c1efb94475f7f4a8095b746f4c5b9cd6325945ef3dd86d9b2481900507d
+size 3340642
diff --git a/illustrious_generated/6a51d82ae8bc.png b/illustrious_generated/6a51d82ae8bc.png
new file mode 100644
index 0000000000000000000000000000000000000000..b1bf177803c9bc01a4a46bebd640b778a7abe982
--- /dev/null
+++ b/illustrious_generated/6a51d82ae8bc.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd53f767e7165c5b4a638d00d6f97f6950df78594f4ecfa89daf5b110cb28566
+size 649098
diff --git a/illustrious_generated/6b23dba7be7d.png b/illustrious_generated/6b23dba7be7d.png
new file mode 100644
index 0000000000000000000000000000000000000000..82a96db37fabb93c456b55f5c9913841d9bdfc17
--- /dev/null
+++ b/illustrious_generated/6b23dba7be7d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6cdb465a40a20a09d98cd034a63135474e95dd6436c5ffceb5aa8b4e7f6b69ec
+size 925628
diff --git a/illustrious_generated/6c2bec98397a.png b/illustrious_generated/6c2bec98397a.png
new file mode 100644
index 0000000000000000000000000000000000000000..8f62a078783da4578a09a664dc451f8ff66cc59f
--- /dev/null
+++ b/illustrious_generated/6c2bec98397a.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c36ee18045730e61eca0ff27430d092c4c16f381aa4e4e648c5e24428b65940
+size 926957
diff --git a/illustrious_generated/71485a12c097.png b/illustrious_generated/71485a12c097.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa20b33fd97b00df1c9f03941e00c5eff5a14ad7
--- /dev/null
+++ b/illustrious_generated/71485a12c097.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78e5b42619644c975bfb475a11aeff4c2ca04a693bd414cd148f4e9fa6bfce72
+size 2382288
diff --git a/illustrious_generated/71e9866c5494.png b/illustrious_generated/71e9866c5494.png
new file mode 100644
index 0000000000000000000000000000000000000000..40ac1de6fc2526d3bc7338ebebe63580fad5fb7d
--- /dev/null
+++ b/illustrious_generated/71e9866c5494.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b9352c115aea7cb2cc2b08469ba09b418df0ce6de8190d874136387f61e23a92
+size 1942045
diff --git a/illustrious_generated/73e5f02dca91.png b/illustrious_generated/73e5f02dca91.png
new file mode 100644
index 0000000000000000000000000000000000000000..08dbf30bad5f5716af83930ebc54890fc5a45dad
--- /dev/null
+++ b/illustrious_generated/73e5f02dca91.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98d58973b85086b631089a9a7477d248b69771df0bf9e6f0a0e497b9d49841be
+size 1511983
diff --git a/illustrious_generated/7429db9846f3.png b/illustrious_generated/7429db9846f3.png
new file mode 100644
index 0000000000000000000000000000000000000000..57ef2741bbdb3072cef5100bd26b178b307deb63
--- /dev/null
+++ b/illustrious_generated/7429db9846f3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34d542800a6f4681df08f786d5aa0a6b8acfc6e86fc537ba61faeb8f83da507c
+size 3216045
diff --git a/illustrious_generated/754db85fe85f.png b/illustrious_generated/754db85fe85f.png
new file mode 100644
index 0000000000000000000000000000000000000000..e34a803b491340ce1aa32528ca570e1632fc8cfd
--- /dev/null
+++ b/illustrious_generated/754db85fe85f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:480782aac893bcbae661191dbc01b114a6ea46bf069bf506bab682bb6e73fb9d
+size 723943
diff --git a/illustrious_generated/75bc8bf6da61.png b/illustrious_generated/75bc8bf6da61.png
new file mode 100644
index 0000000000000000000000000000000000000000..b488cf99476b86945432e6a36c5985b1c46b78a6
--- /dev/null
+++ b/illustrious_generated/75bc8bf6da61.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfa2137cb7b1c277cca9d8d38827112fd1b2c1f99485926f1bdc750921ae3457
+size 3185603
diff --git a/illustrious_generated/776bf5e2d2a1.png b/illustrious_generated/776bf5e2d2a1.png
new file mode 100644
index 0000000000000000000000000000000000000000..47ba66eda34165bd24ea3e5eae33722c4232f96c
--- /dev/null
+++ b/illustrious_generated/776bf5e2d2a1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfe178b87e332982a6005127c33205ec4ccc6a0c503a383e9c248ed618e61fe2
+size 1528095
diff --git a/illustrious_generated/798df3adbd88.png b/illustrious_generated/798df3adbd88.png
new file mode 100644
index 0000000000000000000000000000000000000000..58fefe2af3a5ba673866bcc9a92d4e1f6e342da3
--- /dev/null
+++ b/illustrious_generated/798df3adbd88.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c183d7fe6366b7c69ac3be01339dbd067c48decd1dff6a9f2f834b968f9cad15
+size 2836838
diff --git a/illustrious_generated/79a4d58eda5c.png b/illustrious_generated/79a4d58eda5c.png
new file mode 100644
index 0000000000000000000000000000000000000000..028297d1f0def948da50d0de2a66d9dc2ea0d43e
--- /dev/null
+++ b/illustrious_generated/79a4d58eda5c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e34ab6fafc6ef69d8a6248298b446328599ce5533ab42ca65a8db53421675428
+size 1106817
diff --git a/illustrious_generated/7ed913ac6954.png b/illustrious_generated/7ed913ac6954.png
new file mode 100644
index 0000000000000000000000000000000000000000..230c6b9ac696a5af298069effbed845a0ab6e769
--- /dev/null
+++ b/illustrious_generated/7ed913ac6954.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:59cdef9ed46bc9658d3cbc17be82834ca85d7bbe22460b60b4a21aa6b324a844
+size 1285651
diff --git a/illustrious_generated/831ac5f05b4e.png b/illustrious_generated/831ac5f05b4e.png
new file mode 100644
index 0000000000000000000000000000000000000000..ef583061354198250871c94f3d4c51613ff4d231
--- /dev/null
+++ b/illustrious_generated/831ac5f05b4e.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63e19da9eb0fd583b627b303c7263833f8957489434d43acad949b6aaa74501e
+size 1585973
diff --git a/illustrious_generated/83da1d18e129.png b/illustrious_generated/83da1d18e129.png
new file mode 100644
index 0000000000000000000000000000000000000000..e19efb626dd0815fe5bf530a3482f339be0b9ea6
--- /dev/null
+++ b/illustrious_generated/83da1d18e129.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d9714e97060497bba54287555bc5724d79bf9a19440aaf26a68d4d295099869
+size 1248661
diff --git a/illustrious_generated/88e259bbe761.png b/illustrious_generated/88e259bbe761.png
new file mode 100644
index 0000000000000000000000000000000000000000..9b7d49635ca39ac517f08a725bcd84a4726df80a
--- /dev/null
+++ b/illustrious_generated/88e259bbe761.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f015cc08f7c823b9a80e86aa1a8d9531bfa8da0aab0e251893b8bcb7e50f768
+size 1175120
diff --git a/illustrious_generated/8a594acaada9.png b/illustrious_generated/8a594acaada9.png
new file mode 100644
index 0000000000000000000000000000000000000000..63f9390d55e6d5ee47c6d41754670cef0d4e4e38
--- /dev/null
+++ b/illustrious_generated/8a594acaada9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ec566100c161dcd172d267d02fb0b137befced35f99176f6da88f92fed192c69
+size 1601584
diff --git a/illustrious_generated/8bc741c00933.png b/illustrious_generated/8bc741c00933.png
new file mode 100644
index 0000000000000000000000000000000000000000..fd075d22a51b059240b2e2f126f1e8f4926524c5
--- /dev/null
+++ b/illustrious_generated/8bc741c00933.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:567732f5518730a95f9595afa81a0e8b885805b9f08e4372aad9cfbc1a7d9f3e
+size 1165423
diff --git a/illustrious_generated/8d3200b9e787.png b/illustrious_generated/8d3200b9e787.png
new file mode 100644
index 0000000000000000000000000000000000000000..972f96ffec8add9c5b8ebcc4a9ea5577c0f07b1d
--- /dev/null
+++ b/illustrious_generated/8d3200b9e787.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6d3cf7bf0820ba3f1483e3714f6d6db27a0e1848915dbccdad52e7138484731e
+size 3059222
diff --git a/illustrious_generated/8d3ab41288cb.png b/illustrious_generated/8d3ab41288cb.png
new file mode 100644
index 0000000000000000000000000000000000000000..02ed1e85937e49651d88dbfd1d73a6f8f802c9d1
--- /dev/null
+++ b/illustrious_generated/8d3ab41288cb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e3660a9417c19f51930cdd3ccadb96a739092831148fd32bd8190f3df43f4a9e
+size 914786
diff --git a/illustrious_generated/8e375573fe30.png b/illustrious_generated/8e375573fe30.png
new file mode 100644
index 0000000000000000000000000000000000000000..e963578393058d2c98c201071ffa259b90a44d74
--- /dev/null
+++ b/illustrious_generated/8e375573fe30.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d620dbe473e2e92445b0900f3b63649c6033763f5ca2931f70333711662901ca
+size 564016
diff --git a/illustrious_generated/8e740eadaf18.png b/illustrious_generated/8e740eadaf18.png
new file mode 100644
index 0000000000000000000000000000000000000000..91f111e6befc47c0d2e6d8c7cc243bca6949193f
--- /dev/null
+++ b/illustrious_generated/8e740eadaf18.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:76556ec41f9bb7075c59215319c56a82de46133dd106be31a912890a4145c01d
+size 810781
diff --git a/illustrious_generated/8ef3560e5a42.png b/illustrious_generated/8ef3560e5a42.png
new file mode 100644
index 0000000000000000000000000000000000000000..e739dcadf71b832c263aad57406589d6559b026f
--- /dev/null
+++ b/illustrious_generated/8ef3560e5a42.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cde78ab5ec95787c46ec5f106cd96137e603621da6dca02a1a8f8fa67db00126
+size 2356490
diff --git a/illustrious_generated/92e013a48101.png b/illustrious_generated/92e013a48101.png
new file mode 100644
index 0000000000000000000000000000000000000000..20e0dbc1d89ba8e49d266c76c269e8292443588e
--- /dev/null
+++ b/illustrious_generated/92e013a48101.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e8adc302570d2c58a698c6b7df2c149586bdc62c5e6a9e5ebbb4c2020d8a637
+size 4170948
diff --git a/illustrious_generated/96a62fa83b20.png b/illustrious_generated/96a62fa83b20.png
new file mode 100644
index 0000000000000000000000000000000000000000..73f30056dc30d31d24e677a627b7bf8668ed25ca
--- /dev/null
+++ b/illustrious_generated/96a62fa83b20.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45d3d9b1dec18da616d3f7c382e29b6bb1361e2fc5f4f63701682f69533c4ede
+size 2347799
diff --git a/illustrious_generated/96ab8498c7a6.png b/illustrious_generated/96ab8498c7a6.png
new file mode 100644
index 0000000000000000000000000000000000000000..345e423bbb5059d0a88fcda0b617dbba2f38105f
--- /dev/null
+++ b/illustrious_generated/96ab8498c7a6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7a2a4ae4251de94b2ddc80532d225ef8f9694036dda493a154e7e8170284b9b
+size 1974033
diff --git a/illustrious_generated/9a7e94351f39.png b/illustrious_generated/9a7e94351f39.png
new file mode 100644
index 0000000000000000000000000000000000000000..6026caaf23c367c19a2aa37d3e206acbd4f24c14
--- /dev/null
+++ b/illustrious_generated/9a7e94351f39.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:620a1ea544dda757b7c5f7544edf4fbd281e5eb20ea65167d227ce44316957cf
+size 2970330
diff --git a/illustrious_generated/9b1bb21a46e8.png b/illustrious_generated/9b1bb21a46e8.png
new file mode 100644
index 0000000000000000000000000000000000000000..f2a8bcf0753a541444fb625f798d4048a6b4819c
--- /dev/null
+++ b/illustrious_generated/9b1bb21a46e8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86a0c7b131e2399a4ba77531ff8339166369d26ee65d891e15df90f77c948bbf
+size 697612
diff --git a/illustrious_generated/9e588e17280b.png b/illustrious_generated/9e588e17280b.png
new file mode 100644
index 0000000000000000000000000000000000000000..70b94042c2028f9793de3d484f22206ee09749a3
--- /dev/null
+++ b/illustrious_generated/9e588e17280b.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ced2c04221e8e02396ad7d7bd2f484268295c8938a469d908f925160f4fe8c4a
+size 1474743
diff --git a/illustrious_generated/a00e6e3ca42d.png b/illustrious_generated/a00e6e3ca42d.png
new file mode 100644
index 0000000000000000000000000000000000000000..adb9b8b2334b651e6b53f7184a62d07dbb2503aa
--- /dev/null
+++ b/illustrious_generated/a00e6e3ca42d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f470b8b62ecb7ce6caa901e67ba9532ac31c5a606d05435477d5a2f315f9b6b0
+size 3644967
diff --git a/illustrious_generated/a048f190d616.png b/illustrious_generated/a048f190d616.png
new file mode 100644
index 0000000000000000000000000000000000000000..b17b15ab073677e992ee77be4d3f5beedeed0d67
--- /dev/null
+++ b/illustrious_generated/a048f190d616.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df103b6279631344d1cdd085f9c54316680efd999a82e2e16628d6a9a7e032e7
+size 661392
diff --git a/illustrious_generated/a0e55ddbf74c.png b/illustrious_generated/a0e55ddbf74c.png
new file mode 100644
index 0000000000000000000000000000000000000000..f65a42a18c20f4b49b858d77674347cb620ac788
--- /dev/null
+++ b/illustrious_generated/a0e55ddbf74c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c24d2557dc921d76562810d0f8ce854afdfe438199ab5f67da12ee93b4113bd
+size 3973092
diff --git a/illustrious_generated/a317dff0fe2e.png b/illustrious_generated/a317dff0fe2e.png
new file mode 100644
index 0000000000000000000000000000000000000000..9f8fc2bed4183b6c6f280ae2a1e1b14ed9a174b8
--- /dev/null
+++ b/illustrious_generated/a317dff0fe2e.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:46a8ff1ec4022319f38404f510d47f30ea18026396dc03466ad2a7ef93b7dd55
+size 3215952
diff --git a/illustrious_generated/a3a89fac4fcf.png b/illustrious_generated/a3a89fac4fcf.png
new file mode 100644
index 0000000000000000000000000000000000000000..7126521cbbdc578f735e66d8566752b30d164eee
--- /dev/null
+++ b/illustrious_generated/a3a89fac4fcf.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8f4f66e2ab517b5a44cc1616d83d3e9eefbf3e58b65bb3e68ef14df0d8fc2d0
+size 3575799
diff --git a/illustrious_generated/a3f618970fb7.png b/illustrious_generated/a3f618970fb7.png
new file mode 100644
index 0000000000000000000000000000000000000000..d9deaa8ef13291a0b744b9081415911e0f0c7cf2
--- /dev/null
+++ b/illustrious_generated/a3f618970fb7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d6a4567e2003f0136829cc15f55c3780bc86126ebe3ed761efd6cf41336b39a1
+size 3074809
diff --git a/illustrious_generated/a4eb52675c75.png b/illustrious_generated/a4eb52675c75.png
new file mode 100644
index 0000000000000000000000000000000000000000..adfef412df9185a6914cbbe625f6c77683176681
--- /dev/null
+++ b/illustrious_generated/a4eb52675c75.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2c99564a5d20df5a6ab3cee6817b1c705940ca6142681ee732202c9d7021560d
+size 1716471
diff --git a/illustrious_generated/a630cf5674f3.png b/illustrious_generated/a630cf5674f3.png
new file mode 100644
index 0000000000000000000000000000000000000000..d81891e53957f78fddd2cc88cb643286b3eedada
--- /dev/null
+++ b/illustrious_generated/a630cf5674f3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:faafba93a2af32e0ecffe9c23cc938dc14198f052456b2c5d5fd3d177e466a18
+size 2482194
diff --git a/illustrious_generated/a7acf95bd41f.png b/illustrious_generated/a7acf95bd41f.png
new file mode 100644
index 0000000000000000000000000000000000000000..091cf762ba37141f52defe8f8ce1ff1d6d902883
--- /dev/null
+++ b/illustrious_generated/a7acf95bd41f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dbc3ac10ae34c4232f954cdac972b419c54a4c54830b99f4b4bf093f7351682b
+size 1111934
diff --git a/illustrious_generated/a9ab451bd4c9.png b/illustrious_generated/a9ab451bd4c9.png
new file mode 100644
index 0000000000000000000000000000000000000000..4a7dcc10f04321dd20c2eccb82b446031034d902
--- /dev/null
+++ b/illustrious_generated/a9ab451bd4c9.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78509da648f076c43186ff658704aef2166041b3a18b906818437eb906bb4a85
+size 2215849
diff --git a/illustrious_generated/a9d16660f9be.png b/illustrious_generated/a9d16660f9be.png
new file mode 100644
index 0000000000000000000000000000000000000000..eb22ca2d114e8ed052fb27b88426a900ade391d3
--- /dev/null
+++ b/illustrious_generated/a9d16660f9be.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9da16b5de850d49e3ff798a7c6f75ff0470680852a5710abcbad5f52abda1b68
+size 1251039
diff --git a/illustrious_generated/aa08fc4ab5f7.png b/illustrious_generated/aa08fc4ab5f7.png
new file mode 100644
index 0000000000000000000000000000000000000000..262e85c15ef65576a5cb2d7386ef7289253d9e35
--- /dev/null
+++ b/illustrious_generated/aa08fc4ab5f7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db8cce62380d01a5c659747b076df3b6f2ef3769e16f330aadf254095b511957
+size 1516100
diff --git a/illustrious_generated/aa2403d0dbde.png b/illustrious_generated/aa2403d0dbde.png
new file mode 100644
index 0000000000000000000000000000000000000000..141834025826efd3d39318afa8328db6925dee66
--- /dev/null
+++ b/illustrious_generated/aa2403d0dbde.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf922c203094f947adc4154076b53e5acc87564581e1de27b4b0796e7288f977
+size 2517603
diff --git a/illustrious_generated/abf601407cb0.png b/illustrious_generated/abf601407cb0.png
new file mode 100644
index 0000000000000000000000000000000000000000..8b052ac88244517187c20a8b81832ddf49ca6a42
--- /dev/null
+++ b/illustrious_generated/abf601407cb0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:022e7f584d3216125690d860a619682a6251f5c17db724191f5cd7ec66a54376
+size 1760059
diff --git a/illustrious_generated/acf002111806.png b/illustrious_generated/acf002111806.png
new file mode 100644
index 0000000000000000000000000000000000000000..484660edd12e91f887c872264c482e9180a11ff8
--- /dev/null
+++ b/illustrious_generated/acf002111806.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:59ece1a7c578d708caaf7bc03df0732eb366f5db65c04dfcdb972243aede47b2
+size 2472798
diff --git a/illustrious_generated/ad90d731a0fa.png b/illustrious_generated/ad90d731a0fa.png
new file mode 100644
index 0000000000000000000000000000000000000000..1c21260b71d4d9529d5d0306b5703ce148840803
--- /dev/null
+++ b/illustrious_generated/ad90d731a0fa.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:13c6532254591ac4eda8eee1ff50b8153f14712643dd8a0432619f077d7b9153
+size 1156831
diff --git a/illustrious_generated/af2a10d42166.png b/illustrious_generated/af2a10d42166.png
new file mode 100644
index 0000000000000000000000000000000000000000..e92e7be2c01e44b21e68bed36e1075eeba402f79
--- /dev/null
+++ b/illustrious_generated/af2a10d42166.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a6bff15deb1c900604c4b5729c6d46523cb6782a92827bd6009541ee5ed54bf9
+size 2627556
diff --git a/illustrious_generated/b01a98afee9a.png b/illustrious_generated/b01a98afee9a.png
new file mode 100644
index 0000000000000000000000000000000000000000..878adcdba1cbe0544d311755407cbdb60f1bdc51
--- /dev/null
+++ b/illustrious_generated/b01a98afee9a.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e98b895a8968be71df3ac76518ab5006813098ca0c8ee745d07f6ebfa58f3658
+size 3029776
diff --git a/illustrious_generated/b0202f93d233.png b/illustrious_generated/b0202f93d233.png
new file mode 100644
index 0000000000000000000000000000000000000000..0dc8106992adf6cbde524227040148ac88b073ec
--- /dev/null
+++ b/illustrious_generated/b0202f93d233.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a5a6428b789975784ca53f383c3f7bdd0dba26044bf987acb7639be8beb203a
+size 2524542
diff --git a/illustrious_generated/b102a265b15d.png b/illustrious_generated/b102a265b15d.png
new file mode 100644
index 0000000000000000000000000000000000000000..0f15e0b2bc352c127dcbd21fda0893ff4d0f092b
--- /dev/null
+++ b/illustrious_generated/b102a265b15d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b675cff0494b1b04f2cc03b3551b25846eff2e4b76e8ded726d60fc085161331
+size 3441934
diff --git a/illustrious_generated/b2a469095584.png b/illustrious_generated/b2a469095584.png
new file mode 100644
index 0000000000000000000000000000000000000000..c25a9d76273f5bc3623cf8e7d2e29d54f1d5996b
--- /dev/null
+++ b/illustrious_generated/b2a469095584.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c4efd9daa10824caf28f7c781698da67266810debad1d1cf23329898afc91fa
+size 1076772
diff --git a/illustrious_generated/b3a626cfe029.png b/illustrious_generated/b3a626cfe029.png
new file mode 100644
index 0000000000000000000000000000000000000000..60155829aadb9b066e934a842c9a8f68a42a23ca
--- /dev/null
+++ b/illustrious_generated/b3a626cfe029.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce548ef734cf9f874cd352457862dcda17c51d717158fd7cd0124f0499887841
+size 1152771
diff --git a/illustrious_generated/b52834e10fd4.png b/illustrious_generated/b52834e10fd4.png
new file mode 100644
index 0000000000000000000000000000000000000000..b8764c8568518eeeb6905d5fe30510768088d632
--- /dev/null
+++ b/illustrious_generated/b52834e10fd4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b902dbe1b8c8b11c284af76a2ca8ec876cad01c679d20b433efd541113c7fb35
+size 2014855
diff --git a/illustrious_generated/b5a6eaf580ce.png b/illustrious_generated/b5a6eaf580ce.png
new file mode 100644
index 0000000000000000000000000000000000000000..3da7e9db36be5faa652ec1c06f96311f05fe71d6
--- /dev/null
+++ b/illustrious_generated/b5a6eaf580ce.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae199832b501ef62ee7f4d0cdae84100c24df31b5310419c8bf20a7b39026364
+size 1160303
diff --git a/illustrious_generated/b5ee5bce0095.png b/illustrious_generated/b5ee5bce0095.png
new file mode 100644
index 0000000000000000000000000000000000000000..31e3b87cc664e1637a2f3fe40f2599e546284de1
--- /dev/null
+++ b/illustrious_generated/b5ee5bce0095.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96038526d83f227646c64849e5d7fabead3c872b4f82860ff927f91c65a0789a
+size 2772710
diff --git a/illustrious_generated/b66afd053661.png b/illustrious_generated/b66afd053661.png
new file mode 100644
index 0000000000000000000000000000000000000000..e0331fcbf823b01402167158f0c21e34478f834c
--- /dev/null
+++ b/illustrious_generated/b66afd053661.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5cd2255b7dfe412695306d95dffbe2193f7e1cafd2644291c4376cb7f01cc49e
+size 638885
diff --git a/illustrious_generated/b716f3a23f11.png b/illustrious_generated/b716f3a23f11.png
new file mode 100644
index 0000000000000000000000000000000000000000..2e02cb99827e23ff061ffa35ec509f337f5211b4
--- /dev/null
+++ b/illustrious_generated/b716f3a23f11.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b63292e744473052d7bb16400a24826ac3e61e8756ddc58d8f397451ba30535b
+size 748395
diff --git a/illustrious_generated/b770edd1abe0.png b/illustrious_generated/b770edd1abe0.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc46b64d1abc6a0fe60f08cd953f0f8396a4ede0
--- /dev/null
+++ b/illustrious_generated/b770edd1abe0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e511479dd1f5e8ca5ebb3fb997b0f5bbb1b3ccc6a0ea6f862ab99c8a54b8943c
+size 2085467
diff --git a/illustrious_generated/b8fe03330562.png b/illustrious_generated/b8fe03330562.png
new file mode 100644
index 0000000000000000000000000000000000000000..878f8127fb0d97a25470911ddeba01153aea9278
--- /dev/null
+++ b/illustrious_generated/b8fe03330562.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b7b69e4d4709459894afa86aa91f934e9a07eb8ac5d1e010580ea5ccd93809af
+size 1172454
diff --git a/illustrious_generated/bc6707dcf1f2.png b/illustrious_generated/bc6707dcf1f2.png
new file mode 100644
index 0000000000000000000000000000000000000000..0307960b5570a7b89fde06249e01e34bd5867d6e
--- /dev/null
+++ b/illustrious_generated/bc6707dcf1f2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2535bf0aee5045694aeab7310fd5a1a4c9db71fbaddd5f68d01eea9d4430ed40
+size 2730088
diff --git a/illustrious_generated/bd545030c487.png b/illustrious_generated/bd545030c487.png
new file mode 100644
index 0000000000000000000000000000000000000000..febda8ae3651b6ed6bffa82dd1f5d640aa43b899
--- /dev/null
+++ b/illustrious_generated/bd545030c487.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42771102ecabc51cd8580a2f4b6f34fa1a141a2b80d8e38128b0cc665a350bd7
+size 1193192
diff --git a/illustrious_generated/bde193a32ad5.png b/illustrious_generated/bde193a32ad5.png
new file mode 100644
index 0000000000000000000000000000000000000000..ae675361ce26fc7b515ee6072732ad9823f4ebfc
--- /dev/null
+++ b/illustrious_generated/bde193a32ad5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b68ccb6784ffd0c4e94cb1f0037c32eb5c55767317ceda29b5f7699eea8c0f7a
+size 1493706
diff --git a/illustrious_generated/bf19c6825d9d.png b/illustrious_generated/bf19c6825d9d.png
new file mode 100644
index 0000000000000000000000000000000000000000..dc379e91d53a4d6a7e341849bb0058e3ded5a2dc
--- /dev/null
+++ b/illustrious_generated/bf19c6825d9d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9f10e326dd26a7603a219847d546b0192440e98a42ac9e648d289282d58736a
+size 1190756
diff --git a/illustrious_generated/bf8fcd7e3ebb.png b/illustrious_generated/bf8fcd7e3ebb.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7e99dc5f15fc4214a74681952902423221ca355
--- /dev/null
+++ b/illustrious_generated/bf8fcd7e3ebb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:60665935135d0a415cfeaf3c723ea66c25e562b4280ea0ca7b198f8d7d3c2b30
+size 2584001
diff --git a/illustrious_generated/c4033b9e6115.png b/illustrious_generated/c4033b9e6115.png
new file mode 100644
index 0000000000000000000000000000000000000000..467e7f50a8254fefaefe72b778f20dfca8743114
--- /dev/null
+++ b/illustrious_generated/c4033b9e6115.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffbbc01b6d6ecc2172d2402d9778eb1a96792292bb8af7c8c3cc59e3b47fae1c
+size 766365
diff --git a/illustrious_generated/c472bc9db7e5.png b/illustrious_generated/c472bc9db7e5.png
new file mode 100644
index 0000000000000000000000000000000000000000..c049495b847c5ea833c56fdcf0c844e06af443f2
--- /dev/null
+++ b/illustrious_generated/c472bc9db7e5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b96c8fe243297c25be39bf6d0d72c8ad363274eaaeb3921fed5d9a694d4ee051
+size 1375572
diff --git a/illustrious_generated/c588cc386611.png b/illustrious_generated/c588cc386611.png
new file mode 100644
index 0000000000000000000000000000000000000000..4034f0faffe1355eefd4782fbe9632612f8e04c7
--- /dev/null
+++ b/illustrious_generated/c588cc386611.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ad69c9ba787d9d16acccb22a0021e0921ec17a3ae9c1e72a6645ae186357937
+size 2998673
diff --git a/illustrious_generated/c636dce88f49.png b/illustrious_generated/c636dce88f49.png
new file mode 100644
index 0000000000000000000000000000000000000000..de1b9ce071e86ac058a817f7c0fe087f06eda879
--- /dev/null
+++ b/illustrious_generated/c636dce88f49.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04b129baa72f11ad62358bb180c07e7ee6096b4c5c66b03e3534d99aa1631008
+size 1991650
diff --git a/illustrious_generated/c65bd94bf971.png b/illustrious_generated/c65bd94bf971.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce6f6f61365a82dc25307cb5269617e22c7a2f81
--- /dev/null
+++ b/illustrious_generated/c65bd94bf971.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f1393a2fb248daeae873782acde2aba6b8555e5adee419d4b74f2c2f59e2bf1
+size 1037629
diff --git a/illustrious_generated/c87663be02a7.png b/illustrious_generated/c87663be02a7.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0c8eee6003c635d9cc07f03bf9dfe79daa8be03
--- /dev/null
+++ b/illustrious_generated/c87663be02a7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24f4432c18d99076bba744b901713dc3d24dd0e40d97c32f7b6f593340f428b6
+size 1010892
diff --git a/illustrious_generated/ca3147d568a3.png b/illustrious_generated/ca3147d568a3.png
new file mode 100644
index 0000000000000000000000000000000000000000..7bac2261db39cbf57904e52fe286c68c619aeb12
--- /dev/null
+++ b/illustrious_generated/ca3147d568a3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a16bef107b7d1fc9b5b2f372e781092c460947d85693d6fcd0b4661aacdb35ff
+size 2812666
diff --git a/illustrious_generated/ce38aa6d180d.png b/illustrious_generated/ce38aa6d180d.png
new file mode 100644
index 0000000000000000000000000000000000000000..432debe2f8b3ea21f1304c291fc0ec042334b807
--- /dev/null
+++ b/illustrious_generated/ce38aa6d180d.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b1110c74ef9a76ac725eec779f8ed5b4978a5c269a5b134fe83e740c7140065
+size 1208529
diff --git a/illustrious_generated/ceb937e0801e.png b/illustrious_generated/ceb937e0801e.png
new file mode 100644
index 0000000000000000000000000000000000000000..bbf343d410b7a5252d163519a002183494124957
--- /dev/null
+++ b/illustrious_generated/ceb937e0801e.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c1719b2c2b938124f0f95b9fd594e0057a296caa0f807f096530368a4c1f1f2
+size 1787398
diff --git a/illustrious_generated/cef6d8bbd520.png b/illustrious_generated/cef6d8bbd520.png
new file mode 100644
index 0000000000000000000000000000000000000000..43d6f810e956a617b16186b060d5a46adbc7ac17
--- /dev/null
+++ b/illustrious_generated/cef6d8bbd520.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f2081e76a8144aaeda1e24108c69b0cadc99f2c4a273a70a79f6eb8724f5c59c
+size 510421
diff --git a/illustrious_generated/d0d5e2ddb402.png b/illustrious_generated/d0d5e2ddb402.png
new file mode 100644
index 0000000000000000000000000000000000000000..645c096239aa98fd5ae58a70870fe12dcb907e70
--- /dev/null
+++ b/illustrious_generated/d0d5e2ddb402.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5062ee5cf57e8d386e1fe13ff6039ea1de13d2ed32a5966dd1d237a23d7e6012
+size 4740321
diff --git a/illustrious_generated/d1995ff196e3.png b/illustrious_generated/d1995ff196e3.png
new file mode 100644
index 0000000000000000000000000000000000000000..28be54edc0efe14ca525eb26f0f9f2001bef41f3
--- /dev/null
+++ b/illustrious_generated/d1995ff196e3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ac07baf5cb9d4b7dd6bfd458c242550a08780fe678c06dbd8ed8af1f3d29779
+size 5144634
diff --git a/illustrious_generated/d209e317052f.png b/illustrious_generated/d209e317052f.png
new file mode 100644
index 0000000000000000000000000000000000000000..22b704b41869a9dbff1ae46a37e5659063d8fa49
--- /dev/null
+++ b/illustrious_generated/d209e317052f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51195f637f4899f89f58f5fb2046a45d3496a0b5f370f6336ee23097a66a80f4
+size 4028055
diff --git a/illustrious_generated/d21efb5f15d1.png b/illustrious_generated/d21efb5f15d1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0bb48c952d63b3b27e6e097f1d7ff0944c094f1c
--- /dev/null
+++ b/illustrious_generated/d21efb5f15d1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:beadfab590fdc2620951a6151cf841a03001d16945ccc2e27c7f542faacf9846
+size 1594496
diff --git a/illustrious_generated/d5f24e4f2554.png b/illustrious_generated/d5f24e4f2554.png
new file mode 100644
index 0000000000000000000000000000000000000000..29f5bae26abcc3e5ce06e9f0e689893ecc8927bc
--- /dev/null
+++ b/illustrious_generated/d5f24e4f2554.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:62c0104ef7c24e8ac40d2deb71c4241d9080b79a030712fc08e97f0da1e6127b
+size 1721106
diff --git a/illustrious_generated/daf79f3a5a47.png b/illustrious_generated/daf79f3a5a47.png
new file mode 100644
index 0000000000000000000000000000000000000000..15b2741e97eb39366145c55ef4e721a97f7ced42
--- /dev/null
+++ b/illustrious_generated/daf79f3a5a47.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d58901df2faa07fe66e773d2cd2a58b903a507533e751549cdef5981dbec650c
+size 2017719
diff --git a/illustrious_generated/db68dd8cb2cf.png b/illustrious_generated/db68dd8cb2cf.png
new file mode 100644
index 0000000000000000000000000000000000000000..7fe59087634ee5e9361a944de59516e1cec38d70
--- /dev/null
+++ b/illustrious_generated/db68dd8cb2cf.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06945e00b9ae966e32e29258dfddb0479b8df197c5c8054af2017d53cf1037d6
+size 2053862
diff --git a/illustrious_generated/e0fa2a6e6b42.png b/illustrious_generated/e0fa2a6e6b42.png
new file mode 100644
index 0000000000000000000000000000000000000000..eaecc31977ec5a26df06c75adfc1dd60bd3ce3a1
--- /dev/null
+++ b/illustrious_generated/e0fa2a6e6b42.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d318afda5b34d65e35fde8b41fbe12c30d37d0b9b3cc7b2db23fce886fe3e58
+size 1539201
diff --git a/illustrious_generated/e150c196c1de.png b/illustrious_generated/e150c196c1de.png
new file mode 100644
index 0000000000000000000000000000000000000000..48cd371ccfbae70002709aeb13dab13b02ec012c
--- /dev/null
+++ b/illustrious_generated/e150c196c1de.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4921929b78f452900782ace3b4da62f378756c05220eb01c26a40ac621047392
+size 1136753
diff --git a/illustrious_generated/e189b3abf6d2.png b/illustrious_generated/e189b3abf6d2.png
new file mode 100644
index 0000000000000000000000000000000000000000..0d81a442520a80667dfe69dbcc514503ee886b46
--- /dev/null
+++ b/illustrious_generated/e189b3abf6d2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b7410a650455474b90030d73502920b7278cc0a1b369d679ddd568a4560c53b
+size 1938925
diff --git a/illustrious_generated/e31935899f08.png b/illustrious_generated/e31935899f08.png
new file mode 100644
index 0000000000000000000000000000000000000000..e25a669fdcbf83be276cd69807dd70324efe2738
--- /dev/null
+++ b/illustrious_generated/e31935899f08.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d9a722368db4ee8f976605bfea517dd061c1daceaa6172b624889ade63ef29c2
+size 2124659
diff --git a/illustrious_generated/e38a1e2e958f.png b/illustrious_generated/e38a1e2e958f.png
new file mode 100644
index 0000000000000000000000000000000000000000..7f78cf5c9d176be93d87e434f2c5888d0924772d
--- /dev/null
+++ b/illustrious_generated/e38a1e2e958f.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eca36b324db10d8492150ca57bba8ba1af7caa8b29397d48d6dab0d2f91c488d
+size 2466549
diff --git a/illustrious_generated/e412e3f1af16.png b/illustrious_generated/e412e3f1af16.png
new file mode 100644
index 0000000000000000000000000000000000000000..e5149a207a8753ae4538c946847b2d01f12bc866
--- /dev/null
+++ b/illustrious_generated/e412e3f1af16.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:32e0b3ad17090071d6679546e75aa811d2fffb51c13480e27dd566658e980b9b
+size 2320488
diff --git a/illustrious_generated/e4bf1a6be7c6.png b/illustrious_generated/e4bf1a6be7c6.png
new file mode 100644
index 0000000000000000000000000000000000000000..5608e4a36d84f20707a535332ec4002ab345d943
--- /dev/null
+++ b/illustrious_generated/e4bf1a6be7c6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:43224b4418138a0ebfe81ca127922167ad6f39a6f7d0bb758f808440b3c4254c
+size 733849
diff --git a/illustrious_generated/e526540fdbb8.png b/illustrious_generated/e526540fdbb8.png
new file mode 100644
index 0000000000000000000000000000000000000000..c5089a62235f2c241e9e4476670cc2dc932b9075
--- /dev/null
+++ b/illustrious_generated/e526540fdbb8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7190096fbcd11aed4acc9224934a713c5938b407c234ef447fd09176fc1898df
+size 1193138
diff --git a/illustrious_generated/e5b4630753fb.png b/illustrious_generated/e5b4630753fb.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7bd5e2cd884eb53910b92964c03fa526403dc61
--- /dev/null
+++ b/illustrious_generated/e5b4630753fb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab8a428f425ed7bf4c0fbe05e9b180d20df5bcdaeb6d35e9baa111b4b9e15459
+size 1559204
diff --git a/illustrious_generated/e5ef86680557.png b/illustrious_generated/e5ef86680557.png
new file mode 100644
index 0000000000000000000000000000000000000000..5845c364074df179d1d22a729ee255478392d403
--- /dev/null
+++ b/illustrious_generated/e5ef86680557.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4a7c8385e3d42cab2262241ffc16fa1b3d2054a882cf0884fc13b57c956160ea
+size 484226
diff --git a/illustrious_generated/e729db57f47c.png b/illustrious_generated/e729db57f47c.png
new file mode 100644
index 0000000000000000000000000000000000000000..066b72835189abb5a260b2b77e104dbaed9638cc
--- /dev/null
+++ b/illustrious_generated/e729db57f47c.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c9bf01648cd7195cae8c1f26d47eab8e4c1ce59639110e95c7b73589a5a5410
+size 4762295
diff --git a/illustrious_generated/ece3750ad6a6.png b/illustrious_generated/ece3750ad6a6.png
new file mode 100644
index 0000000000000000000000000000000000000000..ccb1ec791536799dc71f983e9956c9f11304145e
--- /dev/null
+++ b/illustrious_generated/ece3750ad6a6.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee9c12613b89803a7ed3ff54daf6401bbe53cfb3604c689c2536114e9ca973d7
+size 948034
diff --git a/illustrious_generated/ed0e6ab9592e.png b/illustrious_generated/ed0e6ab9592e.png
new file mode 100644
index 0000000000000000000000000000000000000000..7a96625b24401b7358fd572a86d426540d7d24e7
--- /dev/null
+++ b/illustrious_generated/ed0e6ab9592e.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:73b13855504e09623668bed0e1b6522edaf38557d8f08e7b5f7bb2ff18511f04
+size 1717680
diff --git a/illustrious_generated/efb6fe0239e2.png b/illustrious_generated/efb6fe0239e2.png
new file mode 100644
index 0000000000000000000000000000000000000000..1308725e1ce6581ecc06b131358d4ba61d8deede
--- /dev/null
+++ b/illustrious_generated/efb6fe0239e2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6d1a697df66f53c3aa25f05ac53cf3ba41de3be02cda475d48e6d260afe90dae
+size 3348143
diff --git a/illustrious_generated/f086d40d76b7.png b/illustrious_generated/f086d40d76b7.png
new file mode 100644
index 0000000000000000000000000000000000000000..d9c89a7d74a89ce5828f1197637238180070c861
--- /dev/null
+++ b/illustrious_generated/f086d40d76b7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:141e9a1385d0eec4ea41eaf47911f55223d2b27a240bb45ff2152c3517014d33
+size 2812127
diff --git a/illustrious_generated/f284fb62c9bf.png b/illustrious_generated/f284fb62c9bf.png
new file mode 100644
index 0000000000000000000000000000000000000000..709b05454dd9ec2e4679f139a5c0169a56ae0f8e
--- /dev/null
+++ b/illustrious_generated/f284fb62c9bf.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:82d7476d555aff73b09c4184d66c0b87574356eb1f2fa8579a6620d65306ee9c
+size 594622
diff --git a/illustrious_generated/f3021039e6df.png b/illustrious_generated/f3021039e6df.png
new file mode 100644
index 0000000000000000000000000000000000000000..f99afe30e813a57c7db9e250cef378295d4c56e6
--- /dev/null
+++ b/illustrious_generated/f3021039e6df.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7025ad3726a5f4ae414775d5632301a2bdd80c03d89ac6b5a92ab9128682ab36
+size 3306193
diff --git a/illustrious_generated/f3f41fd3689e.png b/illustrious_generated/f3f41fd3689e.png
new file mode 100644
index 0000000000000000000000000000000000000000..da26e805ce10c1b12ecbf3fdaff6c2c5410f97d5
--- /dev/null
+++ b/illustrious_generated/f3f41fd3689e.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdba709bbc56a85c3c50e5bd79b700c4fa30cc7efe3c1bab6bf468abab42b10d
+size 681103
diff --git a/illustrious_generated/f8b258d2d426.png b/illustrious_generated/f8b258d2d426.png
new file mode 100644
index 0000000000000000000000000000000000000000..9023eebf45837509561921bb72a6d0d20369200a
--- /dev/null
+++ b/illustrious_generated/f8b258d2d426.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b40711d35ccb9cd22c4eb195696d82a4991afcc72fc89f3d88979048387d130e
+size 1526526
diff --git a/illustrious_generated/fb9ed29c63d7.png b/illustrious_generated/fb9ed29c63d7.png
new file mode 100644
index 0000000000000000000000000000000000000000..2b1745c780dc7b87de3c1794a9561572d2fd41e1
--- /dev/null
+++ b/illustrious_generated/fb9ed29c63d7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e8ba24b08bc7b564fe61a0c576b5a191b74b23e41b03734d7de92ab0faaa891
+size 3688591
diff --git a/illustrious_generated/ffa981380c46.png b/illustrious_generated/ffa981380c46.png
new file mode 100644
index 0000000000000000000000000000000000000000..049b11c8a5e88eefce8fdea3ad6048bdb3263482
--- /dev/null
+++ b/illustrious_generated/ffa981380c46.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08bc9d236ca34b19a3075542d1557cc0b86302c5dded4e18b98c42831f92e7dc
+size 2710486